Advertisement
dipBRUR

Persistent Segment Tree - COT Count on a tree

Oct 3rd, 2018
143
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 6.87 KB | None | 0 0
  1. /**
  2.   * Persistent Segment Tree + Binary Search + Fast IO = O(n logn logn) - TLE
  3.   * Persistent Segment Tree + Query + Fast IO         = O(n logn) - AC
  4.   * Note            : Works for duplicate values
  5. **/
  6. #include <bits/stdc++.h>
  7.  
  8. using namespace std;
  9.  
  10. static const int MAXN = 1e5 + 5;
  11. static const int LOGN = 19;
  12.  
  13. map <int, int> valTOind;
  14. int indTOval[MAXN];
  15.  
  16. int tNode, tQuery, n;
  17. int nodeVal[MAXN];// temp[MAXN];
  18. vector <int> graph[MAXN];
  19.  
  20. int readInt()
  21. {
  22.     bool minus = false;
  23.     int result = 0;
  24.     char ch;
  25.     ch = getchar();
  26.     while (true)
  27.     {
  28.         if (ch == '-')
  29.             break;
  30.         if (ch >= '0' && ch <= '9') break;
  31.             ch = getchar();
  32.     }
  33.     if (ch == '-')
  34.         minus = true; else result = ch-'0';
  35.     while (true)
  36.     {
  37.         ch = getchar();
  38.         if (ch < '0' || ch > '9')
  39.             break;
  40.         result = result*10 + (ch - '0');
  41.     }
  42.     if (minus)
  43.         return -result;
  44.     else
  45.         return result;
  46. }
  47.  
  48. struct node
  49. {
  50.     int val;
  51.     node *left, *right;
  52.     node(int val)
  53.     {
  54.         this->val = val;
  55.         this->left = NULL;
  56.         this->right = NULL;
  57.     }
  58. } *version[MAXN<<1];
  59.  
  60. void build(node *root, int a, int b)
  61. {
  62.     if (a > b)
  63.         return;
  64.     if (a == b)
  65.     {
  66.         root->val = 0;
  67.         return;
  68.     }
  69.     int mid = (a+b)>>1;
  70.     root->left = new node(0);
  71.     root->right = new node(0);
  72.     build(root->left, a, mid);
  73.     build(root->right, mid+1, b);
  74.     root->val = root->left->val + root->right->val;
  75. }
  76.  
  77. void update(node *proot, node *root, int a, int b, int pos)
  78. {
  79.     if (a > b || a > pos || b < pos)
  80.         return;
  81.     if (a >= pos && b <= pos)
  82.     {
  83.         root->val += 1;
  84.         return;
  85.     }
  86.     int mid = (a+b)>>1;
  87.     if (pos <= mid)
  88.     {
  89.         root->left = new node(0);
  90.         root->right = proot->right;
  91.         update(proot->left, root->left, a, mid, pos);
  92.     }
  93.     else
  94.     {
  95.         root->left = proot->left;
  96.         root->right = new node(0);
  97.         update(proot->right, root->right, mid+1, b, pos);
  98.     }
  99.     root->val = root->left->val + root->right->val;
  100. }
  101. /**
  102. int query(node *root, int a, int b, int i, int j)
  103. {
  104.     if (a > b || a > j || b < i)
  105.         return 0;
  106.     if (a >= i && b <= j)
  107.         return root->val;
  108.     int mid = (a+b)>>1;
  109.     int p1 = query(root->left, a, mid, i, j);
  110.     int p2 = query(root->right, mid+1, b, i, j);
  111.     return p1+p2;
  112. }
  113. int Count(node *root_u, node *root_v, node *root_lca, node *root_plca, int k)
  114. {
  115.     int sum = 0;
  116.         sum += query(root_u, 1, n, 1, k);
  117.         sum += query(root_v, 1, n, 1, k);
  118.         sum -= query(root_lca, 1, n, 1, k);
  119.         sum -= query(root_plca, 1, n, 1, k);
  120.     return sum;
  121. }
  122. int binarySearch(node *root_u, node *root_v, node *root_lca, node *root_plca,
  123.                  int k)  // l : version[u], r : version[v]
  124. {
  125.     int low = 1;
  126.     int high = n;
  127.     int ans;
  128.     while (low <= high)
  129.     {
  130.         int mid = (low+high)>>1;
  131.         int cnt = Count(root_u, root_v, root_lca, root_plca, mid);
  132.         if (cnt >= k)
  133.         {
  134.             ans = mid;
  135.             high = mid-1;
  136.         }
  137.  
  138.         else
  139.         {
  140.             low = mid+1;
  141.         }
  142.     }
  143.     return ans;
  144. }
  145. **/
  146. int query1(node *root_u, node *root_v, node *root_lca, node *root_plca, int a,
  147.            int b, int k)
  148. {
  149.     if (a == b)
  150.         return a;
  151.  
  152.     int sum = 0;
  153.         sum += root_u->left->val;
  154.         sum += root_v->left->val;
  155.         sum -= root_lca->left->val;
  156.         sum -= root_plca->left->val;
  157.     int mid = (a+b)>>1;
  158.     if (sum >= k)
  159.         return query1(root_u->left, root_v->left, root_lca->left, root_plca->left,
  160.                       a, mid, k);
  161.     else
  162.         return query1(root_u->right, root_v->right, root_lca->right,
  163.                       root_plca->right, mid+1, b, k - sum);
  164. }
  165. int father[MAXN][LOGN];
  166. int depth[MAXN];
  167.  
  168. void dfs(int u, int p = -1)
  169. {
  170.     for (int i = 1; i < LOGN; i++)
  171.         father[u][i] = father[father[u][i-1]][i-1];
  172.     for (int v : graph[u])
  173.     {
  174.         if (v == p)
  175.             continue;
  176.         father[v][0] = u;
  177.         depth[v] = depth[u] + 1;
  178.         dfs(v, u);
  179.     }
  180. }
  181.  
  182. int LCA(int u, int v)
  183. {
  184.     if (depth[u] < depth[v])
  185.         swap(u, v);
  186.     for (int i = LOGN-1; i >= 0; i--)
  187.     {
  188.         if (depth[father[u][i]] >= depth[v])
  189.         {
  190.             u = father[u][i];
  191.         }
  192.     }
  193.     if (u == v)
  194.         return u;
  195.     for (int i = LOGN-1; i >= 0; i--)
  196.     {
  197.         if (father[u][i] != father[v][i])
  198.         {
  199.             u = father[u][i];
  200.             v = father[v][i];
  201.         }
  202.     }
  203.     return father[u][0];
  204. }
  205.  
  206. int ver[MAXN];
  207. bool vis[MAXN];
  208.  
  209. void bfs(int src)
  210. {
  211.     memset(vis, 0, sizeof vis);
  212.     queue <int> PQ;
  213.     PQ.push(src);
  214.     int vr = 1;
  215.     ver[src] = vr;
  216.     vis[src] = 1;
  217.     version[ver[src]] = new node(0);
  218.     update(version[0], version[ver[src]], 1, n, nodeVal[src]);
  219.     while (!PQ.empty())
  220.     {
  221.         int u = PQ.front(); PQ.pop();
  222.         for (int v : graph[u])
  223.         {
  224.             if (vis[v])
  225.                 continue;
  226.             vis[v] = 1;
  227.             vr++;
  228.             ver[v] = vr;
  229.             version[ver[v]] = new node(0);
  230.             update(version[ver[u]], version[ver[v]], 1, n, nodeVal[v]);
  231.             PQ.push(v);
  232.         }
  233.     }
  234.  
  235. }
  236.  
  237. struct structure
  238. {
  239.     int val, ind;
  240.     structure() {}
  241.     structure(int val, int ind)
  242.     {
  243.         this->val = val;
  244.         this->ind = ind;
  245.     }
  246.     friend bool operator<(structure A, structure B)
  247.     {
  248.         if (A.val == B.val)
  249.             return A.ind < B.ind;
  250.         else
  251.             return A.val < B.val;
  252.     }
  253. } temp[MAXN];
  254.  
  255. int main()
  256. {
  257.     tNode = readInt();
  258.     tQuery = readInt();
  259.  
  260.     for (int i = 1; i <= tNode; i++)
  261.     {
  262.         nodeVal[i] = readInt();
  263.         temp[i] = {nodeVal[i], i};
  264.     }
  265.     for (int e = 1; e < tNode; e++)
  266.     {
  267.         int u, v;
  268.         u = readInt();
  269.         v = readInt();
  270.         graph[u].push_back(v);
  271.         graph[v].push_back(u);
  272.     }
  273.     sort(temp+1, temp+tNode+1);
  274.     int id = 0;
  275.     for (int i = 1; i <= tNode; i++)
  276.     {
  277.         int num = temp[i].val;
  278.         int ind = temp[i].ind;
  279.         id++;
  280.         nodeVal[ind] = id;
  281.         indTOval[id] = num;
  282.     }
  283.     n = id;
  284.     version[0] = new node(0);
  285.     build(version[0], 1, n);
  286.     bfs(1);
  287.     depth[1] = 1;
  288.     dfs(1);
  289.     while (tQuery--)
  290.     {
  291.         int u, v, k;
  292.         u = readInt();
  293.         v = readInt();
  294.         k = readInt();
  295.         int lca = LCA(u, v);
  296.         int plca = father[lca][0];
  297.         //int ansb = binarySearch(version[ver[u]], version[ver[v]],
  298.                                   version[ver[lca]], version[ver[plca]], k);
  299.         int ans = query1(version[ver[u]], version[ver[v]], version[ver[lca]],
  300.                          version[ver[plca]], 1, n, k);
  301.         printf("%d\n", indTOval[ans]);
  302.     }
  303. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement