题目大意:给定一棵无根树,每个点有一个颜色,有2种操作:1.将某两点之间路径上的点(包括该两点)的颜色修改为某个颜色。2.询问某两个点之间路径上的连续颜色段的数量。如果一条路径上的点的颜色依次为“1 2 2 3 3 4”,则颜色段的数量为4。如果一条路径上的点的颜色依次为“1 1 2 2 2 1”,则颜色段的数量为3,而不是2。
做法:树链剖分的写法没有什么太大区别,需要注意的是用线段树维护时的操作。对于一个区间,维护4个值:l,r,sum,cov。l表示该区间最左边的颜色段的颜色,r表示该区间最右边的颜色段的颜色,sum表示该区间内的颜色段的数量,cov为标记,如果其不为-1,则表示该区间内的所有点颜色都为cov的值。在pushup时,对于一个区间的左右两个子区间,这个区间的l为它的左子区间的l,它的r为它的右子区间的r,它的sum为它的左右区间的sum的和,特殊地,如果它左子区间的r和它右子区间的l相等,即表示左子区间的最右边的颜色段和右子区间的最左边的颜色段合成了一个颜色段,这时该区间的sum要减去1。
特别要注意的是询问时的处理。我们知道,在查询a,b两点间路径上的点时,要分别从a,b两点向根部扩展,如果把这条路径压成一条链,就可以看出它实际上是从路径的两端逐渐向内扩展。我们用alast和blast分别表示从a,b开始扩展到的最靠近根部的点的颜色。又由于我们把这条路径分成几条链来进行询问,那么每询问一条链,用fir和last分别记录该链在线段树中编号最前和最后的点的颜色,求出这个区间的颜色段的数量,再判断特殊的合并情况,最后就可以得到正确的结果了。
(本来我以为这么复杂要调几个小时的,结果调了一下交上去一次就过了,开心^_^)
以下是本人代码:
#include <cstdio>
#include <cstdlib>
#include <iostream>
#include <algorithm>
using namespace std;
int n,m,tot,c[100010],first[100010],fa[100010],dep[100010],siz[100010],son[100010];
int top[100010],pos[100010],fp[100010],alast,blast,fir,last;
struct {int v,next;} e[200010];
struct {int l,r,sum,cov;} seg[400010];
void insert(int a,int b)
{
e[++tot].v=b;
e[tot].next=first[a];
first[a]=tot;
}
void dfs1(int v)
{
siz[v]=1;son[v]=0;
for(int i=first[v];i;i=e[i].next)
if (e[i].v!=fa[v])
{
fa[e[i].v]=v;
dep[e[i].v]=dep[v]+1;
dfs1(e[i].v);
if (siz[e[i].v]>siz[son[v]]) son[v]=e[i].v;
siz[v]+=siz[e[i].v];
}
}
void dfs2(int v,int chain)
{
top[v]=chain;pos[v]=++tot;
if (son[v]) dfs2(son[v],chain);
for(int i=first[v];i;i=e[i].next)
if (e[i].v!=fa[v]&&e[i].v!=son[v])
dfs2(e[i].v,e[i].v);
}
void pushdown(int no)
{
if (seg[no].cov)
{
seg[no<<1].cov=seg[(no<<1)+1].cov=1;
seg[no<<1].l=seg[no<<1].r=seg[(no<<1)+1].l=seg[(no<<1)+1].r=seg[no].l;
seg[no<<1].sum=seg[(no<<1)+1].sum=1;
seg[no].cov=0;
}
}
void pushup(int no)
{
seg[no].l=seg[no<<1].l;
seg[no].r=seg[(no<<1)+1].r;
seg[no].sum=seg[no<<1].sum+seg[(no<<1)+1].sum;
if (seg[no<<1].r==seg[(no<<1)+1].l) seg[no].sum--;
}
void buildtree(int no,int l,int r)
{
int mid=(l+r)>>1;
seg[no].cov=0;
if (l==r) {seg[no].sum=1;seg[no].l=seg[no].r=c[fp[l]];return;}
buildtree(no<<1,l,mid);
buildtree((no<<1)+1,mid+1,r);
pushup(no);
}
void segcov(int no,int l,int r,int s,int t,int w)
{
int mid=(l+r)>>1;
if (l>=s&&r<=t)
{
seg[no].l=seg[no].r=w;
seg[no].sum=1;
seg[no].cov=1;
return;
}
pushdown(no);
if (s<=mid) segcov(no<<1,l,mid,s,t,w);
if (t>mid) segcov((no<<1)+1,mid+1,r,s,t,w);
pushup(no);
}
int getsum(int no,int l,int r,int s,int t)
{
int mid=(l+r)>>1;
if (l>=s&&r<=t)
{
int k=0;
if (fir==-1) fir=seg[no].l;
if (seg[no].l==last) k=-1;
last=seg[no].r;
return seg[no].sum+k;
}
int sum=0;
pushdown(no);
if (s<=mid) sum+=getsum(no<<1,l,mid,s,t);
if (t>mid) sum+=getsum((no<<1)+1,mid+1,r,s,t);
pushup(no);
return sum;
}
void cover(int a,int b,int w)
{
while(top[a]!=top[b])
{
if (dep[top[a]]<dep[top[b]]) swap(a,b);
segcov(1,1,n,pos[top[a]],pos[a],w);
a=fa[top[a]];
}
if (dep[a]>dep[b]) swap(a,b);
segcov(1,1,n,pos[a],pos[b],w);
}
int query(int a,int b)
{
int ans=0;
alast=blast=-1;
while(top[a]!=top[b])
{
fir=last=-1;
if (dep[top[a]]<dep[top[b]])
{
ans+=getsum(1,1,n,pos[top[b]],pos[b]);
if (blast==last) ans--;
b=fa[top[b]];
blast=fir;
}
else
{
ans+=getsum(1,1,n,pos[top[a]],pos[a]);
if (alast==last) ans--;
a=fa[top[a]];
alast=fir;
}
}
fir=last=-1;
if (dep[a]>dep[b])
{
ans+=getsum(1,1,n,pos[b],pos[a]);
if (last==alast) ans--;
if (fir==blast) ans--;
}
else
{
ans+=getsum(1,1,n,pos[a],pos[b]);
if (fir==alast) ans--;
if (last==blast) ans--;
}
return ans;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&c[i]);
siz[0]=fa[1]=dep[1]=tot=0;
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
insert(a,b);insert(b,a);
}
dfs1(1);
tot=0;dfs2(1,1);
for(int i=1;i<=n;i++) fp[pos[i]]=i;
buildtree(1,1,n);
for(int i=1;i<=m;i++)
{
char op=' ';
while(op<'A'||op>'Z') scanf("%c",&op);
if (op=='C')
{
int a,b,w;
scanf("%d%d%d",&a,&b,&w);
cover(a,b,w);
}
if (op=='Q')
{
int a,b;
scanf("%d%d",&a,&b);
printf("%d
",query(a,b));
}
}
return 0;
}