Advertisement
igoryanchik

SegTree with Eigen

Dec 5th, 2023 (edited)
73
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 8.10 KB | None | 0 0
  1. #include <iostream>
  2. #include <vector>
  3. #include <algorithm>
  4. #include <Eigen/Dense>
  5.  
  6. using namespace Eigen;
  7. using ll = long long;
  8.  
  9. template <class T>
  10. class SegmentTree
  11. {
  12. private:
  13.  
  14.     struct Node
  15.     {
  16.         T modify;
  17.         T query;
  18.     };
  19.  
  20.     int size;
  21.     T q_neutral;
  22.     T m_neutral;
  23.     T(*query)(const T&, const T&);
  24.     T(*modify)(const T&, const T&);
  25.     std::vector<Node> tree;
  26.  
  27.     void build(const std::vector<T>& a, int v, int left, int right)
  28.     {
  29.         if (left + 1 == right)
  30.             tree[v].query = a[left];
  31.         else
  32.         {
  33.             int mid = (left + right) >> 1;
  34.             build(a, 2 * v, left, mid);
  35.             build(a, 2 * v + 1, mid, right);
  36.             tree[v].query = query(tree[2 * v].query, tree[2 * v + 1].query);
  37.         }
  38.     }
  39.  
  40.     void propagate(int v, int left, int right)
  41.     {
  42.         if (left + 1 == right || tree[v].modify == m_neutral) return;
  43.  
  44.         tree[2 * v].modify = modify(tree[2 * v].modify, tree[v].modify);
  45.         tree[2 * v].query = modify(tree[2 * v].query, tree[v].modify);
  46.         tree[2 * v + 1].modify = modify(tree[2 * v + 1].modify, tree[v].modify);
  47.         tree[2 * v + 1].query = modify(tree[2 * v + 1].query, tree[v].modify);
  48.  
  49.         tree[v].modify = m_neutral;
  50.     }
  51.  
  52.  
  53.     T getQuery(int v, int start, int end, int left, int right)
  54.     {
  55.         propagate(v, left, right);
  56.         if (end <= left || right <= start)
  57.             return q_neutral;
  58.         if (start <= left && right <= end)
  59.             return tree[v].query;
  60.         else
  61.         {
  62.             int mid = (left + right) >> 1;
  63.             return query(getQuery(2 * v, start, end, left, mid),
  64.                 getQuery(2 * v + 1, start, end, mid, right));
  65.         }
  66.     }
  67.  
  68.     void update(int v, const T& x, int start, int end, int left, int right)
  69.     {
  70.         propagate(v, left, right);
  71.         if (end <= left || right <= start)
  72.             return;
  73.         if (start <= left && right <= end)
  74.         {
  75.             tree[v].modify = modify(tree[v].modify, x);
  76.             tree[v].query = modify(tree[v].query, x);
  77.             return;
  78.         }
  79.         else
  80.         {
  81.             int mid = (left + right) >> 1;
  82.             update(2 * v, x, start, end, left, mid);
  83.             update(2 * v + 1, x, start, end, mid, right);
  84.             tree[v].query = query(tree[2 * v].query, tree[2 * v + 1].query);
  85.         }
  86.  
  87.     }
  88.  
  89.     void get_init_vec_help(std::vector<T>& vec, int v, int left, int right)
  90.     {
  91.         if (left + 1 == right)
  92.             vec[left] = tree[v].query;
  93.         else
  94.         {
  95.             int mid = (left + right) >> 1;
  96.  
  97.             get_init_vec_help(vec, 2 * v, left, mid);
  98.             get_init_vec_help(vec, 2 * v + 1, mid, right);
  99.         }
  100.     }
  101.  
  102.     void full_propagate(int v, int left, int right)
  103.     {
  104.         if (left + 1 == right) return;
  105.  
  106.         propagate(v, left, right);
  107.  
  108.         int mid = (left + right) >> 1;
  109.         full_propagate(2 * v, left, mid);
  110.         full_propagate(2 * v + 1, mid, right);
  111.     }
  112.  
  113. public:
  114.  
  115.    
  116.     SegmentTree(int n, T(*_query)(const T&, const T&), T(*_modify)(const T&, const T&), const T& _q_neutral = T(), const T& _m_neutral = T())
  117.         : size(n)
  118.         , q_neutral(_q_neutral)
  119.         , m_neutral(_m_neutral)
  120.         , query(_query)
  121.         , modify(_modify)
  122.         , tree(n << 2, { m_neutral , q_neutral })
  123.     {
  124.         if (!_query || !_modify) throw "Nullptr";
  125.     }
  126.  
  127.     SegmentTree(const std::vector<T>& vec, T(*_query)(const T&, const T&), T(*_modify)(const T&, const T&), const T& _q_neutral = T(), const T& _m_neutral = T())
  128.         : SegmentTree(vec.size(), _query, _modify, _q_neutral, _m_neutral)
  129.     {
  130.         if (!_query || !_modify) throw "Nullptr";
  131.         build(vec, 1, 0, size);
  132.     }
  133.  
  134.     SegmentTree<T> operator =(const SegmentTree<T>& Tree)
  135.     {
  136.         if (this == Tree) return *this;
  137.  
  138.         size = Tree.size;
  139.         tree = Tree.tree;
  140.         q_neutral = Tree.q_neutral;
  141.         m_neutral = Tree.m_neutral;
  142.         modify = Tree.modify;
  143.         query = Tree.query;
  144.  
  145.         return *this;
  146.     }
  147.  
  148.     SegmentTree<T> merge(SegmentTree<T>& Tree)
  149.     {
  150.         if (q_neutral != Tree.q_neutral || m_neutral != Tree.m_neutral) return *this;
  151.  
  152.         Tree.full_propagate(1, 0, Tree.size);
  153.         full_propagate(1, 0, size);
  154.  
  155.         std::vector<T> init_vec = get_init_vec();
  156.         std::vector<T> Tree_init_vec = Tree.get_init_vec();
  157.  
  158.         init_vec.insert(init_vec.end(), Tree_init_vec.begin(), Tree_init_vec.end());
  159.         SegmentTree<T> rez(init_vec, query, modify, q_neutral, m_neutral);
  160.  
  161.         return rez;
  162.     }
  163.  
  164.     void update(int start, int end, const T& x)
  165.     {
  166.         update(1, x, start, end + 1, 0, size);
  167.     }
  168.  
  169.     T getQuery(int start, int end)
  170.     {
  171.         return getQuery(1, start, end + 1, 0, size);
  172.     }
  173.  
  174.     T getQueryNeutral() const
  175.     {
  176.         return q_neutral;
  177.     }
  178.  
  179.     T getModifyNeutral() const
  180.     {
  181.         return m_neutral;
  182.     }
  183.  
  184.     std::vector<T> get_init_vec()
  185.     {
  186.         std::vector<T> result(size, T());
  187.         get_init_vec_help(result, 1, 0, size);
  188.  
  189.         return result;
  190.     }
  191.  
  192.     template <class T>
  193.     friend std::ostream& operator << (std::ostream& out, const SegmentTree<T>& Tree);
  194. };
  195.  
  196. template <class T>
  197. std::ostream& operator << (std::ostream& out, const SegmentTree<T>& Tree)
  198. {
  199.     if (out)
  200.     {
  201.         for (auto it : Tree.tree)
  202.         {
  203.             std::cout << it.query << " " << it.modify << '\n';
  204.         }
  205.     }
  206.  
  207.     return out;
  208. }
  209.  
  210. template <class T>
  211. T modify_mat(const T& a, const T& b)
  212. {
  213.     MatrixXi maxtrix_sum_neutral(2, 2);
  214.     maxtrix_sum_neutral.setZero();
  215.  
  216.     if (b == maxtrix_sum_neutral) return a;
  217.     return b;
  218. }
  219.  
  220. template <class T>
  221. T query(const T& a, const T& b)
  222. {
  223.     return a + b;
  224. }
  225.  
  226. template <class T>
  227. T multiply(const T& a, const T& b)
  228. {
  229.     return a * b;
  230. }
  231.  
  232.  
  233. int main()
  234. {
  235.     setlocale(LC_ALL, "Russian");
  236.     /*std::vector<ll> vec = { 2,3,4,1,5,6 };
  237.  
  238.     ll(*ptr_modify_vec)(const ll&, const ll&) = &multiply;
  239.     ll(*ptr_query_vec)(const ll&, const ll&) = &query;
  240.  
  241.     SegmentTree<ll> tree_vec(vec, ptr_query_vec, ptr_modify_vec, 0, 1);
  242.  
  243.     std::cout << "Сумма всего отрезка: " << tree_vec.getQuery(0, 5) << '\n';
  244.     tree_vec.update(0, 5, 2);
  245.     std::cout << "Сумма всего отрезка после умножения всего отрезка на 2: " << tree_vec.getQuery(0, 5) << '\n';
  246.  
  247.  
  248.     std::cout << "\n////////////////////////////////////////////////////////\n";*/
  249.  
  250.     //Проверим работу ДО на матрицах 2х2
  251.     MatrixXi m1(2, 2), m2(2,2), m3(2,2), m4(2,2);
  252.     m1 << 1, 0, 0, 1;
  253.     m2 << 2, 2, 2, 2;
  254.     m3 << 3, 3, 3, 3;
  255.     m4 << 1, 3, 2, 3;
  256.  
  257.     MatrixXi maxtrix_sum_neutral(2, 2);
  258.     maxtrix_sum_neutral.setZero();
  259.    
  260.     std::vector<MatrixXi> vec_mat = { m1,m2,m3,m4 };
  261.  
  262.     MatrixXi(*ptr_modify)(const MatrixXi&, const MatrixXi&) = &modify_mat;
  263.     MatrixXi(*ptr_query)(const MatrixXi&, const MatrixXi&) = &query;
  264.     MatrixXi(*ptr_mult)(const MatrixXi&, const MatrixXi&) = &multiply;
  265.  
  266.     SegmentTree<MatrixXi> tree(vec_mat, ptr_query, ptr_modify, maxtrix_sum_neutral, maxtrix_sum_neutral);
  267.     std::vector<MatrixXi> init = tree.get_init_vec();
  268.  
  269.    
  270.     //std::cout <<"Четыре матрицы 2х2 для проверки работы ДО на объектах сложной структруы:\n"
  271.         //<< m1 << "\n\n" << m2 << "\n\n" << m3 << "\n\n" << m4 << '\n';
  272.  
  273.  
  274.     //std::cout << "Сумма всех четырех матриц:\n" << tree.getQuery(0,3) << '\n';
  275.  
  276.  
  277.     //std::cout << "\n////////////////////////////////////////////////////////\n";
  278.  
  279.     std::vector<MatrixXi> matrices(19); //Вектор на котором будем строить дерево отрезков
  280.  
  281.     MatrixXi MULT_NEUTRAL(6, 6), SUM_NEUTRAL(6, 6);
  282.     MULT_NEUTRAL.setIdentity();
  283.     SUM_NEUTRAL.setZero();
  284.  
  285.     // Заполнение каждой матрицы случайными значениями
  286.     for (int i = 0; i < matrices.size(); i++)
  287.         matrices[i] = MatrixXi::Random(6, 6);
  288.    
  289.     SegmentTree<MatrixXi> tree2(matrices, ptr_query, ptr_mult, SUM_NEUTRAL, MULT_NEUTRAL);
  290.     std::vector<MatrixXi> init_vec = tree2.get_init_vec();
  291.     MatrixXi all_sum = tree2.getQuery(0, 18);
  292.  
  293.     //Проверка корректности операции запроса
  294.     std::cout << "Проверяем верно ли выполняется запрос на сумму матриц: " << (tree2.getQuery(0, 1) == (init_vec[0] + init_vec[1])) << '\n';
  295.  
  296.     //Умножим первые четыре элемента на нулевую матрицу и проверим, что отложенные операции работают
  297.  
  298.     tree2.update(0, 3, SUM_NEUTRAL);
  299.  
  300.     for (int i = 0; i < 4; ++i)
  301.         all_sum = all_sum - init_vec[i];
  302.  
  303.     MatrixXi all_sum_new = tree2.getQuery(0, 18);
  304.  
  305.     std::cout << "Проверяем верно ли работают отложенные операции: " << (all_sum_new == all_sum) << '\n';
  306.  
  307.     return 0;
  308. }
  309.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement