测试地址:世界树
做法:本题需要用到虚树+树形DP。
首先一看这道题我们就知道要用虚树,因此我们先把询问点的虚树先建出来,然后考虑DP。
我们把虚树中每个点受哪个点管辖先求出来,这是通过两次DFS来完成的,一次处理向下方向的最近,一次处理向上方向的最近。然后对于每条虚树上的边,如果边的两端所属的点不同,则表示这条边需要切断,那么我们可以倍增求出断点,在每次切断时求出较下面的那一块的大小即可。在最后别忘了求出根所属块的大小。
然后就是一大堆细节了,比如当距离相同时要比较编号之类……本蒟蒻竟然因为写反一个符号调了4h,太弱了……
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
int n,m,k,a[600010],b[300010];
int first[300010]={0},tot=0,firstv[300010]={0},totv;
int fa[300010][21],dep[300010],order[300010],tim=0;
int st[300010],top;
int down[300010],downp[300010],up[300010],upp[300010],ans[300010],siz[300010];
int belong[300010],dis[300010];
const int inf=1000000000;
bool vis[300010]={0};
struct edge
{
int v,next,w;
}e[600010],ev[300010];
void insert(int a,int b)
{
e[++tot].v=b,e[tot].next=first[a],first[a]=tot;
}
void insertv(int a,int b,int w)
{
ev[++totv].v=b,ev[totv].w=w,ev[totv].next=firstv[a],firstv[a]=totv;
}
void init(int v)
{
order[v]=++tim;
siz[v]=1;
for(int i=first[v];i;i=e[i].next)
if (e[i].v!=fa[v][0])
{
fa[e[i].v][0]=v;
dep[e[i].v]=dep[v]+1;
init(e[i].v);
siz[v]+=siz[e[i].v];
}
}
int lca(int x,int y)
{
if (dep[x]<dep[y]) swap(x,y);
for(int i=20;i>=0;i--)
if (dep[fa[x][i]]>=dep[y]) x=fa[x][i];
if (x==y) return x;
for(int i=20;i>=0;i--)
if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
int findfa(int x,int y)
{
for(int i=20;i>=0;i--)
if ((1<<i)<=y) x=fa[x][i],y-=(1<<i);
return x;
}
bool cmp(int a,int b)
{
return order[a]<order[b];
}
void build()
{
totv=0;
sort(a+1,a+k+1,cmp);
for(int i=1;i<k;i++)
a[k+i]=lca(a[i],a[i+1]);
a[k<<1]=1;
sort(a+1,a+(k<<1)+1,cmp);
top=0;
for(int i=1;i<=(k<<1);i++)
if (i==1||a[top]!=a[i])
{
a[++top]=a[i];
firstv[a[top]]=0;
}
k=top;
top=1;st[1]=1;
for(int i=2;i<=k;i++)
{
while (top>1&&lca(st[top],a[i])!=st[top])
{
insertv(st[top-1],st[top],dep[st[top]]-dep[st[top-1]]);
top--;
}
st[++top]=a[i];
}
while (top>1)
{
insertv(st[top-1],st[top],dep[st[top]]-dep[st[top-1]]);
top--;
}
}
void getdown(int v)
{
down[v]=inf;
for(int i=firstv[v];i;i=ev[i].next)
{
getdown(ev[i].v);
if (down[v]>down[ev[i].v]+ev[i].w||(down[v]==down[ev[i].v]+ev[i].w&&downp[ev[i].v]<downp[v]))
{
downp[v]=downp[ev[i].v];
down[v]=down[ev[i].v]+ev[i].w;
}
}
if (vis[v]) down[v]=0,downp[v]=v;
}
void getup(int v,int lastw,int f)
{
if (up[f]<down[f]||(down[f]==up[f]&&upp[f]<downp[f])) up[v]=up[f]+lastw,upp[v]=upp[f];
else up[v]=down[f]+lastw,upp[v]=downp[f];
if (vis[v]) up[v]=0,upp[v]=v;
for(int i=firstv[v];i;i=ev[i].next)
getup(ev[i].v,ev[i].w,v);
}
int dp(int v)
{
int remain=siz[v];
for(int i=firstv[v];i;i=ev[i].next)
{
int s=dp(ev[i].v);
if (belong[ev[i].v]!=belong[v])
{
int cutlen=(dis[v]-dis[ev[i].v]+ev[i].w),cut;
if (cutlen%2==0&&belong[v]<belong[ev[i].v]) cut=findfa(ev[i].v,cutlen/2-1);
else cut=findfa(ev[i].v,cutlen/2);
ans[belong[ev[i].v]]=siz[cut]-siz[ev[i].v]+s;
remain-=siz[cut];
}
else remain-=siz[ev[i].v]-s;
}
if (v==1) ans[belong[v]]=remain;
return remain;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
insert(a,b),insert(b,a);
}
fa[1][0]=fa[0][0]=0;
dep[1]=1,dep[0]=0;
up[0]=down[0]=inf;
init(1);
for(int i=1;i<=20;i++)
for(int j=1;j<=n;j++)
fa[j][i]=fa[fa[j][i-1]][i-1];
scanf("%d",&m);
while(m--)
{
int pastk;
scanf("%d",&k);
pastk=k;
for(int i=1;i<=k;i++)
{
scanf("%d",&a[i]);
vis[a[i]]=1;
ans[a[i]]=0;
b[i]=a[i];
}
build();
getdown(1);
getup(1,0,0);
for(int i=1;i<=k;i++)
{
if (up[a[i]]<down[a[i]]||(up[a[i]]==down[a[i]]&&upp[a[i]]<downp[a[i]])) belong[a[i]]=upp[a[i]],dis[a[i]]=up[a[i]];
else belong[a[i]]=downp[a[i]],dis[a[i]]=down[a[i]];
}
dp(1);
for(int i=1;i<=pastk;i++)
printf("%d ",ans[b[i]]);
printf("
");
for(int i=1;i<=k;i++)
vis[a[i]]=0;
}
return 0;
}