AT5202 [AGC038E] Gachapon(min-max)
题目大意
有一个随机数生成器,生成 ([0,n-1]) 之间的整数,其中生成 (i) 的概率为 (frac{A_i}{S}),其中,(S=sum A_i)。
这个随机数生成器不断生成随机数,当 (forall iin[0,n-1]),(i) 至少出现了 (B_i) 次时,停止生成,否则继续生成。
求期望生成随机数的次数,输出答案对 (998244353) 取模的结果。
数据范围
(A_i,B_igeq 1),(sum A_i,sum B_i,nleq 400)。
解题思路
显然是一个 min-max 反演
[Ans = sum_{T subseteq S}(-1)^{|T|+1}frac {S}{sum_{iin T}A_i}f(T)
]
其中,(f(T)) 表示 T 集合中第一个至少出现了 (B_i) 次的期望次数。
考虑暴力求 T 集合的答案
[f(T) = sum_{i=1}P(x=i) imes i=sum_{i=0}^{sumB}P(x > i)
]
如何求出 (P(x>i)) 呢?考虑用方案数除以总方案数,方案数就是一个背包问题,用生成函数表示是
[f(T)=sum_{i=0}left[frac {x^i}{i!}
ight]left(prod_{jin T}sum_{t=0}^{B_j-1}A_j^tfrac {x^t}{t!}
ight)
]
容易发现我们只用一维即可,时间复杂度是 (Theta(n^2)) 的
观察发现只要选中的生成函数不会变,而且前面的 ((-1)^{|T|+1}) 可以乘进去,又发现 (S) 很小,我们用另一维状态去压缩它即可,时间复杂度 (Theta(n^3)),最后统计答案即可。
/*
/> フ
| _ _|
/`ミ _x 彡
/ |
/ ヽ ?
/ ̄| | | |
| ( ̄ヽ__ヽ_)_)
\二つ
*/
#include <queue>
#include <vector>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define MP make_pair
#define ll long long
#define fi first
#define se second
using namespace std;
template <typename T>
void read(T &x) {
x = 0; bool f = 0;
char c = getchar();
for (;!isdigit(c);c=getchar()) if (c=='-') f=1;
for (;isdigit(c);c=getchar()) x=x*10+(c^48);
if (f) x=-x;
}
template<typename F>
inline void write(F x, char ed = '
') {
static short st[30];short tp=0;
if(x<0) putchar('-'),x=-x;
do st[++tp]=x%10,x/=10; while(x);
while(tp) putchar('0'|st[tp--]);
putchar(ed);
}
template <typename T>
inline void Mx(T &x, T y) { x < y && (x = y); }
template <typename T>
inline void Mn(T &x, T y) { x > y && (x = y); }
const int P = 998244353;
const int N = 405;
ll inv[N], fac[N], a[N], b[N], A, B, ans, n;
ll fpw(ll x, ll mi) {
ll res = 1;
for (; mi; mi >>= 1, x = x * x % P)
if (mi & 1) res = res * x % P;
return res;
}
ll g[N][N], f[N][N];
int main() {
read(n);
inv[0] = fac[0] = inv[1] = fac[1] = 1;
for (int i = 2;i <= 400; i++) inv[i] = (P - P / i) * inv[P % i] % P;
for (int i = 2;i <= 400; i++)
inv[i] = inv[i-1] * inv[i] % P,
fac[i] = fac[i-1] * i % P;
for (int i = 1;i <= n; i++) {
read(a[i]), read(b[i]);
A += a[i], B += b[i];
}
f[0][0] = -1;
/* for (int i = 1;i <= 50; i++) write(inv[i], ' '), write(fac[i]); */
for (int i = 1;i <= n; i++) {
memcpy(g, f, sizeof(g));
for (int s = A;s >= 0; s--) {
for (int j = B;j >= 0; j--) {
if (s < a[i]) { f[s][j] = 0; continue; }
ll t = a[i];
f[s][j] = f[s-a[i]][j];
for (int k = 1;k < b[i]; k++, t = t * a[i] % P)
f[s][j] = (f[s][j] + t * inv[k] % P * f[s-a[i]][j-k]) % P;
}
}
for (int j = 0;j <= A; j++)
for (int k = 0;k <= B; k++)
f[j][k] = (g[j][k] - f[j][k] + P) % P;
}
for (int s = 1;s <= A; s++) {
ll tt = fpw(s, P - 2), t = A * tt % P;
ll res = 0;
for (int i = 0;i <= B; i++, t = t * tt % P)
res = (res + f[s][i] * fac[i] % P * t) % P;
res %= P, ans = (ans + res) % P;
}
write(ans);
return 0;
}
*/