题目链接 New Year Tree
考虑到$ck <= 60$,那么用位运算统计颜色种数
对于每个点,重新标号并算出他对应的进和出的时间,然后区间更新+查询。
用线段树来维护。
1 #include <bits/stdc++.h> 2 3 using namespace std; 4 5 #define rep(i, a, b) for (int i(a); i <= (b); ++i) 6 7 struct node{ 8 long long num, lazy; 9 } tree[400010 << 2]; 10 11 struct Node{ 12 int l, r; 13 } e[400010]; 14 15 vector <int> v[400010]; 16 17 int n, m; 18 long long val[400010], c[400010]; 19 int Time; 20 bool vis[400010]; 21 long long ans, cover; 22 int op; 23 int x, y; 24 25 void dfs(int x, int fa){ 26 e[x].l = ++Time; 27 val[Time] = c[x]; 28 vis[x] = true; 29 for (auto u : v[x]){ 30 if (u == fa) continue; 31 dfs(u, x); 32 } 33 34 e[x].r = Time; 35 } 36 37 inline void pushup(int i){ 38 tree[i].num = tree[i << 1].num | tree[i << 1 | 1].num; 39 } 40 41 inline void pushdown(int i){ 42 if (tree[i].lazy){ 43 tree[i << 1].num = tree[i << 1 | 1].num = (1LL << tree[i].lazy); 44 tree[i << 1].lazy = tree[i << 1 | 1].lazy = tree[i].lazy; 45 tree[i].lazy = 0; 46 } 47 } 48 49 void build(int i, int l, int r){ 50 tree[i].lazy = 0; 51 if (l == r){ 52 tree[i].num = (1LL << val[l]); 53 return ; 54 } 55 56 int mid = (l + r) >> 1; 57 build(i << 1, l, mid); 58 build(i << 1 | 1, mid + 1, r); 59 pushup(i); 60 } 61 62 void update(int i, int L, int R, int l, int r, long long cover){ 63 if (l <= L && R <= r){ 64 tree[i].lazy = cover; 65 tree[i].num = (1LL << cover); 66 return ; 67 } 68 69 int mid = (L + R) >> 1; 70 pushdown(i); 71 if (l <= mid) update(i << 1, L, mid, l, r, cover); 72 if (r > mid) update(i << 1 | 1, mid + 1, R, l, r, cover); 73 pushup(i); 74 } 75 76 void solve(int i, int L, int R, int l, int r){ 77 if (l <= L && R <= r){ 78 ans |= tree[i].num; 79 return; 80 } 81 82 pushdown(i); 83 int mid = (L + R) >> 1; 84 if (l <= mid) solve(i << 1, L, mid, l, r); 85 if (r > mid) solve(i << 1 | 1, mid + 1, R, l, r); 86 } 87 88 int main(){ 89 90 scanf("%d%d", &n, &m); 91 92 rep(i, 1, n) v[i].clear(); 93 rep(i, 1, n) scanf("%lld", c + i); 94 rep(i, 1, n - 1){ 95 scanf("%d%d", &x, &y); 96 v[x].push_back(y); 97 v[y].push_back(x); 98 } 99 100 memset(vis, 0, sizeof vis); Time = 0; 101 dfs(1, 0); 102 build(1, 1, n); 103 104 rep(i, 1, m){ 105 scanf("%d%d", &op, &x); 106 if (op == 1){ 107 scanf("%lld", &cover); 108 update(1, 1, n, e[x].l, e[x].r, cover); 109 } 110 111 else{ 112 ans = 0; 113 solve(1, 1, n, e[x].l, e[x].r); 114 int ret = 0; 115 for (; ans; ans -= ans & -ans) ++ret; 116 printf("%d ", ret); 117 } 118 } 119 120 return 0; 121 }