从叶子往上先拓扑一下,建立虚拟root,从root开始dfs。注意到每个点的最优取值一定是一个区间(中位数区间),从儿子区间推出父亲区间即可
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int N=500005;
int n,m,l[N],r[N],h[N],cnt,q[N],tot,d[N],fa[N],a[N<<1];
bool v[N];
long long ans;
pair<int,int>b[N<<1];
struct qwe
{
int ne,to;
}e[N<<1];
int read()
{
int r=0,f=1;
char p=getchar();
while(p>'9'||p<'0')
{
if(p=='-')
f=-1;
p=getchar();
}
while(p>='0'&&p<='9')
{
r=r*10+p-48;
p=getchar();
}
return r*f;
}
void add(int u,int v)
{//cerr<<u<<" "<<v<<endl;
cnt++;
d[u]++;
fa[v]=u;
e[cnt].ne=h[u];
e[cnt].to=v;
h[u]=cnt;
}
void dfs(int u)
{
if(h[u]==0)
return;
for(int i=h[u];i;i=e[i].ne)
dfs(e[i].to);
long long tmp=1e18,c=0,s=0,con=0,now;
for(int i=h[u];i;i=e[i].ne)
a[++con]=l[e[i].to],a[++con]=r[e[i].to],c--,s+=l[e[i].to];
sort(a+1,a+1+con);
for(int i=1;i<=con;i++)
{
c++,s-=a[i],now=s+1ll*a[i]*c;
if(now<tmp)
l[u]=a[i],tmp=now;
if(now==tmp)
r[u]=a[i];
}
ans+=tmp;
}
int main()
{
n=read(),m=read();
for(int i=1;i<n;i++)
{
int x=read(),y=read();
add(x,y),add(y,x);
}
for(int i=1;i<=m;i++)
l[i]=r[i]=read();
int ll=1,rr=0;
for(int i=1;i<=m;i++)
v[q[++rr]=i]=1;
for(int x;ll<=rr;ll=x+1)
{
for(int u=ll;u<=rr;u++)
for(int i=h[q[u]];i;i=e[i].ne)
if(!v[e[i].to])
b[++tot]=make_pair(e[i].to,q[u]);
x=rr;
for(int u=ll;u<=x;u++)
for(int i=h[q[u]];i;i=e[i].ne)
if(!v[e[i].to])
if((--d[e[i].to])<=1)
v[q[++rr]=e[i].to]=1;
}
memset(h,0,sizeof(h));
memset(fa,0,sizeof(fa));
cnt=0;
for(int i=1;i<=tot;i++)
add(b[i].first,b[i].second);
for(int i=1;i<=n;i++)
if(!fa[i])
add(0,i);
dfs(0);
printf("%lld
",ans);
return 0;
}