Advertisement
Pro808

Untitled

Jan 26th, 2025 (edited)
143
0
3 hours
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Rust 5.22 KB | Source Code | 0 0
  1. #[dependencies]
  2. #rayon = "1.10.0"
  3. #nalgebra = "0.33.2"
  4.  
  5. use rayon::prelude::*;
  6. use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_mul_pd, _mm256_storeu_pd};
  7. use nalgebra::SMatrix;
  8.  
  9. pub const BLOCK_SIZE: usize = 4;
  10.  
  11. pub fn transpose_matrix(matrix: &Vec<f64>, rows: usize, cols: usize) -> Vec<f64> {
  12.     let mut transposed = vec![0.0; rows * cols];
  13.     for row in 0..rows {
  14.         for col in 0..cols {
  15.             transposed[col * rows + row] = matrix[row * cols + col];
  16.         }
  17.     }
  18.     transposed
  19. }
  20.  
  21. pub unsafe fn multiply_matrix_rayon(
  22.     matrix_a: &Vec<f64>,
  23.     matrix_b: &Vec<f64>,
  24.     cols_a_rows_b: usize,
  25.     rows_a: usize,
  26.     cols_b: usize,
  27. ) -> Vec<f64> {
  28.     let mut matrix_c = vec![0.0; rows_a * cols_b];
  29.     matrix_c
  30.         .par_chunks_mut(cols_b)
  31.         .enumerate()
  32.         .for_each(|(row_idx, row_c)| {
  33.             for col_idx in 0..cols_b {
  34.                 let mut sum = 0.0;
  35.                 for k in 0..cols_a_rows_b {
  36.                     sum += matrix_a[row_idx * cols_a_rows_b + k] * matrix_b[k * cols_b + col_idx];
  37.                 }
  38.                 row_c[col_idx] = sum;
  39.             }
  40.         });
  41.  
  42.     matrix_c
  43. }
  44.  
  45. pub unsafe fn multiply_matrix_rayon_simd(
  46.     matrix_a: &Vec<f64>,
  47.     matrix_b: &Vec<f64>,
  48.     rows_a: usize,
  49.     cols_a_rows_b: usize,
  50.     cols_b: usize,
  51. ) -> Vec<f64> {
  52.     let mut matrix_c = vec![0.0; rows_a * cols_b];
  53.     matrix_c
  54.         .par_chunks_mut(cols_b)
  55.         .enumerate()
  56.         .for_each(|(row_idx, row_c)| {
  57.             for col_idx in 0..cols_b {
  58.                 let mut sum = 0.0;
  59.                 for k in 0..cols_a_rows_b {
  60.                     sum += matrix_a[row_idx * cols_a_rows_b + k] * matrix_b[k * cols_b + col_idx];
  61.                 }
  62.                 row_c[col_idx] = sum;
  63.             }
  64.         });
  65.  
  66.     matrix_c
  67. }
  68.  
  69.  
  70. pub unsafe fn multiply_matrix_simd(
  71.     matrix_a: &Vec<f64>,
  72.     matrix_b: &Vec<f64>,
  73.     cols_a_rows_b: usize,
  74.     rows_a: usize,
  75.     cols_b: usize,
  76. ) -> Vec<f64> {
  77.     let mut matrix_c = vec![0.0; rows_a * cols_b];
  78.     let b_t = transpose_matrix(matrix_b, cols_a_rows_b, cols_b);
  79.     if cols_a_rows_b % BLOCK_SIZE != 0 {
  80.         for row_idx in 0..rows_a {
  81.             for col_idx in 0..cols_b {
  82.                 let mut sum = 0.0;
  83.                 for block_start in (0..cols_a_rows_b).step_by(BLOCK_SIZE) {
  84.                     let block_end = std::cmp::min(block_start + BLOCK_SIZE, cols_a_rows_b);
  85.  
  86.                     sum += unsafe {
  87.                         compute_block_sum(
  88.                             matrix_a,
  89.                             &b_t,
  90.                             row_idx,
  91.                             col_idx,
  92.                             cols_a_rows_b,
  93.                             block_start,
  94.                             block_end,
  95.                         )
  96.                     };
  97.                 }
  98.                 matrix_c[row_idx * cols_b + col_idx] = sum;
  99.             }
  100.         }
  101.     } else {
  102.         for row_idx in 0..rows_a {
  103.             for col_idx in 0..cols_b {
  104.                 let mut sum = 0.0;
  105.                 for block_start in (0..cols_a_rows_b).step_by(BLOCK_SIZE) {
  106.                     let block_end = block_start + BLOCK_SIZE;
  107.  
  108.                     sum += unsafe {
  109.                         compute_block_sum(
  110.                             matrix_a,
  111.                             &b_t,
  112.                             row_idx,
  113.                             col_idx,
  114.                             cols_a_rows_b,
  115.                             block_start,
  116.                             block_end,
  117.                         )
  118.                     };
  119.                 }
  120.                 matrix_c[row_idx * cols_b + col_idx] = sum;
  121.             }
  122.         }
  123.     }
  124.     matrix_c
  125. }
  126.  
  127. #[inline(always)]
  128. pub unsafe fn compute_block_sum(
  129.     matrix_a: &Vec<f64>,
  130.     matrix_b: &Vec<f64>,
  131.     row_idx: usize,
  132.     col_idx: usize,
  133.     cols_a_rows_b: usize,
  134.     block_start: usize,
  135.     block_end: usize,
  136. ) -> f64 {
  137.     let a_ptr = matrix_a.as_ptr().add(row_idx * cols_a_rows_b + block_start);
  138.     let b_ptr = matrix_b.as_ptr().add(col_idx * cols_a_rows_b + block_start);
  139.  
  140.     let a_vec: __m256d = _mm256_loadu_pd(a_ptr);
  141.     let b_vec: __m256d = _mm256_loadu_pd(b_ptr);
  142.  
  143.     let product = _mm256_mul_pd(a_vec, b_vec);
  144.  
  145.     let mut block_sum = [0.0; BLOCK_SIZE];
  146.     _mm256_storeu_pd(block_sum.as_mut_ptr(), product);
  147.  
  148.     block_sum.iter().take(block_end - block_start).sum()
  149. }
  150.  
  151.  
  152. unsafe fn multiply_matrix_nalgebra_as_test() {
  153.     type Matrix128x128 = SMatrix<f64,128,128>;
  154.  
  155.     let mut matrix1 = Matrix128x128::zeros();
  156.     let mut matrix2 = Matrix128x128::zeros();
  157.     matrix1.fill(1.1);
  158.     matrix2.fill(2.2);
  159.     let c: Matrix128x128 = matrix1 * matrix2;
  160. }
  161.  
  162. #[cfg(test)]
  163. mod test {
  164.     use super::*;
  165.  
  166.     #[test]
  167.     fn test_matrix_simd() {
  168.         let a: Vec<f64> = vec![1.0, 2.0];
  169.         let b: Vec<f64> = vec![-3.0, 5.0, 4.0, -6.0];
  170.         let cols_a_rows_b = 2;
  171.         let rows_a = 1;
  172.         let cols_b = 2;
  173.         let c = unsafe { multiply_matrix_simd(&a, &b, cols_a_rows_b, rows_a, cols_b) };
  174.         assert_eq!(c, vec![5.0, -7.0]);
  175.         let c = unsafe { multiply_matrix_rayon(&a, &b, cols_a_rows_b, rows_a, cols_b) };
  176.         assert_eq!(c, vec![5.0, -7.0]);
  177.     }
  178. }
  179.  
  180.  
Tags: rust
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement