题目链接:戳我
其实也是一个比较常规的树链剖分的题目,主要不同是多记录一个区间的左端颜色,右端颜色,如果左右区间颜色相同就-1.
update:因为还有一个树链剖分,所以还要注意一下,上次划分出来的区域的左端点和当前处理区间的右端点是否颜色一样qwq,一样的话要-1(所以需要记录一下左端点的颜色)
至于为什么要记录的是左端点呢?因为我们每次从深度大往上跳,在线段树中,深度越大,区间编号越大.
代码如下:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define MAXN 400010
using namespace std;
int n,m,tt,tot;
int top[MAXN],fa[MAXN],dep[MAXN],son[MAXN],siz[MAXN],id[MAXN],a[MAXN];
int head[MAXN<<1],w[MAXN];
struct Node{int lc,rc,tag,tot,l,r;}t[MAXN<<2];
struct Edge{int nxt,to;}edge[MAXN<<1];
inline int ls(int x){return x<<1;}
inline int rs(int x){return x<<1|1;}
inline void add(int from,int to){edge[++tt].nxt=head[from],edge[tt].to=to,head[from]=tt;}
inline void dfs1(int x,int pre)
{
fa[x]=pre;
siz[x]=1;
dep[x]=dep[pre]+1;
int maxx=-1;
for(int i=head[x];i;i=edge[i].nxt)
{
int v=edge[i].to;
if(v==fa[x]) continue;
dfs1(v,x);
siz[x]+=siz[v];
if(siz[v]>maxx) maxx=siz[v],son[x]=v;
}
}
inline void dfs2(int x,int topf)
{
top[x]=topf;
id[x]=++tot;
a[tot]=w[x];
if(son[x]) dfs2(son[x],topf);
for(int i=head[x];i;i=edge[i].nxt)
{
int v=edge[i].to;
if(v==fa[x]||v==son[x]) continue;
dfs2(v,v);
}
}
inline void push_up(int x)
{
t[x].lc=t[ls(x)].lc;
t[x].rc=t[rs(x)].rc;
if(t[ls(x)].rc==t[rs(x)].lc) t[x].tot=t[ls(x)].tot+t[rs(x)].tot-1;
else t[x].tot=t[ls(x)].tot+t[rs(x)].tot;
}
inline void push_down(int x)
{
if(t[x].tag)
{
t[ls(x)].tag=t[rs(x)].tag=t[x].tag;
t[ls(x)].lc=t[ls(x)].rc=t[x].tag;
t[rs(x)].lc=t[rs(x)].rc=t[x].tag;
t[ls(x)].tot=t[rs(x)].tot=1;
t[x].tag=0;
}
}
inline void build(int x,int l,int r)
{
t[x].l=l;t[x].r=r;
if(l==r)
{
t[x].lc=t[x].rc=a[l];
t[x].tot=1;
return;
}
int mid=(l+r)>>1;
build(ls(x),l,mid);
build(rs(x),mid+1,r);
push_up(x);
}
inline void update(int x,int ll,int rr,int k)
{
int l=t[x].l,r=t[x].r;
if(l==ll&&r==rr)
{
t[x].lc=t[x].rc=t[x].tag=k;
t[x].tot=1;
return;
}
int mid=(l+r)>>1;
push_down(x);
if(rr<=mid) update(ls(x),ll,rr,k);
else if(ll>mid) update(rs(x),ll,rr,k);
else update(ls(x),ll,mid,k),update(rs(x),mid+1,rr,k);
push_up(x);
}
inline int query(int x,int ll,int rr)
{
int l=t[x].l,r=t[x].r;
if(l==ll&&r==rr) return t[x].tot;
int mid=(l+r)>>1;
push_down(x);
if(rr<=mid) return query(ls(x),ll,rr);
else if(mid<ll) return query(rs(x),ll,rr);
else
{
if(t[ls(x)].rc==t[rs(x)].lc)
return query(ls(x),ll,mid)+query(rs(x),mid+1,rr)-1;
else return query(ls(x),ll,mid)+query(rs(x),mid+1,rr);
}
}
inline int calc(int x,int pos)
{
int l=t[x].l,r=t[x].r;
int mid=(l+r)>>1;
push_down(x);
if(l==r) return t[x].lc;
if(pos<=mid) return calc(ls(x),pos);
else return calc(rs(x),pos);
}
inline void UPDATE(int x,int y,int k)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
update(1,id[top[x]],id[x],k);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
update(1,id[x],id[y],k);
}
inline int QUERY(int x,int y)
{
int cur_ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
int tmp=0;
if(calc(1,id[top[x]])==calc(1,id[fa[top[x]]])) tmp=1;
cur_ans+=query(1,id[top[x]],id[x])-tmp;
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
cur_ans+=query(1,id[x],id[y]);
return cur_ans;
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("ce.in","r",stdin);
freopen("ce.out","w",stdout);
#endif
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&w[i]);
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
dfs1(1,1);
dfs2(1,1);
build(1,1,n);
for(int i=1;i<=m;i++)
{
char op;
int a,b,c;
cin>>op;
if(op=='C')
{
scanf("%d%d%d",&a,&b,&c);
UPDATE(a,b,c);
}
else
{
scanf("%d%d",&a,&b);
printf("%d
",QUERY(a,b));
}
}
return 0;
}