题解
考虑如果我们选出了一个偶数区间,那我们可以在树上标出这些点,考虑贪心,如果一个子树内能匹配的就尽量匹配,所以一个子树延伸上去的点不会超过一个,并且我们发现每条边最多被选一次,所以我们考虑每条边的贡献。考虑到确定一段区间后,如果一条边两端子树内的点为奇数的话,那这个区间下这条边就肯定要选。所以我们考虑一个子树内在一段偶数区间内含有奇数个点的方案数。考虑暴力,我们可以把子树内每个点在序列上对应的位置上打上标记,然后做前缀和,如果有 $l$ 和 $r(0 le l < r le m)$ 满足 $l,r$ 奇偶性相同并且 $s_l,s_r$ 奇偶性不同的话,那区间 $(l,r]$ 就是符合要求的,于是我们线段树处理即可。这样是 $O(n^2logm)$ 的,于是我们可以进行 $ ext{dsu on tree}$ ,复杂度就是 $O(nlognlogm)$ 的了。(这题一定要写读优!)
代码
#include <bits/stdc++.h> #define _(d) while(d(isdigit(c=getchar()))) using namespace std; int Rd(){char c;_(!);int x=c^48;_()x=(x<<3)+(x<<1)+(c^48);return x;} const int N=1e5+5,P=998244353;vector<int>h[N]; int n,m,hd[N],V[N<<1],W[N<<1],nx[N<<1],t,f[N],s[2][2][N<<2],tg[N<<2],ans,sz[N],son[N]; void add(int u,int v,int w){ nx[++t]=hd[u];V[hd[u]=t]=v;W[t]=w; } #define Ls k<<1 #define Rs k<<1|1 #define mid ((l+r)>>1) void up(int k){ for (int i=0;i<2;i++) for (int j=0;j<2;j++) s[i][j][k]=s[i][j][Ls]+s[i][j][Rs]; } void build(int k,int l,int r){ if (l==r){s[l&1][0][k]=1;return;} build(Ls,l,mid);build(Rs,mid+1,r);up(k); } void Dfs(int u,int fr){ sz[u]=1; for (int i=hd[u],v;i;i=nx[i]){ if ((v=V[i])==fr) continue; Dfs(v,u),sz[u]+=sz[v];f[v]=W[i]; if (sz[son[u]]<sz[v]) son[u]=v; } } void push(int k){ swap(s[0][0][k],s[0][1][k]); swap(s[1][0][k],s[1][1][k]); tg[k]^=1; } void down(int k){ push(Ls);push(Rs);tg[k]=0; } void upd(int k,int l,int r,int L,int R){ if (L<=l && r<=R) return push(k); if (tg[k]) down(k); if (mid>=L) upd(Ls,l,mid,L,R); if (mid<R) upd(Rs,mid+1,r,L,R); up(k); } void ins(int u){ int z=h[u].size(); for (int i=0;i<z;i++) upd(1,0,m,h[u][i],m); } void upd(int u,int fr){ ins(u); for (int i=hd[u];i;i=nx[i]) if (V[i]!=fr) upd(V[i],u); } void dfs(int u,int fr,int tp){ for (int i=hd[u];i;i=nx[i]) if (V[i]!=fr && V[i]!=son[u]) dfs(V[i],u,0); if (son[u]) dfs(son[u],u,1); for (int i=hd[u];i;i=nx[i]) if (V[i]!=fr && V[i]!=son[u]) upd(V[i],u);ins(u); (ans+=1ll*f[u]*(1ll*s[0][0][1]*s[0][1][1]%P+1ll*s[1][0][1]*s[1][1][1]%P)%P)%=P; if (!tp) upd(u,fr); } int main(){ n=Rd(),m=Rd(); for (int i=1,u,v,w;i<n;i++) u=Rd(),v=Rd(),w=Rd(), add(u,v,w),add(v,u,w); for (int i=1;i<=m;i++) h[Rd()].push_back(i); build(1,0,m);Dfs(1,0);dfs(1,0,0); printf("%d ",ans);return 0; }