题面
虚树+贡献
题解
#include<iostream> #include<cstdio> #include<algorithm> #include<vector> #include<stack> #define ri register int #define N 1000050 #define ll long long using namespace std; int n,q,m,cnt; int a[2*N],yy[2*N],top; int f[N][21],dep[N]; bool vis[N],imp[N]; int s; ll sumdis,maxdis,mindis; vector<int> to[N]; vector<int> son[N]; int siz[N],maxc[N],minc[N],nsmin[N]; int dfin[N],dfou[N]; void dfs(int x,int ff,int d) { f[x][0]=ff; dep[x]=d; for (ri i=1;i<=20;i++) f[x][i]=f[f[x][i-1]][i-1]; dfin[x]=++cnt; for (ri i=to[x].size()-1;i>=0;i--) if (to[x][i]!=ff) dfs(to[x][i],x,d+1); dfou[x]=++cnt; } bool cmp1(int x,int y){ return dfin[x]<dfin[y]; } bool cmp2(int x,int y){ int key1,key2; if (x>0) key1=dfin[x]; else key1=dfou[-x]; if (y>0) key2=dfin[y]; else key2=dfou[-y]; return key1<key2; } int lca(int u,int v){ if (dep[u]<dep[v]) swap(u,v); for (ri i=20;i>=0;i--) if (dep[f[u][i]]>=dep[v]) u=f[u][i]; if (u==v) return u; for (ri i=20;i>=0;i--) if (f[u][i]!=f[v][i]) u=f[u][i],v=f[v][i]; return f[u][0]; } void dp(int x) { if (imp[x]) siz[x]=1,minc[x]=0,maxc[x]=0; else siz[x]=0,minc[x]=7654321,maxc[x]=-7654321; nsmin[x]=7654321; for (ri i=son[x].size()-1;i>=0;i--) { dp(son[x][i]); siz[x]+=siz[son[x][i]]; if (minc[son[x][i]]+dep[son[x][i]]-dep[x]<nsmin[x]) nsmin[x]=minc[son[x][i]]+dep[son[x][i]]-dep[x]; if (minc[son[x][i]]+dep[son[x][i]]-dep[x]<minc[x]) minc[x]=minc[son[x][i]]+dep[son[x][i]]-dep[x]; if (maxc[son[x][i]]+dep[son[x][i]]-dep[x]>maxc[x]) maxc[x]=maxc[son[x][i]]+dep[son[x][i]]-dep[x]; } } void dp2(int x) { for (ri i=son[x].size()-1;i>=0;i--) dp2(son[x][i]); if (imp[x]) { if (maxc[x]>maxdis) maxdis=maxc[x]; if (nsmin[x]<mindis) mindis=nsmin[x]; } int s1=-7654321,s2=-7654321,b1=7654321,b2=7654321; for (ri i=son[x].size()-1;i>=0;i--) { maxc[son[x][i]]+=(dep[son[x][i]]-dep[x]); minc[son[x][i]]+=(dep[son[x][i]]-dep[x]); sumdis+=(siz[s]-siz[son[x][i]])*1LL*(siz[son[x][i]])*1LL*(dep[son[x][i]]-dep[x]); if (maxc[son[x][i]]>s1) s2=s1,s1=maxc[son[x][i]];else if (maxc[son[x][i]]>s2) s2=maxc[son[x][i]]; if (minc[son[x][i]]<b1) b2=b1,b1=minc[son[x][i]];else if (minc[son[x][i]]<b2) b2=minc[son[x][i]]; } if (s1+s2>maxdis) maxdis=s1+s2; if (b1+b2<mindis) mindis=b1+b2; } void view(int cur,int x){ cout<<cur<<" "<<x<<endl; for (ri i=son[x].size()-1;i>=0;i--) cout<<son[x][i]<<" "; cout<<endl; } int main(){ scanf("%d",&n); int u,v; for (ri i=1;i<n;i++) { scanf("%d %d",&u,&v); to[u].push_back(v); to[v].push_back(u); } cnt=0; dfs(1,1,1); scanf("%d",&q); for (ri ti=1;ti<=q;ti++) { scanf("%d",&m); for (ri i=1;i<=m;i++) { scanf("%d",&a[i]); vis[a[i]]=1; imp[a[i]]=1; } sort(a+1,a+m+1,cmp1); cnt=m; for (ri i=1;i<m;i++) { int t=lca(a[i],a[i+1]); if (!vis[t]) vis[t]=1,a[++cnt]=t; } int cnt2=cnt; for (ri i=1;i<=cnt;i++) { a[++cnt2]=-a[i]; } sort(a+1,a+cnt2+1,cmp2); //stack<int> yy; //while (!yy.empty()) yy.pop(); top=0; for (ri i=1;i<=cnt2;i++) if (a[i]>0) yy[++top]=a[i]; else { top--; if (top>0) son[yy[top]].push_back(-a[i]); else s=-a[i]; } sumdis=0; maxdis=0; mindis=987654321; dp(s); dp2(s); printf("%lld %lld %lld ",sumdis,mindis,maxdis); for (ri i=1;i<=cnt2;i++) if (a[i]>0) vis[a[i]]=0,imp[a[i]]=0,son[a[i]].clear(); } }