Address
Solution
-
\[ans=\sum_{i=0}^{n}\sum_{j=0}^{i}S(i,j)*2^j*(j!) \]
- 因为\(i>j\) 时,\(S(i,j)=0\),所以:
\[ans=\sum_{i=0}^{n}\sum_{j=0}^{n}S(i,j)*2^j*(j!)
\]
众所周知:
\[S(i,j)=\frac{1}{j!}\sum_{k=0}^{j}(-1)^k*(j-k)^i*C_j^k
\]
因此:
\[ans=\sum_{i=0}^{n}\sum_{j=0}^{n}\sum_{k=0}^{j}(-1)^k*(j-k)^i*C_j^k*2^j
\]
- 发现 \(2^j\) 只包含了变量 \(j\),所以把它提到前面:
\[ans=\sum_{j=0}^{n}2^j*\sum_{i=0}^{n}\sum_{k=0}^{j}(-1)^k*(j-k)^i*C_j^k
\]
- 然后把 \(C_j^k\) 拆成阶乘形式,再整理得:
\[ans=\sum_{j=0}^{n}2^j*(j!)*\sum_{k=0}^j*\frac{(-1)^k}{k!}*\frac{\sum_{i=0}^{n}(j-k)^i}{(j-k)!}
\]
- 于是令 \(f(i)=\frac{(-1)^i}{i!},g(j)=\frac{\sum_{i=0}^nj^i}{j!}\)
- 显然 \(g(j)\) 可以用等比数列求和公式变成:
\[\frac{j^{n+1}-1}{j!(j-1)}
\]
- 那么用 \(NTT\) 把 \(f\) 和 \(g\) 乘起来就行了。
Code
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int e = 1e6 + 5, mod = 998244353;
int a[e], lim = 1, rev[e], b[e], n, ans, fa[e], g[e], cc[e], dd[e];
inline int ksm(int x, int y)
{
int res = 1;
while (y)
{
if (y & 1) res = 1ll * res * x % mod;
y >>= 1;
x = 1ll * x * x % mod;
}
return res;
}
inline void fft(int n, int *a, int op)
{
int i, j, k, r = (op == 1 ? 3 : 998244354 / 3);
for (i = 0; i < n; i++)
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (k = 1; k < n; k <<= 1)
{
int w0 = ksm(r, (mod - 1) / (k << 1));
for (i = 0; i < n; i += (k << 1))
{
int w = 1;
for (j = 0; j < k; j++)
{
int b = a[i + j], c = 1ll * w * a[i + j + k] % mod;
a[i + j] = (b + c) % mod;
a[i + j + k] = (b - c + mod) % mod;
w = 1ll * w * w0 % mod;
}
}
}
}
int main()
{
cin >> n;
int i, k = 0, fac = 1;
while (lim < n * 2)
{
lim <<= 1;
k++;
}
for (i = 1; i < lim; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << k - 1);
for (i = 0; i <= n; i++)
{
if (i != 0) fac = 1ll * fac * i % mod;
if (i & 1) a[i] = mod - 1;
else a[i] = 1;
a[i] = 1ll * a[i] * ksm(fac, mod - 2) % mod;
if (i == 0) b[i] = 1;
else if (i == 1) b[i] = n + 1;
else
b[i] = 1ll * (ksm(i, n + 1) + mod - 1) % mod * ksm(i - 1, mod - 2) % mod
* ksm(fac % mod, mod - 2) % mod;
int j;
fft(lim, a, 1);
fft(lim, b, 1);
for (i = 0; i < lim; i++) a[i] = 1ll * a[i] * b[i] % mod;
fft(lim, a, -1);
for (i = 0; i < lim; i++) a[i] = 1ll * a[i] * ksm(lim, mod - 2) % mod;
int p = 1;
fac = 1;
for (i = 0; i <= n; i++)
{
if (i != 0) fac = 1ll * fac * i % mod;
int c = a[i];
ans = (ans + 1ll * c * fac % mod * p) % mod;
p = 2ll * p % mod;
}
cout << ans << endl;
fclose(stdin);
fclose(stdout);
return 0;
}