Link:
Solution:
基础树剖,但要注意的就是链合并时的边界问题
每次查询时发现当前区间为目标区间的边界时直接记录边界的值即可
注意最后一次两个边界都要考虑!
Code:
#include <bits/stdc++.h> using namespace std; #define X first #define Y second #define pb push_back #define mid ((l+r)>>1) #define lc k<<1,l,mid #define rc k<<1|1,mid+1,r typedef double db; typedef long long ll; typedef pair<int,int> P; const int MAXN=1e5+10; struct edge{int nxt,to;}e[MAXN<<2]; int tag[MAXN<<2],seg[MAXN<<2],lb[MAXN<<2],rb[MAXN<<2]; int LC,RC; int n,m,l,r,k,col[MAXN],dat[MAXN],head[MAXN],tot,cnt; int sz[MAXN],top[MAXN],f[MAXN][20],dep[MAXN],pos[MAXN]; void add_edge(int x,int y) { e[++tot]=(edge){head[x],y};head[x]=tot; e[++tot]=(edge){head[y],x};head[y]=tot; } void dfs1(int x) { sz[x]=1; for(int i=1;(1<<i)<=dep[x];i++) f[x][i]=f[f[x][i-1]][i-1]; for(int i=head[x];i;i=e[i].nxt) { if(e[i].to==f[x][0]) continue; f[e[i].to][0]=x;dep[e[i].to]=dep[x]+1; dfs1(e[i].to);sz[x]+=sz[e[i].to]; } } void dfs2(int x,int up) { int bs=0; pos[x]=++cnt;top[x]=up; for(int i=head[x];i;i=e[i].nxt) if(e[i].to!=f[x][0]&&sz[e[i].to]>sz[bs]) bs=e[i].to; if(!bs) return; dfs2(bs,up); for(int i=head[x];i;i=e[i].nxt) if(e[i].to!=f[x][0]&&e[i].to!=bs) dfs2(e[i].to,e[i].to); } int LCA(int u,int v) { while(top[u]!=top[v]) { if(dep[top[u]]<dep[top[v]]) swap(u,v); u=f[top[u]][0]; } if(dep[u]>dep[v]) swap(u,v); return u; } void pushup(int k) { lb[k]=lb[k<<1];rb[k]=rb[k<<1|1]; seg[k]=seg[k<<1]+seg[k<<1|1]-(rb[k<<1]==lb[k<<1|1]); } void pushdown(int k) { if(!tag[k]) return; seg[k<<1]=seg[k<<1|1]=1; tag[k<<1]=lb[k<<1]=rb[k<<1]=tag[k]; tag[k<<1|1]=lb[k<<1|1]=rb[k<<1|1]=tag[k]; tag[k]=0; } void build(int k,int l,int r) { if(l==r) {lb[k]=rb[k]=col[l];seg[k]=1;return;} build(lc);build(rc); pushup(k); } void Update(int a,int b,int x,int k,int l,int r) { if(a<=l&&r<=b) {tag[k]=lb[k]=rb[k]=x;seg[k]=1;return;} pushdown(k); if(a<=mid) Update(a,b,x,lc); if(b>mid) Update(a,b,x,rc); pushup(k); } int Query(int a,int b,int k,int l,int r) {//直接记录当次的最左/右处的颜色 if(l==a) LC=lb[k]; if(r==b) RC=rb[k]; if(a<=l&&r<=b) return seg[k]; pushdown(k); int ret=0,cnt=0; if(a<=mid) ret+=Query(a,b,lc),cnt++; if(b>mid) ret+=Query(a,b,rc),cnt++; if(cnt==2) ret-=rb[k<<1]==lb[k<<1|1]; return ret; } void solve(int u,int v,int k) { while(top[u]!=top[v]) { if(dep[top[u]]<dep[top[v]]) swap(u,v); Update(pos[top[u]],pos[u],k,1,1,n); u=f[top[u]][0]; } if(pos[u]>pos[v]) swap(u,v); Update(pos[u],pos[v],k,1,1,n); } int query(int u,int v) { int ret=0,lstu=0,lstv=0; while(top[u]!=top[v]) { if(dep[top[u]]<dep[top[v]]) swap(u,v),swap(lstu,lstv); ret+=Query(pos[top[u]],pos[u],1,1,n); ret-=RC==lstu; u=f[top[u]][0];lstu=LC; } if(pos[u]>pos[v]) swap(u,v),swap(lstu,lstv); ret+=Query(pos[u],pos[v],1,1,n); //注意最后一次要处理两个边界 ret-=RC==lstv;ret-=LC==lstu; return ret; } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) scanf("%d",&dat[i]); for(int i=1;i<n;i++) scanf("%d%d",&l,&r),add_edge(l,r); dfs1(1);dfs2(1,1); for(int i=1;i<=n;i++) col[pos[i]]=dat[i]; build(1,1,n); while(m--) { char s[10]; scanf("%s%d%d",s,&l,&r); if(s[0]=='C') scanf("%d",&k),solve(l,r,k); else printf("%d ",query(l,r)); } return 0; }