题解 (by;zjvarphi)
首先,分析一下这个答案:本质上是求在一条路径上,选择了一些点,这些点的贡献是它周围的点权和 - 它上一步的点权
对于一棵树,可以先确定一个根,然后每条路径就可以分成向上和向下的两部分
那么定状态 (dp_{i,j,0}) 表示由 (i) 向 (i) 的子树走,选了 (j) 个点放磁铁,(dp_{i,j,1}) 则表示向上走
那么转移方程就很好想
[dp_{i,j,0}=max{dp_{sonin son_i,j,0},dp_{sonin son_i,j-1,0}+sum_x-num_{son}}
]
[dp_{i,j,1}=max{dp_{sonin son_i,j,1},dp_{sonin son_i,j-1,1}+sum_x-num_{fa}}
]
这个方程就是由下往上转移,且对于一个放了的点,它上一个经过的点无法被它管到。
初始化时要把 (dp_{x,i,0}) 都初始化为 (sum_x),因为转移时可以是一个点加上一条向下的路径。
转移答案时要注意:对于一个节点的所有儿子要从前往后转移一遍,同时从后往前转移一遍,这样才能保证所有分支都有可能向下或向上。
还要在逆向转移之前恢复 (dp) 数组原有状态。
Code
#include<bits/stdc++.h>
#define ri register signed
#define p(i) ++i
using namespace std;
namespace IO{
char buf[1<<21],*p1=buf,*p2=buf;
#define gc() p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++
template<typename T>inline void read(T &x) {
ri f=1;x=0;register char ch=gc();
while(ch<'0'||ch>'9') {if (ch=='-') f=0;ch=gc();}
while(ch>='0'&&ch<='9') {x=(x<<1)+(x<<3)+(ch^48);ch=gc();}
x=f?x:-x;
}
}
using IO::read;
namespace nanfeng{
#define FI FILE *IN
#define FO FILE *OUT
template<typename T>inline T cmax(T x,T y) {return x>y?x:y;}
template<typename T>inline T cmin(T x,T y) {return x>y?y:x;}
typedef long long ll;
static const int N=1e5+7;
int first[N],nm[N],tmp[N],cnt,t=1,n,vn;
struct edge{int v,nxt;}e[N<<1];
inline void add(int u,int v) {
e[t].v=v,e[t].nxt=first[u],first[u]=t++;
e[t].v=u,e[t].nxt=first[v],first[v]=t++;
}
ll dp[N][107][2],ans,sum[N];
inline void update(int x,int v,int fa) {
for (ri i(1);i<vn;p(i)) ans=cmax(ans,dp[x][i][0]+dp[v][vn-i][1]);
for (ri i(1);i<=vn;p(i)) {
dp[x][i][0]=cmax(dp[x][i][0],cmax(dp[v][i][0],dp[v][i-1][0]+sum[x]-nm[v]));
dp[x][i][1]=cmax(dp[x][i][1],cmax(dp[v][i][1],dp[v][i-1][1]+sum[x]-nm[fa]));
}
}
void dfs(int x,int fa) {
for (ri i(first[x]),v;i;i=e[i].nxt) {
if ((v=e[i].v)==fa) continue;
dfs(v,x);
}
for (ri i(1);i<=vn;p(i)) dp[x][i][0]=sum[x];
cnt=0;
for (ri i(first[x]);i;i=e[i].nxt) update(x,tmp[p(cnt)]=e[i].v,fa);
for (ri i(1);i<=vn;p(i)) dp[x][i][0]=sum[x],dp[x][i][1]=0;
for (ri i(cnt);i;--i) update(x,tmp[i],fa);
ans=cmax(ans,cmax(dp[x][vn][1],dp[x][vn][0]));
}
inline int main() {
// FI=freopen("nanfeng.in","r",stdin);
// FO=freopen("nanfeng.out","w",stdout);
// printf("cost = %d
",sizeof(dp)>>20);
read(n),read(vn);
for (ri i(1);i<=n;p(i)) read(nm[i]);
for (ri i(1),u,v;i<n;p(i)) {
read(u),read(v);
add(u,v);
}
for (ri i(1);i<=n;p(i)) {
for (ri j(first[i]),v;j;j=e[j].nxt) sum[i]+=nm[e[j].v];
}
dfs(1,0);
printf("%lld
",ans);
return 0;
}
}
int main() {return nanfeng::main();}