Description
Japate, while traveling through the forest of Mala, saw $ N $ bags of gold lying in a row. Each bag has some distinct weight of gold between $ 1 $ to $ N $ . Japate can carry only one bag of gold with him, so he uses the following strategy to choose a bag.
Initially, he starts with an empty bag (zero weight). He considers the bags in some order. If the current bag has a higher weight than the bag in his hand, he picks the current bag.
Japate put the bags in some order. Japate realizes that he will pick $ A $ bags, if he starts picking bags from the front, and will pick $ B $ bags, if he starts picking bags from the back. By picking we mean replacing the bag in his hand with the current one.
Now he wonders how many permutations of bags are possible, in which he picks $ A $ bags from the front and $ B $ bags from back using the above strategy.
Since the answer can be very large, output it modulo $ 998244353 $ .
Solution
前缀最大值和后缀最大值其实等价(翻转原串)
设$dp_{i,j}$表示$i$个数的排列,只有$j$个数是前缀最大值,则
$$dp_{i,j}=dp_{i-1,j-1}+(n-1)dp_{i-1,j}$$
题中要求的$a$个数字必定在$n$及其前面,$b$个数字必定在$n$及其后面
那么枚举$n$的位置得
$$Ans=sum _{i=1}^n egin{bmatrix} i-1\a-1end{bmatrix}egin{bmatrix} n-i\b-1end{bmatrix}inom{n-1}{i-1}$$
在$i-1$个中挑$a-1$个成环,剩下的$n-i$中挑$b-1$成环,就是在$n-1$中挑$a+b-2$个环,在这些环中再选出$a-1$个
$$Ans=egin{bmatrix} n-1\a+b-2end{bmatrix}inom{a+b-2}{a-1} $$
由第一类斯特林数的生成函数,可以分治套FFT去求
时间复杂度$O(nlog^2 n)$
#include<iostream> #include<vector> #include<cstdio> using namespace std; int N,a,b,tot,s,rev[400005]; long long fac[200005]={1},inv[200005],F[400005],G[400005]; const long long mod=998244353; vector<long long>ve[400005]; inline int read() { int w=0,f=1; char ch=0; while(ch<'0'||ch>'9'){if(ch=='-') f=-1; ch=getchar();} while(ch>='0'&&ch<='9')w=(w<<1)+(w<<3)+ch-'0',ch=getchar(); return w*f; } long long ksm(long long a,long long p) { long long ret=1; while(p) { if(p&1) (ret*=a)%=mod; (a*=a)%=mod,p>>=1; } return ret; } void ntt(long long *a,int n,int INV) { for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]); for(int i=1;i<n;i<<=1) { long long wn=ksm(3,(mod-1)/i/2); if(INV==-1) wn=ksm(wn,mod-2); for(int j=0;j<n;j+=i*2) { long long w=1; for(int k=j;k<i+j;k++) { long long x=a[k],y=w*a[k+i]%mod; a[k]=(x+y)%mod,a[k+i]=(x-y+mod)%mod,(w*=wn)%=mod; } } } if(INV==-1) { long long temp=ksm(n,mod-2); for(int i=0;i<n;i++) (a[i]*=temp)%=mod; } } void solve(int k,int l,int r) { if(l==r) { ve[k].push_back(l),ve[k].push_back(1); return; } int mid=l+r>>1; solve(k<<1,l,mid),solve(k<<1|1,mid+1,r),s=2,tot=1; while(s<=r-l+3) s<<=1,tot++; for(int i=0;i<s;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(tot-1)); for(int i=0;i<ve[k<<1].size();i++) F[i]=ve[k<<1][i]; for(int i=ve[k<<1].size();i<s;i++) F[i]=0; for(int i=0;i<ve[k<<1|1].size();i++) G[i]=ve[k<<1|1][i]; for(int i=ve[k<<1|1].size();i<s;i++) G[i]=0; ntt(F,s,1),ntt(G,s,1); for(int i=0;i<s;i++) (F[i]*=G[i])%=mod; ntt(F,s,-1); for(int i=0;i<=r-l+1;i++) ve[k].push_back(F[i]); } int main() { for(int i=1;i<=200000;i++) fac[i]=fac[i-1]*i%mod; inv[200000]=ksm(fac[200000],mod-2); for(int i=199999;~i;i--) inv[i]=inv[i+1]*(i+1)%mod; N=read(),a=read(),b=read(); if(!a||!b||a+b-1>N) return puts("0"),0; if(N==1) return puts("1"),0; solve(1,0,N-2),printf("%lld ",ve[1][a+b-2]*fac[a+b-2]%mod*inv[a-1]%mod*inv[b-1]%mod); return 0; }