传送门
观察到 (m) 的值很小,考虑把不同血量的随从的个数计入状态。
设 (dp_{i, A, B, C}) 表示在第 (i) 次攻击之后,还剩 (A) 个一血怪,(B) 个二血怪,(C) 个三血怪的概率。
转移很显然,只需要注意生成新怪的情况即可。
但是这对于 (N le 10^{18}) 来说是完全不行的。
我们考虑矩阵加速。
首先我们对每一个合法的三元组 ((A, B, C)) 进行编号(可以发现可能的最大值是 (166))。
然后我们把这些编号看做矩阵的行列,构造转移矩阵。
简单的说,我们记录状态之间转移的概率,然后再新增一行、一列来记录期望。
对于一种局面 ((A, B, C)),它对答案的贡献是 (frac{dp_{A, B, C}}{A + B + C + 1})。
然后我们就可以快乐的矩阵快速幂了…………才怪。
分析一下复杂度,发现每次都跑一次快速幂是跑不过的。
所以我们不妨预处理 (dp_i) 表示 (2 ^ i) 个转移矩阵的乘积,然后每次只要用一个行向量乘上 (log) 次 (dp_i) 就可以了。
参考代码:
#include <cstdio>
const int p = 998244353;
int T, m, k, ans[170], tmp[170]; long long N;
int n = 1, inv[12], id[12][12][12];
void Add(int& a, int b) { a += b, a >= p ? a -= p : 0; }
struct Matrix {
int a[170][170];
int* operator [] (int x) { return a[x]; }
Matrix operator * (Matrix b) const {
Matrix ans;
for (int i = 1; i <= n + 1; ++i)
for (int j = 1; j <= n + 1; ++j) ans[i][j] = 0;
for (int i = 1; i <= n + 1; ++i)
for (int j = 1; j <= n + 1; ++j)
for (int k = 1; k <= n + 1; ++k)
Add(ans[i][j], 1ll * a[i][k] * b[k][j] % p);
return ans;
}
} dp[170];
void mul(Matrix a) {
for (int i = 1; i <= n + 1; ++i) tmp[i] = 0;
for (int i = 1; i <= n + 1; ++i)
for (int j = 1; j <= n + 1; ++j)
Add(tmp[i], 1ll * ans[j] * a[j][i] % p);
for (int i = 1; i <= n + 1; ++i) ans[i] = tmp[i];
}
void power(long long N) {
for (int i = 0; N; N >>= 1, ++i) if (N & 1) mul(dp[i]);
}
int main() {
#ifndef ONLINE_JUDGE
freopen("cpp.in", "r", stdin), freopen("cpp.out", "w", stdout);
#endif
scanf("%d %d %d", &T, &m, &k);
inv[0] = inv[1] = 1;
for (int i = 2; i < 10; ++i)
inv[i] = 1ll * (p - p / i) * inv[p % i] % p;
for (int A = 0; A <= k; ++A)
for (int B = 0; B <= (m > 1 ? k - A : 0); ++B)
for (int C = 0; C <= (m > 2 ? k - A - B : 0); ++C)
id[A][B][C] = ++n;
for (int A = 0; A <= k; ++A)
for (int B = 0; B <= (m > 1 ? k - A : 0); ++B)
for (int C = 0; C <= (m > 2 ? k - A - B : 0); ++C) {
int x = id[A][B][C], y = A + B + C < k;
if (m == 1)
if (A) dp[0][x][id[A - 1][B][C]] = 1ll * A * inv[A + B + C + 1] % p;
if (m == 2) {
if (A) dp[0][x][id[A - 1][B][C]] = 1ll * A * inv[A + B + C + 1] % p;
if (B) dp[0][x][id[A + 1][B - 1 + y][C]] = 1ll * B * inv[A + B + C + 1] % p;
}
if (m == 3) {
if (A) dp[0][x][id[A - 1][B][C]] = 1ll * A * inv[A + B + C + 1] % p;
if (B) dp[0][x][id[A + 1][B - 1][C + y]] = 1ll * B * inv[A + B + C + 1] % p;
if (C) dp[0][x][id[A][B + 1][C - 1 + y]] = 1ll * C * inv[A + B + C + 1] % p;
}
dp[0][x][x] = dp[0][x][n + 1] = inv[A + B + C + 1];
}
dp[0][n + 1][n + 1] = 1;
for (int i = 1; i <= 60; ++i) dp[i] = dp[i - 1] * dp[i - 1];
while (T--) {
scanf("%lld", &N);
for (int i = 1; i <= n + 1; ++i) ans[i] = 0;
if (m == 1) ans[id[1][0][0]] = 1;
if (m == 2) ans[id[0][1][0]] = 1;
if (m == 3) ans[id[0][0][1]] = 1;
power(N), printf("%d
", ans[n + 1]);
}
return 0;
}