多维护两个信息表示最左边的和最右边的两个点是什么颜色的。
在更新的时候注意合并
#include<bits/stdc++.h> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef pair<ll,ll> pll; const int N=1e6+10; const int M=2e6+10; const int inf=0x3f3f3f3f; const ll mod=998244353; int h[N],ne[N],e[N],idx; int times,pre[N],LC,RC; int n,m,val[N],sz[N],son[N],depth[N],dfn[N],top[N],father[N]; struct node{ int l,r; int cnt; int lazy; int ls,rs; int num; }tr[N<<1]; void add(int a,int b){ e[idx]=b,ne[idx]=h[a],h[a]=idx++; } void dfs1(int u,int fa){ sz[u]=1; int i; for(i=h[u];i!=-1;i=ne[i]){ int j=e[i]; if(j==fa) continue; father[j]=u; depth[j]=depth[u]+1; dfs1(j,u); sz[u]+=sz[j]; if(sz[j]>sz[son[u]]){ son[u]=j; } } } void dfs(int u,int x){ int i; dfn[u]=++times,pre[times]=u; top[u]=x; if(!son[u]) return ; dfs(son[u],x); for(i=h[u];i!=-1;i=ne[i]){ int j=e[i]; if(j==father[u]||j==son[u]) continue; dfs(j,j); } } void pushup(int u){ tr[u].num=tr[u<<1].num+tr[u<<1|1].num; tr[u].ls=tr[u<<1].ls,tr[u].rs=tr[u<<1|1].rs; if(tr[u<<1].rs==tr[u<<1|1].ls) tr[u].num--; } void build(int u,int l,int r){ if(l==r){ tr[u]={l,l,val[pre[l]],0,val[pre[l]],val[pre[l]],1}; } else{ tr[u]={l,r}; int mid=l+r>>1; build(u<<1,l,mid); build(u<<1|1,mid+1,r); pushup(u); } } void pushdown(int u){ if(tr[u].lazy){ int x=tr[u].lazy; tr[u<<1].num=tr[u<<1|1].num=1; tr[u<<1].lazy=tr[u<<1|1].lazy=x; tr[u<<1].ls=tr[u<<1].rs=tr[u<<1|1].rs=tr[u<<1|1].ls=x; tr[u].lazy=0; } } void modify(int u,int l,int r,int k){ if(tr[u].l>=l&&tr[u].r<=r){ tr[u].lazy=k; tr[u].ls=tr[u].rs=k; tr[u].num=1; return ; } pushdown(u); int mid=tr[u].l+tr[u].r>>1; if(l<=mid) modify(u<<1,l,r,k); if(r>mid) modify(u<<1|1,l,r,k); pushup(u); } int query(int u,int l,int r){ if(tr[u].l>=l&&tr[u].r<=r){ if(tr[u].l==l) LC=tr[u].ls; if(tr[u].r==r) RC=tr[u].rs; return tr[u].num; } pushdown(u); int mid=tr[u].l+tr[u].r>>1; ll sum=0; if(r<=mid) return query(u<<1,l,r); if(l>mid) return query(u<<1|1,l,r); sum=query(u<<1,l,r)+query(u<<1|1,l,r); if(tr[u<<1].rs==tr[u<<1|1].ls) sum--; return sum; } int querypath(int x,int y){ int ans1=-1,ans2=-1; int sum=0; while(top[x]!=top[y]){ if(depth[top[x]]<depth[top[y]]){ swap(x,y); swap(ans1,ans2); } sum+=query(1,dfn[top[x]],dfn[x]); if(ans1==RC) sum--; ans1=LC; x=father[top[x]]; } if(depth[x]>depth[y]){ swap(x,y); swap(ans1,ans2); } sum+=query(1,dfn[x],dfn[y]); if(LC==ans1) sum--; if(RC==ans2) sum--; return sum; } void change(int x,int y,int c){ while(top[x]!=top[y]){ if(depth[top[x]]<depth[top[y]]) swap(x,y); modify(1,dfn[top[x]],dfn[x],c); x=father[top[x]]; } if(depth[x]>depth[y]) swap(x,y); modify(1,dfn[x],dfn[y],c); } int main(){ ios::sync_with_stdio(false); //freopen("1.in","r",stdin); cin>>n>>m; int i; memset(h,-1,sizeof h); for(i=1;i<=n;i++){ cin>>val[i]; } for(i=1;i<n;i++){ int a,b; cin>>a>>b; add(a,b); add(b,a); } depth[1]=1; dfs1(1,-1); dfs(1,1); build(1,1,n); while(m--){ string s; int a,b,c; cin>>s; if(s=="Q"){ cin>>a>>b; cout<<querypath(a,b)<<endl; } else{ cin>>a>>b>>c; change(a,b,c); } } return 0; }