题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6567
#include<iostream> #include<algorithm> using namespace std; #define ll long long #define inf 0x3f3f3f3f #define maxn 100005 int n,cnt,root,maxx,head[maxn],size[maxn],vis[maxn],sum,n1,n2; struct edge{ int next,to; }e[maxn<<1]; void add(int u,int v) { e[++cnt].to=v; e[cnt].next=head[u]; head[u]=cnt; } void dfs0(int u) { if(vis[u])return ; vis[u]=1; sum++; for(int i=head[u];i;i=e[i].next) { dfs0(e[i].to); } } void getroot(int u,int fa,int N) { size[u]=1; int f=0; for(int i=head[u];i;i=e[i].next) { if(e[i].to!=fa) { getroot(e[i].to,u,N); size[u]+=size[e[i].to]; f=max(f,size[e[i].to]); } } f=max(f,N-size[u]); if(f<maxx) { root=u; maxx=f; } } ll ans; void dfs(int u,int f) { size[u]=1; for(int i=head[u];i;i=e[i].next) { int v=e[i].to; if(v==f)continue; dfs(v,u); size[u]+=size[v]; ans+=(ll)(size[v])*(ll)(n-size[v]); } } int main() { cin>>n; int u,v; cnt=0; for(int i=1;i<=n-2;i++) { cin>>u>>v; add(u,v); add(v,u); } sum=0; int rot1=1,rot2; dfs0(1); n1=sum; for(int i=1;i<=n;i++) { if(vis[i]==0) { sum=0; rot2=i; dfs0(i); n2=sum; break; } } maxx=inf; getroot(rot1,0,n1); rot1=root; maxx=inf; getroot(rot2,0,n2); rot2=root; add(rot1,rot2); add(rot2,rot1); ans=0; dfs(1,0); cout<<ans<<endl; return 0; }