SP10707 COT2 - Count on a tree II - 洛谷
- 给定 (n) 个结点的树,每个结点有一种颜色。
- (m) 次询问,每次询问给出 (u,v),回答 (u,v) 之间的路径上的结点的不同颜色数。
很显然,如果是在序列上的化就是最基础的莫队模板题
考虑转化到序列上
欧拉序列 即 (DFS) 序
可以解决这个问题
对于询问 ((x,y)) 让 (x) 是在 (DFS) 中先出现的那个
分为两种情况:
- (x == lca(x,y)) , 则对应的区间是 ([in(x),in(y)]) 中出现次数为一的点 ,如上图红色询问
- (x eq lca(x,y)) , 则对应区间是 ([out(x), in(y)]) 中出现次数为一的点,加上 (lca)
对于翻转操作,可以用一个数组 (st) 进行标识,增添和删除其实都是一样的
/*
* @Author: zhl
* @Date: 2020-11-19 15:34:09
*/
#include<bits/stdc++.h>
#pragma GCC optimize(3)
#define pb(x) push_back(x)
using namespace std;
const int N = 1e5 + 10;
vector<int>G[N];
int n, m, len, ID[N], ans[N];
int w[N], nums[N], in[N], out[N], ord[N];
int tot, dep[N], f[N][32], cnt[N], st[N];
void add(int x, int& res) {
st[x] ^= 1;
if (!st[x]) {
if (!--cnt[w[x]])res--;
}
else {
if (!cnt[w[x]]++)res++;
}
}
void dfs(int u, int p) {
in[u] = ++tot;
ord[tot] = u;
dep[u] = dep[p] + 1; f[u][0] = p;
for (int x = 1; (1 << x) < dep[u]; x++) {
f[u][x] = f[f[u][x - 1]][x - 1];
}
for (int v : G[u]) {
if (v == p)continue;
dfs(v, u);
}
out[u] = ++tot;
ord[tot] = u;
}
int LCA(int x, int y) {
if (dep[x] < dep[y])swap(x, y);
while (dep[x] != dep[y]) {
int u = dep[x] - dep[y];
int v = 0;
while (!(u & (1 << v)))v++;
x = f[x][v];
}
while (x != y) {
int v = 0;
while (f[x][v] != f[y][v])v++;
x = f[x][max(0, v - 1)]; y = f[y][max(0, v - 1)];
}
return x;
}
struct Query {
int id, l, r, lca;
bool operator < (const Query& b)const {
if (ID[l] != ID[b.l])return ID[l] < ID[b.l];
return r < b.r;
}
}q[N];
int main() {
scanf("%d%d", &n, &m); int numID = 0;
for (int i = 1; i <= n; i++)scanf("%d", w + i), nums[++numID] = w[i];
sort(nums + 1, nums + 1 + n);
numID = unique(nums + 1, nums + 1 + n) - nums - 1;
for (int i = 1; i <= n; i++)w[i] = lower_bound(nums + 1, nums + 1 + numID, w[i]) - nums;
for (int i = 1; i < n; i++) {
int a, b; scanf("%d%d", &a, &b);
G[a].pb(b); G[b].pb(a);
}
dfs(1, 0);
len = sqrt(tot);
for (int i = 1; i <= tot; i++)ID[i] = i / len;
for (int i = 1; i <= m; i++) {
int x, y; scanf("%d%d", &x, &y);
int lca = LCA(x, y);
if (in[x] > in[y])swap(x, y);
if (lca == x) {
q[i] = { i,in[x],in[y],0 };
}
else {
q[i] = { i,out[x],in[y],lca };
}
}
sort(q + 1, q + 1 + m);
int res = 0, l = 1, r = 0;
for (int i = 1; i <= m; i++) {
int ql = q[i].l, qr = q[i].r, lca = q[i].lca;
if (lca != 0)add(lca, res);
while (l < ql)add(ord[l++], res);
while (l > ql)add(ord[--l], res);
while (r < qr)add(ord[++r], res);
while (r > qr)add(ord[r--], res);
ans[q[i].id] = res;
if (lca != 0)add(lca, res);
}
for (int i = 1; i <= m; i++)printf("%d
", ans[i]);
}