这道题用到了4个dfs,分别是找出所有家的最小生成树,找出一点距离树的最小距离,找出每个点儿子距离的最大值(不包括父亲,也就是指不包括根节点的子树),用父亲的值来更新自己
因为我们可以知道:如果我们在树上,那么最短的距离就是树的长度的两倍-距自己最远的点的距离,当我们不在树上时,就得先走到树上(这条路径是唯一的),然后再重复刚才的过程
找出生成树比较简单,重点是找出距树上一点最远的点的距离,这里先找出除了父亲之外每个子树的距离,求出最大和第二,然后再用父亲更新自己的距离,很难想到
#include<iostream> #include<stdio.h> #include<string.h> #define N 1000010 using namespace std; typedef long long ll; ll tot; int cnt=-1; int head[N],to[N],next[N],w[N],used[N],nearest_node[N]; ll dis[N],max_dis[N],sec_dis[N],min_dis[N]; int max_dis_node[N],sec_dis_node[N]; bool k[N],on_tree[N]; void insert(int u,int v,int c) { next[++cnt]=head[u]; head[u]=cnt; to[cnt]=v; w[cnt]=c; } ll max(ll x,ll y) { return x>y?x:y; } ll min(ll x,ll y) { return x<y?x:y; } void init() { memset(head,-1,sizeof(head)); memset(to,-1,sizeof(to)); memset(next,-1,sizeof(next)); memset(used,0,sizeof(used)); memset(on_tree,false,sizeof(on_tree)); memset(min_dis,0,sizeof(min_dis)); memset(nearest_node,0,sizeof(nearest_node)); memset(dis,0,sizeof(dis)); memset(max_dis,0,sizeof(max_dis)); } bool dfs1(int u)//找树 { used[u]=1; for(int i=head[u];~i;i=next[i]) { int v=to[i]; if(!used[v]) { bool f2=dfs1(v); if(f2) { on_tree[u]|=f2; tot+=w[i]; } dfs1(v); } } return on_tree[u]; } void dfs2(int u)//找树距 { used[u]=1; if(on_tree[u]) nearest_node[u]=u; for(int i=head[u];~i;i=next[i]) { int v=to[i]; if(!on_tree[v]&&!used[v]) { min_dis[v]=min_dis[u]+w[i]; nearest_node[v]=nearest_node[u]; } if(!used[v]) dfs2(v); } } ll dfs3(int u)//儿子的最大距离 { used[u]=1; for(int i=head[u];~i;i=next[i]) { int v=to[i]; if(!used[v]&&on_tree[v]) { ll dis=dfs3(v)+w[i]; if(dis>max_dis[u]) { sec_dis[u]=max_dis[u]; max_dis[u]=dis; max_dis_node[u]=v; } else if(dis>sec_dis[u]) { sec_dis[u]=dis; } } } return max_dis[u]; } void dfs4(int u,ll last_max) { used[u]=1; for(int i=head[u];~i;i=next[i]) { int v=to[i]; ll temp=max(last_max,max_dis[u]); dis[u]=max(last_max,max_dis[u]); if(!used[v]&&on_tree[v]) { if(max_dis_node[u]==v) { dfs4(v,max(last_max,sec_dis[u])+w[i]); } else dfs4(v,max(last_max,max_dis[u])+w[i]); } } } int main() { init(); int n,k; scanf("%d%d",&n,&k); for(int i=1;i<n;i++) { int u,v,c; scanf("%d%d%d",&u,&v,&c); insert(u,v,c); insert(v,u,c); } int x; for(int i=1;i<=k;i++) { scanf("%d",&x); on_tree[x]=true; } dfs1(x); memset(used,0,sizeof(used)); dfs2(x); memset(used,0,sizeof(used)); dfs3(x); memset(used,0,sizeof(used)); dfs4(x,0); /* printf("tot=%d ",tot); printf("------------- "); for(int i=1;i<=n;i++) printf("%d ",dis[i]); printf(" ------------ "); */ for(int i=1;i<=n;i++) printf("%lld ",2*tot+min_dis[i]-dis[nearest_node[i]]); return 0; }