题解:
定义 dp[u][0] 为遍历完u中的所有节点, 但不回到u点的路径花费值。
定义 dp[u][1] 为遍历完u中的所有节点, 且要回到u点的路径花费值。
转移方程.
dp[u][1] = sum(dp[v][1] + 2).
dp[u][0] = max(dp[v][1] + 2 - dp[v][0] - 1).
需要注意的是,不要把不需要走的路径值传递上来。
只有这个路径会遍历一个需要清除的点的时候,才可以转移状态。
这样从1dfs完之后,我们就可以计算出上面定义的状态的值。
然后我们反着dfs一遍,就可以算出从每个点出发清除完所有点的值。
代码:
#include<bits/stdc++.h> using namespace std; #define Fopen freopen("_in.txt","r",stdin); freopen("_out.txt","w",stdout); #define LL long long #define ULL unsigned LL #define fi first #define se second #define pb push_back #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 #define lch(x) tr[x].son[0] #define rch(x) tr[x].son[1] #define max3(a,b,c) max(a,max(b,c)) #define min3(a,b,c) min(a,min(b,c)) typedef pair<int,int> pll; const int inf = 0x3f3f3f3f; const int _inf = 0xc0c0c0c0; const LL INF = 0x3f3f3f3f3f3f3f3f; const LL _INF = 0xc0c0c0c0c0c0c0c0; const LL mod = (int)1e9+7; const int N = 2e5 + 100; vector<int> vc[N]; LL dp[N][2]; /// 1->back 0-no-back LL dif[N][2]; /// dif[0] > dif[1] int vis[N]; int dfs(int o, int u){ for(int v : vc[u]){ if(o == v) continue; int f = dfs(u, v); if(!f) continue; LL t1 = dp[v][1] + 2; LL t2 = dp[v][0] + 1; t2 = t1 - t2; dp[u][1] += t1; if(dif[u][0] < t2) swap(t2, dif[u][0]); if(dif[u][1] < t2) swap(t2, dif[u][1]); } dp[u][0] = dp[u][1] - dif[u][0]; if(dp[u][1] == 0 && !vis[u]) return 0; return 1; } LL ans = INF, ansid; int fdfs(int o, int u){ if(o){ LL t1 = dp[o][1]; if(dp[u][1] || vis[u]) t1 -= (dp[u][1] + 2); if(t1 == 0 && !vis[o]) ; else { LL t2; LL now_dif = 0; if(dp[u][1] || vis[u]) now_dif= (dp[u][1]+2) - (dp[u][0]+1); if(now_dif == dif[o][0]) t2 = t1 - dif[o][1]; else t2 = t1 - dif[o][0]; t2 += 1; t1 += 2; dp[u][1] += t1; now_dif = t1 - t2; if(dif[u][0] < now_dif) swap(dif[u][0], now_dif); if(dif[u][1] < now_dif) swap(dif[u][1], now_dif); dp[u][0] = dp[u][1] - dif[u][0]; } } if(dp[u][0] < ans){ ans = dp[u][0]; ansid = u; } else if(dp[u][0] == ans && ansid > u) ansid = u; for(int v : vc[u]){ if(o == v) continue; fdfs(u, v); } return 0; } int main(){ int n, m; scanf("%d%d", &n, &m); int u, v; for(int i = 1; i < n; ++i){ scanf("%d%d", &u, &v); vc[u].pb(v); vc[v].pb(u); } for(int i = 1; i <= m; ++i){ scanf("%d", &u); vis[u] = 1; } dfs(0,1); fdfs(0,1); cout << ansid << " "<< ans << endl; return 0; }