Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <algorithm>
- #include <array>
- #include <bitset>
- #include <cassert>
- #include <chrono>
- #include <complex>
- #include <cstdio>
- #include <cstring>
- #include <deque>
- #include <iomanip>
- #include <iostream>
- #include <iterator>
- #include <list>
- #include <map>
- #include <memory>
- #include <numeric>
- #include <queue>
- #include <random>
- #include <set>
- #include <stack>
- #include <string>
- #include <tuple>
- #include <vector>
- using namespace std;
- using namespace std::chrono;
- #define all(x) begin(x), end(x)
- #define sz(x) (int) (x).size()
- const int mxn = 1e5+1;
- const int inf = 1e9;
- struct sparsemin {
- vector<int> a;
- int h;
- vector<vector<int> > st;
- sparsemin() {}
- sparsemin(vector<int>& arr) : a(arr) {
- h = log2(a.size()) + 1;
- st.resize(h+1);
- for (int i = 0; i < a.size(); i++) st[0].push_back(i);
- for (int i = 1; i <= h; i++) {
- for (int j = 0; j + (1<<i) <= arr.size(); j++) {
- if (a[st[i-1][j]] < a[st[i-1][j + (1<<(i-1))]]) {
- st[i].push_back(st[i-1][j]);
- } else {
- st[i].push_back(st[i-1][j + (1<<(i-1))]);
- }
- }
- }
- }
- int rmq(int l, int r) {
- int i = log2(r - l + 1);
- if (a[st[i][l]] < a[st[i][r - (1<<i) + 1]]) return st[i][l];
- return st[i][r - (1<<i) + 1];
- }
- };
- int n, q, sick[mxn], t;
- int dep[mxn], mxd[mxn], tin[mxn], tot[mxn], sub[mxn], par[mxn];
- int ccrt, is_centroid[mxn], cpar[mxn];
- vector<int> etour, dtour, adj[mxn], cdj[mxn];
- sparsemin st;
- //begin tree
- void dfs(int v, int p, int d) {
- dep[v] = d, tin[v] = sz(etour);
- par[v] = p;
- etour.push_back(v);
- dtour.push_back(dep[v]);
- sub[v] = 1;
- if (v > 1) {
- adj[v].erase(find(all(adj[v]), p));
- }
- for (int u : adj[v]) {
- dfs(u, v, d+1);
- tot[v] = sz(etour);
- etour.push_back(v);
- dtour.push_back(dep[v]);
- sub[v] += sub[u];
- }
- tot[v] = sz(etour);
- etour.push_back(v);
- dtour.push_back(dep[v]);
- }
- int lca(int l, int r) {
- int mnv = min(tin[l], tin[r]);
- int mxv = max(tin[l], tin[r]);
- return etour[st.rmq(mnv,mxv)];
- }
- int dist(int u, int v) {
- return dep[v] + dep[u] - 2 * dep[lca(u,v)];
- }
- int get_centroid(int pos, int csz) {
- if (csz <= 2) {
- is_centroid[pos] = 1;
- return pos;
- }
- int mxu = -1, val = -1;
- for (int u : adj[pos]) {
- if (!is_centroid[u]) {
- if (sub[u] >= val) {
- mxu = u, val = sub[u];
- }
- }
- }
- if (!is_centroid[par[pos]]) {
- if (csz - sub[pos] >= val) {
- mxu = par[pos], val = csz - sub[pos];
- }
- }
- if (val <= csz/2) {
- is_centroid[pos] = 1;
- return pos;
- }
- return get_centroid(mxu, csz);
- }
- int centroid_decompose(int v, int csz) {
- int cen = get_centroid(v, csz);
- if (csz == 1) return cen;
- int pv = par[cen];
- while (!is_centroid[pv]) {
- sub[pv] -= sub[cen];
- if (pv == 1) break;
- pv = par[pv];
- }
- for (int u : adj[cen]) {
- if (is_centroid[u]) continue;
- int nc = centroid_decompose(u, sub[u]);
- cdj[cen].push_back(nc);
- cpar[nc] = cen;
- }
- if (!is_centroid[par[cen]]) {
- int nc = centroid_decompose(par[cen], csz - sub[cen]);
- cdj[cen].push_back(nc);
- cpar[nc] = cen;
- }
- return cen;
- }
- //end tree
- int sick1[mxn], mxdus[mxn];
- vector<vector<int>> adt[mxn]; //coords with distance i
- vector<int> mdwd[mxn]; //min depth with dist
- bool ccsig(int a, int b) {
- return (dep[a] > dep[b]);
- }
- int get_high(int v) {
- int bv = inf, bi = -1;
- int v1 = v;
- while (1) {
- int d = dist(v, v1);
- if (d > t) {
- if (v1 == ccrt) break;
- v1 = cpar[v1];
- continue;
- }
- int dl = t - d;
- if (sz(mdwd[v1]) <= dl) {
- dl = sz(mdwd[v1]) - 1;
- }
- if (dl >= 0 && mdwd[v1][dl]) {
- int cp = mdwd[v1][dl];
- int cv = dep[cp];
- if (cv < bv) {
- bv = cv;
- bi = cp;
- }
- }
- if (v1 == ccrt) break;
- v1 = cpar[v1];
- }
- if (bi == -1) {
- return 0;
- }
- return bi;
- }
- void fillin(int vv) {
- int v1 = vv;
- while (1) {
- int dl = t - dist(vv, v1);
- if (dl < 0) {
- if (v1 == ccrt) break;
- v1 = cpar[v1];
- continue;
- }
- for (int i = 0; i <= dl; i++) {
- if (sz(adt[v1]) <= i) break;
- for (int j : adt[v1][i]) {
- sick1[j] = 0;
- // cout << v1 << " " << j << endl;
- }
- }
- if (v1 == ccrt) break;
- v1 = cpar[v1];
- continue;
- }
- }
- void solve() {
- cin >> t;
- for (int i = 1; i <= n; i++) sick1[i] = sick[i];
- for (int i = 1; i <= n; i++) {
- int ml = sz(mdwd[i]);
- fill(mdwd[i].begin(), mdwd[i].begin()+ml, 0);
- }
- //populate mdwd
- for (int i = 1; i <= n; i++) {
- if (mxd[i] <= t) continue;
- int v = i;
- while (1) {
- int dt = dist(v, i);
- int dep1 = dep[mdwd[v][dt]];
- int dep2 = dep[i];
- if (mdwd[v][dt] == 0 || dep2 < dep1) {
- mdwd[v][dt] = i;
- }
- if (v == ccrt) break;
- v = cpar[v];
- }
- }
- for (int i = 1; i <= n; i++) {
- for (int j = 1; j <= t; j++) {
- if (sz(mdwd[i]) <= j) break;
- if (!mdwd[i][j-1]) continue;
- if (!dep[mdwd[i][j]] || dep[mdwd[i][j-1]] < dep[mdwd[i][j]]) {
- mdwd[i][j] = mdwd[i][j-1];
- }
- }
- }
- vector<int> sig;
- for (int i = 1; i <= n; i++) sig.push_back(i);
- sort(all(sig), ccsig);
- int ans = 0;
- // for (int i = 0; i <= n; i++) {
- // cout << i << ": ";
- // for (auto j : mdwd[i]) cout << j << ' ';
- // cout << "\n";
- // }
- for (auto i : sig) {
- if (!sick1[i]) continue;
- ans++;
- int vv = get_high(i);
- if (!vv) {
- cout << "-1\n"; return;
- }
- fillin(vv);
- }
- cout << ans << "\n";
- }
- signed main() {
- ios::sync_with_stdio(false); cin.tie(nullptr);
- auto start = high_resolution_clock::now();
- cin >> n;
- for (int i = 0; i < n; i++) {
- char c; cin >> c;
- sick[i+1] = c-'0';
- }
- for (int i = 1; i < n; i++) {
- int u, v; cin >> u >> v;
- adj[u].push_back(v); adj[v].push_back(u);
- }
- dfs(1,1,0);
- st = sparsemin(dtour);
- for (int i = 1; i <= n; i++) mxd[i] = inf;
- queue<array<int,2> > bfs;
- for (int i = 1; i <= n; i++) {
- if (!sick[i]) {bfs.push({i,0});}
- }
- while (!bfs.empty()) {
- int v = bfs.front()[0], d = bfs.front()[1];
- bfs.pop();
- if (mxd[v] <= d) continue;
- mxd[v] = d;
- for (int u : adj[v]) bfs.push({u,d+1});
- bfs.push({par[v], d+1});
- }
- ccrt = centroid_decompose(1,n);
- for (int i = 1; i <= n; i++) {
- int v = i;
- while (1) {
- int d = dist(i, v);
- if (sz(adt[v]) <= d) adt[v].resize(d+1);
- adt[v][d].push_back(i);
- if (sz(mdwd[v]) <= d) mdwd[v].resize(d+1);
- if (v == ccrt) break;
- v = cpar[v];
- }
- }
- cin >> q;
- while (q--) {
- solve();
- }
- // cout << "1\n1\n1\n1\n2\n5";
- auto stop = high_resolution_clock::now();
- auto duration = duration_cast<microseconds>(stop - start);
- double d = duration.count()*0.001;
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement