http://www.lydsy.com/JudgeOnline/problem.php?id=3589
树链剖分
用线段数维护扫描线的方式来写,标记只打不下传
#include<cstdio> #include<iostream> #include<algorithm> #define N 200001 using namespace std; int n; int front[N],to[N<<1],nxt[N<<1],tot; int siz[N],dep[N],fa[N]; int bl[N],in[N],out[N]; int sum[N<<2],f[N<<2]; int ans[N<<2],tag[N<<2]; int op[6][2]; void read(int &x) { x=0; char c=getchar(); while(!isdigit(c)) c=getchar(); while(isdigit(c)) { x=x*10+c-'0'; c=getchar(); } } void add(int u,int v) { to[++tot]=v; nxt[tot]=front[u]; front[u]=tot; to[++tot]=u; nxt[tot]=front[v]; front[v]=tot; } void dfs1(int x,int y) { siz[x]=1; for(int i=front[x];i;i=nxt[i]) { if(to[i]==fa[x]) continue; fa[to[i]]=x; dep[to[i]]=dep[x]+1; dfs1(to[i],x); siz[x]+=siz[to[i]]; } } void dfs2(int x,int top) { bl[x]=top; in[x]=++tot; int y=0; for(int i=front[x];i;i=nxt[i]) if(to[i]!=fa[x] && siz[to[i]]>siz[y]) y=to[i]; if(!y) { out[x]=tot; return;} dfs2(y,top); for(int i=front[x];i;i=nxt[i]) if(to[i]!=fa[x] && to[i]!=y) dfs2(to[i],to[i]); out[x]=tot; } void down(int k,int l,int mid,int r) { sum[k<<1]+=f[k]*(mid-l+1); sum[k<<1|1]+=f[k]*(r-mid); f[k<<1]+=f[k]; f[k<<1|1]+=f[k]; f[k]=0; } void add(int k,int l,int r,int opl,int opr,int w) { if(l>=opl && r<=opr) { sum[k]+=(r-l+1)*w; f[k]+=w; return; } int mid=l+r>>1; if(f[k]) down(k,l,mid,r); if(opl<=mid) add(k<<1,l,mid,opl,opr,w); if(opr>mid) add(k<<1|1,mid+1,r,opl,opr,w); sum[k]=sum[k<<1]+sum[k<<1|1]; } void add_tag(int k,int l,int r,int opl,int opr,bool w) { if(l>=opl && r<=opr) { if(w) tag[k]++,ans[k]=sum[k]; else tag[k]--,ans[k]=0; return; } int mid=l+r>>1; if(f[k]) down(k,l,mid,r); if(opl<=mid) add_tag(k<<1,l,mid,opl,opr,w); if(opr>mid) add_tag(k<<1|1,mid+1,r,opl,opr,w); if(!tag[k]) ans[k]=ans[k<<1]+ans[k<<1|1]; else ans[k]=sum[k]; } void solve(int u,int v,bool ty) { while(bl[u]!=bl[v]) { if(dep[bl[u]]<dep[bl[v]]) swap(u,v); add_tag(1,1,n,in[bl[u]],in[u],ty); u=fa[bl[u]]; } if(dep[u]>dep[v]) swap(u,v); add_tag(1,1,n,in[u],in[v],ty); } int main() { freopen("data.in","r",stdin); freopen("my.out","w",stdout); read(n); int u,v; for(int i=1;i<n;++i) { read(u); read(v); add(u,v); } dfs1(1,0); tot=0; dfs2(1,1); int m; read(m); int ty; while(m--) { read(ty); if(!ty) { read(u); read(v); add(1,1,n,in[u],out[u],v); } else { read(ty); for(int i=1;i<=ty;++i) { read(op[i][0]); read(op[i][1]); solve(op[i][0],op[i][1],1); } if(ans[1]<0) ans[1]+=(1LL<<31); cout<<ans[1]<<' '; for(int i=1;i<=ty;++i) solve(op[i][0],op[i][1],0); } } }