Description
求 (sumlimits_{i=0}^{n-1}sumlimits_{j=0}^{i}C(i,j) imes (j+1)^moperatorname{mod}998244353)
(nleq10^9,mleq 100000)
Solution
傻逼推式子题...
首先 (sumlimits_{i=0}^nC(i,j)=C(n+1,j+1)),所以原式可化为
[sum_{i=1}^nC(n,i) imes i^m
]
斯特林展开 (n^k=sumlimits_{i=0}^nS(k,i) imes i! imes C(n,i))
[sum_{i=1}^nC(n,i) imes sum_{k=0}^mC(i,k) imes k! imes S(m,k)
]
因为 (S(i,j)=0(i<j)),所以将 (k) 的枚举提前
[sum_{k=0}^mS(m,k) imes k! imes sum_{i=1}^nC(n,i) imes C(i,k)
]
观察 (sumlimits_{i=1}^nC(n,i) imes C(i,k)) 的组合意义,即先从 (n) 个球中选 (i) 个,再从 (i) 个球中选 (k) 个。这和从 (n) 个球中先取 (k) 个,剩下的球随意拿是等价的。所以 (sumlimits_{i=1}^nC(n,i) imes C(i,k)=C(n,k) imes 2^{n-k})
[sum_{k=0}^mS(m,k) imes k! imes C(n,k) imes 2^{n-k}
]
将组合数拆开
[sum_{k=0}^mS(m,k) imes frac{n! imes 2^{n-k}}{(n-k)!}
]
这是个卷积的形式,那么就先 (NTT) 一遍求出第二类斯特林数,再 (NTT) 求答案就行了。
因为 (n) 很大但是 (k) 很小,所以 (frac{n!}{(n-k)!}) 是可以算的,数组下标再平移一下就好了。
Code
#include<bits/stdc++.h>
using std::min;
using std::max;
using std::swap;
using std::vector;
typedef double db;
typedef long long ll;
#define pb(A) push_back(A)
#define pii std::pair<int,int>
#define all(A) A.begin(),A.end()
#define mp(A,B) std::make_pair(A,B)
#define inv(x) ksm(x,mod-2)
const int N=4e5+5;
const int mod=998244353;
int fac[N],ifac[N];
int rev[N],a[N],b[N];
int n,m,lim,c[N],d[N];
int ksm(int a,int b,int ans=1){
while(b){
if(b&1) ans=1ll*ans*a%mod;
a=1ll*a*a%mod;b>>=1;
} return ans;
}
int getint(){
int X=0,w=0;char ch=getchar();
while(!isdigit(ch))w|=ch=='-',ch=getchar();
while( isdigit(ch))X=X*10+ch-48,ch=getchar();
if(w) return -X;return X;
}
void ntt(int *f,int opt){
for(int i=0;i<lim;i++) if(i<rev[i]) swap(f[i],f[rev[i]]);
for(int mid=1;mid<lim;mid<<=1){
int tmp=ksm(3,(mod-1)/(mid<<1));
if(opt<0) tmp=inv(tmp);
for(int R=mid<<1,j=0;j<lim;j+=R){
int w=1;
for(int k=0;k<mid;k++,w=1ll*w*tmp%mod){
int x=f[j+k],y=1ll*w*f[j+k+mid]%mod;
f[j+k]=(x+y)%mod,f[j+k+mid]=(mod+x-y)%mod;
}
}
} if(opt<0){
for(int in=inv(lim),i=0;i<lim;i++)
f[i]=1ll*f[i]*in%mod;
}
}
void mul(int *a,int *b){
ntt(a,1),ntt(b,1);
for(int i=0;i<lim;i++) a[i]=1ll*a[i]*b[i]%mod;
ntt(a,-1);
}
signed main(){
n=getint(),m=getint();
fac[0]=ifac[0]=1;
for(int i=1;i<=m;i++) fac[i]=1ll*fac[i-1]*i%mod;
ifac[m]=inv(fac[m]);
for(int i=m-1;i;i--) ifac[i]=1ll*ifac[i+1]*(i+1)%mod;
if(n<m){
int ans=0;
for(int i=1;i<=n;i++)
ans=(ans+1ll*fac[n]%mod*ifac[i]%mod*ifac[n-i]%mod*ksm(i,m)%mod)%mod;
printf("%d
",ans);return 0;
}
lim=1;while(lim<=m+m) lim<<=1;
for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
for(int i=0;i<=m;i++){
a[i]=1ll*(i&1?mod-1:1)*ifac[i]%mod;
b[i]=1ll*ksm(i,m)*ifac[i]%mod;
} mul(a,b);int now=1;
for(int i=n;i>=n-m;i--){
c[i-n+m]=1ll*ksm(2,i)*now%mod;
now=1ll*now*i%mod;
} mul(a,c);
printf("%d
",a[m]);
return 0;
}