题目链接:##
题目分析:##
树链剖分经典题
首先考虑在序列上如何维护染色与色段个数,可以很容易想到每次pushup时将左子和右子的个数合并上来,并判断一下中间是否是同一段
在树上同理,对每个重链建线段树并如上所述维护,注意在查找答案时跳每段重链要记录端点颜色来比较,判一下两端端点颜色是否相同
也可用LCT做,还没学,学完回来补档。
写树剖细节很多,尽量考虑全面。
代码:##
//记录线段树区间左端点右端点颜色和每段重链的头尾颜色
#include<bits/stdc++.h>
#define N (100000 + 5)
using namespace std;
inline int read() {
int cnt = 0, f = 1; char c;
c = getchar();
while(!isdigit(c)) {
if (c == '-') f = -f;
c = getchar();
}
while(isdigit(c)) {
cnt = cnt * 10 + c - '0';
c = getchar();
}
return cnt * f;
}
int n, m, a[N];
int nxt[N<<1], first[N], to[N<<1], tot, num[N], id[N], top[N], idx = 0, father[N], siz[N], dep[N], son[N];
int ans1, ans2, Lc, Rc;
int x, y, z; char opr = '0';
struct node{
int l, r;
long long tag, dat, lc, rc;
#define l(p) tree[p].l
#define r(p) tree[p].r
#define dat(p) tree[p].dat
#define tag(p) tree[p].tag
#define lc(p) tree[p].lc
#define rc(p) tree[p].rc
} tree[N * 4];
void add(int x, int y) {
nxt[++tot] = first[x];
first[x] = tot;
to[tot] = y;
}
void pushdown(int p);
void pushup(int p);
void debug(int p, int l, int r) {
if(l == r) {
//cout<<id[l]<<" "<<lc(l)<<endl;
cout<<lc(l)<<" ";
return;
}
pushdown(p);
int mid = (l + r) >> 1;
debug(p << 1, l, mid);
debug(p << 1 | 1, mid + 1, r);
pushup(p);
}
void dfs1(int cur, int fa) {
father[cur] = fa, siz[cur] = 1, dep[cur] = dep[fa] + 1;
for (register int i = first[cur]; i; i = nxt[i]) {
int v = to[i];
if (v != fa) {
dfs1(v, cur);
siz[cur] += siz[v];
if (siz[son[cur]] < siz[v]) son[cur] = v;
}
}
}
void dfs2(int cur, int tp) {
top[cur] = tp; num[cur] = ++idx;
id[idx] = cur;
if(son[cur]) dfs2(son[cur], tp);
for (register int i = first[cur]; i; i = nxt[i]) {
int v = to[i];
if (!num[v]) dfs2(v, v);
}
}
void pushup(int p) {
lc(p) = lc(p << 1), rc(p) = rc(p << 1 | 1);
dat(p) = dat(p << 1) + dat(p << 1 | 1);
if(rc(p << 1) == lc(p << 1 | 1)) --dat(p);
}
void pushdown(int p) {
if(tag(p)) {
tag(p << 1) = tag(p << 1 | 1) = tag(p);
lc(p << 1) = lc(p << 1 | 1) = rc(p << 1) = rc(p << 1 | 1) = tag(p);
dat(p << 1) = dat(p << 1 | 1) = 1;
tag(p) = 0;
}
}
void build_tree(int p, int l, int r) {
l(p) = l, r(p) = r;
if (l == r) {
lc(p) = rc(p) = a[id[l]];
dat(p) = 1;
return;
}
int mid = (l + r) >> 1;
build_tree(p << 1, l, mid);
build_tree(p << 1 | 1, mid + 1, r);
pushup(p);
}
void modify(int p, int l, int r, int d) {
if (l <= l(p) && r >= r(p)) {
tag(p) = d;
lc(p) = rc(p) = d;
dat(p) = 1;
return;
}
pushdown(p);
int mid = (l(p) + r(p)) >> 1;
if(l <= mid) modify(p << 1, l, r, d);
if(r > mid) modify(p << 1 | 1, l, r, d);
pushup(p);
}
long long query(int p, int l, int r) {
// cout<<"p= "<<p<<" l(p)= "<<l(p)<<" r(p)= "<<r(p)<<endl;
if (l(p) == l) Lc = lc(p);
if (r(p) == r) Rc = rc(p);
// if(l(p)>r||r(p)<l)return 0;
if (l <= l(p) && r >= r(p)) {
// cout<<" dat(p)= "<<dat(p)<<" "<<l<<" "<<r<<endl;
return dat(p);
}
pushdown(p);
long long val = 0;
int mid = (l(p) + r(p)) >> 1;
if (l > mid) val += query(p << 1 | 1, l, r);
else if (r <= mid) val += query(p << 1, l, r);
else {
// cout<<val<<endl;
if (rc(p << 1) == lc(p << 1 | 1)) val += query(p << 1, l, r) + query(p << 1 | 1, l, r) - 1;
else val += query(p << 1, l, r) + query(p << 1 | 1, l, r);
}
return val;
}
void chain_modify(int u, int v, int d) {
while(top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
modify(1, num[top[u]], num[u], d);
u = father[top[u]];
}
if (dep[u] < dep[v]) swap(u, v);
modify(1, num[v], num[u], d);
}
long long chain_query(int u, int v) {
long long ans = 0;
ans1 = ans2 = 0;
while(top[u] != top[v]) {
// cout<<"chain_query_start:"<<u<<" "<<v<<endl;
if (dep[top[u]] < dep[top[v]]) swap(u, v), swap(ans1, ans2);
ans += query(1, num[top[u]], num[u]);
// cout<<"fuck "<<u<<" "<<top[u]<<endl;
if (ans1 == Rc) ans--;
ans1 = Lc;
u = father[top[u]];
}
if (dep[u] < dep[v]) swap(u, v), swap(ans1, ans2);
ans += query(1, num[v], num[u]);
if (Rc == ans1) ans--;
if (Lc == ans2) ans--;
return ans;
}
void solve() {
n = read(); m = read();
for (register int i = 1; i <= n; i++) a[i] = read() + 1;
for (register int i = 1; i < n; i++) {
x = read(); y = read();
add(x, y); add(y, x);
}
dfs1(1, 0); dfs2(1, 1);
//for(int i=1; i<=n; i++) cout<<father[i]<<" "<<top[i]<<" "<<dep[i]<<" "<<num[i]<<endl;
build_tree(1, 1, n);
// cout<<"#"<<endl;
for (register int i = 1; i <= m; i++) {
cin >> opr;
x = read(); y = read();
if (opr == 'C') {
z = read();
// cout<<"fffffasfasf"<<endl;
// debug(1, 1, n);cout<<endl;
chain_modify(x, y, z + 1);
// debug(1, 1, n);
// cout<<endl;
}
if (opr == 'Q') {
long long res = chain_query(x, y);
printf("%lld
", res);
}
}
}
int main() {
//freopen("input.in","r",stdin);
solve();
return 0;
}