因为距离为2,所以枚举中间点即可。
#include<cstdio> #include<iostream> #include<cstring> #include<algorithm> #define maxn 200010 using namespace std; struct node { int ed,nxt; }; node edge[maxn<<1]; int n,m,first[maxn],cnt; long long w[maxn],d[maxn]; const long long mod=10007; long long ans=0,sum=0; inline void add_edge(int s,int e) { ++cnt; edge[cnt].ed=e; edge[cnt].nxt=first[s]; first[s]=cnt; return; } int main() { scanf("%d",&n); for(register int i=1;i<=n-1;++i) { int s,e; scanf("%d%d",&s,&e); add_edge(s,e); add_edge(e,s); d[s]++; d[e]++; } for(register int i=1;i<=n;++i) scanf("%lld",&w[i]); for(register int i=1;i<=n;++i) { if(d[i]==1) continue; long long max_1=0,max_2=0,sum1=0,sum2=0; for(register int j=first[i];j;j=edge[j].nxt) { //cout<<max_1<<' '<<max_2<<endl; int e=edge[j].ed; if(w[e]>max_1) max_2=max_1,max_1=w[e]; else if(w[e]>max_2) max_2=w[e]; sum1=(sum1+w[e])%mod; sum2=(sum2+w[e]*w[e])%mod; //cout<<max_1<<' '<<max_2<<endl; } sum1=sum1*sum1%mod; sum=(sum+sum1+mod-sum2)%mod; ans=max(ans,max_1*max_2); //cout<<max_1<<' '<<max_2<<endl; } printf("%lld %lld ",ans,sum); return 0; }