import random as rnd
import threading
import time
import concurrent.futures
from copy import deepcopy

def generateSquareMatrix(size):
    return [[rnd.randint(0, 100) for i in range(size)] for j in range(size)]

def printMatrix(matrix):
    for row in matrix:
        print(*row, sep="\t")

testmatrix = generateSquareMatrix(500)

def process_row(args):
    i, j, m, n = args

    factor = m[j][i] / m[i][i]

    for k in range(i, n):
        m[j][k] -= factor * m[i][k]

    return m[j]


def parallel_det(matrix, threadss):
    n = len(matrix)

    m = deepcopy(matrix)

    det_value = 1

    def process_row(i, j):
        factor = m[j][i] / m[i][i]
        for k in range(i, n):
            m[j][k] -= factor * m[i][k]

    for i in range(n):
        if m[i][i] == 0:
            for j in range(i + 1, n):
                if m[j][i] != 0:
                    m[i], m[j] = m[j], m[i]

                    det_value *= -1

                    break
            else:
                return 0

        with concurrent.futures.ThreadPoolExecutor(max_workers=threadss) as executor:
            futures = [
                executor.submit(process_row, i, j) for j in range(i + 1, n)
            ]
            concurrent.futures.wait(futures)

        det_value *= m[i][i]

    return det_value


def det(matrix):
    n = len(matrix)
    m = [row[:] for row in matrix]
    det_value = 1

    for i in range(n):
        if m[i][i] == 0:
            for j in range(i + 1, n):
                if m[j][i] != 0:
                    m[i], m[j] = m[j], m[i]

                    det_value *= -1

                    break
            else:
                return 0

        for j in range(i + 1, n):
            factor = m[j][i] / m[i][i]

            for k in range(i, n):
                m[j][k] -= factor * m[i][k]

        det_value *= m[i][i]

    return det_value


if __name__ == "__main__":

    sizes = [100, 300, 500]
    num_threads = [1, 5, 8, 12]

    for size in sizes:
        matrix1 = generateSquareMatrix(size)
        for threads in num_threads:
            start_time = time.time()
            parallel_det(matrix1, threads)
            end_time = time.time()
            print(f"Parallel size {size}, {threads} thread(s): {end_time - start_time}s")

        print("-" * 100)