Count on a tree
给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
Input
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
Output
M行,表示每个询问的答案。最后一个询问不输出换行符
Sample Input
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
Sample Output
2
8
9
105
7
Hint
N,M<=100000
标签:DFS序+LCA+值域主席树
这时一个区间第k小的问题,所以可以很自然的想到值域主席树。但是此题将区间移到了树上,在树上套线段树,可以想到DFS序和树链剖分。此题应该是DFS序。
在解区间第k小的时候,对于每次询问区间[a,b],我们需要找到a-1位置的线段树和b位置的线段树,然后递归query的时候用个数相减。对于这道题,我们把每个结点到根的那条链作为一个序列,用区间第k小的方法存储,然后找到u和v的LCA(假定它为t),递归query的时候计算左区间数的个数,即u结点对应线段树左区间数的个数+v结点....数的个数-t结点...数的个数-t的父结点...数的个数。即tmp = tr[tr[u].ls].val+tr[tr[v].ls].val-tr[tr[t].ls].val-tr[tr[fa[t]].ls].val。
写的时候注意强制在线的操作方式和读入数后先离散化。
最后附上AC代码:
#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>
#define MAX_N 100000
using namespace std;
int n, m, c[MAX_N+5];
int cnt = 0, root[MAX_N+5];
int tot = 0, map[MAX_N+5];
int anc[MAX_N+5][25], dep[MAX_N+5];
bool vis[MAX_N+500];
vector <int> G[MAX_N+500];
struct Pre {int id, val;} pre[MAX_N+5];
struct TNode {int ls, rs, val;} tr[MAX_N*32+500];
bool cmp(const Pre &a, const Pre &b) {return a.val < b.val;}
void DFS(int u) {
vis[u] = true;
for (int i = 1; (1<<i) <= dep[u]; i++) anc[u][i] = anc[anc[u][i-1]][i-1];
for (int i = 0; i < G[u].size(); i++) {
int v = G[u][i];
if (!vis[v]) anc[v][0] = u, dep[v] = dep[u]+1, DFS(v);
}
}
int LCA(int a, int b) {
int i, j;
if (dep[a] < dep[b]) swap(a, b);
for (i = 0; (1<<i) <= dep[a]; i++) ; i--;
for (j = i; j >= 0; j--)
if (dep[a]-(1<<j) >= dep[b])
a = anc[a][j];
if (a == b) return a;
for (j = i; j >= 0; j--)
if (anc[a][j] != anc[b][j])
a = anc[a][j], b = anc[b][j];
return anc[a][0];
}
void init(int &v, int s, int t) {
v = ++cnt;
if (s == t) return;
int mid = s+t>>1;
init(tr[v].ls, s, mid);
init(tr[v].rs, mid+1, t);
}
void insert(int v, int o, int s, int t, int val) {
tr[v] = tr[o];
if (s == t) {tr[v].val++; return;}
int mid = s+t>>1;
if (val <= mid) insert(tr[v].ls = ++cnt, tr[o].ls, s, mid, val);
else insert(tr[v].rs = ++cnt, tr[o].rs, mid+1, t, val);
tr[v].val = tr[tr[v].ls].val+tr[tr[v].rs].val;
}
void build(int u) {
root[u] = ++cnt;
insert(root[u], root[anc[u][0]], 1, tot, c[u]);
for (int i = 0; i < G[u].size(); i++) {
int v = G[u][i];
if (v != anc[u][0]) build(v);
}
}
int query(int v1, int v2, int v3, int v4, int s, int t, int k) {
if (s == t) return s;
int mid = s+t>>1, tmp = tr[tr[v1].ls].val+tr[tr[v2].ls].val-tr[tr[v3].ls].val-tr[tr[v4].ls].val;
if (k <= tmp) return query(tr[v1].ls, tr[v2].ls, tr[v3].ls, tr[v4].ls, s, mid, k);
return query(tr[v1].rs, tr[v2].rs, tr[v3].rs, tr[v4].rs, mid+1, t, k-tmp);
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) pre[i].id = i, scanf("%d", &pre[i].val);
sort(pre+1, pre+n+1, cmp);
for (int i = 1; i <= n; i++) {
if (i == 1 || pre[i].val != pre[i-1].val) map[++tot] = pre[i].val;
c[pre[i].id] = tot;
}
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(v), G[v].push_back(u);
}
DFS(1);
init(root[0], 1, tot), build(1);
int ans = 0;
while (m--) {
int u, v, k;
scanf("%d%d%d", &u, &v, &k); u ^= ans;
int lca = LCA(u, v);
ans = map[query(root[u], root[v], root[lca], root[anc[lca][0]], 1, tot, k)];
printf("%d", ans);
if (m >= 1) printf("
");
}
return 0;
}