Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),
如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
第一眼看上去肯定是树链剖分,然后就是想怎么用线段树维护区间色段。
我们用线段树维护一个区间最左边的颜色,最右边的颜色,和颜色段数。如果一个节点的左儿子的右颜色和右儿子的左颜色相同,那么它的色段数是左+右-1,否则是左+右。
但是在查询时一定要注意,跑完每一条重链,和下一条重链中的轻链时,他们在线段树上并不是一起查询的。我们需要单点找出当前重链的顶端和下一个重链的底端的颜色,如果颜色相同,那么ans-1.
#include <iostream> #include <cstdio> #include <algorithm> #include <cstdlib> #include <cstring> #define in(a) a=read() #define REP(i,k,n) for(int i=k;i<=n;i++) #define MAXN 100010 using namespace std; inline int read(){ int x=0,f=1; char ch=getchar(); for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-1; for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0'; return x*f; } int n,m,a,b,d; char c; int input[MAXN]; int total,head[MAXN],nxt[MAXN<<1],to[MAXN<<1]; int depth[MAXN],size[MAXN],son[MAXN],f[MAXN]; int cnt,dfn[MAXN],top[MAXN],link[MAXN]; struct node{ int l,r,lc,rc,s,lt; }tree[MAXN<<2]; inline void adl(int a,int b){ total++; to[total]=b; nxt[total]=head[a]; head[a]=total; return ; } inline void getson(int u,int fa){//得到重儿子 size[u]=1; for(int e=head[u];e;e=nxt[e]) if(to[e]!=fa){ depth[to[e]]=depth[u]+1; f[to[e]]=u; getson(to[e],u); size[u]+=size[to[e]]; if(!son[u] || size[to[e]]>size[son[u]]) son[u]=to[e]; } return ; } inline void getdfn(int u,int t){//得到重边 top[u]=t; dfn[u]=++cnt; link[cnt]=u; if(!son[u]) return ; getdfn(son[u],t); for(int e=head[u];e;e=nxt[e]) if(to[e]!=f[u] && to[e]!=son[u]) getdfn(to[e],to[e]); return ; } inline void build(int i,int l,int r){//建树 tree[i].l=l; tree[i].r=r; if(l==r){ tree[i].s=1,tree[i].lc=tree[i].rc=input[link[l]]; return ; } int mid=(l+r)>>1; build(i<<1,l,mid); build(i<<1|1,mid+1,r); if(tree[i<<1].rc==tree[i<<1|1].lc) tree[i].s=tree[i<<1].s+tree[i<<1|1].s-1; else tree[i].s=tree[i<<1].s+tree[i<<1|1].s; tree[i].lc=tree[i<<1].lc; tree[i].rc=tree[i<<1|1].rc; } inline void pushdown(int i){//下传懒标记 if(!tree[i].lt) return ; int k=tree[i].lt; tree[i<<1].s=tree[i<<1|1].s=1; tree[i<<1].lc=tree[i<<1].rc=tree[i<<1|1].lc=tree[i<<1|1].rc=k; tree[i<<1].lt=tree[i<<1|1].lt=k; tree[i].lt=0; return ; } inline void add(int i,int l,int r,int k){//修改颜色 if(tree[i].l>=l && tree[i].r<=r){ tree[i].s=1; tree[i].lt=tree[i].lc=tree[i].rc=k; return ; } pushdown(i); if(tree[i<<1].r>=l) add(i<<1,l,r,k); if(tree[i<<1|1].l<=r) add(i<<1|1,l,r,k); if(tree[i<<1].rc==tree[i<<1|1].lc) tree[i].s=tree[i<<1].s+tree[i<<1|1].s-1; else tree[i].s=tree[i<<1].s+tree[i<<1|1].s; tree[i].lc=tree[i<<1].lc; tree[i].rc=tree[i<<1|1].rc; return ; } inline void updates(int x,int y,int z){//枚举两点间每一条重边 int tx=top[x],ty=top[y]; while(tx!=ty){ if(depth[tx]<depth[ty]) swap(tx,ty),swap(x,y); add(1,dfn[tx],dfn[x],z); x=f[tx]; tx=top[x],ty=top[y]; } if(depth[x]<depth[y]) swap(x,y); add(1,dfn[y],dfn[x],z); } inline int query(int i,int l,int r){//区间查询 int sum=0; if(tree[i].l>=l && tree[i].r<=r) return tree[i].s; pushdown(i); if(tree[i<<1].r>=l) sum+=query(i<<1,l,r); if(tree[i<<1|1].l<=r) sum+=query(i<<1|1,l,r); if(tree[i<<1].r>=l && tree[i<<1|1].l<=r && tree[i<<1].rc==tree[i<<1|1].lc) sum--; return sum; } inline int getcolor(int i,int dis){//查询单点颜色 if(tree[i].l==tree[i].r) return tree[i].lc; pushdown(i); int mid=(tree[i].l+tree[i].r)>>1; if(dis<=mid) return getcolor(i<<1,dis); else return getcolor(i<<1|1,dis); } inline int getsum(int x,int y){//枚举查询时两点间的重边 int tx=top[x],ty=top[y],ans=0; while(tx!=ty){ if(depth[tx]<depth[ty]) swap(tx,ty),swap(x,y); ans+=query(1,dfn[tx],dfn[x]); if(getcolor(1,dfn[tx])==getcolor(1,dfn[f[tx]])) ans--;//看轻边两点的颜色是否相同 x=f[tx]; tx=top[x],ty=top[y]; } if(depth[x]<depth[y]) swap(x,y); ans+=query(1,dfn[y],dfn[x]); return ans; } int main(){ in(n),in(m); REP(i,1,n) in(input[i]); REP(i,1,n-1) in(a),in(b),adl(a,b),adl(b,a); depth[1]=1; getson(1,0); getdfn(1,1); build(1,1,n); REP(i,1,m){ cin>>c; if(c=='C') in(a),in(b),in(d),updates(a,b,d); if(c=='Q') in(a),in(b),printf("%d ",getsum(a,b)); } }