- 源自 ditoly 大爷的 FJ 省队集训课件
Statement
-
有 (m) 个正整数变量,求有多少种取值方案
-
使得所有变量的和不超过 (S)
-
并且前 (n) 个变量的值都不超过 (t)
-
答案对 (10^9+7) 取模
-
(m-nle1000) ,(mle 10^9) ,(tle 10^5) ,(ntle sle 10^{18})
Solution
-
由于 (ntle sle 10^{18}) ,所以我们的解可以表示成:
-
[sum_{x_1=1}^tsum_{x_2=1}^tdotssum_{x_n=1}^tinom{S-sum_{i=1}^nx_i}{m-n} ]
-
设 (s=S-(m-n)+1-sum_{i=1}^nx_i)
-
注意到上面的和式里面的组合数是一个关于 (s) 的多项式,具体地,它等于 (frac{s^{overline{m-n}}}{(m-n)!})
-
其中 (n^{overline m}) 表示 (n) 的 (m) 次上升幂,即 (prod_{i=0}^{m-1}(n+i))
-
由第一类斯特林数的生成函数可得,这个多项式的 (i) 次项系数为 (frac{egin{bmatrix}m-n\iend{bmatrix}}{(m-n)!}) ,其中 (egin{bmatrix}n\mend{bmatrix}) 为第一类斯特林数,即把 (n) 个元素分成 (m) 个圆排列的方案数
-
把组合数表示成多项式的形式之后,我们考虑如果我们求出了在前 (n) 个变量所有 (t^n) 种取值下对应的 (s^i) ((0le ile m-n))之和,那么问题就能很好地解决了
-
考虑 DP:(f[i][j]) 表示 (x_{1dots i}) 所有取值下 ((S-m+n+1-sum_{k=1}^ix_k)^j) 的和
-
考虑如何从 (f[i]) 转移到 (f[i+1])
-
这个转移即枚举 (x_{i+1}) 的取值,设 (s=S-m+n+1-sum_{k=1}^ix_k)
-
可以发现:
-
[sum_{k=1}^t(s-k)^j=sum_{k=0}^j(-1)^{j-k}inom jks^ksum(j-k,t) ]
-
其中 (sum(p,n)=sum_{i=1}^ni^p) ,可以利用插值 (O(p)) 求出
-
于是我们有了一个 DP 转移:
-
[f[i+1][j]=sum_{k=0}^j(-1)^{j-k}inom jksum(j-k,t)f[i][k] ]
-
把组合数拆开:
-
[frac{f[i+1][j]}{j!}=sum_{k=0}^jfrac{(-1)^{j-k}sum(j-k,t)}{(j-k)!}frac{f[i][k]}{k!} ]
-
设多项式 (F_i(x)=sum_{j=0}^{m-n}f[i][j]x^j) ,(G(x)=sum_{i=0}^{m-n}frac{(-1)^isum(i,t)}{i!}x^i)
-
那么很容易得到 (F_n(x)=G(x)^n) ,倍增快速幂即可
-
复杂度 (O((m-n)^2log n)) 或 (O((m-n)^2))
Code
#include <bits/stdc++.h>
const int N = 1010, rqy = 1e9 + 7;
typedef long long ll;
ll s;
int t, n, m, l, pw[N][N], f[N], fac[N], inv[N], invt[N], g[N], ans, S[N][N];
int qpow(int a, int b)
{
int res = 1;
while (b)
{
if (b & 1) res = 1ll * res * a % rqy;
a = 1ll * a * a % rqy;
b >>= 1;
}
return res;
}
int calc(int T)
{
int sum = 0, res = 0, al = 1;
for (int i = 1; i <= T + 2; i++) al = 1ll * al * (t - i + rqy) % rqy;
for (int i = 1; i <= T + 2; i++)
{
sum = (sum + pw[i][T]) % rqy;
if (i == t) return sum;
int delta = 1ll * al * invt[i] % rqy *
inv[i - 1] % rqy * inv[T + 2 - i] % rqy;
if (T + 2 - i & 1) delta = (rqy - delta) % rqy;
res = (1ll * delta * sum + res) % rqy;
}
return res;
}
int main()
{
std::cin >> s >> t >> n >> m;
l = m - n; s -= l - 1;
f[0] = fac[0] = inv[0] = inv[1] = 1;
for (int i = 1; i <= l + 2; i++) fac[i] = 1ll * fac[i - 1] * i % rqy;
for (int i = 2; i <= l + 2; i++)
inv[i] = 1ll * (rqy - rqy / i) * inv[rqy % i] % rqy;
for (int i = 2; i <= l + 2; i++) inv[i] = 1ll * inv[i] * inv[i - 1] % rqy;
for (int i = 1; i <= l + 2; i++)
{
pw[i][0] = 1;
for (int j = 1; j <= l; j++) pw[i][j] = 1ll * pw[i][j - 1] * i % rqy;
}
for (int i = 1; i <= l + 2; i++) invt[i] = qpow((t - i + rqy) % rqy, rqy - 2);
for (int i = 1; i <= l; i++) f[i] = s % rqy * f[i - 1] % rqy;
for (int i = 0; i <= l; i++) f[i] = 1ll * f[i] * inv[i] % rqy;
for (int i = 0; i <= l; i++)
{
g[i] = 1ll * calc(i) * inv[i] % rqy;
if (i & 1) g[i] = (rqy - g[i]) % rqy;
}
while (n)
{
if (n & 1) for (int i = l; i >= 0; i--)
{
f[i] = 1ll * f[i] * g[0] % rqy;
for (int j = 1; j <= i; j++)
f[i] = (1ll * f[i - j] * g[j] + f[i]) % rqy;
}
for (int i = l; i >= 0; i--)
{
g[i] = 1ll * g[i] * g[0] % rqy;
if (i) g[i] = (g[i] + g[i]) % rqy;
for (int j = 1; j < i; j++)
g[i] = (1ll * g[i - j] * g[j] + g[i]) % rqy;
}
n >>= 1;
}
S[0][0] = 1;
for (int i = 1; i <= l; i++)
for (int j = 1; j <= i; j++)
S[i][j] = (1ll * S[i - 1][j] * (i - 1) + S[i - 1][j - 1]) % rqy;
for (int i = 0; i <= l; i++)
ans = (1ll * f[i] * fac[i] % rqy * S[l][i] + ans) % rqy;
return std::cout << 1ll * ans * inv[l] % rqy << std::endl, 0;
}