题目大意为,求一个树的直径(最长路),以及直径的数量
朴素的dp只能找出某点开始的最长路径,但这个最长路径却不一定是树的直径,本弱先开始就想简单了,一直wa
直到我看了某位大牛的题解。。。
按照那位大牛的思路,我们来考虑直径的构成:
情况1:由某叶子节点出发产生的最长路径直接构成
情况2:由某有多个儿子的节点出发产生的两条长路径组成,这其中,又可以分为两条长路径长度相等与否两种情况
所以 在dp的时候,我们需要记录每个节点出发产生的最长路径和次长路径,以及他们的数量,数量的统计也是非常麻烦
详细请见代码:
#include<stdio.h> #include<iostream> #include<stdlib.h> #include<math.h> #include<ctype.h> #include<algorithm> #include<string> #include<string.h> #include<queue> #define mod 998244353 #define MAX 100000000 using namespace std; int t,n,m,p,k,tt,f; int x; int head[10010]; typedef struct Node { int en; int value; int next; }node; node edge[20010]; typedef struct DPnode { int dp1,dp2,len,nn; int n1,n2; }DP; DP dp[10010]; void ini() { int x,y,z; for(int i=1;i<=n-1;i++) { scanf("%d%d%d",&x,&y,&k); edge[2*i-1].en=y; edge[2*i-1].next=head[x]; edge[2*i-1].value=k; head[x]=2*i-1; edge[2*i].en=x; edge[2*i].next=head[y]; edge[2*i].value=k; head[y]=2*i; } } void dfs(int s,int p) { dp[s].dp1=dp[s].dp2=dp[s].len=dp[s].n1=dp[s].n2=dp[s].nn=0; int leaf=1; for(int i=head[s];i;i=edge[i].next) { int q=edge[i].en; if(q==p) continue; leaf=0; dfs(q,s); int tmp=dp[q].dp1+edge[i].value; if(tmp>dp[s].dp1) { dp[s].dp2=dp[s].dp1; dp[s].n2=dp[s].n1; dp[s].dp1=tmp; dp[s].n1=dp[q].n1; } else if(tmp==dp[s].dp1) { dp[s].n1+=dp[q].n1; } else if(tmp>dp[s].dp2) { dp[s].dp2=tmp; dp[s].n2=dp[q].n1; } else if(tmp==dp[s].dp2) { dp[s].n2+=dp[q].n1; } } if(leaf) { dp[s].n1=1;dp[s].nn=1; dp[s].len=0; dp[s].dp1=0; return; } int c1=0,c2=0; for(int i=head[s];i;i=edge[i].next) { int q=edge[i].en; if(q==p) continue; int tmp=dp[q].dp1+edge[i].value; if(tmp==dp[s].dp1) c1++; else if(tmp==dp[s].dp2&&dp[s].dp2) c2++; } if(c1>1) { dp[s].len=dp[s].dp1*2; int sum=0; for(int i=head[s];i;i=edge[i].next) { int q=edge[i].en; if(q==p) continue; if(dp[q].dp1+edge[i].value==dp[s].dp1) { dp[s].nn+=sum*dp[q].n1; sum+=dp[q].n1; } } } else if(c2>0) { dp[s].len=dp[s].dp1+dp[s].dp2; for(int i=head[s];i;i=edge[i].next) { int q=edge[i].en; if(q==p) continue; if(dp[q].dp1+edge[i].value==dp[s].dp2) { dp[s].nn+=dp[s].n1*dp[q].n1; } } } else { dp[s].len=dp[s].dp1; dp[s].nn=dp[s].n1; } return ; } void solve() { int ans=0; int num=0; for(int i=1;i<=n;i++) { if(dp[i].len>ans) { ans=dp[i].len; num=dp[i].nn; } else if(dp[i].len==ans) { num+=dp[i].nn; } } printf("%d %d ",ans,num); } int main() { while(scanf("%d",&n)!=EOF) { memset(head,0,sizeof(head)); ini(); dfs(1,0); solve(); } return 0; }