这题如果不要输出方案绝对没有2500。状态10分钟就开出来了,输出方案搞了2两个多小时。。。那个最小生成树解法太牛逼了,完全想不到,只能用树形dp了。
光靠题目所给的那个很像博弈的题面必然没法做,第一步先转化题面:对于每一个叶子找一个"控制"它的点,这个点在它到根的路径上,使得选出的点不重复且权值和最小。
显然对于每一个叶子必须且仅需一个点就可以把它控制成需要的权值。
然后是需要直觉的一步了,我也不知道我怎么想到的:一个以 (u) 为根的子树除 (u) 外的部分,至多一个叶子没有控制的节点,否则子树不合法。
如果有一叶子没被控制,那么我们选择 (u) 即可,都被控制了就不用选了。
那么dp方程就呼之欲出了:
(dp[u][0]) 表示:以 (u) 为根的子树内,叶子个数等于选出的节点个数的最小代价
(dp[u][1]) 表示:以 (u) 为根的子树内,不包括 (u),叶子个数等于选出的节点个数加一的最小代价。这里的不包括 (u) 很重要,不然没法转移(或者说我不会转移)。
(dp[u][1]=(sum dp[v][0])-max{dp[v][0]-dp[v][1] }) ,也就是说我们必须选择恰好一棵子树让它叶子个数等于选出节点个数加一。
(dp[u][0]=max(sum dp[v][0],dp[u][1]+w[u])),要么不包括 (u) 的地方少了一个,那么 (u) 必须选。要么每一个子树都恰好选满。
最小代价就是 (dp[1][0])
然后就是输出方案了。其实不走弯路也还好,就是要想清楚,然后大力乱搞。
输出方案可以参考一下我的思路,最好自己乱搞,毕竟乱搞还是自己想比较清楚。
我们发现一个点可能可以取必须满足 (dp[u][0]=dp[u][1]+w[u]) 。如果只取 (dp[u][1]) 说明这个点没取,取这个点就必须从 (dp[u][1]) 转移然后加上 (w[u]) ,又要让这颗子树代价最小,那么 (dp[u][0]) 就必须从 (dp[u][1]+w[u]) 转移过来。
但是这样还不够(好像过不去样例),因为 (dp[u][0]) 不一定可以转移到 (dp[1][0]) ,所以我们还需要判断每一个状态 (dp[u][i]) 是否可以转移到 (dp[1][0]),我们再用一个数组 (can[i][j]) 来记录这个。
再记一些辅助变量:
(g[u]) 表示 (max{dp[v][0]-dp[v][1] }),那么 (dp[u][0]=(sum dp[v][0])-g[u]);
(num[u]) 表示 (dp[u][0]) 可以从多少个 (dp[v][1]) 转移过来
(sum[u]) 表示 (sum dp[v][0])
然后转移 (can) ,主要思想就是判断 (v) 是否只能从 (dp[v][0]) 或 (dp[v][1]) 转移上来。
-
如果 (can[u][0]=1)
-
转移到 (can[v][0])
-
如果 (num[u]) 大于 (1) ,那么 (can[v][0]) 等于 (1) ,因为它不必须从 (dp[u][1]) 转移上去。
-
如果 (g[u] ot= dp[v][0]-dp[v][1]) ,那么 (can[v][0]) 等于 (1),因为 (dp[u][0]) 根本不是从 (dp[v][1]) 转移上去的,必然是从 (dp[v][0]) 转移上去。
-
如果 (sum[u]=dp[u][0]) ,那么 (can[v][0]) 等于 (1) ,因为 (u) 可以全部从 (dp[v][0]) 转移上来。
-
-
转移到 (can[v][1])
- 如果 (sum[u]-dp[v][0]+dp[v][1]+w[u]=dp[u][0]) ,那么 (can[v][1]=1) ,即强制从 (dp[v][1]) 转移上来,看看是否改变最优解。
-
-
如果 (can[u][1]=1)
-
转移到 (can[v][0])
-
如果 (g[u] ot= dp[v][0]-dp[v][1]) ,那么 (can[v][0]) 等于 (1) ,因为它根本不会转移到 (dp[u][1])。
-
如果 (num[u]>1) ,那么 (can[v][0]) 等于 (1) ,因为 (dp[u][1]) 有多种选择,(dp[v]) 可以从 (0) 转移上去。
-
-
转移到 (can[v][1])
- 如果 (g[u]=dp[v][0]-dp[v][1]) ,那么 (can[v][1]) 等于 (1) ,从上面的转移方程可以看出 (dp[u][1]) 可以从 (dp[v][1]) 转移上来。
-
调死我了,太乱搞了
而且为啥没有spj啊/fn,好好的 (O(n)) 又因为输出变成了 (O(nlog n)) ,和那种非常牛逼的最小生成树解法复杂度相同了。本来还以为“理论复杂度更优”的/kk
upd:2020.12.15:忘了可以桶排,不过没啥用就不管了,理论复杂度确实更优(
//Orz cyn2006
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define sz(v) (int)v.size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return f?x:-x;
}
const int N = 200005;
int n, w[N], deg[N], num[N];
vector <int> ans;
LL dp[N][2], sum[N], g[N];
int hed[N], et;
bool lef[N], can[N][2];
struct edge { int nx, to; }e[N << 1];
void adde(int u, int v) { e[++ et].nx = hed[u], e[et].to = v, hed[u] = et; }
void dfs(int u, int ft) {
if (deg[u] == 1 && u != 1) {
dp[u][0] = w[u], dp[u][1] = 0, lef[u] = 1;
return;
}
LL mx = 0;
for (int i = hed[u]; i; i = e[i].nx) {
int v = e[i].to; if (v == ft) continue;
dfs(v, u);
sum[u] += dp[v][0];
if(mx < dp[v][0] - dp[v][1]) mx = dp[v][0] - dp[v][1], g[u] = mx, num[u] = 1;
else if (mx == dp[v][0] - dp[v][1]) ++ num[u];
}
dp[u][1] = sum[u] - mx;
dp[u][0] = min(dp[u][1] + w[u], sum[u]);
}
void calc(int u, int ft) {
if (can[u][0] && dp[u][0] == dp[u][1] + w[u]) ans.pb(u);
for (int i = hed[u]; i; i = e[i].nx) {
int v = e[i].to; if (v == ft) continue;
if (can[u][0]) {
if (sum[u] - dp[v][0] + dp[v][1] + w[u] == dp[u][0]) can[v][1] = 1;
if (sum[u] == dp[u][0] || g[u] != dp[v][0] - dp[v][1] || num[u] > 1) can[v][0] = 1;
}
if (can[u][1]) {
if (g[u] == dp[v][0] - dp[v][1]) can[v][1] = 1;
if (g[u] != dp[v][0] - dp[v][1] || num[u] > 1) can[v][0] = 1;
}
calc(v, u);
}
}
signed main() {
n=read();
rep (i, 1, n) w[i] = read();
rep (i, 2, n) {
int x = read(), y = read();
adde(x, y), adde(y, x), ++ deg[x], ++ deg[y];
}
dfs(1, 0), can[1][0] = 1, can[1][1] = dp[1][0] == dp[1][1] + w[1], calc(1, 0);
printf("%lld %d
",dp[1][0], sz(ans));
sort(ans.begin(), ans.end());
rep(i, 0, sz(ans) - 1) printf("%d ", ans[i]);
puts("");
return 0;
}