分析:典型的两遍dfs树形dp,先统计到子树的,再统计从祖先来的,dp[i][0]代表从从子树回来的最大值,dp[i][1]代表不回来,id[i]记录从i开始到哪不回来
吐槽:赛场上想到了状态,但是不会更新,而且据说这是一种典型的树形dp,还是太弱
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> #include <vector> #include <queue> #include <set> #include <map> #include <string> #include <cmath> #include <stdlib.h> using namespace std; typedef long long LL; const int N=1e5+10; const int BufferSize=1<<16; char buffer[BufferSize],*hea,*tail; inline char Getchar() { if(hea==tail) { int l=fread(buffer,1,BufferSize,stdin); tail=(hea=buffer)+l; } return *hea++; } inline int read() { int x=0,f=1;char c=Getchar(); for(;!isdigit(c);c=Getchar()) if(c=='-') f=-1; for(;isdigit(c);c=Getchar()) x=x*10+c-'0'; return x*f; } int T,n,a[N],head[N],tot,ret[N],dp[N][2],id[N],kase; struct Edge{ int v,w,next; }edge[N<<1]; void add(int u,int v,int w){ edge[tot].v=v; edge[tot].w=w; edge[tot].next=head[u]; head[u]=tot++; } void dfs(int u,int f){ dp[u][0]=dp[u][1]=a[u];id[u]=-1; for(int i=head[u];~i;i=edge[i].next){ int v=edge[i].v;if(v==f)continue;dfs(v,u); int tmp=max(0,dp[v][0]-2*edge[i].w); dp[u][1]+=tmp; if(dp[v][1]-edge[i].w>0){ if(dp[u][0]+dp[v][1]-edge[i].w>dp[u][1]){ dp[u][1]=dp[u][0]+dp[v][1]-edge[i].w; id[u]=v; } } dp[u][0]+=tmp; } } void get(int u,int f,int x0,int x1){ ret[u]=max(dp[u][0]+x1,dp[u][1]+x0); int w1=dp[u][1]+x0,w0=dp[u][0]; if(w0+x1>=w1){ w1=w0+x1; id[u]=f; } w0+=x0; for(int i=head[u];~i;i=edge[i].next){ int v=edge[i].v;if(v==f)continue; int tmp,tmp0,tmp1; if(v!=id[u]){ tmp=max(dp[v][0]-2*edge[i].w,0); tmp0=w0-tmp,tmp1=w1-tmp; get(v,u,max(0,tmp0-2*edge[i].w),max(0,tmp1-edge[i].w)); continue; } tmp0=a[u]+x0,tmp1=a[u]+x1; for(int j=head[u];~j;j=edge[j].next){ int to=edge[j].v;if(to==f||to==v)continue; tmp=max(0,dp[to][0]-2*edge[j].w); tmp1+=tmp; if(dp[to][1]-edge[j].w>0){ if(tmp0+dp[to][1]-edge[j].w>tmp1)tmp1=tmp0+dp[to][1]-edge[j].w; } tmp0+=tmp; } get(v,u,max(0,tmp0-2*edge[i].w),max(0,tmp1-edge[i].w)); } } int main(){ T=read(); while(T--){ n=read(); for(int i=1;i<=n;++i){ head[i]=-1;a[i]=read(); } tot=0; for(int i=1;i<n;++i){ int u=read(),v=read(),w=read(); add(u,v,w);add(v,u,w); } dfs(1,-1); get(1,-1,0,0); printf("Case #%d: ",++kase); for(int i=1;i<=n;++i) printf("%d ",ret[i]); } return 0; }