DAS_2024_1/afanasev_dmitry_lab_5/main/src/MatrixMultiplier.java

121 lines
3.9 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
public class MatrixMultiplier {
private final int[][] matrixA;
private final int[][] matrixB;
private final int[][] result;
private final int size;
public MatrixMultiplier(int size) {
this.size = size;
this.matrixA = generateMatrix(size);
this.matrixB = generateMatrix(size);
this.result = new int[size][size];
}
private int[][] generateMatrix(int size) {
int[][] matrix = new int[size][size];
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
matrix[i][j] = (int) (Math.random() * 10);
}
}
return matrix;
}
public void multiplySequential() {
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
for (int k = 0; k < size; k++) {
result[i][j] += matrixA[i][k] * matrixB[k][j];
}
}
}
}
public void multiplyParallel(int numThreads) throws InterruptedException {
if (numThreads == 1) {
multiplySequential();
return;
}
ExecutorService executor = Executors.newFixedThreadPool(numThreads);
int chunkSize = (int) Math.ceil((double) size / numThreads);
for (int thread = 0; thread < numThreads; thread++) {
final int startRow = thread * chunkSize;
final int endRow = Math.min(startRow + chunkSize, size);
executor.submit(() -> {
for (int i = startRow; i < endRow; i++) {
for (int j = 0; j < size; j++) {
for (int k = 0; k < size; k++) {
result[i][j] += matrixA[i][k] * matrixB[k][j];
}
}
}
});
}
executor.shutdown();
executor.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
}
private void resetResult() {
for (int i = 0; i < size; i++) {
Arrays.fill(result[i], 0);
}
}
static class Result {
int threads;
long time;
Result(int threads, long time) {
this.threads = threads;
this.time = time;
}
}
public static void main(String[] args) throws InterruptedException {
int[] matrixSizes = {100, 300, 500};
int[] threadCounts = {1, 2, 4, 6, 8, 10};
int runs = 5; // количество прогонов
for (int size : matrixSizes) {
System.out.println("\nРазмер матрицы: " + size + "x" + size);
MatrixMultiplier multiplier = new MatrixMultiplier(size);
List<Result> results = new ArrayList<>();
for (int threads : threadCounts) {
long totalDuration = 0;
for (int run = 0; run < runs; run++) {
multiplier.resetResult();
long startTime = System.nanoTime();
multiplier.multiplyParallel(threads);
long endTime = System.nanoTime();
totalDuration += (endTime - startTime);
}
long averageDuration = totalDuration / runs;
results.add(new Result(threads, averageDuration));
}
// Сортировка по времени выполнения
results.sort(Comparator.comparingLong(r -> r.time));
System.out.println("Результаты (среднее время за " + runs + " прогонов):");
for (Result result : results) {
System.out.printf("Потоки: %d, среднее время: %d нс\n", result.threads, result.time);
}
}
}
}