(Code)
#include<cstdio>
#include<algorithm>
using namespace std;
int vis[100005],n,m,fa[100005];
struct nd{
int l,r,v;
}a[100005];
bool cmp(nd x,nd y){return x.v > y.v;}
int find(int x)
{
if (x != fa[x]) fa[x] = find(fa[x]);
return fa[x];
}
int merge(int x,int y)
{
int t1,t2;
t1 = find(x);
t2 = find(y);
if (vis[t1] && vis[t2]) return 0;
if (vis[fa[t1]]) fa[t2] = fa[t1];
else fa[t1] = fa[t2];
return 1;
}
int main()
{
scanf("%d%d",&n,&m);
for (int i = 1; i <= m; i++)
{
int q;
scanf("%d",&q);
vis[q + 1] = 1;
}
long long sum = 0,ans = 0;
for (int i = 1; i < n; i++)
scanf("%d%d%d",&a[i].l,&a[i].r,&a[i].v),sum += (long long)a[i].v;
sort(a + 1,a + n,cmp);
for (int i = 1; i <= n; i++) fa[i] = i;
for (int i = 1; i < n; i++)
if (merge(a[i].l + 1,a[i].r + 1))
ans += (long long)a[i].v;
printf("%lld",sum - ans);
}