import math
from multiprocessing import Pool
import numpy as np
from datetime import datetime

# Поиск строки с наибольшим кол-вом нулей
def check_zeros_string(arr, start=0, end=0):
    if end == 0:
        end = len(arr)
    max_zeros = []
    for i in range(start, end):
        max_zeros.append((i, len(arr[start]) - int(np.count_nonzero(arr[i]))))
    max = max_zeros[0]
    for zero in max_zeros:
        if zero[1] >= max[1]:
            max = zero
    return max


# Поиск столбца с наибольшим кол-вом нулей
def check_zeros_column(arr, start=0, end=0):
    if end == 0:
        end = len(arr)
    max_zeros = []
    for i in range(start, end):
        max_zeros.append((i, len(arr[:, start]) - int(np.count_nonzero(arr[:, i]))))
    max = max_zeros[0]
    for zero in max_zeros:
        if zero[1] >= max[1]:
            max = zero
    return max

# Уменьшение порядка матрицы по строке

def delta_string(arr, id, start=0, end=0):
    if end == 0:
        end = len(arr)
    if arr.shape == (2, 2):
        return arr[0][0] * arr[1][1] - arr[0][1] * arr[1][0]
    num_arrays = []
    for j in range(start, end):
        if arr[id][j] == 0:
            continue
        minor = np.delete(arr, id, 0)
        minor = np.delete(minor, j, 1)
        num_arrays.append((arr[id][j] * pow(-1, id + j + 2), minor))
    result = 0
    for n_a in num_arrays:
        max_zeros_strings = check_zeros_string(n_a[1])
        max_zeros_columns = check_zeros_column(n_a[1])
        if max_zeros_strings[1] >= max_zeros_columns[1]:
            delta = delta_string(n_a[1], max_zeros_strings[0])
        else:
            delta = delta_column(n_a[1], max_zeros_columns[0])
        result += n_a[0] * delta
    return result

# Уменьшение порядка матрицы по столбцу
def delta_column(arr, id, start=0, end=0):
    if end == 0:
        end = len(arr)
    if arr.shape == (2, 2):
        return arr[0][0] * arr[1][1] - arr[0][1] * arr[1][0]
    num_arrays = []
    for i in range(start, end):
        if arr[i][id] == 0:
            continue
        minor = np.delete(arr, i, 0)
        minor = np.delete(minor, id, 1)
        num_arrays.append((arr[i][id] * pow(-1, i + id + 2), minor))
    result = 0
    for n_a in num_arrays:
        max_zeros_strings = check_zeros_string(n_a[1])
        max_zeros_columns = check_zeros_column(n_a[1])
        if max_zeros_strings[1] >= max_zeros_columns[1]:
            delta = delta_string(n_a[1], max_zeros_strings[0])
        else:
            delta = delta_column(n_a[1], max_zeros_columns[0])
        result += n_a[0] * delta
    return result


if __name__ == '__main__':
    print("Start")
    sizes = [6, 8, 11]
    threads_counts = [1, 2, 4]
    for size in sizes:
        fst = np.random.randint(0, 5, size=(size, size))
        for thread_count in threads_counts:
            step = math.floor(size / thread_count)
            remaining_lines = size % thread_count
            steps = [step] * thread_count
            pool = Pool(thread_count)

            for i in range(0, len(steps)):
                steps[i] = steps[i] + math.ceil(remaining_lines / thread_count)
                remaining_lines -= math.ceil(remaining_lines / thread_count)
                if remaining_lines == 0:
                    break
            args = []
            i = 0
            for step in steps:
                args.append([fst,i, i + step])
                i += step


            startTime = datetime.now()
            max_zero_string = pool.starmap(check_zeros_string, args)

            max_zero_column = pool.starmap(check_zeros_column, args)

            mzs = max_zero_string[0]
            for mz in max_zero_string:
                if mz[1] >= mzs[1]:
                    mzs = mz

            mzc = max_zero_column[0]
            for mz in max_zero_column:
                if mz[1] >= mzc[1]:
                    mzc = mz

            args.clear()
            i = 0
            if mzs[1] >= mzc[1]:
                for step in steps:
                    args.append([fst, mzs[0], i, i + step])
                    i += step
                result = pool.starmap(delta_string, args)
            else:
                for step in steps:
                    args.append([fst, mzc[0], i, i + step])
                    i += step
                result = pool.starmap(delta_column, args)
            endTime = datetime.now()
            print(f"Size: {size}")
            print(f"Count of threads: {thread_count}")
            print(f"Work time: {endTime-startTime}")
            print("_-_-_-_-_-_-_-_-_-")
        print("-------------------------------------------")