Advertisement
slash0t

matrix lib

Oct 19th, 2024 (edited)
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 8.31 KB | None | 0 0
  1. // source http://lib.cp-algorithms.com/cp-algo/linalg/matrix.hpp
  2. // TODO: removedependencies
  3.  
  4. #include "../random/rng.hpp"
  5. #include "../math/common.hpp"
  6. #include "vector.hpp"
  7. #include <optional>
  8. #include <cassert>
  9. #include <array>
  10. namespace cp_algo::linalg {
  11.     enum gauss_mode {normal, reverse};
  12.     template<typename base_t>
  13.     struct matrix: valarray_base<matrix<base_t>, vec<base_t>> {
  14.         using base = base_t;
  15.         using Base = valarray_base<matrix<base>, vec<base>>;
  16.         using Base::Base;
  17.  
  18.         matrix(size_t n): Base(vec<base>(n), n) {}
  19.         matrix(size_t n, size_t m): Base(vec<base>(m), n) {}
  20.  
  21.         size_t n() const {return size(*this);}
  22.         size_t m() const {return n() ? size(row(0)) : 0;}
  23.         auto dim() const {return std::array{n(), m()};}
  24.  
  25.         auto& row(size_t i) {return (*this)[i];}
  26.         auto const& row(size_t i) const {return (*this)[i];}
  27.  
  28.         matrix& operator *=(base t) {for(auto &it: *this) it *= t; return *this;}
  29.         matrix operator *(base t) const {return matrix(*this) *= t;}
  30.  
  31.         // Make sure the result is matrix, not Base
  32.         matrix& operator*=(matrix const& t) {return *this = *this * t;}
  33.  
  34.         void read() {
  35.             for(auto &it: *this) {
  36.                 it.read();
  37.             }
  38.         }
  39.         void print() const {
  40.             for(auto const& it: *this) {
  41.                 it.print();
  42.             }
  43.         }
  44.  
  45.         static matrix block_diagonal(std::vector<matrix> const& blocks) {
  46.             size_t n = 0;
  47.             for(auto &it: blocks) {
  48.                 assert(it.n() == it.m());
  49.                 n += it.n();
  50.             }
  51.             matrix res(n);
  52.             n = 0;
  53.             for(auto &it: blocks) {
  54.                 for(size_t i = 0; i < it.n(); i++) {
  55.                     res[n + i][std::slice(n, it.n(), 1)] = it[i];
  56.                 }
  57.                 n += it.n();
  58.             }
  59.             return res;
  60.         }
  61.         static matrix random(size_t n, size_t m) {
  62.             matrix res(n, m);
  63.             std::ranges::generate(res, std::bind(vec<base>::random, m));
  64.             return res;
  65.         }
  66.         static matrix random(size_t n) {
  67.             return random(n, n);
  68.         }
  69.         static matrix eye(size_t n) {
  70.             matrix res(n);
  71.             for(size_t i = 0; i < n; i++) {
  72.                 res[i][i] = 1;
  73.             }
  74.             return res;
  75.         }
  76.  
  77.         // Concatenate matrices
  78.         matrix operator |(matrix const& b) const {
  79.             assert(n() == b.n());
  80.             matrix res(n(), m()+b.m());
  81.             for(size_t i = 0; i < n(); i++) {
  82.                 res[i] = row(i) | b[i];
  83.             }
  84.             return res;
  85.         }
  86.         matrix submatrix(auto slicex, auto slicey) const {
  87.             matrix res = (*this)[slicex];
  88.             for(auto &row: res) {
  89.                 row = vec<base>(row[slicey]);
  90.             }
  91.             return res;
  92.         }
  93.  
  94.         matrix T() const {
  95.             matrix res(m(), n());
  96.             for(size_t i = 0; i < n(); i++) {
  97.                 for(size_t j = 0; j < m(); j++) {
  98.                     res[j][i] = row(i)[j];
  99.                 }
  100.             }
  101.             return res;
  102.         }
  103.  
  104.         matrix operator *(matrix const& b) const {
  105.             assert(m() == b.n());
  106.             matrix res(n(), b.m());
  107.             for(size_t i = 0; i < n(); i++) {
  108.                 for(size_t j = 0; j < m(); j++) {
  109.                     res[i].add_scaled(b[j], row(i)[j]);
  110.                 }
  111.             }
  112.             return res.normalize();
  113.         }
  114.  
  115.         vec<base> apply(vec<base> const& x) const {
  116.             return (matrix(x) * *this)[0];
  117.         }
  118.  
  119.         matrix pow(uint64_t k) const {
  120.             assert(n() == m());
  121.             return bpow(*this, k, eye(n()));
  122.         }
  123.  
  124.         matrix& normalize() {
  125.             for(auto &it: *this) {
  126.                 it.normalize();
  127.             }
  128.             return *this;
  129.         }
  130.         template<gauss_mode mode = normal>
  131.         void eliminate(size_t i, size_t k) {
  132.             auto kinv = base(1) / row(i).normalize()[k];
  133.             for(size_t j = (mode == normal) * i; j < n(); j++) {
  134.                 if(j != i) {
  135.                     row(j).add_scaled(row(i), -row(j).normalize(k) * kinv);
  136.                 }
  137.             }
  138.         }
  139.         template<gauss_mode mode = normal>
  140.         void eliminate(size_t i) {
  141.             row(i).normalize();
  142.             for(size_t j = (mode == normal) * i; j < n(); j++) {
  143.                 if(j != i) {
  144.                     row(j).reduce_by(row(i));
  145.                 }
  146.             }
  147.         }
  148.         template<gauss_mode mode = normal>
  149.         matrix& gauss() {
  150.             for(size_t i = 0; i < n(); i++) {
  151.                 eliminate<mode>(i);
  152.             }
  153.             return normalize();
  154.         }
  155.         template<gauss_mode mode = normal>
  156.         auto echelonize(size_t lim) {
  157.             return gauss<mode>().sort_classify(lim);
  158.         }
  159.         template<gauss_mode mode = normal>
  160.         auto echelonize() {
  161.             return echelonize<mode>(m());
  162.         }
  163.  
  164.         size_t rank() const {
  165.             if(n() > m()) {
  166.                 return T().rank();
  167.             }
  168.             return size(matrix(*this).echelonize()[0]);
  169.         }
  170.  
  171.         base det() const {
  172.             assert(n() == m());
  173.             matrix b = *this;
  174.             b.echelonize();
  175.             base res = 1;
  176.             for(size_t i = 0; i < n(); i++) {
  177.                 res *= b[i][i];
  178.             }
  179.             return res;
  180.         }
  181.  
  182.         std::optional<matrix> inv() const {
  183.             assert(n() == m());
  184.             matrix b = *this | eye(n());
  185.             if(size(b.echelonize<reverse>(n())[0]) < n()) {
  186.                 return std::nullopt;
  187.             }
  188.             for(size_t i = 0; i < n(); i++) {
  189.                 b[i] *= base(1) / b[i][i];
  190.             }
  191.             return b.submatrix(std::slice(0, n(), 1), std::slice(n(), n(), 1));
  192.         }
  193.  
  194.         // Can also just run gauss on T() | eye(m)
  195.         // but it would be slower :(
  196.         auto kernel() const {
  197.             auto A = *this;
  198.             auto [pivots, free] = A.template echelonize<reverse>();
  199.             matrix sols(size(free), m());
  200.             for(size_t j = 0; j < size(pivots); j++) {
  201.                 base scale = base(1) / A[j][pivots[j]];
  202.                 for(size_t i = 0; i < size(free); i++) {
  203.                     sols[i][pivots[j]] = A[j][free[i]] * scale;
  204.                 }
  205.             }
  206.             for(size_t i = 0; i < size(free); i++) {
  207.                 sols[i][free[i]] = -1;
  208.             }
  209.             return sols;
  210.         }
  211.  
  212.         // [solution, basis], transposed
  213.         std::optional<std::array<matrix, 2>> solve(matrix t) const {
  214.             matrix sols = (*this | t).kernel();
  215.             if(sols.n() < t.m() || sols.submatrix(
  216.                 std::slice(sols.n() - t.m(), t.m(), 1),
  217.                 std::slice(m(), t.m(), 1)
  218.             ) != -eye(t.m())) {
  219.                 return std::nullopt;
  220.             } else {
  221.                 return std::array{
  222.                     sols.submatrix(std::slice(sols.n() - t.m(), t.m(), 1),
  223.                                    std::slice(0, m(), 1)),
  224.                     sols.submatrix(std::slice(0, sols.n() - t.m(), 1),
  225.                                    std::slice(0, m(), 1))
  226.                 };
  227.             }
  228.         }
  229.     private:
  230.         // To be called after a gaussian elimination run
  231.         // Sorts rows by pivots and classifies
  232.         // variables into pivots and free
  233.         auto sort_classify(size_t lim) {
  234.             size_t rk = 0;
  235.             std::vector<size_t> free, pivots;
  236.             for(size_t j = 0; j < lim; j++) {
  237.                 for(size_t i = rk + 1; i < n() && row(rk)[j] == base(0); i++) {
  238.                     if(row(i)[j] != base(0)) {
  239.                         std::swap(row(i), row(rk));
  240.                         row(rk) = -row(rk);
  241.                     }
  242.                 }
  243.                 if(rk < n() && row(rk)[j] != base(0)) {
  244.                     pivots.push_back(j);
  245.                     rk++;
  246.                 } else {
  247.                     free.push_back(j);
  248.                 }
  249.             }
  250.             return std::array{pivots, free};
  251.         }
  252.     };
  253.     template<typename base_t>
  254.     auto operator *(base_t t, matrix<base_t> const& A) {return A * t;}
  255. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement