原文链接https://www.cnblogs.com/zhouzhendong/p/CF1109D.html
题意
所有边权都是 [1,m] 中的整数的所有 n 个点的树中,点 a 到点 b 的距离恰好是 m 的有几个。
$$n,mleq 10^6$$
题解
首先显然 a 和 b 的具体值是没用的。
于是我们就可以直接计数:
枚举树链 ab 上除了 a 和 b 有几个节点,假设是 i 个节点,那么这种情况下的方案总数是多少?
首先,ab 路径上 i+1 条 [1,m] 的边的和是 m ,共有 $inom{m-1}{i}$ 种边权的取值;
然后,ab 路径上共有 i 个点,方案数是从剩下的 n-2 个点种选出 i 个并排列,即 $inom{n-2}i i!$ 。
然后,剩下的 n-2-i 条边每条都有 m 种取值,方案数是 $m^{n-2-i}$ 。
最后,考虑生成树的个数,用 prufer 序列的结论推一推就可以知道方案数是 $n^{n-3-i}(i+2)$ 。
所以答案是
$$sum_{i=0}^{n-2} inom{n-2}{i}inom{m-1}{i}m^{n-2-i}n^{n-3-i}i!(i+2)$$
UPD(2019-03-04): 更新一下关于那个 prufer 编码推出的公式的证明:
prufer 编码有几个性质:
1. 假如是 n 个点,那么编码长度为 n-2 ,且每一个位置可以放 1~n 之间的任意数,每一个 prufer 编码与每一个树一一对应。
2. 假设树中一个点 x 的度是 d[x] ,那么在对应的 prufer 编码中,x 出现 d[x]-1 次。
假设我们有 n 个点,被分成了 k 个点集,每个点集里的点已经连通,不同点集之间的点两两无边,现在我们要在这个 n 个点 n-k 条边的基础上求生成树个数。设第 i 个点集包含的点数为 size[i] 。
那么,如果我们把这 k 个点集每一个点都看作一个点,做一个 k 个点的生成树,那么有 $k^{k-2}$ 种方案;但是由于这里的每一个点都是一个点集,所以假设它是点集 i,那么从他连出去的每一条边的属于集合i的端点,都有 size[i] 种选法。也就是说,对于一个 k 个点的 prufer 编码,假设在这个编码中,数字 i 出现了 c[i] 次,那么这个编码对应到原树上就会贡献 $prod_{i=1}^k size[i] ^ {c[i]+1} $ 次。
我们把每一个 "c[i]+1" 中多出来的 1 提出,看作常量,我们来对于所有 prufer 编码求贡献总和:
$$sum_{P是一个prufer编码} prod_{i=1}^k size[i] ^ {c[i]}$$ 。
考虑到这个prufer编码的每一位选择第 i 个点集,就会对乘积有 $size[i]$ 的贡献,根据乘法分配律,我们可以得到上面的那个式子就是: $(sum_{i=1}^k size[i])^{k-2} = n ^ {k-2}$ 。
再乘上之前提出的东西,所以答案就是:
$$(sum_{i=1}^k size[i])^{k-2} cdot prod_{i=1}^k size[i] $$
本题要求的那个,只是这个模型的弱化版。至此已经可以解决这个问题了。
代码
#include <bits/stdc++.h> using namespace std; typedef long long LL; LL read(){ LL x=0,f=0; char ch=getchar(); while (!isdigit(ch)) f|=ch=='-',ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); return f?-x:x; } const int N=1000005,mod=1e9+7; int n,m,a,b; int Fac[N],Inv[N]; int C(int n,int m){ if (m>n||m<0) return 0; return (LL)Fac[n]*Inv[m]%mod*Inv[n-m]%mod; } int Pow(int x,int y){ if (y<0) return Pow(x,y+mod-1); int ans=1; for (;y;y>>=1,x=(LL)x*x%mod) if (y&1) ans=(LL)ans*x%mod; return ans; } void init(int n){ for (int i=Fac[0]=1;i<=n;i++) Fac[i]=(LL)Fac[i-1]*i%mod; Inv[n]=Pow(Fac[n],mod-2); for (int i=n;i>=1;i--) Inv[i-1]=(LL)Inv[i]*i%mod; } void Add(int &x,int y){ if ((x+=y)>=mod) x-=mod; } int main(){ n=read(),m=read(),a=read(),b=read(); init(max(n,m)); int ans=0; for (int i=0;i<=n-2;i++) Add(ans,(LL)C(n-2,i)*C(m-1,i)%mod*Fac[i]%mod *Pow(m,n-2-i)%mod*Pow(n,n-3-i)%mod*(i+2)%mod); cout<<ans<<endl; return 0; }