【BZOJ2588】Spoj 10628. Count on a tree
Description
给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
Input
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
Output
M行,表示每个询问的答案。
Sample Input
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
Sample Output
2
8
9
105
7
8
9
105
7
HINT
N,M<=100000
暴力自重。。。
题解:先树剖求出LCA,然后再树上搞一个主席树,第i棵线段树保存的是从根到i的路径上的所有点,然后查询的时候就在用a的线段树+b的线段树-lca(a,b)的线段树-fa(lca(a,b))的线段树就行了。
主席树写的还是不熟练,注意在建树的时候一定要按着DFS序建。
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; const int maxn=100010; struct NUM { int num,org; }v[maxn]; struct NODE { int siz,ls,rs; }s[maxn*40]; int n,m,tot,nm,ans,cnt; int ref[maxn],root[maxn],q[maxn]; int fa[maxn],size[maxn],top[maxn],son[maxn],deep[maxn],to[maxn<<1],next[maxn<<1],head[maxn]; bool cmp1(NUM a,NUM b) { return a.num<b.num; } bool cmp2(NUM a,NUM b) { return a.org<b.org; } void add(int a,int b) { to[cnt]=b; next[cnt]=head[a]; head[a]=cnt++; } void dfs1(int x) { size[x]=1; for(int i=head[x];i!=-1;i=next[i]) { if(to[i]!=fa[x]) { fa[to[i]]=x,deep[to[i]]=deep[x]+1; dfs1(to[i]); size[x]+=size[to[i]]; if(size[to[i]]>size[son[x]]) son[x]=to[i]; } } } void dfs2(int x,int tp) { top[x]=tp; q[++q[0]]=x; if(son[x]) dfs2(son[x],tp); for(int i=head[x];i!=-1;i=next[i]) if(to[i]!=fa[x]&&to[i]!=son[x]) dfs2(to[i],to[i]); } void insert(int x,int &y,int l,int r,int p) { y=++tot; if(l==r) { s[y].siz=s[x].siz+1; return ; } int mid=l+r>>1; if(p<=mid) s[y].rs=s[x].rs,insert(s[x].ls,s[y].ls,l,mid,p); else s[y].ls=s[x].ls,insert(s[x].rs,s[y].rs,mid+1,r,p); s[y].siz=s[s[y].ls].siz+s[s[y].rs].siz; } void query(int a,int b,int c,int d,int l,int r,int p) { if(l==r) { ans=ref[l]; return ; } int mid=l+r>>1; if(s[s[a].ls].siz+s[s[b].ls].siz-s[s[c].ls].siz-s[s[d].ls].siz>=p) query(s[a].ls,s[b].ls,s[c].ls,s[d].ls,l,mid,p); else query(s[a].rs,s[b].rs,s[c].rs,s[d].rs,mid+1,r,p-s[s[a].ls].siz-s[s[b].ls].siz+s[s[c].ls].siz+s[s[d].ls].siz); } int getlca(int x,int y) { while(top[x]!=top[y]) { if(deep[top[x]]>deep[top[y]]) x=fa[top[x]]; else y=fa[top[y]]; } if(deep[x]<deep[y]) return x; return y; } int readin() { int ret=0,f=1; char gc=getchar(); while(gc<'0'||gc>'9') {if(gc=='-')f=-f; gc=getchar();} while(gc>='0'&&gc<='9') ret=ret*10+gc-'0',gc=getchar(); return ret*f; } int main() { n=readin(),m=readin(); memset(head,-1,sizeof(head)); int i,a,b,c,d,e; for(i=1;i<=n;i++) v[i].num=readin(),v[i].org=i; sort(v+1,v+n+1,cmp1); ref[nm]=-1; for(i=1;i<=n;i++) { if(v[i].num>ref[nm]) ref[++nm]=v[i].num; v[i].num=nm; } sort(v+1,v+n+1,cmp2); for(i=1;i<n;i++) { a=readin(),b=readin(); add(a,b),add(b,a); } deep[1]=1,dfs1(1),dfs2(1,1); for(i=1;i<=n;i++) insert(root[fa[q[i]]],root[q[i]],1,nm,v[q[i]].num); for(i=1;i<=m;i++) { scanf("%d%d%d",&a,&b,&e); a^=ans; int c=getlca(a,b),d=fa[c]; query(root[a],root[b],root[c],root[d],1,nm,e); printf("%d",ans); if(i!=m) printf(" "); } return 0; }