[题目链接]
https://codeforces.com/contest/1349/problem/D
[题解]
首先设 (E_{x}) 表示所有饼干在 (x) 手里的情况 , 概率乘期望的和 , 则答案为 (sum_{x}{E_{x}})。
设 (E'_{x}) 表示如果游戏只在所有饼干都在 (x) 手里才会结束的期望步数。
设 (P_{x}) 为最终游戏结束时所有饼干在 (x) 手里的概率 , 则 (sum_{x}{P_{x}} = 1)。
再设常数 (C) 表示将所有饼干从一个人手里转移到另一个人手里的期望步数。
考虑容斥 , 不难发现存在恒等式 : (E(x) = E'(x) - sum_{x}{[i eq x](P_{i} * C + E_{i})})
这个等式相当于枚举了在 (x) 之前谁先得到了所有饼干。
一共有 (N) 个这样的等式 , 将它们求和 , 得到 : (Nsum_{i}{E_{i}} = sum_{i}{E'_{i}} - C(N - 1)sum{P_{i}}) 。
用 (ans) 替换 (sum{E_{i}}) , 用 (1) 替换 (sum{P_{i}})。
得到 (n cdot ans = sum_{i}{E'_{i} - C(N - 1)})。
注意到对于 (E'_{x}) 和 (C) 而言 , 只需关心饼干是否到了指定的某个人 , 不妨设 (f_{m}) 为剩余 (m) 个小饼干在目标人手中 , 最后全部转移到其手中的期望时间。 则不难得到
(f_i=egin{cases}frac{s-i}{s}left(frac{1}{n-1}f_{i+1}+frac{n-2}{n-1}f_i ight)+frac{i}{s}f_{i-1}+1,&0<i<s\frac{n-2}{n-1}f_i+frac{1}{n-1}f_{i+1}+1,&i=0\0,&i=send{cases})
首先解第二个方程 , 得到 (f_{0} = f_{1} + n - 1)。
直接通过消元求 (f_{2} ... f_{n}) 是不稳妥的 , 因为可能出现除以 (0) 的情况。
那么不妨定义 (g_{i} = f_{i} - f_{i + 1}) , 则 (f_i=sum_{j=i}^sg_j)。 将 (f) 代换为 (g) 得到 (sum_{j=i}^sg_j=frac{s-i}{s}left(frac{1}{n-1}sum_{j=i+1}^sg_j+frac{n-2}{n-1}sum_{j=i}^sg_j ight)+frac{i}{s}sum_{j=i-1}^sf_j+1)。
(j > i) 的项直接抵消了 , 因此可以得到 :
(g_i=frac{s(n-1)+i(n-1)g_{i-1}}{s-i})
又因为 (g_{0} = f_{0} - f_{1} = n - 1) , 所以可以直接递推求 (g) , 而 (f) 就是 (g) 的后缀和。
时间复杂度 : (O(NlogN))
[代码]
#include<bits/stdc++.h>
using namespace std;
#define rep(i , l , r) for (int i = (l); i < (r); ++i)
typedef long long LL;
const int MN = 3e6 + 5 , mod = 998244353;
int N , s , f[MN] , g[MN] , a[MN];
inline void inc(int &x , int y) {
x = x + y < mod ? x + y : x + y - mod;
}
inline void dec(int &x , int y) {
x = x - y >= 0 ? x - y : x - y + mod;
}
inline int qPow(int a , int b) {
int c = 1;
for (; b; b >>= 1 , a = 1ll * a * a % mod) if (b & 1) c = 1ll * c * a % mod;
return c;
}
int main() {
scanf("%d" , &N);
for (int i = 1; i <= N; ++i) {
scanf("%d" , &a[i]);
s += a[i];
}
g[0] = N - 1;
for (int i = 1; i < s; ++i)
g[i] = (1ll * s * (N - 1) % mod + 1ll * i * (N - 1) % mod * g[i - 1] % mod) % mod * qPow(s - i , mod - 2) % mod;
for (int i = s; ~i; --i) f[i] = (f[i + 1] + g[i]) % mod;
int ans = 0;
for (int i = 1; i <= N; ++i) inc(ans , f[a[i]]);
dec(ans , 1ll * f[0] * (N - 1) % mod);
ans = 1ll * ans * qPow(N , mod - 2) % mod;
printf("%d
" , ans);
return 0;
}