@description@
求有多少个长度为 n 的排列,从左往右遍历有 a 个数比之前遍历的所有数都大,从右往左遍历有 b 个数比之前遍历的所有数都大。
模 998244323。
input
一行三个整数 n, a, b。1 ≤ n ≤ 10^5,0 ≤ A, B ≤ n。
output
输出排列数模 998244353。
sample input
5 2 2
sample output
22
@solution@
@part - 1@
首先从左往右和从右往左都会在最大值的地方停下来。
我们枚举最大值的位置,并记 dp(i, j) 表示 i 个元素顺序遍历有 j 个符合要求的元素的方案数。
则:
为什么要减去 1 呢?因为我们的最后一个元素一定是最大值。
考虑怎么求解 dp(i, j)。我们为了避免繁杂的枚举,直接考虑 i 个元素中最小的那个元素的位置。如果最小的元素是第一个,则它一定被计算进去,剩下的状态变为 dp(i-1, j-1);否则,它一定不会被计算进去,就可以删除它,变为 dp(i-1, j)。
故:
如果你对组合数学足够熟悉,就会发现上面那个式子其实是第一类斯特林数 (s(i, j)) 的递推式。
考虑其组合意义。如果我们最后符合要求的数为 (a_{p_1}, a_{p_2}, dots , a_{p_j}),则一定有 (a_{p_1+1dots p_2-1} < a_{p_1})。
如果我们把 (a_{p_1...p_2-1}) 看成一个整体,则这个整体对答案的贡献其实是圆排列——每个排列都必须要保证 (a_{p_1}) 在第一个位置,就像是某个圆排列将 (a_{p_1}) 旋转到第一个位置。
如果我们确定了数放在哪一个圆排列中,则圆排列之间的相对位置是唯一的,因为我们必须要满足 (a_{p_1} < a_{p_2} < dots < a_{p_j})。也就是说最后的方案数就是将 i 个数分成 j 个圆排列的方案数——即第一类斯特林数的定义。
既然扯到了组合意义,那么最初那个枚举最大值的位置可不可以直接用组合数学的方法来搞定了?
我们可以这样理解:先将 n-1(除最大值以外)个数分成 a+b-2 个圆排列,再将这 a+b-2 个圆排列黑白染色,选择 a-1 个染黑色(放在最大值左边),剩下的染白色(放在最大值右边)。则:
@part - 2@
【接下来只是来讲讲怎么 O(nlogn) 求解第一类斯特林数的,如果你已经很熟悉了可以直接跳过这一段】
我们根据这样一个公式进行求解:
有些类似于二项式定理。可以根据对最后一项是选择 x 还是 n-1 得到和我们递推公式一样的结果。
我们利用倍增解决这一问题。
记 (f_n(x)=prod_{i=0}^{n-1}(x+i) = a_0+a_1x+dots+a_{n-1}x^{n-1})。
则 (f_{2n}(x) = f(x)*f(x+n)),(f_{2n+1}(x)=f(x)*f(x+n)*(x+2n))。
如果已知 (f(x+n)),则可以用 fft 快速计算多项式乘法。
考虑怎么已知 (f_n(x)) 求 (f_n(x+n))。将 (f_n(x+n)) 的式子写出来:
二项式展开:
把内层的求和去掉:
把组合数拆成阶乘形式,并适当整理:
如果记 (A_i = a_i*i!),(B_i = frac{n^j}{j!}),则我们相当于是要求解 A 与 B 的减法卷积。将 A 翻转一下就可以正常用 fft 做加法卷积,然后把结果再翻转回来即可。
@accepted code@
注意一些该特判的地方还是要特判。
#include<cstdio>
#include<algorithm>
using namespace std;
const int G = 3;
const int MOD = 998244353;
const int MAXN = 400000;
int pow_mod(int b, int p) {
int ret = 1;
while( p ) {
if( p & 1 ) ret = 1LL*ret*b%MOD;
b = 1LL*b*b%MOD;
p >>= 1;
}
return ret;
}
int fct[MAXN + 5], inv[MAXN + 5];
void ntt(int *A, int n, int type) {
for(int i=0,j=0;i<n;i++) {
if( i < j ) swap(A[i], A[j]);
for(int l=(n>>1);(j^=l)<l;l>>=1);
}
for(int s=2;s<=n;s<<=1) {
int t = (s>>1);
int u = (type == 1) ? pow_mod(G, (MOD-1)/s) : pow_mod(G, (MOD-1)-(MOD-1)/s);
for(int i=0;i<n;i+=s) {
int p = 1;
for(int j=0;j<t;j++,p=1LL*p*u%MOD) {
int x = A[i+j], y = 1LL*A[i+j+t]*p%MOD;
A[i+j] = (x + y)%MOD, A[i+j+t] = (x + MOD - y)%MOD;
}
}
}
if( type == -1 ) {
int k = 1LL*fct[n-1]*inv[n]%MOD;
for(int i=0;i<n;i++)
A[i] = 1LL*A[i]*k%MOD;
}
}
void init() {
fct[0] = 1;
for(int i=1;i<=MAXN;i++)
fct[i] = 1LL*fct[i-1]*i%MOD;
inv[MAXN] = pow_mod(fct[MAXN], MOD - 2);
for(int i=MAXN-1;i>=0;i--)
inv[i] = 1LL*inv[i+1]*(i+1)%MOD;
}
int comb(int n, int m) {
return 1LL*fct[n]*inv[m]%MOD*inv[n-m]%MOD;
}
int tmp1[MAXN + 5], tmp2[MAXN + 5], tmp3[MAXN + 5];
void sterling1(int *A, int n) {
if( !n ) {
A[0] = 1;
return ;
}
int m = n/2, pw = 1, len;
sterling1(A, m);
for(len = 1;len <= n;len <<= 1);
for(int i=0;i<=m;i++) tmp1[m - i] = 1LL*fct[i]*A[i]%MOD;
for(int i=0;i<=m;i++) tmp2[i] = 1LL*inv[i]*pw%MOD, pw=1LL*pw*m%MOD;
for(int i=m+1;i<len;i++) tmp1[i] = tmp2[i] = 0;
ntt(tmp1, len, 1), ntt(tmp2, len, 1);
for(int i=0;i<len;i++) tmp1[i] = 1LL*tmp1[i]*tmp2[i]%MOD;
ntt(tmp1, len, -1);
for(int i=0;i<=m;i++) tmp3[m - i] = 1LL*tmp1[i]*inv[m - i]%MOD;
for(int i=0;i<=m;i++) tmp1[i] = A[i];
for(int i=m+1;i<len;i++) tmp1[i] = tmp3[i] = 0;
if( n & 1 ) {
tmp2[1] = 1, tmp2[0] = (MOD + n - 1);
for(int i=2;i<len;i++) tmp2[i] = 0;
}
else {
tmp2[0] = 1;
for(int i=1;i<len;i++) tmp2[i] = 0;
}
ntt(tmp1, len, 1), ntt(tmp2, len, 1), ntt(tmp3, len, 1);
for(int i=0;i<len;i++) tmp1[i] = 1LL*tmp1[i]*tmp2[i]%MOD*tmp3[i]%MOD;
ntt(tmp1, len, -1);
for(int i=0;i<=n;i++) A[i] = tmp1[i];
}
int f[MAXN + 5];
int main() {
int n, a, b; init();
scanf("%d%d%d", &n, &a, &b);
if( a + b > n + 1 || a == 0 || b == 0 ) {
printf("%d
", 0);
return 0;
}
sterling1(f, n - 1);
/*
for(int i=0;i<=n-1;i++)
printf("%d ", f[i]);
puts("");
*/
printf("%lld
", 1LL*f[a + b - 2]*comb(a + b - 2, a - 1)%MOD);
}
@details@
写程序的时候突然发现斯特林数的简写是 STL。
我就说用998244353这个模数肯定是ntt嘛。
不要忘记乘上 (frac{1}{(i-j)!})。
奇数长度的还要多乘一个多项式。
边界当 n = 0 的时候,返回一个常数 1。