题面
题解
大力分类讨论题。
显然,(L(L(T))) 中的点代表着 (T) 上的一条三点链。
所以 (L(L(T))) 上两点的最短路显然是沿着 (T) 上两条三点链间的唯一路径走。
然后就可以大力分类讨论了。
有很多种情况,这里举一种情况吧:
如图,我现在想从红线框起来的三点链走到蓝线框起来的三点链。不妨设 (w_1<w_2),那么 (dis=3w_1+w_2+W+3w_3+w_4)。
注意还要计算两条三点链重合的情况。
代码如下:
#include<bits/stdc++.h>
#define N 500010
#define int long long
#define ll long long
#define mod 998244353
#define div2 499122177
#define div3 332748118
using namespace std;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
int n,deg[N],fa[N];
int cnt,head[N],nxt[N<<1],to[N<<1],w[N<<1];
int p[N];
ll ans,f[N],g[N];
ll val[N],sum[N],pre[N],suf[N],su[N],sv[N];
//f[i]:i子树内三点链的个数
//g[i]:i子树外三点链的个数
void adde(int u,int v,int wi)
{
to[++cnt]=v;
w[cnt]=wi;
nxt[cnt]=head[u];
head[u]=cnt;
}
bool cmp(int a,int b)
{
return val[a]<val[b];
}
void dfs(int u)
{
f[u]=1ll*deg[u]*(deg[u]-1)%mod*div2%mod;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v==fa[u]) continue;
fa[v]=u;
dfs(v);
f[u]=(f[u]+f[v])%mod;
}
}
void solve(int u)
{
for(int i=head[u];i;i=nxt[i])
if(to[i]!=fa[u]) solve(to[i]);
int tot=0;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
sum[u]=(sum[u]+w[i])%mod;
p[++tot]=v,val[v]=w[i];
su[v]=(v==fa[u]?f[u]:g[v]);
sv[v]=(v==fa[u]?g[u]:f[v]);
}
sort(p+1,p+tot+1,cmp);
pre[0]=suf[tot+1]=0;
for(int i=1;i<=tot;i++)
pre[i]=(pre[i-1]+sv[p[i]])%mod;
for(int i=tot;i>=1;i--)
suf[i]=(suf[i+1]+sv[p[i]])%mod;
for(int i=1;i<=tot;i++)
{
int v=p[i];
ans=(ans+1ll*(tot-1)*(tot-2)%mod*div2%mod*4ll*val[v]%mod)%mod;//Y字形
if(v!=fa[u])
{
ans=(ans+(sv[v]-(deg[v]-1)+mod)*(su[v]-(deg[u]-1)+mod)%mod*4ll*val[v]%mod)%mod;//中间4倍的边
ans=(ans+1ll*(deg[u]-1)*(deg[v]-1)%mod*2ll*val[v]%mod+(sum[u]-val[v]+mod)*(deg[v]-1)%mod+(sum[v]-val[v]+mod)*(deg[u]-1)%mod)%mod;//四点链
}
ans=(ans+1ll*(tot-1)*(sv[v]-(deg[v]-1)+mod)%mod*3ll*val[v]%mod+(sum[u]-val[v]+mod)*(sv[v]-(deg[v]-1)+mod)%mod)%mod;//端点连出
ans=(ans+1ll*(su[v]-1ll*deg[u]*(deg[u]-1)%mod*div2%mod+mod)%mod*(tot-2)%mod*val[v])%mod;//中点连出
ans=(ans+1ll*(su[v]-pre[i-1]+mod-1ll*deg[u]*(deg[u]-1)%mod*div2%mod+mod)%mod*(tot-i-1)%mod*2ll*val[v])%mod;//中点连出,当前边最小,且中点连出边比当前边大
ans=(ans+1ll*(su[v]-suf[i+1]+mod-1ll*deg[u]*(deg[u]-1)%mod*div2%mod+mod)%mod*(tot-i)%mod*2ll*val[v])%mod;//中点连出,当前边最小,且中点连出边比当前边小
ans=(ans+1ll*(tot-i)*(tot-i-1)%mod*(tot-i-2)%mod*div2%mod*div3%mod*9ll*val[v]%mod)%mod;//X字形,当前边为第1小边
ans=(ans+1ll*(i-1)*(tot-i)%mod*(tot-i-1)%mod*div2%mod*7ll*val[v]%mod)%mod;//X字形,当前边为第2小边
ans=(ans+1ll*(i-1)*(i-2)%mod*div2%mod*(tot-i)%mod*5ll*val[v]%mod)%mod;//X字形,当前边为第3小边
ans=(ans+1ll*(i-1)*(i-2)%mod*(i-3)%mod*div2%mod*div3%mod*3ll*val[v]%mod)%mod;//X字形,当前边为第4小边
}
}
signed main()
{
n=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read(),w=read();
adde(u,v,w),adde(v,u,w);
deg[u]++,deg[v]++;
}
dfs(1);
for(int i=1;i<=n;i++) g[i]=f[1]-f[i];
solve(1);
printf("%lld
",ans);
return 0;
}
/*
5
5 3 1
5 2 1
1 2 1
5 4 3
*/