放假之前十分钟问LLJ大佬有没有什么水题可做,他看了看指了一道树剖水题;
我:喵喵喵?
然后被无情地嘲笑了十分钟打不完一道树剖。
并没有什么想说的,反正打得超级慢。别人家大佬一分钟写完匈牙利,半小时写完可持久化平衡树。我太弱啦。
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<queue>
#include<vector>
typedef long long LL;
const int maxn=30000+5;
using namespace std;
int n,m,u,v,fir[maxn],nxt[maxn*2],to[maxn*2],sz[maxn],top[maxn],sgmax[maxn<<2],sgsum[maxn<<2],tid[maxn],w[maxn];
int l,r,ecnt,tot,R[maxn],fa[maxn];
char s[10];
void add(int u,int v) {
nxt[++ecnt]=fir[u];fir[u]=ecnt;to[ecnt]=v;
nxt[++ecnt]=fir[v];fir[v]=ecnt;to[ecnt]=u;
}
#define lc x<<1
#define rc x<<1|1
void change(int x,int l,int r,int pos,int val) {
if(l==r) {sgmax[x]=sgsum[x]=val; return ;}
int mid=(l+r)>>1;
if(pos<=mid) change(lc,l,mid,pos,val);
else change(rc,mid+1,r,pos,val);
sgmax[x]=max(sgmax[lc],sgmax[rc]);
sgsum[x]=sgsum[lc]+sgsum[rc];
}
int query(int x,int l,int r,int ql,int qr,int op){
if(l>=ql&&r<=qr) {
if(op==1) return sgmax[x];
else return sgsum[x];
}
int mid=(l+r)>>1;
if(qr<=mid) return query(lc,l,mid,ql,qr,op);
else if(ql>mid) return query(rc,mid+1,r,ql,qr,op);
return op?max(query(lc,l,mid,ql,qr,op),query(rc,mid+1,r,ql,qr,op)):query(lc,l,mid,ql,qr,op)+query(rc,mid+1,r,ql,qr,op);
}
int squery(int l,int r,int op) {
int res; op?res=-1000000:res=0;
while(top[l]!=top[r]) {
if(R[top[l]]<R[top[r]]) swap(l,r);
op?res=max(res,query(1,1,n,tid[top[l]],tid[l],op)):res+=query(1,1,n,tid[top[l]],tid[l],op);
l=fa[top[l]];
}
if(tid[l]>tid[r]) swap(l,r);
op?res=max(res,query(1,1,n,tid[l],tid[r],op)):res+=query(1,1,n,tid[l],tid[r],op);
return res;
}
void dfs(int x,int f) {
sz[x]=1; R[x]=R[f]+1;
fa[x]=f;
for(int i=fir[x];i;i=nxt[i]) if(to[i]!=f){
dfs(to[i],x);
sz[x]+=sz[to[i]];
}
}
void DFS(int x,int tt) {
tid[x]=++tot;
top[x]=tt;
change(1,1,n,tot,w[x]);
int mson=0;
for(int i=fir[x];i;i=nxt[i]) if(to[i]!=fa[x]){
if(!mson||sz[to[i]]>sz[mson]) mson=to[i];
}
if(!mson) return ;
DFS(mson,tt);
for(int i=fir[x];i;i=nxt[i])
if(to[i]!=fa[x]&&to[i]!=mson) DFS(to[i],to[i]);
}
int main()
{
scanf("%d",&n);
for(int i=1;i<n;i++) {
scanf("%d%d",&u,&v);
add(u,v);
}
for(int i=1;i<=n;i++) scanf("%d",&w[i]);
memset(sgmax,128,sizeof(sgmax));
dfs(1,0);
DFS(1,1);
scanf("%d",&m);
while(m--) {
scanf("%s%d%d",&s,&l,&r);
if(s[0]=='C') change(1,1,n,tid[l],r);
else if(s[1]=='M') printf("%d
",squery(l,r,1)); //max
else printf("%d
",squery(l,r,0));
}
return 0;
}