import random as rnd
import threading
import time
from multiprocessing import Pool

def generateSquareMatrix(size):
    return [[rnd.randint(0, 100) for i in range(size)] for j in range(size)]

def printMatrix(matrix):
    for row in matrix:
        print(*row, sep="\t")


# Перемножение без использования потоков
def matrixMultiplyStandard(matrix1, matrix2):
    l1 = len(matrix1)
    l2 = len(matrix2)
    global result_matrix
    result = result_matrix
    for i in range(l1):
        for j in range(l2):
            for k in range(l2):
                result[i][j] += matrix1[i][k] * matrix2[k][j]

    return result

result_matrix = [[0 for i in range(500)] for j in range(500)]

# Перемножение в отдельном потоке
def matrixMultiplySingleThread(args):
    matrix1, matrix2, start_i, end_i = args
    global result_matrix

    result = result_matrix

    for i in range(start_i, end_i):
        for j in range(len(matrix2[0])):
            for k in range(len(matrix2)):
                result[i][j] += matrix1[i - start_i][k] * matrix2[k][j]

# Параллельное перемножение, использует ф-ю выше для каждого потока
def matrixMultiplyWithThreads(matrix1, matrix2, thread_count):
    l1 = len(matrix1)
    l2 = len(matrix2)

    # Кол-во строк на последний поток, если деление по потокам будет неточным
    last_rows_count = 0

    if l1 % thread_count == 0:
        rows_per_thread = l1 // thread_count
    else:
        rows_per_thread = l1 // thread_count
        last_rows_count = l1 % thread_count

    for i in range(thread_count):
        start_i = i * rows_per_thread

        if (i - 1) == thread_count and last_rows_count > 0:
            end_i = start_i + last_rows_count
        else:
            end_i = start_i + rows_per_thread

    args = []
    args.append((matrix1[start_i:end_i], matrix2, start_i, end_i))
    with Pool(processes = thread_count) as pool:
        pool.map(matrixMultiplySingleThread, args)


if __name__ == "__main__":

    sizes = [100, 300, 500]
    num_threads = [1, 5, 8, 12]

    for size in sizes:
        matrix1 = generateSquareMatrix(size)
        matrix2 = generateSquareMatrix(size)
        start_time = time.time()
        matrixMultiplyStandard(matrix1, matrix2)
        end_time = time.time()
        print(f"Standard size {size}: {end_time - start_time}s")

        for threads in num_threads:
            start_time = time.time()
            matrixMultiplyWithThreads(matrix1, matrix2, threads)
            end_time = time.time()
            print(f"Parallel size {size}, {threads} thread(s): {end_time - start_time}s")

        print("-" * 100)