Count on a tree 树上主席树
给(n)个树,每个点有点权,每次询问(u,v)路径上第(k)小点权,强制在线
求解区间静态第(k)小即用主席树。
树上主席树类似于区间上主席树,我们利用前缀和相减获得区间的信息,树上主席树也是这样,维护一个到根节点的前缀和。对于((u,v))路径,(sum[u]+sum[v]-sum[lca(u,v)]-sum[fa[lca(u,v)]])即可获得树上(u,v)路径的区间信息,然后按照区间查询即可。
#include <cstdio>
#include <algorithm>
#define MAXN 1001000
#define MAXM MAXN*30
#define LOG 30
int head[MAXN],nxt[MAXN*2],vv[MAXN*2],tot;
inline void add_edge(int u, int v){
vv[++tot]=v;
nxt[tot]=head[u];
head[u]=tot;
}
int n,m,s,cnt;
int f[MAXN][31],dep[MAXN];
int rot[MAXN];
int val[MAXN],idx[MAXN];
int tre[MAXM],sl[MAXM],sr[MAXM];
void build_tre(int &x, int l, int r){
x=++cnt;
if(l==r) return;
int mid=(l+r)>>1;
build_tre(sl[x], l, mid);
build_tre(sr[x], mid+1, r);
}
void change(int &x, int pre, int l, int r, int pos){
if(x==0) x=++cnt;
tre[x]=tre[pre]+1;
if(l==r) return;
int mid=(l+r)>>1;
if(pos<=mid) change(sl[x], sl[pre], l, mid, pos),sr[x]=sr[pre];
else change(sr[x], sr[pre], mid+1, r, pos),sl[x]=sl[pre];
}
int query(int x, int y, int t, int tf, int l, int r, int k){
if(l==r) return l;
int mid=(l+r)>>1;
int tmp=tre[sl[x]]+tre[sl[y]]-tre[sl[t]]-tre[sl[tf]];
if(tmp>=k) return query(sl[x], sl[y], sl[t], sl[tf], l, mid, k);
else return query(sr[x], sr[y], sr[t], sr[tf], mid+1, r, k-tmp);
}
void dfs(int u, int fa){
f[u][0]=fa;
dep[u]=dep[fa]+1;
change(rot[u], rot[fa], 1, s, idx[u]);
for(int i=head[u];i;i=nxt[i]){
int v=vv[i];
if(v==fa) continue;
dfs(v, u);
}
}
namespace lca{
void init(){
for(int i=1;i<=LOG;++i)
for(int j=1;j<=n;++j)
f[j][i]=f[f[j][i-1]][i-1];
}
int lca(int a, int b){
if(dep[a]<dep[b]) std::swap(a,b);
for(int i=LOG;i>=0;--i)
if(dep[f[a][i]]>=dep[b])
a=f[a][i];
if(a==b) return a;
for(int i=LOG;i>=0;--i)
if(f[a][i]!=f[b][i]){
a=f[a][i];
b=f[b][i];
}
return f[a][0];
}
}
inline int read(){
char ch=getchar();int s=0;
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s;
}
int val_sort[MAXN];
int main(){
n=read(),m=read();
for(int i=1;i<=n;++i) val_sort[i]=val[i]=read();
for(int i=1;i<n;++i){
int x=read(),y=read();
add_edge(x, y);add_edge(y, x);
}
std::sort(val_sort+1, val_sort+1+n);
s=std::unique(val_sort+1, val_sort+1+n)-val_sort;
for(int i=1;i<=n;++i)
idx[i]=std::lower_bound(val_sort+1, val_sort+1+s, val[i])-val_sort;
build_tre(rot[0], 1, s);
dfs(1, 0);
lca::init();
int lastans=0;
while(m--){
int u=read()^lastans,v=read(),k=read();
int tmp=lca::lca(u, v);
lastans=val_sort[query(rot[u], rot[v], rot[tmp], rot[f[tmp][0]], 1, s, k)];
printf("%d
", lastans);
}
return 0;
}