[BZOJ2588][Spoj 10628]Count on a tree
试题描述
给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
输入
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
输出
M行,表示每个询问的答案。最后一个询问不输出换行符
输入示例
8 5 105 2 9 3 8 5 7 7 1 2 1 3 1 4 3 5 3 6 3 7 4 8 2 5 1 0 5 2 10 5 3 11 5 4 110 8 2
输出示例
2 8 9 105 7
数据规模及约定
N,M<=100000
题解
我们可以把主席树按照树形结构来建,即每一个节点上的版本从它父亲节点的版本修改而来,那么一个节点上的主席树记录的就是该节点到根节点的权值信息了,于是利用 d(a, b) = dep(a) + dep(b) - dep(lca(a, b)) - dep(fa[lca(a, b)]) 这个公式(其中 d(a, b) 表示路径 a 到 b 的权值和,dep(u) = d(root, u),root 为根节点,lca(a, b) 为 a 与 b 的最近公共祖先,fa[u] 为 u 的父亲)二分。
#include <iostream> #include <cstdio> #include <algorithm> #include <cmath> #include <stack> #include <vector> #include <queue> #include <cstring> #include <string> #include <map> #include <set> using namespace std; const int BufferSize = 1 << 16; char buffer[BufferSize], *Head, *Tail; inline char Getchar() { if(Head == Tail) { int l = fread(buffer, 1, BufferSize, stdin); Tail = (Head = buffer) + l; } return *Head++; } int read() { int x = 0, f = 1; char c = Getchar(); while(!isdigit(c)){ if(c == '-') f = -1; c = Getchar(); } while(isdigit(c)){ x = x * 10 + c - '0'; c = Getchar(); } return x * f; } #define maxn 100010 #define maxm 200010 #define maxlog 17 #define maxnode 2000010 int n, rt[maxn], val[maxn], num[maxn]; int ToT, sumv[maxnode], lc[maxnode], rc[maxnode]; void update(int& y, int x, int l, int r, int p) { sumv[y = ++ToT] = sumv[x] + 1; if(l == r) return ; int mid = l + r >> 1; lc[y] = lc[x]; rc[y] = rc[x]; if(p <= mid) update(lc[y], lc[x], l, mid, p); else update(rc[y], rc[x], mid + 1, r, p); return ; } int m, head[maxn], next[maxm], to[maxm], fa[maxlog][maxn], dep[maxn]; void AddEdge(int a, int b) { to[++m] = b; next[m] = head[a]; head[a] = m; swap(a, b); to[++m] = b; next[m] = head[a]; head[a] = m; return ; } void build(int u) { update(rt[u], rt[fa[0][u]], 1, n, val[u]); for(int i = 1; i < maxlog; i++) fa[i][u] = fa[i-1][fa[i-1][u]]; for(int e = head[u]; e; e = next[e]) if(to[e] != fa[0][u]) { fa[0][to[e]] = u; dep[to[e]] = dep[u] + 1; build(to[e]); } return ; } int lca(int a, int b) { if(dep[a] < dep[b]) swap(a, b); for(int i = maxlog - 1; i >= 0; i--) if(dep[a] - dep[b] >= (1 << i)) a = fa[i][a]; for(int i = maxlog - 1; i >= 0; i--) if(fa[i][a] != fa[i][b]) a = fa[i][a], b = fa[i][b]; return a == b ? a : fa[0][b]; } int solve(int a, int b, int k) { int lrt[2] = {rt[a], rt[b]}, c = lca(a, b), rrt[2] = {rt[c], rt[fa[0][c]]}; int l = 1, r = n; while(l < r) { int mid = l + r >> 1, sum = 0; for(int i = 0; i < 2; i++) if(lrt[i] && lc[lrt[i]]) sum += sumv[lc[lrt[i]]]; for(int i = 0; i < 2; i++) if(rrt[i] && lc[rrt[i]]) sum -= sumv[lc[rrt[i]]]; if(sum < k) { k -= sum; l = mid + 1; for(int i = 0; i < 2; i++) if(lrt[i]) lrt[i] = rc[lrt[i]]; for(int i = 0; i < 2; i++) if(rrt[i]) rrt[i] = rc[rrt[i]]; } else { r = mid; for(int i = 0; i < 2; i++) if(lrt[i]) lrt[i] = lc[lrt[i]]; for(int i = 0; i < 2; i++) if(rrt[i]) rrt[i] = lc[rrt[i]]; } } return num[l]; } int main() { n = read(); int q = read(); for(int i = 1; i <= n; i++) val[i] = num[i] = read(); sort(num + 1, num + n + 1); for(int i = 1; i <= n; i++) val[i] = lower_bound(num + 1, num + n + 1, val[i]) - num; for(int i = 1; i < n; i++) { int a = read(), b = read(); AddEdge(a, b); } build(1); int lst = 0; while(q--) { int a = read() ^ lst, b = read(), k = read(); lst = solve(a, b, k); if(q) printf("%d ", lst); else printf("%d", lst); } return 0; }