这大概是NOIP前最后一道题了,顺便复习一些模板。
考虑到不是真正要进行“换根”这个操作,我们发现在查询子树和修改子树的时候只要按照遥远的国度这题的方法分类讨论一下就好了,我们考虑一下如何换根意义下的$lca$。
先分类讨论,假设当前结点是$(u, v)$根是$rt$。
1、$u, v$都在$rt$的子树内,那么$lca$不变。
2、$u, v$有一个在$rt$的子树内(假设为$u$),那么$lca$为$(lca, v)$。
3、$u, v$都不在$rt$的子树内,那么$lca$是$lca(u, rt)$和$lca(v, rt)$中最接近$rt$的那一个,也就是深度较大的那一个。
可以画个图脑补一下。
综合一下发现只要找到$lca(u, v)$和$lca(u, rt)$和$lca(v, rt)$中深度最大的就可以了。
时间复杂度$O(nlogn)$。
练习了一下树状数组维护区间的做法,算了一下似乎要爆的样子。
Code:
#include <cstdio> #include <cstring> using namespace std; typedef long long ll; const int N = 1e5 + 5; const int Lg = 20; int n, qn, tot = 0, head[N]; int dfsc = 0, id[N], siz[N], fa[N][Lg], dep[N]; ll a[N], w[N]; struct Edge { int to, nxt; } e[N << 1]; inline void add(int from, int to) { e[++tot].to = to; e[tot].nxt = head[from]; head[from] = tot; } template <typename T> inline void read(T &X) { X = 0; char ch = 0; T op = 1; for(; ch > '9' || ch < '0'; ch = getchar()) if(ch == '-') op = -1; for(; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } template <typename T> inline void swap(T &x, T &y) { T t = x; x = y; y = t; } void dfs(int x, int fat, int depth) { fa[x][0] = fat, dep[x] = depth; w[id[x] = ++dfsc] = a[x], siz[x] = 1; for(int i = 1; i <= 18; i++) fa[x][i] = fa[fa[x][i - 1]][i - 1]; for(int i = head[x]; i; i = e[i].nxt) { int y = e[i].to; if(y == fat) continue; dfs(y, x, depth + 1); siz[x] += siz[y]; } } inline int getLca(int x, int y) { if(dep[x] < dep[y]) swap(x, y); for(int i = 18; i >= 0; i--) if(dep[fa[x][i]] >= dep[y]) x = fa[x][i]; if(x == y) return x; for(int i = 18; i >= 0; i--) if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i]; return fa[x][0]; } inline int getPos(int x, int stp) { int res = x; for(int i = 18; i >= 0; i--) if((stp >> i) & 1) res = fa[res][i]; return res; } namespace Bit { ll s1[N], s2[N]; #define lowbit(p) (p & (-p)) inline void modify(int p, ll v) { for(int i = p; i <= n; i += lowbit(i)) { s1[i] += v; s2[i] += v * p; } } inline ll query(int p) { ll res1 = 0LL, res2 = 0LL; for(int i = p; i > 0; i -= lowbit(i)) { res1 += s1[i]; res2 += s2[i]; } return res1 * (p + 1) - res2; } inline ll getSum(int l, int r) { return query(r) - query(l - 1); } } using namespace Bit; int main() { read(n), read(qn); for(int i = 1; i <= n; i++) read(a[i]); for(int x, y, i = 1; i < n; i++) { read(x), read(y); add(x, y), add(y, x); } int rt; dfs(rt = 1, 0, 1); for(int i = 1; i <= n; i++) modify(i, w[i] - w[i - 1]); for(int op; qn--; ) { read(op); if(op == 1) read(rt); if(op == 2) { int x, y, z = 0; ll v; read(x), read(y), read(v); int tmp = getLca(x, y); if(dep[tmp] > dep[z]) z = tmp; tmp = getLca(x, rt); if(dep[tmp] > dep[z]) z = tmp; tmp = getLca(y, rt); if(dep[tmp] > dep[z]) z = tmp; if(z == rt) { modify(1, v); continue; } if(id[z] <= id[rt] && id[rt] <= id[z] + siz[z] - 1) { int pos = getPos(rt, dep[rt] - dep[z] - 1); modify(1, v), modify(id[pos], -v), modify(id[pos] + siz[pos], v); } else modify(id[z], v), modify(id[z] + siz[z], -v); } if(op == 3) { int x; read(x); if(x == rt) { printf("%lld ", getSum(1, n)); continue; } if(id[x] <= id[rt] && id[rt] <= id[x] + siz[x] - 1) { int y = getPos(rt, dep[rt] - dep[x] - 1); printf("%lld ", getSum(1, n) - getSum(id[y], id[y] + siz[y] - 1)); } else printf("%lld ", getSum(id[x], id[x] + siz[x] - 1)); } } return 0; }