Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <assert.h>
- #include <bits/stdc++.h>
- using namespace std;
- using ll = long long;
- using pii = pair<int, int>;
- const int MAXN = 3e5 + 3;
- int n;
- ll ans;
- int c[MAXN], x[MAXN];
- int idx[MAXN];
- int tmp_mg[MAXN];
- void sort(int beg, int end)
- {
- if (end - beg <= 1) {
- return;
- }
- int mid = (beg + end) / 2;
- sort(beg, mid);
- sort(mid, end);
- unordered_map<int, int> ccnt;
- int now = beg;
- int i, j;
- for (i = beg, j = mid; i < mid && j < end;) {
- if (x[idx[i]] <= x[idx[j]]) {
- ans += (j - mid) - ccnt[c[idx[i]]];
- tmp_mg[now++] = idx[i++];
- } else {
- ccnt[c[idx[j]]]++;
- tmp_mg[now++] = idx[j++];
- }
- }
- while (i < mid) {
- ans += (j - mid) - ccnt[c[idx[i]]];
- tmp_mg[now++] = idx[i++];
- }
- while (j < end) {
- tmp_mg[now++] = idx[j++];
- }
- assert(now == end);
- for (int i = beg; i < end; ++i) {
- idx[i] = tmp_mg[i];
- }
- }
- int main(int argc, char **argv)
- {
- scanf("%d", &n);
- for (int i = 0; i < n; ++i)
- idx[i] = i;
- for (int i = 0; i < n; ++i) {
- scanf("%d", c + i);
- }
- for (int i = 0; i < n; ++i) {
- scanf("%d", x + i);
- }
- ans = 0;
- sort(0, n);
- printf("%lld\n", ans);
- return 0;
- };
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement