Advertisement
mquinlan

MatrixMultiplyTest

Oct 1st, 2024
139
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C# 4.46 KB | Source Code | 0 0
  1. using System.Diagnostics;
  2. using BenchmarkDotNet.Attributes;
  3. using BenchmarkDotNet.Running;
  4.  
  5. namespace MatrixMultiplyTest;
  6.  
  7. public class Program
  8. {
  9.     private const int MatrixSize = 1000;
  10.  
  11.     private static readonly double[,] TestMatrixA = new double[MatrixSize, MatrixSize];
  12.     private static readonly double[,] TestMatrixB = new double[MatrixSize, MatrixSize];
  13.  
  14.     static Program()
  15.     {
  16.         var stopwatch = Stopwatch.StartNew();
  17.         var rand = new Random(0);
  18.         for (var i = 0; i < MatrixSize; i++)
  19.         {
  20.             for (var j = 0; j < MatrixSize; j++)
  21.             {
  22.                 TestMatrixA[i, j] = rand.NextDouble();
  23.                 TestMatrixB[i, j] = rand.NextDouble();
  24.             }
  25.         }
  26.         stopwatch.Stop();
  27.         Console.WriteLine($"Initialized in {stopwatch.ElapsedMilliseconds} ms.");
  28.     }
  29.    
  30.     private static double[,] ParallelMultiplicationTest(double[,] matrixA, double[,] matrixB)
  31.     {
  32.         var rowsA = matrixA.GetLength(0);
  33.         var colsA = matrixA.GetLength(1);
  34.         var rowsB = matrixB.GetLength(0);
  35.         var colsB = matrixB.GetLength(1);
  36.         Debug.Assert(colsA == rowsB);
  37.        
  38.         var result = new double[rowsA, colsB];
  39.         Parallel.For(0, rowsA, i =>
  40.         {
  41.             for (var j = 0; j < colsB; j++)
  42.             {
  43.                 result[i, j] = 0.0;
  44.                 for (var k = 0; k < colsA; k++)
  45.                 {
  46.                     result[i, j] += matrixA[i, k] * matrixB[k, j];
  47.                 }
  48.             }
  49.         });
  50.         return result;
  51.     }
  52.    
  53.     private static double[,] SerialMultiplicationTest(double[,] matrixA, double[,] matrixB)
  54.     {
  55.         var rowsA = matrixA.GetLength(0);
  56.         var colsA = matrixA.GetLength(1);
  57.         var rowsB = matrixB.GetLength(0);
  58.         var colsB = matrixB.GetLength(1);
  59.         Debug.Assert(colsA == rowsB);
  60.        
  61.         var result = new double[rowsA, colsB];
  62.         for (var i = 0; i < rowsA; i++)
  63.         {
  64.             for (var j = 0; j < colsB; j++)
  65.             {
  66.                 result[i, j] = 0.0;
  67.                 for (var k = 0; k < colsA; k++)
  68.                 {
  69.                     result[i, j] += matrixA[i, k] * matrixB[k, j];
  70.                 }
  71.             }
  72.         }
  73.         return result;
  74.     }
  75.    
  76.     [Benchmark]
  77.     public void ParallelMultiplication() => ParallelMultiplicationTest(TestMatrixA, TestMatrixB);
  78.    
  79.     [Benchmark]
  80.     public void SerialMultiplication() => SerialMultiplicationTest(TestMatrixA, TestMatrixB);
  81.    
  82.     static void Main()
  83.     {
  84.         Validate();
  85. #if RELEASE
  86.         _ = BenchmarkRunner.Run<Program>();
  87. #endif
  88.     }
  89.  
  90.     private static void Validate()
  91.     {
  92.         var sStopwatch = Stopwatch.StartNew();
  93.         var sResult = SerialMultiplicationTest(TestMatrixA, TestMatrixB);
  94.         sStopwatch.Stop();
  95.  
  96.         var pStopwatch = Stopwatch.StartNew();
  97.         var pResult = ParallelMultiplicationTest(TestMatrixA, TestMatrixB);
  98.         pStopwatch.Stop();
  99.        
  100.         Console.WriteLine($"Parallel time: {pStopwatch.ElapsedMilliseconds} ms. " +
  101.                           $"Serial time: {sStopwatch.ElapsedMilliseconds} ms.");
  102.        
  103.         if (pResult.GetLength(0) != sResult.GetLength(0))
  104.         {
  105.             throw new Exception($"pResult.GetLength(0) {pResult.GetLength(0)} != sResult.GetLength(0) {sResult.GetLength(0)}");
  106.         }
  107.  
  108.         if (pResult.GetLength(1) != sResult.GetLength(1))
  109.         {
  110.             throw new Exception($"pResult.GetLength(1) {pResult.GetLength(1)} != sResult.GetLength(1) {sResult.GetLength(1)}");
  111.         }
  112.  
  113.         var maxDiff = double.MinValue;
  114.         for (var i = 0; i < pResult.GetLength(0); i++)
  115.         {
  116.             for (var j = 0; j < pResult.GetLength(1); j++)
  117.             {
  118.                 var diff = Math.Abs(pResult[i, j] - sResult[i, j]);
  119.                 if (diff > maxDiff) maxDiff = diff;
  120.             }
  121.         }
  122.         Console.WriteLine($"maxDiff = {maxDiff}");
  123.  
  124.         for (var i = 0; i < pResult.GetLength(0); i++)
  125.         {
  126.             for (var j = 0; j < pResult.GetLength(1); j++)
  127.             {
  128.                 // ReSharper disable once CompareOfFloatsByEqualityOperator
  129.                 if (pResult[i, j] != sResult[i, j])
  130.                 {
  131.                     throw new Exception($"pResult[{i}, {j}] {pResult[i, j]} != sResult[{i}, {j}] {sResult[i, j]}");
  132.                 }
  133.             }
  134.         }
  135.        
  136.         Console.WriteLine("Validated");
  137.     }
  138. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement