Description
给定一棵有(n)个节点的无根树和(m)个操作,操作有(2)类:
- 将节点(a)到节点(b)路径上所有点都染成颜色(c);
- 询问节点(a)到节点(b)路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这(m)个操作。
Input
第一行包含(2)个整数(n)和(m),分别表示节点数和操作数;
第二行包含(n)个正整数表示(n)个节点的初始颜色
下面 行每行包含两个整数(x)和(y),表示(x)和(y)之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点(a)到节点(b)路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点(a)到节点(b)(包括(a)和(b))路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
HINT
数(N leqslant 10^5),操作数(M leqslant 10^5),所有的颜色(C)为整数且在([0, 10^9])之间。
Solution
这个题啊,不过就是14个函数、160行的重工业题吗(雾。
链剖+线段树,详见代码。
话说Azrael_Death今天写了超重工业题猪国杀,%%%orz。
#include<bits/stdc++.h>
using namespace std;
inline int read() {
int x = 0, flag = 1; char ch = getchar();
while (ch > '9' || ch < '0') { if (ch == '-') flag = -1; ch = getchar(); }
while (ch <= '9' && ch >= '0') { x = x * 10 + ch - '0'; ch = getchar(); }
return x * flag;
}
inline void write(int x) { if (x >= 10) write(x / 10); putchar(x % 10 + '0'); }
#define N 500001
#define rep(i, a, b) for (int i = a; i <= b; i++)
#define drp(i, a, b) for (int i = a; i >= b; i--)
#define fech(i, x) for (int i = 0; i < x.size(); i++)
#define ls rt << 1
#define rs ls | 1
int n;
int begCol[N];
vector<int> tr[N];
int Size[N], dep[N], fa[N][18];
int pos[N], belong[N], sz;
struct segmentTree { int l, r, s, lc, rc, tag; }seg[N];
int lca(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
int d = dep[x] - dep[y];
rep(i, 0, 17) if (d & (1 << i)) x = fa[x][i];
drp(i, 17, 0) if (fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
if (x == y) return x;
return fa[x][0];
}
void dfs1(int u, int f) {
Size[u] = 1;
fa[u][0] = f;
rep(i, 1, 17) fa[u][i] = fa[fa[u][i - 1]][i - 1];
fech(i, tr[u]) if (tr[u][i] != f) {
int v = tr[u][i];
dep[v] = dep[u] + 1;
dfs1(v, u);
Size[u] += Size[v];
}
}
void dfs2(int u, int chain) {
pos[u] = ++sz; belong[u] = chain;
int k = 0;
fech(i, tr[u]) if (fa[u][0] != tr[u][i] && Size[tr[u][i]] > Size[k]) k = tr[u][i];
if (!k) return;
dfs2(k, chain);
fech(i, tr[u]) if (tr[u][i] != fa[u][0] && tr[u][i] != k) dfs2(tr[u][i], tr[u][i]);
}
void build(int l, int r, int rt) {
seg[rt].l = l, seg[rt].r = r, seg[rt].s = 1, seg[rt].tag = -1;
if (l == r) return;
int mid = l + r >> 1;
build(l, mid, ls), build(mid + 1, r, rs);
}
void pushDown(int rt) {
int tmp = seg[rt].tag; seg[rt].tag = -1;
if (tmp == -1 || seg[rt].l == seg[rt].r)return;
seg[ls].s = seg[rs].s = 1;
seg[ls].tag = seg[rs].tag = tmp;
seg[ls].lc = seg[ls].rc = tmp;
seg[rs].lc = seg[rs].rc = tmp;
}
void pushUp(int rt) {
seg[rt].lc = seg[ls].lc; seg[rt].rc = seg[rs].rc;
if (seg[ls].rc != seg[rs].lc) seg[rt].s = seg[ls].s + seg[rs].s;
else seg[rt].s = seg[ls].s + seg[rs].s - 1;
}
void change(int x, int y, int v, int rt) {
pushDown(rt);
int l = seg[rt].l, r = seg[rt].r;
if (l == x && r == y) {
seg[rt].lc = seg[rt].rc = v; seg[rt].s = 1; seg[rt].tag = v; return;
}
int mid = (l + r) >> 1;
if (mid >= y) change(x, y, v, ls);
else if (mid<x) change(x, y, v, rs);
else
change(x, mid, v, ls),
change(mid + 1, y, v, rs);
pushUp(rt);
}
int getc(int rt, int x) {
pushDown(rt);
int l = seg[rt].l, r = seg[rt].r;
if (l == r) return seg[rt].lc;
int mid = l + r >> 1;
if (x <= mid) return getc(ls, x);
else return getc(rs, x);
}
int ask(int x, int y, int rt) {
pushDown(rt);
int l = seg[rt].l, r = seg[rt].r;
if (l == x && r == y) return seg[rt].s;
int mid = l + r >> 1;
if (mid >= y) return ask(x, y, ls);
else if (mid < x) return ask(x, y, rs);
else {
int tmp = 1;
if (seg[ls].rc != seg[rs].lc) tmp = 0;
return ask(x, mid, ls) + ask(mid + 1, y, rs) - tmp;
}
}
int solveSum(int x, int f) {
int sum = 0;
while (belong[x] != belong[f]) {
sum += ask(pos[belong[x]], pos[x], 1);
if (getc(1, pos[belong[x]]) == getc(1, pos[fa[belong[x]][0]])) sum--;
x = fa[belong[x]][0];
}
sum += ask(pos[f], pos[x], 1);
return sum;
}
void solveChange(int x, int f, int c) {
while (belong[x] != belong[f])
change(pos[belong[x]], pos[x], c, 1),
x = fa[belong[x]][0];
change(pos[f], pos[x], c, 1);
}
int main() {
cin >> n; int q = read();
rep(i, 1, n) begCol[i] = read();
rep(i, 2, n) {
int u = read(), v = read();
tr[u].push_back(v), tr[v].push_back(u);
}
dfs1(1, 0); dfs2(1, 1);
build(1, n, 1);
rep(i, 1, n) change(pos[i], pos[i], begCol[i], 1);
while (q--) {
char op[10];
scanf("%s", op);
int a = read(), b = read(); int t = lca(a, b);
if (op[0] == 'Q') write(solveSum(a, t) + solveSum(b, t) - 1), puts("");
else {
int c = read();
solveChange(a, t, c), solveChange(b, t, c);
}
}
return 0;
}