题目描述
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
输入
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
输出
对于每个询问操作,输出一行答案。
样例输入
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
样例输出
3
1
2
提示
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间
题解
裸的树链剖分+线段树。
区间修改非常恶心,很多细节。
多写写应该就能好了吧。。。
#include <stdio.h> #include <algorithm> using namespace std; #define lson l , mid , x << 1 #define rson mid + 1 , r , x << 1 | 1 #define N 100005 int fa[N] , deep[N] , si[N] , val[N] , bl[N] , pos[N] , tot; int head[N] , to[N << 1] , next[N << 1] , cnt; int sum[N << 2] , lc[N << 2] , rc[N << 2] , mark[N << 2] , n; char str[10]; void add(int x , int y) { to[++cnt] = y; next[cnt] = head[x]; head[x] = cnt; } void dfs1(int x) { int i , y; si[x] = 1; for(i = head[x] ; i ; i = next[i]) { y = to[i]; if(y != fa[x]) { fa[y] = x; deep[y] = deep[x] + 1; dfs1(y); si[x] += si[y]; } } } void dfs2(int x , int c) { int k = 0 , i , y; bl[x] = c; pos[x] = ++tot; for(i = head[x] ; i ; i = next[i]) { y = to[i]; if(fa[x] != y && si[y] > si[k]) k = y; } if(k != 0) { dfs2(k , c); for(i = head[x] ; i ; i = next[i]) { y = to[i]; if(fa[x] != y && y != k) dfs2(y , y); } } } void pushup(int x) { lc[x] = lc[x << 1]; rc[x] = rc[x << 1 | 1]; sum[x] = sum[x << 1] + sum[x << 1 | 1]; if(rc[x << 1] == lc[x << 1 | 1]) sum[x] -- ; } void pushdown(int x) { int tmp = mark[x]; mark[x] = 0; if(tmp) { sum[x << 1] = sum[x << 1 | 1] = 1; lc[x << 1] = rc[x << 1] = lc[x << 1 | 1] = rc[x << 1 | 1] = tmp; mark[x << 1] = mark[x << 1 | 1] = tmp; } } void update(int b , int e , int v , int l , int r , int x) { if(b <= l && r <= e) { sum[x] = 1; lc[x] = rc[x] = v; mark[x] = v; return; } pushdown(x); int mid = (l + r) >> 1; if(b <= mid) update(b , e , v , lson); if(e > mid) update(b , e , v , rson); pushup(x); } void solveupdate(int x , int y , int v) { while(bl[x] != bl[y]) { if(deep[bl[x]] < deep[bl[y]]) { swap(x , y); } update(pos[bl[x]] , pos[x] , v , 1 , n , 1); x = fa[bl[x]]; } if(deep[x] > deep[y]) swap(x , y); update(pos[x] , pos[y] , v , 1 , n , 1); } int query(int b , int e , int l , int r , int x) { if(b <= l && r <= e) { return sum[x]; } pushdown(x); int mid = (l + r) >> 1 , ans = 0; if(b <= mid) ans += query(b , e , lson); if(e > mid) ans += query(b , e , rson); if(b <= mid && e > mid && rc[x << 1] == lc[x << 1 | 1]) ans -- ; return ans; } int getcl(int p , int l , int r , int x) { if(l == r) return lc[x]; pushdown(x); int mid = (l + r) >> 1; if(p <= mid) return getcl(p , lson); else return getcl(p , rson); } int solvequery(int x , int y) { int ans = 0; while(bl[x] != bl[y]) { if(deep[bl[x]] < deep[bl[y]]) swap(x , y); ans += query(pos[bl[x]] , pos[x] , 1 , n , 1); if(getcl(pos[bl[x]] , 1 , n , 1) == getcl(pos[fa[bl[x]]] , 1 , n , 1)) ans -- ; x = fa[bl[x]]; } if(deep[x] > deep[y]) swap(x , y); ans += query(pos[x] , pos[y] , 1 , n , 1); return ans; } int main() { int i , x , y , z , m; scanf("%d%d" , &n , &m); for(i = 1 ; i <= n ; i ++ ) scanf("%d" , &val[i]); for(i = 1 ; i < n ; i ++ ) { scanf("%d%d" , &x , &y); add(x , y); add(y , x); } dfs1(1); dfs2(1 , 1); for(i = 1 ; i <= n ; i ++ ) update(pos[i] , pos[i] , val[i] , 1 , n , 1); while(m -- ) { scanf("%s" , str); switch(str[0]) { case 'C': scanf("%d%d%d" , &x , &y , &z); solveupdate(x , y , z); break; default: scanf("%d%d" , &x , &y); printf("%d " , solvequery(x , y)); } } return 0; }