2588: Spoj 10628. Count on a tree
Time Limit: 12 Sec Memory Limit: 128 MBSubmit: 8100 Solved: 2010
[Submit][Status][Discuss]
Description
给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
Input
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
Output
M行,表示每个询问的答案。最后一个询问不输出换行符
Sample Input
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
Sample Output
2
8
9
105
7
8
9
105
7
HINT
N,M<=100000
暴力自重。。。
Source
分析:挺裸的主席树.每一颗点所代表的主席树建立在它的父节点的基础上.建树就很容易地解决了.查询的话首先要把(u,v)这条链给找出来,就是u到根的路径 + v到根的路径 - lca到根的路径 - lca的父亲到根的路径(lca的权值要计算).那么查询的时候在这4棵主席树上跑跑就可以了.
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; const int maxn = 100010; int n,m,head[maxn],to[maxn * 2],nextt[maxn * 2],tot = 1,a[maxn],fa[maxn][21]; int b[maxn],cnt,num,root[maxn],lastans,deep[maxn]; struct node { int left,right,sum; }e[maxn * 20]; void add(int x,int y) { to[tot] = y; nextt[tot] = head[x]; head[x] = tot++; } void update(int l,int r,int x,int &y,int v) { e[y = ++num] = e[x]; e[y].sum++; if (l == r) return; int mid = (l + r) >> 1; if (v <= mid) update(l,mid,e[x].left,e[y].left,v); else update(mid + 1,r,e[x].right,e[y].right,v); } void dfs(int u,int faa,int dep) { fa[u][0] = faa; deep[u] = dep; int pos = lower_bound(b + 1,b + 1 + cnt,a[u]) - b; update(1,cnt,root[faa],root[u],pos); for (int i = head[u];i;i = nextt[i]) { int v = to[i]; if (v == faa) continue; dfs(v,u,dep + 1); } } int lca(int x,int y) { if (x == y) return x; if (deep[x] < deep[y]) swap(x,y); for (int j = 19; j >= 0; j--) if (deep[fa[x][j]] >= deep[y]) x = fa[x][j]; if (x == y) return x; for (int j = 19; j >= 0; j--) if (fa[x][j] != fa[y][j]) x = fa[x][j],y = fa[y][j]; return fa[x][0]; } int query(int l,int r,int a,int b,int c,int d,int k) { if (l == r) return l; int mid = (l + r) >> 1; int temp = e[e[a].left].sum + e[e[b].left].sum - e[e[c].left].sum - e[e[d].left].sum; if (k <= temp) return query(l,mid,e[a].left,e[b].left,e[c].left,e[d].left,k); else return query(mid + 1,r,e[a].right,e[b].right,e[c].right,e[d].right,k - temp); } int main() { scanf("%d%d",&n,&m); for (int i = 1; i <= n; i++) scanf("%d",&a[i]); memcpy(b,a,sizeof(a)); sort(b + 1,b + 1 + n); cnt = unique(b + 1,b + 1 + n) - b - 1; for (int i = 1; i < n; i++) { int x,y; scanf("%d%d",&x,&y); add(x,y); add(y,x); } dfs(1,0,1); for (int j = 1; j <= 19; j++) for (int i = 1; i <= n; i++) fa[i][j] = fa[fa[i][j - 1]][j - 1]; for (int i = 1; i <= m; i++) { int u,v,k; scanf("%d%d%d",&u,&v,&k); u ^= lastans; int LCA = lca(u,v); lastans = b[query(1,cnt,root[u],root[v],root[LCA],root[fa[LCA][0]],k)]; if (i != m) printf("%d ",lastans); else printf("%d",lastans); } return 0; }