题目描述
`LYK` 在森林里找到了一棵树。这棵树非常神奇,每条边都有其边权,每个点也有它的点权 $a_i$ 。我们令 $dis(i,j)$ 表示点 $i$ 与点 $j$ 之间的最短路的距离。
`LYK` 每次选择树上两个点 $x,y(x<y)$ ,它将会得到 $(a_x wedge a_y)dis(x,y)$ 的金币(其中 $wedge$ 表示异或)。它认为这个值太容易计算了,于是它想把所有关于 $x,y$ 的点对( $frac{n(n-1)}{2}$ 对)能得到的金币都求出来,并计算它能得到的金币总和。
这时人群中跳出一个熊孩子来搞破坏,它每次会修改 `LYK` 的树中的某个点权。 `LYK` 想知道每次被搞破坏后它能得到的金币总和。
题解
考虑拆位,即我们计算每一位 $x$ 为 $0$ 且 $y$ 为 $1$ 的 $dis(x,y)$ 的总和。
看到 $dis$ 可以想到点分治,题目又有点类似强制在线,所以我们考虑动态点分,每个点分中心维护每一位为 $0/1$ 的个数和 $dis$ 的总和,记得在上一级的点分中心更新的时候要在这一级上进行容斥即可。
效率: $O(14(n+q)logn)$ 。
代码
#include <bits/stdc++.h> #define LL long long using namespace std; const int N=30005,N2=N<<1; int n,a[N],hd[N],V[N2],nx[N2],W[N2],sz[N],son[N]; int f[N2][16],Lg[N2],c,d[N],e[N],up[N],o,rt,t,q; bool vis[N];LL s1[2][14][N2],s2[2][14][N2],ans; void add(int u,int v,int w){ nx[++t]=hd[u];V[hd[u]=t]=v;W[t]=w; } void Sz(int u,int fr){ sz[u]=1; for (int v,i=hd[u];i;i=nx[i]) if (!vis[v=V[i]] && V[i]!=fr) Sz(v,u),sz[u]+=sz[v]; } void Rt(int u,int fr){ son[u]=o-sz[u]; for (int v,i=hd[u];i;i=nx[i]) if (!vis[v=V[i]] && V[i]!=fr) Rt(v,u),son[u]=max(son[u],sz[v]); if (son[rt]>son[u]) rt=u; } void work(int u,int fr){ Sz(u,0);o=sz[u];rt=0; Rt(u,0);up[rt]=fr;vis[fr=rt]=1; for (int i=hd[rt];i;i=nx[i]) if (!vis[V[i]]) work(V[i],fr); } void dfs(int u,int fr){ f[e[u]=++c][0]=d[u]; for (int i=hd[u],v;i;i=nx[i]) if ((v=V[i])!=fr) d[v]=d[u]+W[i], dfs(v,u),f[++c][0]=d[u]; } int lca(int l,int r){ if (l>r) swap(l,r);int i=Lg[r-l+1]; return min(f[l][i],f[r-(1<<i)+1][i]); } int dis(int u,int v){ return d[u]+d[v]-(lca(e[u],e[v])<<1); } void Upd(int x,int u,int v){ int y=dis(x,u>n?up[u-n]:u); for (int i=0;i<14;i++) s1[(a[x]>>i)&1][i][u]+=(1ll<<i)*v*y, s2[(a[x]>>i)&1][i][u]+=v; } void upd(int u,int v){ Upd(u,u,v); for (int x=u;up[u];u=up[u]) Upd(x,up[u],v),Upd(x,u+n,-v); } void Qry(int x,int u,int v){ int y=dis(x,u>n?up[u-n]:u); for (int w,i=0;i<14;i++) w=(a[x]>>i)&1, ans+=(1ll<<i)*v*s2[!w][i][u]*y+s1[!w][i][u]*v; } void qry(int u,int v){ Qry(u,u,v); for (int x=u;up[u];u=up[u]) Qry(x,up[u],v),Qry(x,u+n,v); } int main(){ cin>>n;son[0]=1e9; for (int i=1;i<=n;i++) scanf("%d",&a[i]); for (int u,v,w,i=1;i<n;i++) scanf("%d%d%d",&u,&v,&w), add(u,v,w),add(v,u,w); work(1,0);dfs(1,0); for (int i=2;i<=c;i++) Lg[i]=Lg[i>>1]+1; for (int i=c;i;i--) for (int j=1;i+(1<<j)<=c+1;j++) f[i][j]=min(f[i][j-1],f[i+(1<<(j-1))][j-1]); for (int i=1;i<=n;i++) qry(i,1),upd(i,1);cin>>q; for (int x,y;q--;) scanf("%d%d",&x,&y),upd(x,-1),qry(x,-1), a[x]=y,qry(x,1),upd(x,1),printf("%lld ",ans); return 0; }