树链剖分就行了,注意线段树上颜色的合并
Code
#include <cstdio> #include <algorithm> #define N 100010 #define MID int mid=(l+r)>>1,ls=id<<1,rs=id<<1|1 #define len (r-l+1) using namespace std; struct tree{ int lc,rc,sum,tag; tree(){lc=rc=tag=-1;sum=0;} friend tree operator +(tree a,tree b){ if(a.lc==-1) return b; if(b.lc==-1) return a; tree c; c.lc=a.lc,c.rc=b.rc; c.sum=a.sum+b.sum-(a.rc==b.lc?1:0); return c; } }T[N*4]; struct info{int to,nex;}e[N*2]; int n,m,tot,head[N],cnt,A[N]; int tid[N],dep[N],son[N],fa[N],sz[N],tp[N],tw[N]; inline int read(){ int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } inline void Link(int u,int v){ e[++tot].nex=head[u];head[u]=tot;e[tot].to=v; } void dfs(int u,int pre){ sz[u]=1; for(int i=head[u],mx=0;i;i=e[i].nex){ int v=e[i].to; if(v==pre) continue; fa[v]=u; dep[v]=dep[u]+1; dfs(v,u); sz[u]+=sz[v]; if(sz[v]>mx){son[u]=v;mx=sz[v];} } } void dddfs(int u,int top){ tp[u]=top; tid[u]=++cnt; tw[cnt]=A[u]; if(!son[u]) return; dddfs(son[u],top); for(int i=head[u];i;i=e[i].nex){ int v=e[i].to; if(v==fa[u]||v==son[u]) continue; dddfs(v,v); } } void build(int l,int r,int id){ if(l==r){T[id].sum=1;T[id].lc=T[id].rc=tw[l];return;} MID; build(l,mid,ls); build(mid+1,r,rs); T[id]=T[ls]+T[rs]; } void Init(){ n=read(),m=read(); for(int i=1;i<=n;A[i++]=read()); for(int i=1;i<n;++i){ int u=read(),v=read(); Link(u,v),Link(v,u); } dfs(1,0); dddfs(1,1); build(1,n,1); } inline void pushdown(int l,int r,int id){ int &tmp=T[id].tag; if(tmp==-1) return; MID; T[ls].lc=T[ls].rc=T[rs].lc=T[rs].rc=tmp; T[ls].sum=T[rs].sum=1; T[ls].tag=T[rs].tag=tmp; tmp=-1; } int query(int l,int r,int id,int L,int R){ if(L<=l&&r<=R) return T[id].sum; pushdown(l,r,id); MID; int res=0; if(R<=mid) res+=query(l,mid,ls,L,R); else if(L>mid) res+=query(mid+1,r,rs,L,R); else res+=query(l,mid,ls,L,R),res+=query(mid+1,r,rs,L,R),res-=(T[ls].rc==T[rs].lc)?1:0; return res; } int qDot(int l,int r,int id,int x){ if(l==r&&l==x) return T[id].lc; pushdown(l,r,id); MID; if(x<=mid) return qDot(l,mid,ls,x); else return qDot(mid+1,r,rs,x); } inline int qRange(int u,int v){ int res=0; while(tp[u]!=tp[v]){ if(dep[tp[u]]<dep[tp[v]]) swap(u,v); res+=query(1,n,1,tid[tp[u]],tid[u]); int x=qDot(1,n,1,tid[tp[u]]),y=qDot(1,n,1,tid[fa[tp[u]]]); if(x==y) --res; u=fa[tp[u]]; } if(dep[u]>dep[v]) swap(u,v); res+=query(1,n,1,tid[u],tid[v]); return res; } void update(int l,int r,int id,int L,int R,int x){ if(L<=l&&r<=R){ T[id].sum=1; T[id].lc=T[id].rc=T[id].tag=x; return; } pushdown(l,r,id); MID; if(L<=mid) update(l,mid,ls,L,R,x); if(R>mid) update(mid+1,r,rs,L,R,x); T[id]=T[ls]+T[rs]; } void updRange(int u,int v,int x){ while(tp[u]!=tp[v]){ if(dep[tp[u]]<dep[tp[v]]) swap(u,v); update(1,n,1,tid[tp[u]],tid[u],x); u=fa[tp[u]]; } if(dep[u]>dep[v]) swap(u,v); update(1,n,1,tid[u],tid[v],x); } inline void solve(){ char ch; while(m--){ for(ch=getchar();ch!='C'&&ch!='Q';ch=getchar()); if(ch=='Q'){ int u=read(),v=read(); printf("%d ",qRange(u,v)); }else{ int u=read(),v=read(),x=read(); updRange(u,v,x); } } } int main(){Init();solve();}