题解:
其实就是单点修改,树链查max+sum。
没啥好说的,树剖+线段树搞一搞就好了。
代码:
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; #define N 30050 #define ll long long inline int rd() { int f=1,c=0;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){c=10*c+ch-'0';ch=getchar();} return f*c; } const ll inf = 0x3f3f3f3f3f3f3f3fll; int n,q,hed[N],cnt; ll w[N]; struct EG { int to,nxt; }e[2*N]; void ae(int f,int t) { e[++cnt].to = t; e[cnt].nxt = hed[f]; hed[f] = cnt; } int fa[N],siz[N],top[N],son[N],dep[N]; void dfs1(int u,int f) { siz[u]=1; fa[u]=f; dep[u]=dep[f]+1; for(int j=hed[u];j;j=e[j].nxt) { int to = e[j].to; if(to==f)continue; dfs1(to,u); siz[u]+=siz[to]; if(siz[to]>siz[son[u]])son[u]=to; } } int tin[N],pla[N],tim; void dfs2(int u,int tp) { top[u] = tp,tin[u]=++tim,pla[tim]=u; if(son[u]) { dfs2(son[u],tp); for(int j=hed[u];j;j=e[j].nxt) { int to = e[j].to; if(to==fa[u]||to==son[u])continue; dfs2(to,to); } } } int get_lca(int x,int y) { while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]])swap(x,y); x=fa[top[x]]; } return dep[x]<dep[y]?x:y; } struct segtree { ll vx[N<<2],vs[N<<2]; void update(int u) { vx[u] = max(vx[u<<1],vx[u<<1|1]); vs[u] = vs[u<<1]+vs[u<<1|1]; } void build(int l,int r,int u) { if(l==r) { vs[u]=vx[u]=w[pla[l]]; return ; } int mid = (l+r)>>1; build(l,mid,u<<1); build(mid+1,r,u<<1|1); update(u); } void insert(int l,int r,int u,int qx,ll d) { if(l==r) { vs[u]=vx[u]=d; return ; } int mid = (l+r)>>1; if(qx<=mid)insert(l,mid,u<<1,qx,d); else insert(mid+1,r,u<<1|1,qx,d); update(u); } ll query1(int l,int r,int u,int ql,int qr)//max { if(l==ql&&r==qr)return vx[u]; int mid = (l+r)>>1; if(qr<=mid)return query1(l,mid,u<<1,ql,qr); else if(ql>mid)return query1(mid+1,r,u<<1|1,ql,qr); else return max(query1(l,mid,u<<1,ql,mid),query1(mid+1,r,u<<1|1,mid+1,qr)); } ll query2(int l,int r,int u,int ql,int qr)//sum { if(l==ql&&r==qr)return vs[u]; int mid = (l+r)>>1; if(qr<=mid)return query2(l,mid,u<<1,ql,qr); else if(ql>mid)return query2(mid+1,r,u<<1|1,ql,qr); else return query2(l,mid,u<<1,ql,mid)+query2(mid+1,r,u<<1|1,mid+1,qr); } ll q1(int u,int lim) { ll ret = -inf; int now = top[u]; while(dep[now]>=dep[lim]) { ret = max(ret , query1(1,n,1,tin[now],tin[u]) ); u = fa[now] , now = top[u]; } if(dep[u]>=dep[lim]) ret = max(ret , query1(1,n,1,tin[lim],tin[u]) ); return ret; } ll q2(int u,int lim) { ll ret = 0; int now = top[u]; while(dep[now]>dep[lim]) { ret += query2(1,n,1,tin[now],tin[u]); u = fa[now],now = top[u]; } if(dep[u]>dep[lim]) ret += query2(1,n,1,tin[lim]+1,tin[u]); return ret; } }tr; char ch[10]; int main() { n=rd(); for(int f,t,i=1;i<n;i++) { f=rd(),t=rd(); ae(f,t),ae(t,f); } dfs1(1,0); dfs2(1,1); for(int i=1;i<=n;i++)w[i]=rd(); tr.build(1,n,1); q=rd(); for(int u,v,i=1;i<=q;i++) { scanf("%s",ch+1); if(ch[1]=='C') { u=rd(),v=rd(); tr.insert(1,n,1,tin[u],v); }else if(ch[2]=='M') { u=rd(),v=rd(); int lca = get_lca(u,v); ll ans = max(tr.q1(u,lca),tr.q1(v,lca)); printf("%lld ",ans); }else { u=rd(),v=rd(); int lca = get_lca(u,v); ll ans = tr.query2(1,n,1,tin[lca],tin[lca]); ans += tr.q2(u,lca)+tr.q2(v,lca); printf("%lld ",ans); } } return 0; }