import itertools
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from typing import Tuple, List
import numpy as np

_SUPLYER_TYPE = Callable[[], int | float] | int | float
_QUEUE_TYPE = Queue[Tuple[List[float | int], List[float | int], int]]


class Matrix:
    def __init__(self, size: int, suplyer: _SUPLYER_TYPE = 0):
        self.__size = size
        self.__matrix = self._generate_matrix(suplyer)

    def _generate_matrix(self, suplyer: _SUPLYER_TYPE):
        if suplyer:
            match suplyer:
                case int() | float():
                    return [[suplyer for _ in range(self.__size)] for _ in range(self.__size)]
                case Callable():
                    return [[suplyer() for _ in range(self.__size)] for _ in range(self.__size)]
        return [[0 for _ in range(self.__size)] for _ in range(self.__size)]

    def from_flat(self, numbers: List[int | float]):
        if len(numbers) != self.__size ** 2:
            raise Exception(f"Invalid matrix size {self.__size} ^ 2 != {len(numbers)}")
        x, y = 0, 0
        for number in numbers:
            self.__matrix[y][x] = number
            x += 1
            if x >= self.__size:
                x = 0
                y += 1

    @property
    def rows(self):
        return self.__matrix

    @property
    def columns(self):
        return [[self.__matrix[i][j] for i in range(self.__size)] for j in range(self.__size)]

    @property
    def size(self):
        return self.__size

    @staticmethod
    def random(*, size: int):
        import random
        return Matrix(size=size, suplyer=random.random)

    def to_numpy(self):
        return np.array(self.__matrix)

    def __eq__(self, other):
        return (isinstance(other, Matrix)
                and self.__size == other.__size)

    def __str__(self):
        return f"Matrix {self.__size}x{self.__size} \n" + "\n".join([str(
            " ".join([f"{element:.5f}" for element in row])
        ) for row in self.__matrix])

    def __iter__(self):
        return iter(self.__matrix)

    def __getitem__(self, index):
        return self.__matrix[index]

    def __mul__(self, other):
        match other:
            case Matrix():
                return mul_matrixs(self, other)
            case tuple():
                other_matrix, count_threads = other
                return mul_matrixs(self, other_matrix, count_threads)
        return None


def mul_row_and_column_in_thread(queue: _QUEUE_TYPE) -> list[tuple[int | float, int]]:
    result = []
    while queue.qsize():
        local_result = 0
        row, column, place = queue.get()
        for k in range(len(row)):
            local_result += row[k] * column[k]
        result.append((local_result, place))

    return result


def mul_matrixs(m1: Matrix, m2: Matrix, threads: int = 0):
    if m1.size != m2.size:
        return None

    if threads == 0:
        threads = 1

    result = Matrix(size=m1.size, suplyer=0)

    thread_queues = [Queue() for _ in range(threads)]
    thread_iterator = 0

    for row_m1, column_m2 in itertools.product(m1.rows, m2.columns):
        thread_queues[thread_iterator].put((row_m1, column_m2, thread_iterator))
        thread_iterator += 1
        if thread_iterator >= threads:
            thread_iterator = 0

    with ThreadPoolExecutor(max_workers=threads) as executor:
        flat = []

        for item in executor.map(mul_row_and_column_in_thread, thread_queues):
            flat += item

        flat.sort(key=lambda x: x[1])
        result.from_flat([*map(lambda x: x[0], flat)])

    return result