题目
给你三个正整数 (n),(a),(b),定义 (A) 为一个排列中是前缀最大值的数的个数,
定义 (B) 为一个排列中是后缀最大值的数的个数,求长度为 (n) 的排列中满足 (A = a) 且 (B = b) 的排列个数。
分析
最大值一定是 (n),左边有 (a-1) 的前缀最大值,右边有 (b-1) 的后缀最大值。
假设已经钦定 (a+b-2) 个最大值,等于要将 (n-1) 个数分成 (a+b-2) 个部分且最大值要顶在最外面,也就是 (a+b-2) 个圆排列,
然后还要把 (a+b-2) 个部分给左边 (a-1) 个,那也就是 (C(a+b-2,a-1)*Stir(n-1,a+b-2)),求第一类斯特林数即可
代码
#include <cstdio>
#include <cctype>
#include <cmath>
#include <cstring>
#include <algorithm>
#define rr register
#define mem(f,n) memset(f,0,sizeof(int)*(n))
#define cpy(f,g,n) memcpy(f,g,sizeof(int)*(n))
using namespace std;
const int mod=998244353,inv3=332748118,N=200011;
typedef long long lll; typedef unsigned long long ull;
int n,m,Gmi[31],Imi[31],len,fac[N],X,Y,inv[N],ff[N<<1],gg[N<<1],tt[N<<1];
inline signed iut(){
rr int ans=0; rr char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
return ans;
}
inline void print(int ans){
if (ans>9) print(ans/10);
putchar(ans%10+48);
}
inline signed ksm(int x,int y){
rr int ans=1;
for (;y;y>>=1,x=1ll*x*x%mod)
if (y&1) ans=1ll*ans*x%mod;
return ans;
}
namespace Theoretic{
int rev[N<<1],LAST; ull Wt[N<<1],F[N<<1];
inline void Pro(int n){
if (LAST==n) return; LAST=n,Wt[0]=1;
for (rr int i=0;i<n;++i)
rev[i]=(rev[i>>1]>>1)|((i&1)?n>>1:0);
}
inline void NTT(int *f,int n,int op){
Pro(n);
for (rr int i=0;i<n;++i) F[i]=f[rev[i]];
for (rr int o=1,len=1;len<n;++o,len<<=1){
rr int W=(op==1)?Gmi[o]:Imi[o];
for (rr int j=1;j<len;++j) Wt[j]=Wt[j-1]*W%mod;
for (rr int i=0;i<n;i+=len+len)
for (rr int j=0;j<len;++j){
rr int t=Wt[j]*F[i|j|len]%mod;
F[i|j|len]=F[i|j]+mod-t,F[i|j]+=t;
}
if (o==10) for (rr int j=0;j<n;++j) F[j]%=mod;
}
if (op==-1){
rr int invn=ksm(n,mod-2);
for (rr int i=0;i<n;++i) F[i]=F[i]%mod*invn%mod;
}else for (rr int i=0;i<n;++i) F[i]%=mod;
for (rr int i=0;i<n;++i) f[i]=F[i];
}
inline void Cb(int *f,int *g,int n){
for (rr int i=0;i<n;++i) f[i]=1ll*f[i]*g[i]%mod;
}
inline void Shift(int *f,int *g,int n,int sh){
for (rr int i=0;i<n;++i) tt[n-i-1]=1ll*g[i]*fac[i]%mod;
for (rr int i=0,t=1;i<n;++i,t=1ll*t*sh%mod) f[i]=1ll*t*inv[i]%mod;
rr int len=1; for (;len<n+n;len<<=1);
NTT(tt,len,1),NTT(f,len,1),Cb(tt,f,len),NTT(tt,len,-1);
for (rr int i=0;i<n;++i) f[n-i-1]=1ll*tt[i]*inv[n-i-1]%mod;
mem(f+n,len-n),mem(tt,len);
}
inline void Doubly(int *f,int n){
if (!n) {f[0]=1; return;}
else if (n&1){
Doubly(f,n-1),f[n]=0;
for (rr int i=n;i;--i)
f[i]=(f[i-1]+1ll*f[i]*(n-1)%mod)%mod;
f[0]=1ll*f[0]*(n-1)%mod;
}else{
Doubly(f,n>>1);
Shift(gg,f,(n>>1)+1,n>>1);
rr int len=1; for (;len<n+2;len<<=1);
NTT(f,len,1),NTT(gg,len,1),Cb(f,gg,len),
NTT(f,len,-1),mem(gg,len),mem(f+n+2,len-n-2);
}
}
}
inline void GmiImi(){
for (rr int i=0;i<31;++i) Gmi[i]=ksm(3,(mod-1)/(1<<i));
for (rr int i=0;i<31;++i) Imi[i]=ksm(inv3,(mod-1)/(1<<i));
}
signed main(){
n=iut(),X=iut(),Y=iut(),GmiImi(),fac[0]=fac[1]=inv[0]=inv[1]=1;
for (rr int i=2;i<=n*2;++i) inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
for (rr int i=2;i<=n*2;++i) fac[i]=1ll*fac[i-1]*i%mod,inv[i]=1ll*inv[i]*inv[i-1]%mod;
Theoretic::Doubly(ff,n-1);
return !printf("%lld",1ll*ff[X+Y-2]*fac[X+Y-2]%mod*inv[X-1]%mod*inv[Y-1]%mod);
}