题目描述
给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
输入输出格式
输入格式:
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
输出格式:
M行,表示每个询问的答案。
输入输出样例
输入样例#1: 复制
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
输出样例#1: 复制
2
8
9
105
7
说明
HINT:
N,M<=100000
暴力自重。。。
来源:bzoj2588 Spoj10628.
题解
一道比较涨见识的题目吧。
在这里我们要统计一条路径上的Kth值,考虑把路径转化为一条序列。怎么转换?
我们可以让新增的那个点在原有的fa节点上再建树。这样我们就可以利用树上差分了。
但是我们会发现,x,y两个节点减去两次它们共同的LCA时,LCA的值就没有算在里面了,这个时候我们就不要减去两次LCA应该减去一次LCA,减去一次fa[LCA]。这里的点都是指在主席树相对应的状态。
代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<cmath>
#include<algorithm>
using namespace std;
const int N=100001;
int top[N],size[N],fa[N],dep[N],son[N];
int ch[N],n,m,q,b[N],tr[N<<5],sum[N<<5],tot,l[N<<5],r[N<<5];
int num,head[N],cnt;
struct node{
int to,next;
}e[N<<1];
int read()
{
int x=0,w=1;char ch=getchar();
while(ch>'9'||ch<'0'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar();
return x*w;
}
void add(int from,int to){
num++;
e[num].to=to;
e[num].next=head[from];
head[from]=num;
}
void dfs1(int x){
size[x]=1;
for(int i=head[x];i;i=e[i].next){
int v=e[i].to;
if(!dep[v]){
dep[v]=dep[x]+1;fa[v]=x;
dfs1(v);size[x]+=size[v];
if(son[son[x]]<size[v])son[x]=v;
}
}
}
void dfs2(int x,int tp){
top[x]=tp;if(son[x])dfs2(son[x],tp);
for(int i=head[x];i;i=e[i].next){
int v=e[i].to;if(v!=fa[x]&&v!=son[x])dfs2(v,v);
}
}
int cal(int x,int y){
int fx=top[x],fy=top[y];
while(fx!=fy){
if(dep[fx]<dep[fy])swap(fx,fy),swap(x,y);
x=fa[fx],fx=top[x];
}
if(dep[x]>=dep[y])return y;return x;
}
int build(int left,int right){
int mid=(left+right)>>1;
int root=++cnt;
sum[root]=0;
if(left<right){
l[root]=build(left,mid);
r[root]=build(mid+1,right);
}
return root;
}
int update(int pre,int left,int right,int v){
int mid=(left+right)>>1;
int root=++cnt;
l[root]=l[pre];r[root]=r[pre];sum[root]=sum[pre]+1;
if(left<right){
if(v<=mid)l[root]=update(l[pre],left,mid,v);
else r[root]=update(r[pre],mid+1,right,v);
}
return root;
}
int query(int u1,int u2,int v1,int v2,int left,int right,int k){
int mid=(left+right)>>1;
int x=sum[l[u1]]+sum[l[u2]]-sum[l[v1]]-sum[l[v2]];
if(left>=right)return left;
if(x>=k)return query(l[u1],l[u2],l[v1],l[v2],left,mid,k);
else return query(r[u1],r[u2],r[v1],r[v2],mid+1,right,k-x);
}
void dfs(int x){
int t=lower_bound(b+1,b+m+1,ch[x])-b;
tr[x]=update(tr[fa[x]],1,m,t);
for(int i=head[x];i;i=e[i].next){
int v=e[i].to;
if(v!=fa[x])dfs(v);
}
}
int main()
{
n=read();q=read();
for(int i=1;i<=n;i++)
{
b[i]=ch[i]=read();
}
sort(b+1,b+n+1);
m=unique(b+1,b+n+1)-b-1;
for(int i=1;i<n;i++)
{
int x=read(),y=read();
add(x,y);add(y,x);
}
fa[1]=0;dep[1]=1;
dfs1(1);dfs2(1,0);
tr[0]=build(1,m);
dfs(1);
int last=0;
while(q--)
{
int l=read(),r=read(),k=read();
int lca=cal(l^last,r);
int ff=fa[lca];
printf("%d
",last=b[query(tr[l^last],tr[r],tr[lca],tr[ff],1,m,k)]);
}
return 0;
}