分析:
如果我们知道了有哪些点需要访问,最短距离是多少呢
建出虚树,所有边权和为\(Sum\),直径为\(L\),那么答案为\(2Sum-L\)
期望=总贡献/方案
方案肯定为\(\binom{m}{k}\),我们开始算总贡献
先求\(Sum\),考虑每条边会在多少种情况下做贡献,显然是其两端都有关键点的情况下
设其在树上所接的儿子为\(u\)
这里的方案为\(\binom{m}{k}-\binom{m-sz_u}{k}-\binom{sz_u}{k}\)
应该不用解释什么
开始算直径
(我记得直径期望不是个很恐怖的东西吗(错乱)
由于这里\(m\)只有500,我们可以暴力处理出每对点的距离
我们暴力枚举,强行让某两点做直径端点,遇到同样大小的取编号最小,看剩下哪些点是可以选择的,假设有\(P\)个
那么这条直径做贡献的方案数为\(\binom{P}{k-2}\)
总复杂度\(O(nlogn+m^2logn+m^3)\),可以通过
\(O(nlogn+m^2logn)\)这里看自己的LCA求法吧,我主要为了省事(
#include<cstdio>
#include<cmath>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<queue>
#include<set>
#include<map>
#include<vector>
#include<string>
#define maxn 200005
#define maxm 505
#define INF 0x3f3f3f3f
#define MOD 998244353
using namespace std;
inline long long getint()
{
long long num=0,flag=1;char c;
while((c=getchar())<'0'||c>'9')if(c=='-')flag=-1;
while(c>='0'&&c<='9')num=num*10+c-48,c=getchar();
return num*flag;
}
int n,m,K;
int fir[maxn],nxt[maxn],to[maxn],len[maxn],cnt;
int f[maxn][18],sz[maxn],dpt[maxn];
long long dis[maxn];
int C[maxm][maxm];
int p[maxm];
long long D[maxm][maxm];
int ans;
inline int upd(int x){return x<MOD?x:x-MOD;}
inline int ksm(int num,int k)
{
int ret=1;
for(;k;k>>=1,num=1ll*num*num%MOD)if(k&1)ret=1ll*ret*num%MOD;
return ret;
}
inline void newnode(int u,int v,int w)
{to[++cnt]=v,nxt[cnt]=fir[u],fir[u]=cnt,len[cnt]=w;}
inline void dfs(int u)
{
for(int i=fir[u];i;i=nxt[i])if(to[i]!=f[u][0])
{
dpt[to[i]]=dpt[u]+1,dis[to[i]]=dis[u]+len[i],f[to[i]][0]=u;
dfs(to[i]),sz[u]+=sz[to[i]];
int tmp=upd(C[m][K]-upd(C[m-sz[to[i]]][K]+C[sz[to[i]]][K])+MOD);
ans=(ans+1ll*len[i]*tmp)%MOD;
}
}
inline int LCA(int u,int v)
{
if(dpt[u]<dpt[v])swap(u,v);
for(int i=17;~i;i--)if((dpt[u]-dpt[v])&(1<<i))u=f[u][i];
if(u==v)return u;
for(int i=17;~i;i--)if(f[u][i]!=f[v][i])u=f[u][i],v=f[v][i];
return f[u][0];
}
inline long long getdis(int u,int v)
{return dis[u]+dis[v]-2*dis[LCA(u,v)];}
int main()
{
n=getint(),m=getint(),K=getint();
for(int i=1;i<=m;i++)sz[p[i]=getint()]=1;
for(int i=1;i<n;i++)
{
int u=getint(),v=getint(),w=getint();
newnode(u,v,w),newnode(v,u,w);
}
if(K==1){printf("0\n");return 0;}
for(int i=0;i<=m;i++)
{
C[i][0]=1;
for(int j=1;j<=i;j++)C[i][j]=upd(C[i-1][j-1]+C[i-1][j]);
}
dfs(1);
ans=upd(2*ans);
for(int j=1;j<18;j++)for(int i=1;i<=n;i++)f[i][j]=f[f[i][j-1]][j-1];
for(int i=1;i<=m;i++)for(int j=1;j<=m;j++)D[i][j]=getdis(p[i],p[j]);
for(int i=1;i<=m;i++)for(int j=i+1;j<=m;j++)
{
int P=0;
long long L=D[i][j];
for(int k=1;k<=m;k++)
{
long long L1=D[i][k],L2=D[j][k];
if((L>L1||(L==L1&&j<k))&&(L>L2||(L==L2&&i<k)))P++;
}
L%=MOD;
ans=upd(ans-1ll*L*C[P][K-2]%MOD+MOD);
}
ans=1ll*ans*ksm(C[m][K],MOD-2)%MOD;
printf("%d\n",ans);
}