题意
小Z是养鸽子的人。一天,小Z给鸽子们喂玉米吃。一共有(n)只鸽子,小Z每秒会等概率选择一只鸽子并给他一粒玉米。一只鸽子饱了当且仅当它吃了的玉米粒数量(≥k)。 小Z想要你告诉他,期望多少秒之后所有的鸽子都饱了。
数据范围: (n≤50,k≤1000),答案模(998244353)输出
题解
显然非常适合min-max容斥.令 (f(n)) 为有 (n) 只鸽子,将其中一只喂到饱的期望次数,就得到:
[ans = sum_{i = 0} ^ {n} (-1) ^ {i + 1} dbinom{n}{i} frac{n}{i} f(i)
]
要乘以 (frac{n}{i}) 是因为期望投喂 (frac{n}{i}) 次才能有一次投喂到(i)只鸽子中的一只.
来大力计算 (f(n)) ,假设被喂饱的是第一只鸽子,总共投喂了 (j + k) 次,第 (j+k) 次投喂把第 (1) 只喂饱了. (g(n,c)) 表示 (n) 只鸽子,喂了 (c) 颗玉米,没有一只被喂饱的方案数.因为 (n) 只鸽子都可能被喂饱,最后还要乘以 (n) :
[f(n) = nsum_{j} frac{(j + k)dbinom{j+k-1}{j}g(n-1,j)}{n^{j+k}}
]
而 (g(n,c)) 可以由EGF得到:
[g(n,c) = ((sum_{i = 0} ^ {k - 1} frac{x^i}{i!}) ^ n [x^c]) c!
]
令 (g = sum_{i = 0} ^ {k - 1} frac{x^i}{i!}) ,暴力算出它的前 (n) 次方就可以在 (O(nk^2log(nk))) 的复杂度内解决本题.
然而还有更优美的做法.
(e^x = sum_{i = 0} ^ {infty} frac{x^i}{i!}) ,而 ((e^x)' = e^x)
对于(g),可以类似得到 (g' = g - frac{x^{k - 1}}{(k - 1)!})
从而
[egin{split}
(g^n)' &= g^{n - 1} g'
\ &= (g^{n-1}) (g - frac{x^{k-1}}{(k-1)!})
\ &= g^n - frac{x^{k-1}}{(k-1)!} g^{n-1}
end{split}
]
从中可以得出 (g^n[x^n]) 和 (g^n[x^{n-1}]) 之间的关系, (O(1)) 计算一项,总复杂度 (O(nk^2))
#pragma GCC optimize("2,Ofast,inline")
#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define LL long long
#define pii pair<int, int>
using namespace std;
const int N = 5e4 + 10;
const int mod = 998244353;
template <typename T> T read(T &x) {
int f = 0;
register char c = getchar();
while (c > '9' || c < '0') f |= (c == '-'), c = getchar();
for (x = 0; c >= '0' && c <= '9'; c = getchar())
x = (x << 3) + (x << 1) + (c ^ 48);
if (f) x = -x;
return x;
}
namespace Comb {
const int Maxn = 1e6 + 10;
int fac[Maxn], fav[Maxn], inv[Maxn];
void comb_init() {
fac[0] = fav[0] = 1;
inv[1] = fac[1] = fav[1] = 1;
for (int i = 2; i < N; ++i) {
fac[i] = 1LL * fac[i - 1] * i % mod;
inv[i] = 1LL * -mod / i * inv[mod % i] % mod + mod;
fav[i] = 1LL * fav[i - 1] * inv[i] % mod;
}
}
inline int C(int x, int y) {
if (x < y || y < 0) return 0;
return 1LL * fac[x] * fav[y] % mod * fav[x - y] % mod;
}
inline int Qpow(int x, int p) {
int ans = 1;
for (; p; p >>= 1) {
if (p & 1) ans = 1LL * ans * x % mod;
x = 1LL * x * x % mod;
}
return ans;
}
inline int Inv(int x) {
return Qpow(x, mod - 2);
}
inline void upd(int &x, int y) {
(x += y) >= mod ? x -= mod : 0;
}
inline int add(int x, int y) {
return (x += y) >= mod ? x - mod : x;
}
inline int dec(int x, int y) {
return (x -= y) < 0 ? x + mod : x;
}
}
using namespace Comb;
int n, k;
int f[51], g[51][N];
int main() {
comb_init();
read(n); read(k);
g[0][0] = 1;
for (int i = 0; i < k; ++i) {
g[1][i] = fav[i];
}
for (int i = 2; i <= n; ++i) {
g[i][0] = 1;
for (int j = 1; j <= i * (k - 1); ++j) {
g[i][j] = 1LL * i * g[i][j - 1] % mod;
if (j >= k) g[i][j] = dec(g[i][j], 1LL * i * fav[k - 1] % mod * g[i - 1][j - k] % mod);
g[i][j] = 1LL * g[i][j] * inv[j] % mod;
}
}
int ans = 0;
for (int i = 1; i <= n; ++i) {
int pw = Qpow(inv[i], k);
for (int j = 0; j <= (i - 1) * (k - 1); ++j) {
int d = 1LL * (j + k) * C(j + k - 1, j) % mod * g[i - 1][j] % mod * fac[j] % mod * pw % mod;
upd(f[i], 1LL * i * d % mod);
pw = 1LL * pw * inv[i] % mod;
}
int del = 1LL * C(n, i) * n % mod * inv[i] % mod * f[i] % mod;
if (!(i & 1)) del = mod - del;
upd(ans, del);
}
cout << ans << endl;
return 0;
}