说实话这题写树剖 $LCT$ 什么的真的思想又不难又好实现的样子,但是我还是选择自虐选择了动态点分治
那就两种做法都稍微提一下:
树链剖分 / $LCT$
很容易可以发现一个换根操作只会对当前根在原树(根为 $1$ )上的祖先一条链造成影响,也就是将它们的子树变成除当前链方向其它与之相连的点集,那么用树剖跳,用线段树维护一下原树上从上面来的和从下面来的,再将所有涉及的节点合并,并且删去算重复的和不该算的即可(虽然我没实现但是这个思路应该是对的
动态点分治
因为我们要算 $sumlimits_{i = 1}^n s_i^2$ ,可以先想一下 $sumlimits_{i = 1}^n s_i$ 怎么算,因为每个点的贡献只会被祖先计算到,那么易知 $sumlimits_{i = 1}^n s_i = sumlimits_{i = 1}^n value_i * (depth_i + 1)$ ,这个直接用动态点分治维护三个变量 $sumo_i, sumt_i, sumfa_i$ (sumo -> $p$ 子节点权值之和, $sumt$ -> 子节点权值与距离的乘积到 $p$ 之和, $sumfa$ -> 子节点权值与距离的乘积到 $fa$ (点分树上)之和)即可得到
接下来有个很重要的结论(反正我是肯定想不到
- 不论根如何换, $sumlimits_{i = 1}^n s_i (sum - s_i)$ 一定是一个定值(说实话一开始一直想着如何化简 $sumlimits_{i = 1}^n s_i$ ,所以是真的没有想到可以通过构造定值的方法来解出 $sumlimits_{i = 1}^n s_i^2$ )
先来意会一下这个结论:就是每一条边连接的两个点在他们的子树中各自选两个点让它们权值相乘,求总权值
那么就比较容易知道证明了:每条边的边权为所有对应路径经过这条边的两个点的权值和,求总边权,即求的是 $sumlimits_{i = 1}^n sumlimits_{j = 1}^n value_i * value_j * dist (i, j)$ ,所以有 $sumlimits_{i = 1}^n s_i (sum - s_i) = sumlimits_{i = 1}^n sumlimits_{j = 1}^n value_i * value_j * dist (i, j)$
因为该式是个定值,所以求出 $sumlimits_{i = 1}^n s_i$ 后直接解出 $Ans$ 就完成了查询操作
对于修改操作,若修改完后与原权值的差值为 $Delta value$ ,那么 $Delta total = Delta value sumlimits_{j = 1}^n value_j * dist (p, j)$ ( $p$ 为修改点)
代码
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 #include <algorithm> 5 #include <cmath> 6 7 using namespace std; 8 9 typedef long long LL; 10 11 const int MAXN = 2e05 + 10; 12 const int MAXM = 2e05 + 10; 13 14 const int INF = 0x7fffffff; 15 16 struct LinkedForwardStar { 17 int to; 18 19 int next; 20 } ; 21 22 LinkedForwardStar Link[MAXM << 1]; 23 int Head[MAXN]= {0}; 24 int size = 0; 25 26 void Insert (int u, int v) { 27 Link[++ size].to = v; 28 Link[size].next = Head[u]; 29 30 Head[u] = size; 31 } 32 33 int N, Q; 34 int value[MAXN]; 35 36 LL Deep[MAXN]= {0}; 37 int Dfn[MAXN]= {0}; 38 int val[MAXN << 1]= {0}, belong[MAXN << 1]= {0}; 39 int dfsord = 0; 40 LL sum = 0; 41 LL subtree[MAXN]= {0}; 42 LL s1 = 0; 43 void DFS (int root, int fa) { 44 Dfn[root] = ++ dfsord; 45 val[dfsord] = Deep[root], belong[dfsord] = root; 46 sum += value[root]; 47 subtree[root] = value[root]; 48 for (int i = Head[root]; i; i = Link[i].next) { 49 int v = Link[i].to; 50 if (v == fa) 51 continue; 52 Deep[v] = Deep[root] + 1; 53 DFS (v, root); 54 val[++ dfsord] = Deep[root], belong[dfsord] = root; 55 subtree[root] += subtree[v]; 56 } 57 } 58 pair<int, int> ST[MAXN << 1][25]; 59 void RMQ () { 60 for (int i = 1; i <= dfsord; i ++) 61 ST[i][0] = make_pair (val[i], belong[i]); 62 for (int j = 1; j <= 20; j ++) 63 for (int i = 1; i + (1 << j) - 1 <= dfsord; i ++) 64 ST[i][j] = ST[i][j - 1].first < ST[i + (1 << (j - 1))][j - 1].first ? ST[i][j - 1] : ST[i + (1 << (j - 1))][j - 1]; 65 } 66 int LCA (int x, int y) { 67 int L = Dfn[x], R = Dfn[y]; 68 if (L > R) 69 swap (L, R); 70 int k = log2 (R - L + 1); 71 return ST[L][k].first < ST[R - (1 << k) + 1][k].first ? ST[L][k].second : ST[R - (1 << k) + 1][k].second; 72 } 73 LL dist (int x, int y) { 74 int lca = LCA (x, y); 75 return Deep[x] + Deep[y] - (Deep[lca] << 1); 76 } 77 78 int father[MAXN]= {0}; 79 bool Vis[MAXN]= {false}; 80 int Size[MAXN]= {0}; 81 int minv = INF, grvy; 82 int total; 83 void Grvy_Acqu (int root, int fa) { 84 Size[root] = 1; 85 int maxpart = 0; 86 for (int i = Head[root]; i; i = Link[i].next) { 87 int v = Link[i].to; 88 if (v == fa || Vis[v]) 89 continue; 90 Grvy_Acqu (v, root); 91 Size[root] += Size[v]; 92 maxpart = max (maxpart, Size[v]); 93 } 94 maxpart = max (maxpart, total - Size[root]); 95 if (maxpart < minv) 96 minv = maxpart, grvy = root; 97 } 98 LL sumo[MAXN]= {0}, sumt[MAXN]= {0}, sumfa[MAXN]= {0}; 99 // sumo -> p子节点权值之和, sumt -> 子节点权值与距离的乘积到p之和, sumfa -> 子节点权值与距离的乘积到fa(点分树上)之和 100 void sums_Acqu (int root, int fa) { 101 sumo[grvy] += value[root], sumt[grvy] += value[root] * dist (root, grvy); 102 if (father[grvy]) 103 sumfa[grvy] += value[root] * dist (root, father[grvy]); 104 for (int i = Head[root]; i; i = Link[i].next) { 105 int v = Link[i].to; 106 if (v == fa || Vis[v]) 107 continue; 108 sums_Acqu (v, root); 109 } 110 } 111 void point_DAC (int p, int pre) { 112 minv = INF, grvy = p, total = Size[p]; 113 Grvy_Acqu (p, 0); 114 Vis[grvy] = true, father[grvy] = pre; 115 sums_Acqu (grvy, 0); 116 int fgrvy = grvy; 117 for (int i = Head[fgrvy]; i; i = Link[i].next) { 118 int v = Link[i].to; 119 if (Vis[v]) 120 continue; 121 point_DAC (v, fgrvy); 122 } 123 } 124 125 LL Query (int op) { 126 LL tsum = 0; 127 for (int p = op; p; p = father[p]) { 128 tsum += sumt[p]; 129 if (p != op) 130 tsum += sumo[p] * dist (p, op); 131 if (father[p]) 132 tsum -= sumo[p] * dist (father[p], op) + sumfa[p]; 133 } 134 return tsum; 135 } 136 void Modify (int op, int delta) { 137 for (int p = op; p; p = father[p]) { 138 sumo[p] -= value[op] - delta; 139 sumt[p] -= value[op] * dist (p, op) - delta * dist (p, op); 140 if (father[p]) 141 sumfa[p] -= value[op] * dist (father[p], op) - delta * dist (father[p], op); 142 } 143 sum -= value[op] - delta; 144 LL s = Query (op); 145 s1 += (delta - value[op]) * s; 146 value[op] = delta; 147 } 148 149 int getnum () { 150 int num = 0; 151 char ch = getchar (); 152 int isneg = 0; 153 154 while (! isdigit (ch)) { 155 if (ch == '-') 156 isneg = 1; 157 ch = getchar (); 158 } 159 while (isdigit (ch)) 160 num = (num << 3) + (num << 1) + ch - '0', ch = getchar (); 161 162 return isneg ? - num : num; 163 } 164 165 int main () { 166 N = getnum (), Q = getnum (); 167 for (int i = 1; i < N; i ++) { 168 int u = getnum (), v = getnum (); 169 Insert (u, v), Insert (v, u); 170 } 171 for (int i = 1; i <= N; i ++) 172 value[i] = getnum (); 173 DFS (1, 0), RMQ (); 174 for (int i = 1; i <= N; i ++) 175 s1 += subtree[i] * (sum - subtree[i]); 176 Size[1] = N, point_DAC (1, 0); 177 /*cout << "Next----------------------" << endl; 178 for (int i = 1; i <= N; i ++) 179 cout << sumo[i] << ' ' << sumt[i] << ' ' << sumfa[i] << endl; 180 cout << "End-----------------------" << endl;*/ 181 for (int Case = 1; Case <= Q; Case ++) { 182 int opt = getnum (); 183 if (opt == 1) { 184 int p = getnum (), delta = getnum (); 185 Modify (p, delta); 186 } 187 else if (opt == 2) { 188 int p = getnum (); 189 LL ans = (Query (p) + sum) * sum - s1; 190 printf ("%lld ", ans); 191 } 192 } 193 194 return 0; 195 } 196 197 /* 198 4 5 199 1 2 200 2 3 201 2 4 202 4 3 2 1 203 2 2 204 1 1 3 205 2 3 206 1 2 4 207 2 4 208 */ 209 210 /* 211 4 1 212 1 2 213 2 3 214 2 4 215 4 3 2 1 216 2 1 217 */