虚树
对树上一类问题的处理
这类问题中,询问的点数远远小于树的点数,有些边可以一起统计
那我们只保留需要用的点就可以了
明显是关键点以及拐点(都是某些LCA)
考虑把每个点按照dfs序排序,按顺序求得两两lca并且去重就能求出虚树上的所有点
构建
具体构建我们用栈实现,栈里面的元素按dfs序单调递增
按dfs序访问所有关键点,对于每个要加入的点 (p),有两种情况:
- (p) 和栈顶元素 (x) 的lca是(x),说明 (p) 在 (x) 子树里面,将其入栈
- 否则栈顶元素的子树一定访问完毕(dfs序),考虑构建
设栈顶元素为(x),第二个元素为(y)
- 若(dfn[y]>dfn[lca]),可以连边(y) −> (x) ,将 (x) 出栈;
- 若(dfn[y]=dfn[lca]),即(y=lca),连边(lca−>x),此时 (lca(y)) 的子树构建完毕;
- 若(dfn[y]<dfn[lca]),即(lca)在(y,x)之间,连边(lca−>x),(x)出栈,再将(lca)入栈。此时lca的子树构建完毕(break)
代码:
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#define ll long long
const int N = 1000010,M=2000010,K = 23,inf=192608170;
using namespace std;
int n;
inline int read(){
int x=0,pos=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') pos=0;
for(;isdigit(ch);ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
return pos?x:-x;
}
struct node{
int v,nex;
};
ll ans1,ans2,f[N],maxs[N],mins[M],sz[N];
int fa[N][23],dfn[N],tim,bo[N],pi[N],st[N],dep[N],dis[N],num,bq[N];
struct graph{
node edge[M];int top,head[N];
void add(int u,int v){
edge[++top].v=v;edge[top].nex=head[u];head[u]=top;
}
void dfs1(int x){
dfn[x]=++tim;
for(int i=1;i<=20;i++) fa[x][i]=fa[fa[x][i-1]][i-1];
for(int i=head[x];i;i=edge[i].nex){
int v=edge[i].v;
if(v!=fa[x][0])
dis[v]=dis[x]+1,dep[v]=dep[x]+1,fa[v][0]=x,dfs1(v);
}
}
void dfs2(int x){
sz[x]=bo[x],maxs[x]=0,mins[x]=inf,f[x]=0;
for(int i=head[x];i;i=edge[i].nex){
int v=edge[i].v,d=dis[v]-dis[x];
dfs2(v);sz[x]+=sz[v];
ans1=min(ans1,mins[x]+mins[v]+d),ans2=max(ans2,maxs[x]+maxs[v]+d);
mins[x]=min(mins[x],mins[v]+d),maxs[x]=max(maxs[x],maxs[v]+d);
f[x]+=f[v]+1ll*d*(num-sz[v])*sz[v];
}
if(bo[x]) ans1=min(ans1,mins[x]),ans2=max(ans2,maxs[x]),mins[x]=0;
head[x]=0;
}
}g1,g2;
int q;
int cmp(int a,int b){
return dfn[a]<dfn[b];
}
int lca(int a,int b){
if(dep[a]<dep[b]) swap(a,b);
for(int h=dep[a]-dep[b],i=20;i>=0;i--){
if(h>=(1<<i)){
h-=(1<<i);
a=fa[a][i];
}
}
if(a==b) return a;
for(int i=20;i>=0;i--) if(fa[a][i]!=fa[b][i]) a=fa[a][i],b=fa[b][i];
return fa[a][0];
}
void work(){
sort(pi+1,pi+num+1,cmp);
int tp=0;
for(int i=1;i<=num;i++){
if(!tp){
st[++tp]=pi[i];continue;
}
int u=lca(st[tp],pi[i]);
while(dfn[u]<dfn[st[tp]]){
if(dfn[u]>=dfn[st[tp-1]]){
g2.add(u,st[tp]);
if(st[--tp]!=u) st[++tp]=u;
break;
}
g2.add(st[tp-1],st[tp]);tp--;
}
st[++tp]=pi[i];
}
while(tp>1) g2.add(st[tp-1],st[tp]),tp--;
ans1=inf,ans2=0,g2.dfs2(st[1]);
printf("%lld %d %d
",f[st[1]],ans1,ans2);
for(int i=1;i<=num;i++) bo[pi[i]]=0;
for(int i=1;i<=g2.top;i++) g2.edge[i].nex=g2.edge[i].v=0;g2.top=0;
}
int main(){
n=read();g1.top=g2.top=0;
for(int i=1,u,v;i<n;i++){
u=read(),v=read();
g1.add(u,v);g1.add(v,u);
}
g1.dfs1(1);q=read();
for(int i=1;i<=q;i++){
num=read();
for(int j=1;j<=num;j++){
pi[j]=read();bo[pi[j]]=1;
}
work();
}
return 0;
}