先选一点为根节点找出所有父节点i到下面所有点距离和dp[i],该父节点下面有多少个点Node[i]。
然后求出所有节点的所有非子节点到该点的距离dp1[v]+=(dp1[u]+(dp[u]-dp[v]-Node[v]-1)+n-Node[v]-1)
dp[u]-dp[v]-Node[v]-1:u的子节点中除了v这一部分子节点到u的距离
n-Node[v]-1:非v的字节点的个数
#include<stdio.h> #include<string.h> #define N 50002 #define inf 0x3fffffff int head[N],num,vis[N],dp[N],Node[N],dp1[N],n,I,R; struct edge { int st,ed,next; }E[N*2]; void addedge(int x,int y) { E[num].st=x; E[num].ed=y; E[num].next=head[x]; head[x]=num++; } void dfs(int u) { vis[u]=1; int i,v; for(i=head[u];i!=-1;i=E[i].next) { v=E[i].ed; if(vis[v]==1)continue; dfs(v); dp[u]+=(dp[v]+Node[v]+1);//所有子节点到到父节点的距离 Node[u]+=(Node[v]+1);//子节点个数 } } long long mm; void dfs1(int u) { int i,v; vis[u]=1; for(i=head[u];i!=-1;i=E[i].next) { v=E[i].ed; if(vis[v]==1)continue; dp1[v]+=(dp1[u]+(dp[u]-dp[v]-Node[v]-1)+n-Node[v]-1);//除了子节点外所有节点到该点的距离 dfs1(v); } if(mm>dp[u]+dp1[u]) mm=dp[u]+dp1[u]; } int main() { int i,x,y,t; scanf("%d",&t); while(t--) { scanf("%d%d%d",&n,&I,&R); memset(head,-1,sizeof(head)); num=0; for(i=1;i<n;i++) { scanf("%d%d",&x,&y); addedge(x,y); addedge(y,x); } memset(dp,0,sizeof(dp)); memset(dp1,0,sizeof(dp1)); memset(Node,0,sizeof(Node)); memset(vis,0,sizeof(vis)); mm=inf; dfs(1); memset(vis,0,sizeof(vis)); dfs1(1); printf("%lld ",I*I*R*mm); for(i=1;i<=n;i++) { if(dp[i]+dp1[i]==mm) printf("%d ",i); } printf(" "); } return 0; }