题目描述
给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
输入
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
输出
M行,表示每个询问的答案。
样例输入
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
样例输出
2
8
9
105
7
题解
主席树+最近公共祖先
需要明确主席树的原理:线段树相加减。
那么A到B的路径就是 A到根的路径+B到根的路径-最近公共祖先到根的路径-最近公共祖先的父亲到根的路径。
可以直接在树上建立主席树,注意每棵树是从它父亲的树推来的。
然后查询即可。
注意最后一行千万不能有换行,否则无限PE!
#include <cstdio> #include <algorithm> #define N 100001 using namespace std; struct data { int num , rank; }a[N]; int root[N] , lp[N << 5] , rp[N << 5] , sum[N << 5] , val[N] , top , tot; int head[N] , to[N << 1] , next[N << 1] , cnt , fa[N] , bl[N] , deep[N] , si[N] , q[N] , tail; bool cmp1(data a , data b) { return a.num < b.num; } bool cmp2(data a , data b) { return a.rank < b.rank; } void add(int x , int y) { to[++cnt] = y; next[cnt] = head[x]; head[x] = cnt; } void dfs1(int x) { int i; si[x] = 1; for(i = head[x] ; i ; i = next[i]) { if(to[i] != fa[x]) { fa[to[i]] = x; deep[to[i]] = deep[x] + 1; dfs1(to[i]); si[x] += si[to[i]]; } } } void dfs2(int x , int c) { int i , k = 0; bl[x] = c; q[++tail] = x; for(i = head[x] ; i ; i = next[i]) if(to[i] != fa[x] && si[to[i]] > si[k]) k = to[i]; if(k) { dfs2(k , c); for(i = head[x] ; i ; i = next[i]) if(to[i] != fa[x] && to[i] != k) dfs2(to[i] , to[i]); } } int getlca(int x , int y) { while(bl[x] != bl[y]) { if(deep[bl[x]] < deep[bl[y]]) swap(x , y); x = fa[bl[x]]; } if(deep[x] < deep[y]) return x; return y; } void pushup(int x) { sum[x] = sum[lp[x]] + sum[rp[x]]; } void ins(int x , int &y , int l , int r , int p) { y = ++tot; if(l == r) { sum[y] = sum[x] + 1; return; } int mid = (l + r) >> 1; if(p <= mid) rp[y] = rp[x] , ins(lp[x] , lp[y] , l , mid , p); else lp[y] = lp[x] , ins(rp[x] , rp[y] , mid + 1 , r , p); pushup(y); } int query(int a , int b , int c , int d , int l , int r , int p) { if(l == r) return val[l]; int mid = (l + r) >> 1; if(sum[lp[a]] + sum[lp[b]] - sum[lp[c]] - sum[lp[d]] >= p) return query(lp[a] , lp[b] , lp[c] , lp[d] , l , mid , p); else return query(rp[a] , rp[b] , rp[c] , rp[d] , mid + 1 , r , p - sum[lp[a]] - sum[lp[b]] + sum[lp[c]] + sum[lp[d]]); } int main() { int n , m , i , x , y , z , f , last = 0; scanf("%d%d" , &n , &m); for(i = 1 ; i <= n ; i ++ ) { scanf("%d" , &a[i].num); a[i].rank = i; } sort(a + 1 , a + n + 1 , cmp1); val[0] = 0x80000000; for(i = 1 ; i <= n ; i ++ ) { if(a[i].num != val[top]) val[++top] = a[i].num; a[i].num = top; } sort(a + 1 , a + n + 1 , cmp2); for(i = 1 ; i < n ; i ++ ) { scanf("%d%d" , &x , &y); add(x , y); add(y , x); } dfs1(1); dfs2(1 , 1); for(i = 1 ; i <= tail ; i ++ ) ins(root[fa[q[i]]] , root[q[i]] , 1 , top , a[q[i]].num); while(m -- ) { scanf("%d%d%d" , &x , &y , &z); x ^= last; f = getlca(x , y); last = query(root[x] , root[y] , root[f] , root[fa[f]] , 1 , top , z); printf("%d" , last); if(m) printf(" "); } return 0; }