树链剖分
用区间修改线段树维护
对于颜色段的计算:sum[o]=sum[lc]+sum[rc]
因为可能重复计算,即左子树的右端点和右子树的左端点可能颜色相同
多开2个数组lx,rx记录左/右端点的颜色,重复的话 sum[o]- - 即可
(我的错误模板坑了我2h)
#include<iostream> #include<cstdio> #include<cstring> #include<cctype> using namespace std; template <typename T> inline T min(T &a,T &b) {return a<b ?a:b;} template <typename T> inline T max(T &a,T &b) {return a>b ?a:b;} template <typename T> inline void read(T &x){ char c=getchar(); x=0; bool f=1; while(!isdigit(c)) f= !f||c=='-' ? 0:1,c=getchar(); while(isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar(); x= f ? x:-x; } template <typename T> inline void output(T x){ if(!x) {putchar(48); return ;} if(x<0) putchar('-'),x=-x; int wt[50],l=0; while(x) wt[++l]=x%10,x/=10; while(l) putchar(wt[l--]+48); } typedef int arr[100005]; int n,m,cnt,tot; arr d,fa,siz,bgs,tp,tmp,val,id,hd,ed; int nxt[200005],poi[200005]; int lx[400005],rx[400005],sum[400005],tag[400005]; inline void add_(int x,int y){ nxt[ed[x]]=++cnt; hd[x]= hd[x] ? hd[x]:cnt; ed[x]=cnt; poi[cnt]=y; } inline void pushdown(int o){ if(tag[o]==-1) return ; int lc=o<<1,rc=o<<1|1; sum[lc]=sum[rc]=1; lx[lc]=rx[lc]=lx[rc]=rx[rc]=tag[o]; tag[lc]=tag[rc]=tag[o]; tag[o]=-1; } inline void maintain(int o){ int lc=o<<1,rc=o<<1|1; lx[o]=lx[lc],rx[o]=rx[rc],sum[o]=sum[lc]+sum[rc]; if(rx[lc]==lx[rc]) --sum[o]; //重复减掉 } inline void build(int o,int l,int r){ tag[o]=-1; if(l==r) {tag[o]=lx[o]=rx[o]=tmp[l]; sum[o]=1; return ;} int lc=o<<1,rc=o<<1|1,mid=l+((r-l)>>1); build(lc,l,mid); build(rc,mid+1,r); maintain(o); } inline void modify(int o,int l,int r,int x1,int x2,int v){ if(x1<=l&&r<=x2){ lx[o]=rx[o]=tag[o]=v; sum[o]=1; return ; }pushdown(o); int lc=o<<1,rc=o<<1|1,mid=l+((r-l)>>1); if(x1<=mid) modify(lc,l,mid,x1,x2,v); if(x2>mid) modify(rc,mid+1,r,x1,x2,v); maintain(o); } inline int query(int o,int l,int r,int x1,int x2){ if(x1<=l&&r<=x2) return sum[o]; pushdown(o); int res=0; int lc=o<<1,rc=o<<1|1,mid=l+((r-l)>>1); if(x1<=mid) res+=query(lc,l,mid,x1,x2); if(x2>mid){ res+=query(rc,mid+1,r,x1,x2); if(x1<=mid&&rx[lc]==lx[rc]) --res; }return res; } inline void dfs1(int x,int _fa){ d[x]=d[_fa]+1,fa[x]=_fa,siz[x]=1; for(int i=hd[x];i;i=nxt[i]) if(poi[i]!=_fa){ dfs1(poi[i],x); siz[x]+=siz[poi[i]]; if(siz[bgs[x]]<siz[poi[i]]) bgs[x]=poi[i]; } } inline void dfs2(int x,int _top){ id[x]=++tot,tmp[tot]=val[x],tp[x]=_top; if(siz[x]==1) return; dfs2(bgs[x],_top); for(int i=hd[x];i;i=nxt[i]) if(poi[i]!=fa[x]&&poi[i]!=bgs[x]) dfs2(poi[i],poi[i]); } inline int findcol(int o,int l,int r,int x){ if(l==r) return tag[o]; pushdown(o); int lc=o<<1,rc=o<<1|1,mid=l+((r-l)>>1); if(x<=mid) return findcol(lc,l,mid,x); else return findcol(rc,mid+1,r,x); } inline void change(int x,int y,int v){ while(tp[x]!=tp[y]){ //™模板原来这里多套了个d数组上去,变成50pts if(d[tp[x]]<d[tp[y]]) swap(x,y); modify(1,1,n,id[tp[x]],id[x],v); x=fa[tp[x]]; }if(d[x]>d[y]) swap(x,y); modify(1,1,n,id[x],id[y],v); } inline int ask(int x,int y){ int res=0,r1,r2; while(tp[x]!=tp[y]){ if(d[tp[x]]<d[tp[y]]) swap(x,y); res+=query(1,1,n,id[tp[x]],id[x]); r1=findcol(1,1,n,id[tp[x]]); r2=findcol(1,1,n,id[fa[tp[x]]]); if(r1==r2) --res; x=fa[tp[x]]; }if(d[x]>d[y]) swap(x,y); res+=query(1,1,n,id[x],id[y]); return max(res,1); } int main(){ read(n); read(m); int q1,q2,q3; char opt[5]; for(int i=1;i<=n;++i) read(val[i]); for(int i=1;i<n;++i) read(q1),read(q2),add_(q1,q2),add_(q2,q1); dfs1(1,0); dfs2(1,1); build(1,1,n); for(int i=1;i<=m;++i){ scanf("%s",opt); read(q1),read(q2); if(opt[0]=='Q') output(ask(q1,q2)),putchar(' '); else read(q3),change(q1,q2,q3); }return 0; }