感觉这题做下来心态有点崩……$RMQ$求$LCA$没有树剖快我可以理解为是常数太大……然而我明明用了自以为不会退化的点分然而为什么比会退化的点分跑得反而更慢啊啊啊啊~~~
先膜一波zsy大佬
讲讲做法。题目的要求是给定一个根$p$,求$sum _{i=1}^ns_i^2$,其中$s_i$表示子树中的点权和
我们设$sum=sum _{i=1}^n val_i$,即整棵树的点权和。先考虑一下$sum _{i=1}^ns_i$怎么求。考虑一下每一个点的贡献,每一个点都会对被计算$dep_i+1$次(其中$dep_i$表示$dist(i,p)$),那么很显然$sum _{i=1}^ns_i=sum_{i=1}^nval_i*(dep_i+1)=sum_{i=1}^nval_i*dep_i+sum$。然后考虑一下$val_i*dep_i$,如何动态维护?->幻想乡战略游戏……
简单来说,就是建好点分树,然后每一次及时修改和查询
然后我们令$calc(p)$以$p$为根时的$sum _{i=1}^nval_i*dep_i$
然后考虑如下式子$$sum_{i=1}^nsum_{j=1}^nval_i*val_j*dist(i,j)$$
是不是可以理解为在所有的点对$(i,j)$之间的所有边上加上权值$val_i*val_j$(刚好有$dist(i,j)$条边),然后再求整棵树的权值?
然后我们考虑一下每条边的权值,肯定等于两侧的子树点权和的乘积。那么,不论是以哪一个点$p$为根,它的权值都等于$s_i*(sum-s_i)$,其中$s_i$表示这条边指向的儿子的子树的点权和
那么,上面的式子就可以变成这样$$sum_{i=1}^nsum_{j=1}^nval_i*val_j*dist(i,j)=sum_{i=1}^ns_i*(sum-s_i)$$
又因为上式左边是不变的,所以不管选取哪一个$p$为根,右边都是不变的
令$W=sum_{i=1}^ns_i*(sum-s_i)$,然后可以直接$O(n)dp$出$W$,然后考虑对点的修改对$W$造成的影响
$W=sum_{i=1}^nsum{j=1}^nval_i*val_j*dist(i,j)$,设点$u$的变化量为$Δv$,那么$ΔW=Δv*sum_{j=1}^nval_j*dist(i,j)$,相当于$Δv*calc(i)$,然后可以考虑和一般的动态点分一样计算
然后最后询问的答案就是$$W=sum_{i=1}^ns_i*(sum-s_i)$$
$$sum_{i=1}^ns_i^2=sum_{i=1}^ns_i*sum-W$$
$$sum_{i=1}^ns_i^2=sum(calc(i)+sum)-W$$
1 // luogu-judger-enable-o2 2 //minamoto 3 #include<iostream> 4 #include<cstdio> 5 #include<algorithm> 6 #define ll long long 7 using namespace std; 8 #define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++) 9 char buf[1<<21],*p1=buf,*p2=buf; 10 template<class T>inline bool cmax(T&a,const T&b){return a<b?a=b,1:0;} 11 inline int read(){ 12 #define num ch-'0' 13 char ch;bool flag=0;int res; 14 while(!isdigit(ch=getc())) 15 (ch=='-')&&(flag=true); 16 for(res=num;isdigit(ch=getc());res=res*10+num); 17 (flag)&&(res=-res); 18 #undef num 19 return res; 20 } 21 char sr[1<<21],z[20];int C=-1,Z; 22 inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;} 23 inline void print(ll x){ 24 if(C>1<<20)Ot();if(x<0)sr[++C]=45,x=-x; 25 while(z[++Z]=x%10+48,x/=10); 26 while(sr[++C]=z[Z],--Z);sr[++C]=' '; 27 } 28 const int N=200005; 29 int ver[N<<1],head[N],Next[N<<1]; 30 int val[N],fa[N],pa[N],d[N],sz[N],son[N],top[N]; 31 int n,q,tot; 32 void dfs1(int u,int fa){ 33 pa[u]=fa,d[u]=d[fa]+1,sz[u]=1; 34 for(int i=head[u];i;i=Next[i]){ 35 int v=ver[i];if(v==fa) continue; 36 dfs1(v,u); 37 sz[u]+=sz[v];if(sz[v]>sz[son[u]]) son[u]=v; 38 } 39 } 40 void dfs2(int u,int fa){ 41 top[u]=fa; 42 if(son[u]) dfs2(son[u],fa);else return; 43 for(int i=head[u];i;i=Next[i]) 44 if(ver[i]!=pa[u]&&ver[i]!=son[u]) 45 dfs2(ver[i],ver[i]); 46 } 47 int LCA(int u,int v){ 48 while(top[u]^top[v]){ 49 if(d[top[u]]<d[top[v]]) swap(u,v); 50 u=pa[top[u]]; 51 } 52 return d[u]<d[v]?u:v; 53 } 54 int dis(int u,int v){return d[u]+d[v]-(d[LCA(u,v)]<<1);} 55 int size,rt,vis[N]; 56 ll sum[N],sum1[N],sum2[N],sigma,omega,ans; 57 void getrt(int u,int fa){ 58 sz[u]=1,son[u]=0; 59 for(int i=head[u];i;i=Next[i]){ 60 int v=ver[i];if(v==fa||vis[v]) continue; 61 getrt(v,u),sz[u]+=sz[v],cmax(son[u],sz[v]); 62 } 63 cmax(son[u],size-sz[u]); 64 if(son[u]<son[rt]) rt=u; 65 } 66 void solve(int u,int f){ 67 fa[u]=f,vis[u]=1;int totsz=size; 68 for(int i=head[u];i;i=Next[i]){ 69 int v=ver[i]; 70 if(!vis[v]){ 71 size=sz[v]>sz[rt]?totsz-sz[rt]:sz[v]; 72 rt=0; 73 getrt(v,0),solve(rt,u); 74 } 75 } 76 } 77 inline void modify(int u,int v){ 78 sum[u]+=v; 79 for(int i=u;fa[i];i=fa[i]){ 80 int dist=dis(u,fa[i]); 81 sum[fa[i]]+=v; 82 sum1[fa[i]]+=dist*v; 83 sum2[i]+=dist*v; 84 } 85 } 86 inline ll calc(int u){ 87 ll res=sum1[u]; 88 for(int i=u;fa[i];i=fa[i]){ 89 int dist=dis(fa[i],u); 90 res+=(ll)dist*(sum[fa[i]]-sum[i]); 91 res+=sum1[fa[i]]-sum2[i]; 92 } 93 return res; 94 } 95 void DP(int u,int fa){ 96 sz[u]=val[u]; 97 for(int i=head[u];i;i=Next[i]){ 98 int v=ver[i]; 99 if(v!=fa) DP(v,u),sz[u]+=sz[v]; 100 } 101 omega+=1ll*sz[u]*(sigma-sz[u]); 102 } 103 int main(){ 104 n=read(),q=read(); 105 for(int i=1;i<n;++i){ 106 int u=read(),v=read(); 107 ver[++tot]=v,Next[tot]=head[u],head[u]=tot; 108 ver[++tot]=u,Next[tot]=head[v],head[v]=tot; 109 } 110 dfs1(1,0),dfs2(1,1); 111 size=n,son[rt=0]=n+1; 112 getrt(1,0),solve(rt,0); 113 for(int i=1;i<=n;++i) 114 val[i]=read(),modify(i,val[i]),sigma+=val[i]; 115 DP(1,0); 116 while(q--){ 117 int opt=read(),x=read(); 118 if(opt&1){ 119 int y=read();y-=val[x]; 120 modify(x,y),sigma+=y,omega+=y*calc(x); 121 val[x]+=y; 122 } 123 else print((calc(x)+sigma)*sigma-omega); 124 } 125 Ot(); 126 return 0; 127 }