动态$dp$好题
考虑用树链剖分将整棵树剖成若干条链。
设x的重儿子为$son[x]$,设$x$所在链链头为$top[x]$
对于重链上的每个节点(不妨设该节点编号为$x$)令$f[x]$表示以$x$为根的子树内(除以$son[x]$为根的子树),包含节点$x$的联通块的最大权值和。
我们求出一条重链上每个节点的f值后,考虑如何求出以$top[x]$为根的子树内的最大联通快。
我们考虑用线段树来合并每一个f值。我们用线段树维护四个值:
$sum$,该区间内所有$f$值的总和
$suml$,以该区间左端点为起点的所有区间中,权值最大区间权值。
$sumr$,以该区间右端点为七点的所有区间中,权值最大区间权值。
$ans$,该区间内所有区间的最大值
简单pushup一下就可以维护了。
考虑如何询问以x为根子树内的最大值,我们通过一遍dfs求出该树的dfs序,直接在线段树上查询即可。
注意n个INF相加可能会爆long long
#include<bits/stdc++.h> #define M 400005 #define mid ((a[x].l+a[x].r)>>1) #define L long long #define INF (1LL<<50) using namespace std; struct edge{int u,next;}e[M*2]={0}; int head[M]={0},use=0; void add(L x,L y){use++;e[use].u=y;e[use].next=head[x];head[x]=use;} L val[M]={0},f[M]={0},g[M]={0}; int fa[M]={0},siz[M]={0},son[M]={0},dfn[M]={0},low[M]={0},top[M]={0},dn[M]={0},rec[M]={0},t=0; void dfs(L x){ siz[x]=1; f[x]=val[x]; for(L i=head[x];i;i=e[i].next) if(e[i].u!=fa[x]){ fa[e[i].u]=x; dfs(e[i].u); f[x]+=f[e[i].u]; g[x]=max(g[x],g[e[i].u]); siz[x]+=siz[e[i].u]; if(siz[son[x]]<siz[e[i].u]) son[x]=e[i].u; } f[x]=max(f[x],0LL); g[x]=max(g[x],f[x]); } void dfs(L x,L Top){ top[x]=Top; dfn[x]=++t; rec[t]=x; if(son[x]) dfs(son[x],Top),dn[x]=dn[son[x]]; else dn[x]=x,t++; for(L i=head[x];i;i=e[i].next) if(e[i].u!=fa[x]&&e[i].u!=son[x]) dfs(e[i].u,e[i].u); low[x]=t; } struct mat{ L ans,suml,sumr,sum; mat(){suml=sumr=ans=sum=0;} mat(L Ans,L Suml,L Sumr,L Sum){ans=Ans; suml=Suml; sumr=Sumr; sum=Sum;} friend mat operator *(mat a,mat b){ mat c; c.ans=max(a.sumr+b.suml,max(a.ans,b.ans)); c.suml=max(a.suml,a.sum+b.suml); c.sumr=max(a.sumr+b.sum,b.sumr); c.sum=a.sum+b.sum; c.sum=max(c.sum,-INF); return c; } }wei[M]; struct seg{L l,r; mat a;}a[M<<2]; void pushup(L x){a[x].a=a[x<<1].a*a[x<<1|1].a;} void build(L x,L l,L r){ a[x].l=l; a[x].r=r; if(l==r){ L u=rec[l],sum=val[u]; if(u==0){ a[x].a=mat(0,0,0,-INF); return; } for(L i=head[u];i;i=e[i].next) if(e[i].u!=fa[u]&&e[i].u!=son[u]){ sum+=f[e[i].u]; } a[x].a=wei[l]=mat(max(sum,0LL),max(sum,0LL),max(sum,0LL),sum); return; } build(x<<1,l,mid); build(x<<1|1,mid+1,r); pushup(x); } mat query(L x,L l,L r){ if(l<=a[x].l&&a[x].r<=r) return a[x].a; if(r<=mid) return query(x<<1,l,r); if(mid<l) return query(x<<1|1,l,r); return query(x<<1,l,r)*query(x<<1|1,l,r); } mat query(L x){return query(1,dfn[top[x]],dfn[dn[x]]);} void updata(L x,L k){ if(a[x].l==a[x].r) return void(a[x].a=wei[k]); if(k<=mid) updata(x<<1,k); else updata(x<<1|1,k); pushup(x); } void Updata(L x,L Val){ L cha=Val-val[x]; val[x]=Val; L hh=(wei[dfn[x]].sum+=cha); wei[dfn[x]]=mat(max(hh,0LL),max(hh,0LL),max(hh,0LL),hh); while(x){ mat last=query(x); updata(1,dfn[x]); mat now=query(x); x=fa[top[x]]; if(!x) return; cha=now.suml-last.suml; hh=(wei[dfn[x]].sum+=cha); wei[dfn[x]]=mat(max(hh,0LL),max(hh,0LL),max(hh,0LL),hh); } } L n,m; main(){ //freopen("in.txt","r",stdin); //freopen("out.txt","w",stdout); scanf("%lld%lld",&n,&m); for(L i=1;i<=n;i++) scanf("%lld",val+i); for(L i=1,x,y;i<n;i++) scanf("%lld%lld",&x,&y),add(x,y),add(y,x); dfs(1); dfs(1,1); build(1,1,t); while(m--){ char op[10]; L x,y; scanf("%s%lld",op,&x); if(op[0]=='Q'){ mat hh=query(1,dfn[x],low[x]); printf("%lld ",hh.ans); }else{ scanf("%lld",&y); Updata(x,y); } } }