题目大意
给你一棵(n)个点的树,每个点有一种颜色;现在有(m)个询问,每次询问你(x)到(y)的路径上,若将(a)颜色视作(b)颜色,不同的颜色有几种。
(nleq 50000,mleq 100000)
分析
如果是把问题放到序列上:询问区间([l,r])不同的颜色有几种。这个问题有两个已知的解法:
- 主席树(传送门)
- 莫队
看这题的数据范围显然是让你莫队了。(雾
树上莫队的第一步,是把树上问题转换为序列问题。我们求出原树的欧拉序,可以发现这个序列有这样的性质:
将一个点在欧拉序中首次出现和第二次出现的位置分别记作(fir_u)和(las_u),对于一条路径((x,y))(假定(fir_x<fir_y))。
若(lca(x,y)=x),那么这条路径对应欧拉序中的区间([fir_x,fir_y])。但是区间中出现两次的点要去掉,因为它们不属于这条路径。
若(lca(x,y) eq x),那么这条路径对应欧拉序中的区间([las_x,fir_y])。同样的要去掉出现两次的点,并且这个区间没有包括上(lca),要将(lca)再单独统计。
这样,树上问题就变成了序列问题。
为了不计算出现两次的点,我们开个标记数组,一个点每次出现,都把标记数组对应位置异或(1),那么一个点在标记数组中的值为(1)时才能被计算,当一个点对应的值变为(0)时又把它的贡献删去,这样问题便迎刃而解。再注意计算(lca)的答案即可。关于将(a)颜色视作(b)颜色的,只需判断区间中是否同时有(a)颜色和(b)颜色,有的话答案减(1),注意(a=b)要特判,不然要炸!
Code
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 200007;
int n, m, col[N], ans[N], ord[N];
int tot, dfn, st[N], to[N << 1], nx[N << 1], fir[N], las[N], anc[N][17], dep[N];
void add(int u, int v) { to[++tot] = v, nx[tot] = st[u], st[u] = tot; }
void dfs(int u)
{
fir[u] = ++dfn, ord[dfn] = u;
for (int i = st[u]; i; i = nx[i]) if (!fir[to[i]]) anc[to[i]][0] = u, dep[to[i]] = dep[u] + 1, dfs(to[i]);
las[u] = ++dfn, ord[dfn] = u;
}
int getlca(int u, int v)
{
if (dep[u] < dep[v]) swap(u, v);
for (int i = 16; i >= 0; i--) if (dep[anc[u][i]] >= dep[v]) u = anc[u][i];
if (u == v) return u;
for (int i = 16; i >= 0; i--) if (anc[u][i] != anc[v][i]) u = anc[u][i], v = anc[v][i];
return anc[u][0];
}
int block, ret, be[N], tag[N], buc[N];
struct note { int l, r, id, a, b, lca; } q[N];
int cmp(note a, note b) { return be[a.l] == be[b.l] ? ((be[a.l] & 1) ? a.r < b.r : a.r > b.r) : a.l < b.l; }
void ins(int c, int v)
{
if (v == 1) { if (!buc[c]) ret++; buc[c]++; }
else { buc[c]--; if (!buc[c]) ret--; }
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &col[i]);
for (int i = 1, u, v; i <= n; i++)
{
scanf("%d%d", &u, &v);
if (u && v) add(u, v), add(v, u);
}
dep[1] = 1, dfs(1);
for (int j = 1; j <= 16; j++) for (int i = 1; i <= n; i++) anc[i][j] = anc[anc[i][j - 1]][j - 1];
block = sqrt(2 * n);
for (int i = 1; i <= 2 * n; i++) be[i] = i / block + 1;
for (int i = 1, x, y, a, b, lca; i <= m; i++)
{
scanf("%d%d%d%d", &x, &y, &a, &b);
if (fir[x] > fir[y]) swap(x, y);
lca = getlca(x, y);
if (lca == x) q[i] = (note){fir[x], fir[y], i, a, b, 0};
else q[i] = (note){las[x], fir[y], i, a, b, lca};
}
sort(q + 1, q + m + 1, cmp);
for (int i = 1, l = 1, r = 0; i <= m; i++)
{
while (l < q[i].l) tag[ord[l]] ^= 1, ins(col[ord[l]], tag[ord[l]]), ++l;
while (l > q[i].l) --l, tag[ord[l]] ^= 1, ins(col[ord[l]], tag[ord[l]]);
while (r < q[i].r) ++r, tag[ord[r]] ^= 1, ins(col[ord[r]], tag[ord[r]]);
while (r > q[i].r) tag[ord[r]] ^= 1, ins(col[ord[r]], tag[ord[r]]), --r;
if (q[i].lca) ins(col[q[i].lca], 1);
ans[q[i].id] = ret;
if (q[i].a != q[i].b && buc[q[i].a] && buc[q[i].b]) ans[q[i].id]--;
if (q[i].lca) ins(col[q[i].lca], 0);
}
for (int i = 1; i <= m; i++) printf("%d
", ans[i]);
return 0;
}