题解:
和在线的边分治差不多。 就是将每层都信息都存下来。
然后对于每一层记录上一层的重心是哪个。
对于求和的话, 从自己的那层出发,然后暴力往上爬, 然后计算答案。
对于修改来说,也暴力的往上爬,对于每层所对应的信息来修改
用树状数组来统计同一层、不同深度的前缀和。
本来想用线段树,然后TLE了,非常卡。
然后用另一颗树状数组来容斥前缀和。
代码:
#include<bits/stdc++.h> using namespace std; #define Fopen freopen("_in.txt","r",stdin); freopen("_out.txt","w",stdout); #define LL long long #define ULL unsigned LL #define fi first #define se second #define pb push_back #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 #define lch(x) tr[x].son[0] #define rch(x) tr[x].son[1] #define max3(a,b,c) max(a,max(b,c)) #define min3(a,b,c) min(a,min(b,c)) typedef pair<int,int> pll; const int inf = 0x3f3f3f3f; const int _inf = 0xc0c0c0c0; const LL INF = 0x3f3f3f3f3f3f3f3f; const LL _INF = 0xc0c0c0c0c0c0c0c0; const LL mod = (int)1e9+7; const int N = 5e5; int head[N], to[N<<1], nt[N<<1], tot, dtot; int in[N], out[N], dfn[N], deep[N]; int Log[N]; struct ST { int dp[N][20], a[N]; void init(int n) { for(int i = -(Log[0]=-1); i <= n; i++) Log[i] = Log[i - 1] + ((i & (i - 1)) == 0); for(int i = 1; i <= n; ++i) dp[i][0] = a[i]; for(int j = 1; j <= Log[n]; j++) for(int i = 1; i+(1<<j)-1 <= n; i++){ int x = dp[i][j-1], y = dp[i+(1<<(j-1))][j-1]; if(deep[x] < deep[y]) dp[i][j] = x; else dp[i][j] = y; } } inline int lca(int l, int r) { l = in[l], r = in[r]; if(l > r) swap(l, r); int k = Log[r-l + 1]; if(deep[dp[l][k]] < deep[dp[r-(1<<k)+1][k]]) return deep[dp[l][k]]; return dp[r-(1<<k)+1][k]; } }st; inline void add(int u, int v){ to[tot] = v; nt[tot] = head[u]; head[u] = tot++; } void dfs(int o, int u){ in[u] = ++dtot; st.a[dtot] = u; deep[u] = deep[o] + 1; for(int i = head[u]; ~i; i = nt[i]){ int v = to[i]; if(o ^ v){ dfs(u, v); st.a[++dtot] = u; } } } inline int dis(int u, int v){ return deep[u]+deep[v]-2*deep[st.lca(u, v)]; } int w[N]; int a[N]; int atot; int Prt[N][2], Pr[N][2]; vector<int> vc[N][2]; int vis[N], tsz; int sz[N]; int rt, rtnum; int fa[N]; void Get_root(int o, int u){ sz[u] = 1; int Max = 0; for(int i = head[u]; ~i; i = nt[i]){ int v = to[i]; if(vis[v] || v == o) continue; Get_root(u, v); Max = max(Max, sz[v]); sz[u] += sz[v]; } Max = max(Max, tsz - sz[u]); if(Max < rtnum){ rt = u; rtnum = Max; } } int Maxdep; void dfs(int g, int o, int u){ sz[u] = 1; Maxdep = max(Maxdep, dis(g, u)); for(int i = head[u]; ~i; i = nt[i]){ int v = to[i]; if(vis[v] || v == o) continue; dfs(g, u, v); sz[u] += sz[v]; } } void Add(int g, int o, int u){ a[dis(g, u)] += w[u]; for(int i = head[u]; ~i ; i = nt[i]){ int v = to[i]; if(vis[v] || v == o) continue; Add(g, u, v); } } void Run(vector<int> & v){ for(int i = 0; i <= Maxdep; ++i) v.pb(a[i]); for(int i = Maxdep; i >= 1; --i){ int j = i; j += i & (-i); while(j <= Maxdep){ v[j] += a[i]; j += j & (-j); } } } void solve(int o, int u, int num){ tsz = num; rtnum = num + 1; Get_root(0, u); fa[rt] = o; vis[rt] = 1; Maxdep = 0; dfs(rt, 0, rt);/// Find_Max_Deep for(int i = 0; i <= Maxdep; ++i) a[i] = 0; Add(rt, 0, rt); Run(vc[rt][0]); if(o){ Maxdep = 0; dfs(o, 0, rt); for(int i = 0; i <= Maxdep; ++i) a[i] = 0; Add(o, 0, rt); Run(vc[rt][1]); } int nrt = rt; for(int i = head[nrt]; ~i; i = nt[i]){ int v = to[i]; if(vis[v]) continue; solve(nrt, v, sz[v]); } } void Updata(vector<int> &v, int x, int val){ int lim = v.size(); if(x < 0) return ; if(x == 0){ v[0] += val; return ; } while(x < lim){ v[x] += val; x += x & (-x); } } int Query(vector<int> &v, int k){ int lim = v.size(); k = min(k, lim-1); if(k < 0) return 0; int ret = v[0]; while(k > 0){ ret += v[k]; k -= k & (-k); } return ret; } void Updata(int x, int val){ for(int i = x; i; i = fa[i]){ Updata(vc[i][0], dis(x,i), val); if(fa[i]) Updata(vc[i][1], dis(fa[i], x), val); } } int Query(int x, int k){ int ret = 0; for(int i = x; i; i = fa[i]){ ret += Query(vc[i][0], k-dis(x,i)); if(fa[i]) ret -= Query(vc[i][1], k-dis(fa[i], x)); } return ret; } struct FastIO { static const int S = 1310720; int wpos; char wbuf[S]; FastIO() : wpos(0) { } inline int xchar() { static char buf[S]; static int len = 0, pos = 0; if (pos == len) pos = 0, len = fread(buf, 1, S, stdin); if (pos == len) return -1; return buf[pos++]; } inline int xint() { int c = xchar(), x = 0, s = 1; while (c <= 32) c = xchar(); if (c == '-') s = -1, c = xchar(); for (; '0' <= c && c <= '9'; c = xchar()) x = x * 10 + c - '0'; return x * s; } ~FastIO() { if (wpos) fwrite(wbuf, 1, wpos, stdout), wpos = 0; } } io; int main(){ int n, m; scanf("%d%d", &n, &m); for(int i = 1; i <= n; ++i){ w[i] = io.xint(); } memset(head, -1, sizeof head); for(int i = 1, u, v; i < n; ++i){ u = io.xint(), v = io.xint(); add(u, v); add(v, u); } dfs(0, 1); st.init(dtot); solve(0, 1, n); int lastans = 0; int op, x, y; for(int i = 1; i <= m; ++i){ op = io.xint(), x = io.xint(), y = io.xint(); x ^= lastans, y ^= lastans; if(op){ Updata(x, y-w[x]); w[x] = y; } else{ lastans = Query(x,y); printf("%d ", lastans); } } return 0; }