题意:有一颗苹果树,每个节点上有一些苹果,任何两个节点之间的距离都为1,每走一步距离都增加1,无论该边有没有被走过。问,在最多走k步的情况下,最多能吃到多少个苹果。
解法:设d[x][i][0]表示在以x为根的树上,最多走i步,并且最终回到节点x所能吃到的最多的苹果数。
设d[x][i][1]表示在以x为根的树上,最多走i步,并且最终不需要回到节点x所能吃到的最多的苹果数。注意到,一定有d[x][i][1] >= d[x][i][0]。
易得d[x][i][0] = max(d[x][i][0], d[x][i-j-2][0] + d[y][j][0]),其中y为所有x的子节点。
d[x][i][1] = max(d[x][i][1], d[x][i-j-1][0] + d[y][j][1])。
但是,我WA了很久,就是因为漏了这种情况:d[x][i][1] = max(d[x][i][1], d[x][i-j-2] + d[y][j][0]),即最终没有回到x节点,但是也并非停留在以y为根的树上,而是停留在别的子树上。
d[0][k][1]即为所求。
tag:树形dp, 背包
Ps:感觉做了三道树形dp的题,都几乎是dfs+背包的模式。
1 /* 2 * Author: Plumrain 3 * Created Time: 2013-11-20 11:21 4 * File Name: DP-POJ-2486.cpp 5 */ 6 #include <iostream> 7 #include <cstdio> 8 #include <cstring> 9 #include <vector> 10 11 using namespace std; 12 13 #define CLR(x) memset(x, 0, sizeof(x)) 14 #define PB push_back 15 #define out(x) cout<<#x<<":"<<(x)<<endl 16 17 int n, k; 18 bool vis[105]; 19 vector<int> pat[105], v[105]; 20 int c[105], d[105][205][2]; 21 22 void init() 23 { 24 CLR (vis); 25 for (int i = 0; i < n; ++ i){ 26 pat[i].clear(); 27 v[i].clear(); 28 } 29 30 for (int i = 0; i < n; ++ i) 31 scanf ("%d", &c[i]); 32 int t1, t2; 33 for (int i = 0; i < n-1; ++ i){ 34 scanf ("%d%d", &t1, &t2); 35 -- t1; -- t2; 36 pat[t1].PB (t2); 37 pat[t2].PB (t1); 38 } 39 } 40 41 void dfs1(int x) 42 { 43 vis[x] = 1; 44 int sz = pat[x].size(); 45 if (!sz) return; 46 47 for (int i = 0; i < sz; ++ i){ 48 int y = pat[x][i]; 49 if (vis[y]) continue; 50 v[x].PB (y); 51 dfs1(y); 52 } 53 } 54 55 void dfs2(int x, int p) 56 { 57 for (int i = 0; i <= k; ++ i) 58 d[x][i][0] = d[x][i][1] = c[x]; 59 if (!k) return; 60 61 int sz = v[x].size(); 62 for (int i = 0; i < sz; ++ i){ 63 int y = v[x][i]; 64 dfs2(y, k-1); 65 66 for (int j = p; j >= 0; -- j){ 67 for (int k = 0; k <= j - 2; ++ k){ 68 d[x][j][0] = max(d[x][j][0], d[x][j-k-2][0] + d[y][k][0]); 69 d[x][j][1] = max(d[x][j][1], d[x][j-k-2][1] + d[y][k][0]); 70 d[x][j][1] = max(d[x][j][1], d[x][j-k-1][0] + d[y][k][1]); 71 } 72 d[x][j][1] = max(d[x][j][1], d[x][0][0] + d[y][j-1][1]); 73 } 74 } 75 } 76 77 int main() 78 { 79 while (scanf ("%d%d", &n, &k) != EOF){ 80 init(); 81 82 dfs1(0); 83 84 CLR (d); 85 dfs2(0, k); 86 87 printf ("%d ", d[0][k][1]); 88 } 89 return 0; 90 }