package main

import (
	"fmt"
	"strconv"
	"sync"
	"time"
	"zhimolostnova_anna_lab_5/util"
)

// Параллельное умножение матриц
func multiplyMatricesParallel(a, b [][]int, threads int) [][]int {
	size := len(a)
	result := make([][]int, size)
	for i := range result {
		result[i] = make([]int, size)
	}

	// Функция для обработки части работы потока
	worker := func(startRow, endRow int, wg *sync.WaitGroup) {
		defer wg.Done()
		for i := startRow; i < endRow; i++ {
			for j := 0; j < size; j++ {
				for k := 0; k < size; k++ {
					result[i][j] += a[i][k] * b[k][j]
				}
			}
		}
	}

	// Запуск потоков
	var wg sync.WaitGroup
	rowsPerThread := size / threads
	for i := 0; i < threads; i++ {
		startRow := i * rowsPerThread
		endRow := startRow + rowsPerThread
		if i == threads-1 {
			endRow = size
		}
		wg.Add(1)
		go worker(startRow, endRow, &wg)
	}

	wg.Wait()
	return result
}

func benchmarkMatrixMultiplication(sizes []int, threadsList []int) {
	for _, size := range sizes {
		for _, threads := range threadsList {
			matrixA := util.CreateMatrix(size)
			matrixB := util.CreateMatrix(size)

			start := time.Now()
			_ = multiplyMatricesParallel(matrixA, matrixB, threads)
			elapsed := time.Since(start)

			fmt.Printf("Parallel multiplication of matrix %sx%s with %d threads took %s\n", strconv.Itoa(size), strconv.Itoa(size), threads, elapsed)
		}
	}
}

func main() {
	// Список размерностей матриц
	sizes := []int{100, 300, 500}

	// Список количества потоков для тестирования
	threadsList := []int{2, 4, 6, 8}

	// Запуск бенчмарка
	benchmarkMatrixMultiplication(sizes, threadsList)
}