分析
g[x][i]表示x点向上i条边的所有关键点都被覆盖的代价
f[x][i]表示x点向下i条边有关键点未被覆盖的代价
转移即可
代码
#include<bits/stdc++.h>
using namespace std;
const int inf = 1e9+7;
int n,m,d,w[500100],is[500100],f[500100][22],g[500100][22];
vector<int>v[500100];
inline void dfs(int x,int fa){
if(is[x])f[x][0]=g[x][0]=w[x];
for(int i=1;i<=d;i++)g[x][i]=w[x];
g[x][d+1]=inf;
for(int i=0;i<v[x].size();i++)
if(v[x][i]!=fa){
dfs(v[x][i],x);
for(int j=0;j<=d;j++)g[x][j]=min(g[x][j]+f[v[x][i]][j],f[x][j+1]+g[v[x][i]][j+1]);
for(int j=d;j>=0;j--)g[x][j]=min(g[x][j],g[x][j+1]);
f[x][0]=g[x][0];
for(int j=1;j<=d;j++)f[x][j]+=f[v[x][i]][j-1];
for(int j=1;j<=d;j++)f[x][j]=min(f[x][j],f[x][j-1]);
}
return;
}
int main(){
int i,j,k;
scanf("%d%d",&n,&d);
for(i=1;i<=n;i++)scanf("%d",&w[i]);
scanf("%d",&m);
for(i=1;i<=m;i++){
int x;
scanf("%d",&x);
is[x]=1;
}
for(i=1;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);
v[x].push_back(y);
v[y].push_back(x);
}
dfs(1,0);
printf("%d
",g[1][0]);
return 0;
}