http://acm.hdu.edu.cn/showproblem.php?pid=4616
要记录各种状态的段 a[2][4]
a[0][j]表示以trap为起点一共有j个trap的最优值
a[1][j]表示不以trap为起点一共有j个trap的最优值
dp[x][i][j] 表示以x为根节点的子树从各个叶子到x节点的各状态最优值
每到一个节点 要枚举经过此节点的所有符合要求的段中最优的(需要合并段)
代码:
#include<iostream> #include<cstdio> #include<string> #include<cstring> #include<cmath> #include<set> #include<map> #include<stack> #include<vector> #include<algorithm> #include<queue> #include<bitset> #include<deque> #include<numeric> #pragma comment(linker, "/STACK:1024000000,1024000000") using namespace std; typedef long long ll; typedef unsigned int uint; typedef pair<int,int> pp; const double eps=1e-9; const int INF=0x3f3f3f3f; const ll MOD=1000000007; const int N=100005; int head[N],I; struct node { int j,next; }edge[N*2]; int value[N],trap[N]; int dp[N][2][4]; int ans,C; void add(int i,int j) { edge[I].j=j; edge[I].next=head[i]; head[i]=I++; } void init(int n) { for(int i=0;i<n;++i) scanf("%d %d",&value[i],&trap[i]); memset(head,-1,sizeof(head));I=0; for(int i=1;i<n;++i) { int l,r; scanf("%d %d",&l,&r); add(l,r); add(r,l); } } void copyArr(int (*b)[4],int (*a)[4]) { for(int i=0;i<2;++i) for(int j=0;j<4;++j) b[i][j]=a[i][j]; } void clArr(int (*b)[4]) { for(int i=0;i<2;++i) for(int j=0;j<4;++j) b[i][j]=-1; b[0][0]=b[1][0]=0; } void update(int (*b)[4],int x) { if(trap[x]==0) { for(int i=0;i<2;++i) for(int j=0;j<4;++j) if(b[i][j]!=-1) b[i][j]+=value[x]; b[0][0]=0; }else { for(int i=0;i<2;++i) for(int j=3;j>0;--j) { if(b[i][j-1]!=-1) b[i][j]=b[i][j-1]+value[x]; } b[0][0]=0; b[1][0]=0; } } void print(int (*b)[4]) { for(int i=0;i<2;++i) { for(int j=0;j<4;++j) printf("%4d ",b[i][j]);printf(" "); }printf(" "); } void findAns(int (*b)[4],int (*v1)[4],int (*v2)[4],int x) { int c=C-trap[x]; int tmp=0; for(int i=0;i<2;++i) for(int j=0;j<4;++j) { for(int l=0;l<2;++l) for(int r=0;r<4;++r) { if(j+r>c) continue; if(j+r==c) { if(i+l==2) continue; if(i!=l) { if(i==0&&j==0) continue; if(l==0&&r==0) continue; } } if(v1[l][r]!=b[l][r]) tmp=max(tmp,max(0,v1[l][r])+max(0,b[i][j])); else tmp=max(tmp,max(0,v2[l][r])+max(0,b[i][j])); } } ans=max(ans,tmp+value[x]); } void dfs(int pre,int x,int (*a)[4]) { int b[2][4]; copyArr(b,a); update(b,x); int v1[2][4],v2[2][4]; clArr(v1);clArr(v2); for(int t=head[x];t!=-1;t=edge[t].next) { int l=edge[t].j; if(l==pre) continue; dfs(x,l,b); for(int i=0;i<2;++i) for(int j=0;j<4;++j) { v2[i][j]=max(v2[i][j],dp[l][i][j]); if(v1[i][j]<v2[i][j]) swap(v1[i][j],v2[i][j]); } } copyArr(dp[x],v1); update(dp[x],x); for(int i=0;i<2;++i) for(int j=0;j<4;++j) { v2[i][j]=max(v2[i][j],a[i][j]); if(v1[i][j]<v2[i][j]) swap(v1[i][j],v2[i][j]); } findAns(a,v1,v2,x); for(int t=head[x];t!=-1;t=edge[t].next) { int l=edge[t].j; if(l==pre) continue; findAns(dp[l],v1,v2,x); } } int main() { //freopen("data.in","r",stdin); int T; scanf("%d",&T); while(T--) { int n; scanf("%d %d",&n,&C); init(n); int a[2][4]; clArr(a); memset(dp,-1,sizeof(dp)); ans=0; dfs(-1,0,a); printf("%d ",ans); } return 0; }