https://www.luogu.org/problemnew/show/P3676
这题被我当成动态dp去做了,码了4k,搞了一个换根的动态dp
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 using namespace std; 5 typedef long long ll; 6 struct E 7 { 8 int to,nxt; 9 }e[400011]; 10 int f1[200011],ne; 11 struct P1 12 { 13 int len;ll a,b,c,d,e,f; 14 //长度,(点权)和,后缀和之和,后缀和的平方之和,(答案)和 15 //前缀和之和,前缀和的平方之和 16 }; 17 struct P2 18 { 19 ll a,b; 20 //点权和,答案和 21 }; 22 ll a[200101]; 23 int sz[200101],hson[200101],ff[200101]; 24 int b[200101],pl[200101]; 25 int n,m; 26 inline void merge(P1 &c,const P1 &a,const P1 &b) 27 { 28 c.len=a.len+b.len; 29 c.a=a.a+b.a; 30 c.b=b.b+a.b+b.a*a.len; 31 c.c=b.c+b.a*b.a*a.len+2*a.b*b.a+a.c; 32 c.d=a.d+b.d; 33 c.e=a.e+b.e+a.a*b.len; 34 c.f=a.f+a.a*a.a*b.len+2*b.e*a.a+b.f; 35 } 36 inline void initnode(P1 &c,const P2 &a) 37 { 38 c.len=1;c.a=c.b=c.e=a.a;c.c=c.f=a.a*a.a;c.d=a.b; 39 } 40 namespace S 41 { 42 #define lc (num<<1) 43 #define rc (num<<1|1) 44 P1 d[800101]; 45 inline void upd(int num){merge(d[num],d[lc],d[rc]);} 46 P1 x;int L; 47 void _setx(int l,int r,int num) 48 { 49 if(l==r) 50 { 51 d[num]=x; 52 return; 53 } 54 int mid=(l+r)>>1; 55 if(L<=mid) _setx(l,mid,lc); 56 else _setx(mid+1,r,rc); 57 upd(num); 58 } 59 P1 getx(int L,int R,int l,int r,int num) 60 { 61 if(L<=l&&r<=R) return d[num]; 62 int mid=(l+r)>>1; 63 if(L<=mid&&mid<R) 64 { 65 P1 x; 66 merge(x,getx(L,R,l,mid,lc),getx(L,R,mid+1,r,rc)); 67 return x; 68 } 69 else if(L<=mid) 70 return getx(L,R,l,mid,lc); 71 else if(mid<R) 72 return getx(L,R,mid+1,r,rc); 73 else 74 exit(-1); 75 } 76 } 77 void dfs1(int u,int fa) 78 { 79 sz[u]=1; 80 for(int v,k=f1[u];k;k=e[k].nxt) 81 if(e[k].to!=fa) 82 { 83 v=e[k].to; 84 ff[v]=u; 85 dfs1(v,u); 86 sz[u]+=sz[v]; 87 if(sz[v]>sz[hson[u]]) hson[u]=v; 88 } 89 } 90 P2 d1[200101];//d1[i]维护i节点及其轻儿子的贡献 91 P2 d2[200101];//d2[i]维护i节点(是重链顶)所在重链的值 92 int tp[200101],dwn[200101];//链顶,链底 93 inline void updd1(int x) 94 { 95 initnode(S::x,d1[x]);S::L=pl[x];S::_setx(1,n,1); 96 } 97 void dfs2(int u,int fa) 98 { 99 d1[u].a=a[u]; 100 b[++b[0]]=u;pl[u]=b[0]; 101 tp[u]=(u==hson[fa])?tp[fa]:u; 102 if(hson[u]) dfs2(hson[u],u); 103 dwn[u]=hson[u]?dwn[hson[u]]:u; 104 int v,k; 105 for(k=f1[u];k;k=e[k].nxt) 106 if(e[k].to!=fa&&e[k].to!=hson[u]) 107 { 108 v=e[k].to; 109 dfs2(v,u); 110 d1[u].b+=d2[v].b; 111 d1[u].a+=d2[v].a; 112 } 113 updd1(u); 114 if(u==tp[u]) 115 { 116 P1 t=S::getx(pl[u],pl[dwn[u]],1,n,1); 117 d2[u].a=t.a;d2[u].b=t.d+t.c; 118 } 119 } 120 inline ll getsize(int x) 121 { 122 return S::getx(pl[x],pl[dwn[x]],1,n,1).a; 123 } 124 int main() 125 { 126 int i,x,y,idx;ll z,ans,szall;P1 t; 127 scanf("%d%d",&n,&m); 128 for(i=1;i<n;++i) 129 { 130 scanf("%d%d",&x,&y); 131 e[++ne].to=y;e[ne].nxt=f1[x];f1[x]=ne; 132 e[++ne].to=x;e[ne].nxt=f1[y];f1[y]=ne; 133 } 134 for(i=1;i<=n;++i) scanf("%lld",a+i); 135 dfs1(1,0); 136 dfs2(1,0); 137 while(m--) 138 { 139 scanf("%d",&idx); 140 if(idx==1) 141 { 142 scanf("%d%lld",&x,&z); 143 d1[x].a-=a[x];a[x]=z;d1[x].a+=z; 144 while(x) 145 { 146 updd1(x); 147 x=tp[x];y=ff[x]; 148 t=S::getx(pl[x],pl[dwn[x]],1,n,1); 149 d1[y].a-=d2[x].a;d1[y].b-=d2[x].b; 150 d2[x].a=t.a;d2[x].b=t.d+t.c; 151 d1[y].a+=d2[x].a;d1[y].b+=d2[x].b; 152 x=y; 153 } 154 //printf("3t%d ",d2[1].b); 155 } 156 else 157 { 158 scanf("%d",&x); 159 ans=d2[1].b; 160 szall=getsize(1); 161 if(x!=tp[x]) 162 { 163 y=tp[x]; 164 z=d1[y].a; 165 d1[y].a+=szall-getsize(y); 166 updd1(y); 167 if(y!=dwn[y]) 168 { 169 t=S::getx(pl[y]+1,pl[dwn[y]],1,n,1); 170 ans-=t.c; 171 } 172 if(x!=dwn[y]) 173 { 174 t=S::getx(pl[x]+1,pl[dwn[y]],1,n,1); 175 ans+=t.c; 176 } 177 t=S::getx(pl[y],pl[x]-1,1,n,1); 178 ans+=t.f; 179 d1[y].a=z; 180 updd1(y); 181 x=y; 182 } 183 while(x!=1) 184 { 185 y=ff[x]; 186 z=getsize(x); 187 ans-=z*z; 188 z=szall-z; 189 ans+=z*z; 190 x=y; 191 if(x!=tp[x]) 192 { 193 y=tp[x]; 194 z=d1[y].a; 195 d1[y].a+=szall-getsize(y); 196 updd1(y); 197 if(y!=dwn[y]) 198 { 199 t=S::getx(pl[y]+1,pl[dwn[y]],1,n,1); 200 ans-=t.c; 201 } 202 if(x!=dwn[y]) 203 { 204 t=S::getx(pl[x]+1,pl[dwn[y]],1,n,1); 205 ans+=t.c; 206 } 207 t=S::getx(pl[y],pl[x]-1,1,n,1); 208 ans+=t.f; 209 d1[y].a=z; 210 updd1(y); 211 x=y; 212 } 213 } 214 printf("%lld ",ans); 215 } 216 } 217 return 0; 218 }
码完一看题解,???好像画风不太对??
所以还是无视上面那个代码吧...
正常得多的做法:
1 #include<cstdio> 2 #include<algorithm> 3 using namespace std; 4 typedef long long ll; 5 struct E 6 { 7 int to,nxt; 8 }e[400011]; 9 int f1[200011],ne; 10 int n,m; 11 struct S 12 { 13 #define lowbit(x) ((x)&(-x)) 14 ll d1[200011],d2[200011]; 15 void _add(int p,ll x,ll *d) 16 { 17 for(;p<=n;p+=lowbit(p)) 18 d[p]+=x; 19 } 20 ll _sum(int p,ll *d) 21 { 22 ll ans=0; 23 for(;p>0;p-=lowbit(p)) 24 ans+=d[p]; 25 return ans; 26 } 27 void add(int l,int r,ll x) 28 { 29 _add(l,x,d1); 30 _add(r+1,-x,d1); 31 _add(l,x*l,d2); 32 _add(r+1,-x*(r+1),d2); 33 } 34 ll sum(int l,int r) 35 { 36 return (r+1)*_sum(r,d1)-_sum(r,d2) 37 -l*_sum(l-1,d1)+_sum(l-1,d2); 38 } 39 }s1; 40 int b[200011],pl[200011]; 41 ll a[200011],a2[200011]; 42 int sz[200011],hson[200011],tp[200011]; 43 ll dep[200011]; 44 int ff[200011]; 45 void dfs1(int u,int fa) 46 { 47 sz[u]=1; 48 for(int k=f1[u];k;k=e[k].nxt) 49 if(e[k].to!=fa) 50 { 51 ff[e[k].to]=u; 52 dep[e[k].to]=dep[u]+1; 53 dfs1(e[k].to,u); 54 sz[u]+=sz[e[k].to]; 55 if(sz[e[k].to]>sz[hson[u]]) hson[u]=e[k].to; 56 } 57 } 58 void dfs2(int u,int fa) 59 { 60 b[++b[0]]=u;pl[u]=b[0]; 61 tp[u]=u==hson[fa]?tp[fa]:u; 62 a2[u]=a[u]; 63 if(hson[u]) 64 { 65 dfs2(hson[u],u); 66 a2[u]+=a2[hson[u]]; 67 } 68 for(int k=f1[u];k;k=e[k].nxt) 69 if(e[k].to!=fa&&e[k].to!=hson[u]) 70 { 71 dfs2(e[k].to,u); 72 a2[u]+=a2[e[k].to]; 73 } 74 } 75 inline ll gsum1(int x)//x到1的路径和 76 { 77 int y;ll an=0; 78 for(;x;x=ff[y]) 79 { 80 y=tp[x]; 81 an+=s1.sum(pl[y],pl[x]); 82 } 83 return an; 84 } 85 inline void add1(int x,ll z)//x到1加上z 86 { 87 int y; 88 for(;x;x=ff[y]) 89 { 90 y=tp[x]; 91 s1.add(pl[y],pl[x],z); 92 } 93 } 94 ll anss; 95 int main() 96 { 97 ll ans,z,t; 98 int i,x,y,idx; 99 scanf("%d%d",&n,&m); 100 for(i=1;i<n;++i) 101 { 102 scanf("%d%d",&x,&y); 103 e[++ne].to=y;e[ne].nxt=f1[x];f1[x]=ne; 104 e[++ne].to=x;e[ne].nxt=f1[y];f1[y]=ne; 105 } 106 for(i=1;i<=n;++i) 107 scanf("%lld",a+i); 108 dfs1(1,0); 109 dfs2(1,0); 110 for(i=1;i<=n;++i) 111 { 112 s1.add(pl[i],pl[i],a2[i]); 113 anss+=a2[i]*a2[i]; 114 } 115 while(m--) 116 { 117 scanf("%d",&idx); 118 if(idx==1) 119 { 120 scanf("%d%lld",&x,&z); 121 z=z-a[x];a[x]+=z; 122 anss+=z*z*(dep[x]+1); 123 anss+=2*gsum1(x)*z; 124 add1(x,z); 125 } 126 else 127 { 128 scanf("%d",&x); 129 ans=anss; 130 t=gsum1(1); 131 ans+=dep[x]*t*t; 132 ans-=2*t*(gsum1(x)-t); 133 printf("%lld ",ans); 134 } 135 } 136 return 0; 137 }