MMSet2
给定一棵n个节点的树,点编号为1…n。
Q次询问,每次询问给定一个点集S,令,(f(u)=max_{vin S}dist(u,v))
你需要求出(min_{u=1dots n}f(u))
其中dist(u,v)表示树上路径(u,v)的边数。
输入描述:
第一行一个整数n,接下来n−1行每行两个整数表示树上的一条边。
接下来一行一个整数Q,接着Q行,每行第一个数是|S|,剩下|S|个互不相同的数代表这个集合。
输出描述:
输出Q行,每行一个整数表示答案。
示例1
输入
3
1 2
1 3
1
2 2 3
输出
1
备注:
n≤3×105,|S|≥1,∑|S|≤106
每条边的长度是1,显然答案ans就是S的若干条直径中两条半径差值最小时较大的那条半径,因为如果有u'点到S的其他点的距离均小于刚刚的答案,那么直径就可以减小了,当然,这样的半径也能成为f(u),因为如果f(u)能扩大,则直径就能增大了。
由于边权均为1,假设两半径为(r_i,r_2)且(r_1>r_2)如果(r_1,r_2)相差d>=2,那么可以将中心点左移或右移一个点,使d-1,所以,d<=1;
ans=直径/2(向上取整)
再来考虑如何计算S的直径:
设D(tree)=(a,b)表示tree的直径是a,b间的距离,那么(D(treecup x)=max(dist(a,b),dist(a,x),dis(b,x)))
证明:
设有点c使得dist(c,x)大于直径。那么a或b在c->x的路径上,因为如果不在,
则x->c->a(或b)更大,所以a->b->c长度大于直径,矛盾。
伪代码:
1.刚开始有一个点a;
2.加入一个点b,如果dist (a,b)使直径变大,b记为x,
3.重复2(期间a不变)直到加完。
这样我们就计算了所有dist(a,x),但是没计算dist(b,x);
最后再循环一次,计算所有dist((b_i),x),遇到更大的就更新直径。
#include<bits/stdc++.h>
using namespace std;
const int MAXN=3e5+8;
struct E{int y,nt;}e[MAXN<<1];
int head[MAXN],cnt;
void add(int x,int y){
e[++cnt].nt=head[x];
e[cnt].y=y;
head[x]=cnt;
}
int tot[MAXN],deep[MAXN],son[MAXN],fa[MAXN];
int dfs1(int now,int pre,int dep){
tot[now]=1;
fa[now]=pre;
deep[now]=dep;
int max_son=-1;
for(int i=head[now];i;i=e[i].nt){
int to=e[i].y;
if(to==pre)continue;
tot[now]+=dfs1(to,now,dep+1);
if(tot[to]>max_son){
max_son=tot[to];
son[now]=to;
}
}
return tot[now];
}
int top[MAXN];
void dfs2(int now,int topfa){
top[now]=topfa;
if(!son[now])return;
dfs2(son[now],topfa);
for(int i=head[now];i;i=e[i].nt){
int to=e[i].y;
if(!top[to])dfs2(to,to);
}
}
int lca(int x,int y){
while(top[x]^top[y]){
if(deep[top[x]]<deep[top[y]])swap(x,y);
x=fa[top[x]];
}
if(deep[x]<deep[y])return x;
return y;
}
inline int dist(int x,int y){return deep[x]+deep[y]-2*deep[lca(x,y)];}
int n,s[MAXN];
int main() {
scanf("%d",&n);
int x,y;
for(int i=1; i<n; ++i) {
scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
dfs1(1,0,1);
dfs2(1,1);
int q;
scanf("%d",&q);
while(q--) {
int S;
scanf("%d",&S);
int d=-1,p=0;
for(int i=0;i<S;++i){
scanf("%d",s+i);
int tmp=dist(s[0],s[i]);
if(tmp>d){d=tmp,p=i;}
}
for(int i=0;i<S;++i){
d=max(d,dist(s[p],s[i]));
}
printf("%d
",(d+1)/2);
}
return 0;
}