还算好想的一个树链剖分+线段树...
修改一段路径的颜色直接区间修改即可。
询问颜色段数量时,线段树的每个节点记录这一段有多少个颜色段。
pushup和query时要检查mid和mid+1是否颜色相同,
例如sum[now] = sum[ls] + sum[rs] - (col[mid]==col[mid+1])
递归路径时要检查top[x]和fa[top[x]](两条路径的交界处)是否颜色相同,但这里需要单独写一个函数getc()来查找这两个点的颜色,
即ans += query(dfn[top[x]],dfn[x]...) - ( getc(dfn[top[x]]...) == getc(dfn[fa[top[x]]]...) )
初始建树时,定义col[l] = w[po[l]]
(po即dfn所对应的原来序号)
modify的时候也需要pushdown一下
当然也可以不用col数组,用lc,rc来记录某点的左右端点的颜色,不过劣势是要开4倍空间。
代码如下
#include<cstdio>
#include<iostream>
#include<cmath>
#include<cstring>
#define MogeKo qwq
using namespace std;
const int maxn = 4e5+10;
int n,m,x,y,z,cnt,num;
char op[10];
int head[maxn],to[maxn],nxt[maxn];
int w[maxn],sum[maxn],col[maxn],lazy[maxn];
int dpth[maxn],siz[maxn],fa[maxn],hson[maxn];
int dfn[maxn],top[maxn],po[maxn];
void add(int x,int y) {
to[++cnt] = y;
nxt[cnt] = head[x];
head[x] = cnt;
}
void dfs1(int u) {
dpth[u] = dpth[fa[u]]+1;
siz[u] = 1;
for(int i = head[u]; i; i = nxt[i]) {
int v = to[i];
if(v == fa[u]) continue;
fa[v] = u;
dfs1(v);
siz[u] += siz[v];
if(siz[v] > siz[hson[u]]) hson[u] = v;
}
}
void dfs2(int u,int t) {
dfn[u] = ++num;
po[num] = u;
top[u] = t;
if(!hson[u]) return;
dfs2(hson[u],t);
for(int i = head[u]; i; i = nxt[i]) {
int v = to[i];
if(v == fa[u] || v == hson[u]) continue;
dfs2(v,v);
}
}
#define sora suki
#define Mid ((l+r)>>1)
#define lson (now<<1)
#define rson (now<<1|1)
void pushup(int mid,int now) {
sum[now] = sum[lson] + sum[rson] - (col[mid] == col[mid+1]);
lazy[now] = 0;
}
void build(int l,int r,int now) {
if(l == r) {
sum[now] = 1;
col[l] = w[po[l]];
return;
}
int mid = Mid;
build(l,mid,lson);
build(mid+1,r,rson);
pushup(mid,now);
}
void pushdown(int mid,int now) {
if(!lazy[now]) return;
lazy[lson] = lazy[rson] = col[mid] = col[mid+1] = lazy[now];
sum[lson] = sum[rson] = 1;
lazy[now] = 0;
}
void modify(int L,int R,int l,int r,int now,int c) {
if(L <= l && r <= R) {
lazy[now] = col[l] = col[r] = c;
sum[now] = 1;
return;
}
int mid = Mid;
pushdown(mid,now);
if(L <= mid) modify(L,R,l,mid,lson,c);
if(R >= mid+1) modify(L,R,mid+1,r,rson,c);
pushup(mid,now);
}
int query(int L,int R,int l,int r,int now) {
if(L <= l && r <= R) return sum[now];
int mid = Mid;
pushdown(mid,now);
if(R <= mid) return query(L,R,l,mid,lson);
else if(L >= mid+1) return query(L,R,mid+1,r,rson);
else return query(L,R,l,mid,lson) + query(L,R,mid+1,r,rson) - (col[mid] == col[mid+1]);
}
int getcol(int l,int r,int now,int k) {
if(lazy[now]) return lazy[now];
int mid = Mid;
pushdown(mid,now);
if(l == r) return col[l];
if(k <= mid) return getcol(l,mid,lson,k);
else if(k >= mid+1) return getcol(mid+1,r,rson,k);
}
int getmodify(int x,int y,int c) {
while(top[x] != top[y]) {
if(dpth[top[x]] < dpth[top[y]]) swap(x,y);
modify(dfn[top[x]],dfn[x],1,n,1,c);
x = fa[top[x]];
}
if(dpth[x] > dpth[y]) swap(x,y);
modify(dfn[x],dfn[y],1,n,1,c);
}
int getquery(int x,int y) {
int ans = 0;
while(top[x] != top[y]) {
if(dpth[top[x]] < dpth[top[y]]) swap(x,y);
ans += query(dfn[top[x]],dfn[x],1,n,1);
ans -= (getcol(1,n,1,dfn[top[x]]) == getcol(1,n,1,dfn[fa[top[x]]]));
x = fa[top[x]];
}
if(dpth[x] > dpth[y]) swap(x,y);
ans += query(dfn[x],dfn[y],1,n,1);
return ans;
}
int main() {
scanf("%d%d",&n,&m);
for(int i = 1; i <= n; i++)
scanf("%d",&w[i]);
for(int i = 1; i <= n-1; i++) {
scanf("%d%d",&x,&y);
add(x,y), add(y,x);
}
dfs1(1);
dfs2(1,1);
build(1,n,1);
// for(int i = 1; i <= n; i++)
// col[i] = w[po[i]];
while(m--) {
scanf("%s",op);
scanf("%d%d",&x,&y);
if(op[0] == 'C') {
scanf("%d",&z);
getmodify(x,y,z);
}
if(op[0] == 'Q')
printf("%d
",getquery(x,y));
}
return 0;
}