package main

import (
	"fmt"
	"math/rand"
	"sync"
	"time"
)

func sequentialMatrixMultiply(matrixA, matrixB [][]float64) [][]float64 {
	rowsA, colsA := len(matrixA), len(matrixA[0])
	colsB := len(matrixB[0])
	result := make([][]float64, rowsA)
	for i := range result {
		result[i] = make([]float64, colsB)
	}

	for i := 0; i < rowsA; i++ {
		for j := 0; j < colsB; j++ {
			for k := 0; k < colsA; k++ {
				result[i][j] += matrixA[i][k] * matrixB[k][j]
			}
		}
	}
	return result
}

func parallelMatrixMultiply(matrixA, matrixB [][]float64, numProcesses int) [][]float64 {
	rowsA, colsA := len(matrixA), len(matrixA[0])
	colsB := len(matrixB[0])
	result := make([][]float64, rowsA)
	for i := range result {
		result[i] = make([]float64, colsB)
	}

	var wg sync.WaitGroup
	wg.Add(numProcesses)

	for i := 0; i < numProcesses; i++ {
		go func(id, startRow, endRow int) {
			defer wg.Done()
			for i := startRow; i < endRow; i++ {
				for j := 0; j < colsB; j++ {
					for k := 0; k < colsA; k++ {
						result[i][j] += matrixA[i][k] * matrixB[k][j]
					}
				}
			}
		}(i, i*rowsA/numProcesses, (i+1)*rowsA/numProcesses)
	}

	wg.Wait()
	return result
}

func runTest(matrixSize, numProcesses int) {
	matrixA := generateRandomMatrix(matrixSize, matrixSize)
	matrixB := generateRandomMatrix(matrixSize, matrixSize)

	startTime := time.Now()
	_ = sequentialMatrixMultiply(matrixA, matrixB)
	sequentialTime := time.Since(startTime)
	fmt.Printf("Sequential matrix multiplication took (%dx%d): %s\n", matrixSize, matrixSize, sequentialTime)

	startTime = time.Now()
	_ = parallelMatrixMultiply(matrixA, matrixB, numProcesses)
	parallelTime := time.Since(startTime)
	fmt.Printf("Parallel matrix multiplication with %d threads took (%dx%d): %s\n", numProcesses, matrixSize, matrixSize, parallelTime)
}

func generateRandomMatrix(rows, cols int) [][]float64 {
	matrix := make([][]float64, rows)
	for i := range matrix {
		matrix[i] = make([]float64, cols)
		for j := range matrix[i] {
			matrix[i][j] = rand.Float64()
		}
	}
	return matrix
}

func main() {
	rand.Seed(time.Now().UnixNano())

	// Benchmarks for matrices with sizes 100x100, 300x300, and 500x500 with different numbers of processes
	runTest(100, 2)
	runTest(100, 4)
	runTest(300, 2)
	runTest(300, 4)
	runTest(500, 2)
	runTest(500, 4)
}