推荐博客: http://www.cnblogs.com/Mychael/p/9257242.html
感觉还挺好玩的
首先考虑以1为根,把每一个点子树的权值和都算出来,记为$val_{i}$,那么在所有操作都没有开始的时候(以$1$为根的)$ans_{1} = sum_{i= 1}^{n}val_{i}^{2}$
考虑到一个修改的操作只会对修改的点$x$到根($1$)链上的点产生影响,那么一次修改只要修对这条树链上的点增加$v - a_{x}$(假设修改后的值为$v$)就好了。
链剖之后线段树维护一下$val_{i}$,区间修改就很简单。
然后考虑换根:
我们发现当以$x$为根的时候,$x$原来的子树显然不会受到影响,而变化了的是原来的根$1$到$x$的链上的点,不妨设有$k$个结点,换根前(以$1$为根)的每个结点子树$val$值和为$a_{i}$,换根后(以$x$为根)的每个结点子树$val$值和为$b_{i}$
有一条显然的性质:$a_{i + 1} + b_{i} = a_{1} = b_{k}$都等于原来全部结点的$val$值和
那么换根之后的答案 $ans_{x} = ans_{1} - sum_{i = 1}^{k}a_{i}^{2} + sum_{i = 1}^{k}b_{i}^{2}$
代入上面的那条性质消掉$b$,发现$ans_{x} = ans_{1} + (k - 1)a_{1}^{2} - 2a_{1}sum_{i = 2}^{k}a_{i}$
设$s_{i}$表示$i$的子树中所有$val$值和,那么$ans_{x} = ans_{1} + s_{1}((k + 1) s_{1} - 2sum_{i = 1}^{k}s_{i})$。
容易发现这个$k$即为$dep_{x}$,而这个$sum_{i = 1}^{k}s_{i}$ 和 $s_{1}$显然可以用线段树维护出来。
考虑一下, 一次修改还会对$ans_{1}$产生影响,$ans_{1} += sum_{i = 1}^{tot}(val_{i}+ Delta v)^{2} - sum_{i = 1}^{tot}val_{i}^{2} = totDelta v^{2} + 2Delta vsum_{i = 1}^{tot}val_{i}$。
因为每次发生变化的只有一条树链上的点,所以$tot = dep_{x}$,这个原来的$sum_{i = 1}^{tot}val_{i}$可以在跳轻重链的过程中算出来。
时间复杂度$O(nlog^{2}n)$。
Code:
#include <cstdio> #include <cstring> using namespace std; typedef long long ll; const int N = 2e5 + 5; int n, qn, dfsc = 0, dep[N], siz[N], id[N]; int tot = 0, head[N], top[N], fa[N], son[N]; ll a[N], ans = 0LL, nowSum = 0LL, w[N], val[N]; struct Edge { int to, nxt; } e[N << 1]; inline void add(int from, int to) { e[++tot].to = to; e[tot].nxt = head[from]; head[from] = tot; } template <typename T> inline void read(T &X) { X = 0; char ch = 0; T op = 1; for(; ch > '9'|| ch < '0'; ch = getchar()) if(ch == '-') op = -1; for(; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } void dfs1(int x, int fat, int depth) { siz[x] = 1, fa[x] = fat, dep[x] = depth, val[x] = a[x]; int maxson = -1; for(int i = head[x]; i; i = e[i].nxt) { int y = e[i].to; if(y == fat) continue; dfs1(y, x, depth + 1); siz[x] += siz[y], val[x] += val[y]; if(siz[y] > maxson) maxson = siz[y], son[x] = y; } } void dfs2(int x, int topf) { w[id[x] = ++dfsc] = val[x], top[x] = topf; if(!son[x]) return; dfs2(son[x], topf); for(int i = head[x]; i; i = e[i].nxt) { int y = e[i].to; if(y == fa[x] || y == son[x]) continue; dfs2(y, y); } } namespace SegT { ll s[N << 2], tag[N << 2]; #define lc p << 1 #define rc p << 1 | 1 #define mid ((l + r) >> 1) inline void up(int p) { if(p) s[p] = s[lc] + s[rc]; } inline void down(int p, int l, int r) { if(!tag[p]) return; s[lc] += 1LL * (mid - l + 1) * tag[p]; s[rc] += 1LL * (r - mid) * tag[p]; tag[lc] += tag[p], tag[rc] += tag[p]; tag[p] = 0LL; } void build(int p, int l, int r) { tag[p] = 0LL; if(l == r) { s[p] = w[l]; return; } build(lc, l, mid); build(rc, mid + 1, r); up(p); } void modify(int p, int l, int r, int x, int y, ll v) { if(x <= l && y >= r) { s[p] += 1LL * (r - l + 1) * v; tag[p] += v; return; } down(p, l, r); if(x <= mid) modify(lc, l, mid, x, y, v); if(y > mid) modify(rc, mid + 1, r, x, y, v); up(p); } ll qSum(int p, int l, int r, int x, int y) { if(x <= l && y >= r) return s[p]; down(p, l, r); ll res = 0LL; if(x <= mid) res += qSum(lc, l, mid, x, y); if(y > mid) res += qSum(rc, mid + 1, r, x, y); return res; } } using namespace SegT; inline void mTree(int x) { ll v, sum = 0LL, len = (ll)dep[x]; read(v); v -= a[x], a[x] += v; for(; x != 0; x = fa[top[x]]) { sum += qSum(1, 1, n, id[top[x]], id[x]); modify(1, 1, n, id[top[x]], id[x], v); } ans += 2LL * v * sum + 1LL * v * v * len; nowSum += v; } inline ll qTree(int x) { ll res = 0LL; for(; x != 0; x = fa[top[x]]) res += qSum(1, 1, n, id[top[x]], id[x]); return res; } inline void solve(int x) { ll k = (ll)dep[x], sum = qTree(x); printf("%lld ", ans + nowSum * ((k + 1) * nowSum - 2 * sum)); } int main() { read(n), read(qn); for(int x, y, i = 1; i < n; i++) { read(x), read(y); add(x, y), add(y, x); } for(int i = 1; i <= n; i++) read(a[i]); dfs1(1, 0, 1); dfs2(1, 1); build(1, 1, n); /* for(int i = 1; i <= n; i++) printf("%d ", dep[i]); printf(" "); for(int i = 1; i <= n; i++) printf("%d ", top[i]); printf(" "); for(int i = 1; i <= n; i++) printf("%d ", w[i]); printf(" "); */ for(int i = 1; i <= n; i++) { nowSum += a[i]; ans += val[i] * val[i]; } // printf("%lld ", ans); for(int op, x; qn--; ) { read(op), read(x); if(op == 1) mTree(x); else solve(x); } return 0; }