先考虑怎样配对最优,发现对于一条边,若其两端的子树内需要配对的点的个数都为奇数,则该边会有 (1) 的贡献,否则没有贡献,得这样为最优情况。
那么对于一棵子树,将其内部的点标记为 (1),得到一个 (01) 串,子树父边的贡献为 (01) 串中长度为偶数且区间和为奇数的区间个数。
可以用线段树合并来优化,线段树上每个节点维护 (v(cur,0/1,0/1)),表示 (cur) 对应的区间中前缀和为偶数或奇数,位置为偶数或奇数的位置个数,设点 (x) 的父边权值为 (val),得其对答案的贡献为:
[large val(v(rt_x,0,0)v(rt_x,1,0)+v(rt_x,0,1)v(rt_x,1,1))
]
为方便计算位置个数,线段树区间设为 ([1,m+1])。
#include<bits/stdc++.h>
#define maxn 200010
#define maxm 20000010
#define p 998244353
#define mid ((l+r)>>1)
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
x=0;char c=getchar();bool flag=false;
while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
if(flag)x=-x;
}
int n,m,tot;
ll ans;
int rt[maxn],ls[maxm],rs[maxm],sum[maxm],v[maxm][2][2];
struct edge
{
int to,nxt;
ll v;
}e[maxn];
int head[maxn],edge_cnt;
void add(int from,int to,int val)
{
e[++edge_cnt]={to,head[from],val},head[from]=edge_cnt;
}
void pushup(int cur,int l,int r)
{
sum[cur]=sum[ls[cur]]+sum[rs[cur]];
int k=sum[ls[cur]]&1;
for(int i=0;i<=1;++i)
for(int j=0;j<=1;++j)
v[cur][i][j]=v[ls[cur]][i][j]+v[rs[cur]][i^k][j];
if(!ls[cur]) v[cur][0][0]+=mid/2-(l-1)/2,v[cur][0][1]+=(mid+1)/2-l/2;
if(!rs[cur]) v[cur][k][0]+=r/2-mid/2,v[cur][k][1]+=(r+1)/2-(mid+1)/2;
}
void insert(int l,int r,int pos,int &cur)
{
if(!cur)
{
cur=++tot;
v[cur][0][0]=r/2-(l-1)/2;
v[cur][0][1]=(r+1)/2-l/2;
}
if(l==r)
{
sum[cur]++;
return;
}
if(pos<=mid) insert(l,mid,pos,ls[cur]);
else insert(mid+1,r,pos,rs[cur]);
pushup(cur,l,r);
}
int merge(int x,int y,int l,int r)
{
if(!x||!y) return x+y;
int cur=++tot;
if(l==r) sum[cur]=sum[x]+sum[y];
else
{
ls[cur]=merge(ls[x],ls[y],l,mid);
rs[cur]=merge(rs[x],rs[y],mid+1,r);
pushup(cur,l,r);
}
return cur;
}
void dfs(int x,int fa)
{
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(y==fa) continue;
dfs(y,x);
ans=(ans+e[i].v*((ll)v[rt[y]][0][0]*v[rt[y]][1][0]%p+(ll)v[rt[y]][0][1]*v[rt[y]][1][1]%p)%p)%p;
rt[x]=merge(rt[x],rt[y],1,m+1);
}
}
int main()
{
read(n),read(m);
for(int i=1;i<n;++i)
{
int x,y,v;
read(x),read(y),read(v);
add(x,y,v),add(y,x,v);
}
for(int i=1,x;i<=m;++i) read(x),insert(1,m+1,i,rt[x]);
dfs(1,0),printf("%lld",ans);
return 0;
}