解:首先有个套路是一条边的权值是[两端点颜色不同]。这个用树剖直接维护,支持修改。
每次询问建虚树,查询虚树上每条边的权值。然后树形DP,用开店的方法,每个点链加链查。
1 #include <bits/stdc++.h> 2 3 #define forson(x, i) for(int i = e[x]; i; i = edge[i].nex) 4 5 typedef long long LL; 6 const int N = 100010; 7 8 struct Edge { 9 int nex, v; 10 LL len; 11 }edge[N << 1], EDGE[N]; int tp, TP; 12 13 int e[N], top[N], fa[N], son[N], siz[N], d[N], pos[N], id[N], num, val[N], n, imp2[N]; 14 int sum[N << 2], lc[N << 2], rc[N << 2], tag[N << 2]; 15 int imp[N], K, stk[N], Top, RT, Time, E[N], vis[N], use[N], DEEP[N]; 16 LL SIZ[N], ans[N], D[N]; 17 18 inline void add(int x, int y) { 19 tp++; 20 edge[tp].v = y; 21 edge[tp].nex = e[x]; 22 e[x] = tp; 23 return; 24 } 25 26 /// ------------------- tree 1 ------------------------- 27 28 void DFS_1(int x, int f) { /// get fa son siz d 29 fa[x] = f; 30 siz[x] = 1; 31 d[x] = d[f] + 1; 32 forson(x, i) { 33 int y = edge[i].v; 34 if(y == f) continue; 35 DFS_1(y, x); 36 siz[x] += siz[y]; 37 if(siz[y] > siz[son[x]]) { 38 son[x] = y; 39 } 40 } 41 return; 42 } 43 44 void DFS_2(int x, int f) { /// get top pos id 45 top[x] = f; 46 pos[x] = ++num; 47 id[num] = x; 48 if(son[x]) DFS_2(son[x], f); 49 forson(x, i) { 50 int y = edge[i].v; 51 if(y == fa[x] || y == son[x]) continue; 52 DFS_2(y, y); 53 } 54 return; 55 } 56 57 /// ------------------ seg 1 ---------------------- 58 59 #define ls (o << 1) 60 #define rs (o << 1 | 1) 61 62 inline void pushup(int o) { 63 lc[o] = lc[ls]; 64 rc[o] = rc[rs]; 65 sum[o] = sum[ls] + sum[rs] + (rc[ls] != lc[rs]); 66 return; 67 } 68 69 inline void pushdown(int o) { 70 if(tag[o] != -1) { 71 lc[ls] = rc[ls] = tag[ls] = tag[o]; 72 lc[rs] = rc[rs] = tag[rs] = tag[o]; 73 sum[ls] = sum[rs] = 0; 74 tag[o] = -1; 75 } 76 return; 77 } 78 79 #undef ls 80 #undef rs 81 82 void build(int l, int r, int o) { 83 if(l == r) { 84 lc[o] = rc[o] = val[id[r]]; 85 sum[o] = 0; 86 return; 87 } 88 int mid = (l + r) >> 1; 89 build(l, mid, o << 1); 90 build(mid + 1, r, o << 1 | 1); 91 pushup(o); 92 return; 93 } 94 95 void change(int L, int R, int v, int l, int r, int o) { 96 if(L <= l && r <= R) { 97 lc[o] = rc[o] = tag[o] = v; 98 sum[o] = 0; 99 return; 100 } 101 int mid = (l + r) >> 1; 102 pushdown(o); 103 if(L <= mid) change(L, R, v, l, mid, o << 1); 104 if(mid < R) change(L, R, v, mid + 1, r, o << 1 | 1); 105 pushup(o); 106 return; 107 } 108 109 int ask(int p, int l, int r, int o) { 110 if(l == r) return lc[o]; 111 int mid = (l + r) >> 1; 112 pushdown(o); 113 if(p <= mid) return ask(p, l, mid, o << 1); 114 else return ask(p, mid + 1, r, o << 1 | 1); 115 } 116 117 int getSum(int L, int R, int l, int r, int o) { 118 if(L <= l && r <= R) { 119 return sum[o]; 120 } 121 pushdown(o); 122 int mid = (l + r) >> 1; 123 if(R <= mid) return getSum(L, R, l, mid, o << 1); 124 if(mid < L) return getSum(L, R, mid + 1, r, o << 1 | 1); 125 return getSum(L, R, l, mid, o << 1) + getSum(L, R, mid + 1, r, o << 1 | 1) + (rc[o << 1] != lc[o << 1 | 1]); 126 } 127 128 inline int lca(int x, int y) { 129 while(top[x] != top[y]) { 130 if(d[top[x]] < d[top[y]]) 131 y = fa[top[y]]; 132 else 133 x = fa[top[x]]; 134 } 135 return d[x] < d[y] ? x : y; 136 } 137 138 inline int getLen(int x, int z) { 139 //printf("getLen %d %d ", x, z); 140 int col = ask(pos[x], 1, n, 1), ans = 0; 141 while(top[x] != top[z]) { 142 ans += (col != ask(pos[x], 1, n, 1)); 143 ans += getSum(pos[top[x]], pos[x], 1, n, 1); 144 //printf("x = %d top[x] = %d col = %d ans = %d ", x, top[x], col, ans); 145 col = ask(pos[top[x]], 1, n, 1); 146 x = fa[top[x]]; 147 } 148 ans += (col != ask(pos[x], 1, n, 1)); 149 //printf("%d != %d ", col, ask(pos[x], 1, n, 1)); 150 ans += getSum(pos[z], pos[x], 1, n, 1); 151 //printf("return ans = %d ", ans); 152 return ans; 153 } 154 155 inline void Change(int x, int y, int v) { 156 while(top[x] != top[y]) { 157 if(d[top[x]] > d[top[y]]) { 158 change(pos[top[x]], pos[x], v, 1, n, 1); 159 x = fa[top[x]]; 160 } 161 else { 162 change(pos[top[y]], pos[y], v, 1, n, 1); 163 y = fa[top[y]]; 164 } 165 } 166 if(d[x] < d[y]) std::swap(x, y); 167 change(pos[y], pos[x], v, 1, n, 1); 168 return; 169 } 170 171 /// ------------------- tree 2 ---------------------- 172 173 inline void ADD(int x, int y) { 174 TP++; 175 EDGE[TP].v = y; 176 EDGE[TP].len = getLen(y, x); 177 //printf("getLen %d %d = %d ", y, x, EDGE[TP].len); 178 EDGE[TP].nex = E[x]; 179 E[x] = TP; 180 return; 181 } 182 183 inline bool cmp(const int &a, const int &b) { 184 return pos[a] < pos[b]; 185 } 186 187 inline void work(int x) { 188 if(vis[x] == Time) return; 189 vis[x] = Time; 190 D[x] = E[x] = 0; 191 return; 192 } 193 194 inline void build_t() { 195 TP = 0; 196 memcpy(imp + 1, imp2 + 1, K * sizeof(int)); 197 std::sort(imp + 1, imp + K + 1, cmp); 198 stk[Top = 1] = imp[1]; 199 work(imp[1]); 200 for(int i = 2; i <= K; i++) { 201 int x = imp[i], y = lca(x, stk[Top]); 202 work(x); work(y); 203 while(Top > 1 && d[y] <= d[stk[Top - 1]]) { 204 ADD(stk[Top - 1], stk[Top]); 205 Top--; 206 } 207 if(y != stk[Top]) { 208 ADD(y, stk[Top]); 209 stk[Top] = y; 210 } 211 stk[++Top] = x; 212 } 213 while(Top > 1) { 214 ADD(stk[Top - 1], stk[Top]); 215 Top--; 216 } 217 RT = stk[Top]; 218 return; 219 } 220 221 void dfs_1(int x) { /// DP 1 222 SIZ[x] = (use[x] == Time); 223 for(int i = E[x]; i; i = EDGE[i].nex) { 224 int y = EDGE[i].v; 225 dfs_1(y); 226 SIZ[x] += SIZ[y]; 227 } 228 return; 229 } 230 231 void dfs_2(int x) { /// DP 2 232 if(use[x] == Time) { 233 ans[x] = D[x]; 234 } 235 for(int i = E[x]; i; i = EDGE[i].nex) { 236 int y = EDGE[i].v; 237 D[y] = D[x] + SIZ[y] * EDGE[i].len; 238 DEEP[y] = DEEP[x] + EDGE[i].len; 239 //printf("dfs_2 D %d = %lld * %lld = %lld ", y, SIZ[y], EDGE[i].len, D[y]); 240 dfs_2(y); 241 } 242 return; 243 } 244 245 inline void cal() { 246 build_t(); 247 dfs_1(RT); 248 DEEP[RT] = 0; 249 dfs_2(RT); 250 return; 251 } 252 253 int main() { 254 memset(tag, -1, sizeof(tag)); 255 int q; 256 scanf("%d%d", &n, &q); 257 for(int i = 1; i <= n; i++) { 258 scanf("%d", &val[i]); 259 } 260 for(int i = 1, x, y; i < n; i++) { 261 scanf("%d%d", &x, &y); 262 add(x, y); add(y, x); 263 } 264 DFS_1(1, 0); 265 DFS_2(1, 1); 266 build(1, n, 1); 267 268 for(int i = 1, f, x, y, z; i <= q; i++) { 269 scanf("%d%d", &f, &x); 270 if(f == 1) { 271 scanf("%d%d", &y, &z); 272 Change(x, y, z); 273 } 274 else { 275 Time++; 276 K = x; 277 for(int j = 1; j <= K; j++) { 278 scanf("%d", &imp2[j]); 279 use[imp2[j]] = Time; 280 } 281 cal(); 282 LL SUM = 0; 283 for(int i = 1; i <= K; i++) { 284 SUM += DEEP[imp2[i]]; 285 //printf("D %d = %lld ", imp2[i], D[imp2[i]]); 286 } 287 //printf("SUM = %lld ", SUM); 288 for(int i = 1; i <= K; i++) { 289 printf("%lld ", SUM + K * DEEP[imp2[i]] - 2 * ans[imp2[i]] + K); 290 } 291 puts(""); 292 } 293 } 294 return 0; 295 }