题解
warning:式子全都抄的题解。
我们可以先套一层(min-max)反演。
[ans=sum_{i=1}^n (-1)^{i-1}inom{n}{i}g_i
]
那么(g_i)就表示喂饱(i)只鸽子中至少一只的期望步数。
[g_i=sum_{igeq 1}i*P(x=i)
]
[=sum_{igeq 1}P(xgeq i)
]
然后考虑设计一个(dp),设(f(sum,cnt))表示喂(sum)只鸽子,喂了(cnt)次,都没有喂饱的概率。
[g_i=sum_{jgeq 1}sum_{s=0}^{i-1}inom{i-1}{s}f(i,s)(frac{n-i}{n}) ^{i-1-s}
]
考虑枚举有一次喂食喂到了(i)只鸽子中,根据鸽巢原理,
[g_i=sum_{s=0}^{i(k-1)}f(i,s)sum_{j geq 0}inom{s+j}{s}(frac{n-i}{n})^j
]
有一个不知道为什么的东西:
[(frac{1}{1-x})^k=sum_{igeq 0}inom{i+k-1}{k-1}x^i
]
那么:
[sum_{jgeq 0}inom{s+t}{t}(frac{n-c}{n})^t=(frac{1}{1-frac{n-c}{n}})^{s+1}=(frac{n}{c})^{s+1}
]
[g_i=sum_{s=0}^{i(k-1)}f(i,s)(frac{n}{c})^{s+1}
]
[f(c,s)=sum_{i=0}^{min(s,k-1)}inom{s}{i}frac{1}{n^i}f(c-1,s-i)
]
[frac{f(c,s)}{s!}=sum_{i=0}^{min(s,k-1)}frac{1}{n^ii!}frac{f(c-1,s-i)}{(s-i)!}
]
然后就可以(NTT)算了。
代码
#include<bits/stdc++.h>
#define N 52
#define K 1002
#define M 68002
using namespace std;
typedef long long ll;
int n,k,rev[M];
ll dp[N][M],inv[M],jie[M],ni[M],ans,g[N];
const int G=3;
const int Gi=332748118;
const int mod=998244353;
inline ll rd(){
ll x=0;char c=getchar();bool f=0;
while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
return f?-x:x;
}
inline ll power(ll x,ll y){
ll ans=1;
while(y){
if(y&1)ans=ans*x%mod;
x=x*x%mod;
y>>=1;
}
return ans;
}
inline void MOD(ll &x){x=x>=mod?x-mod:x;}
inline ll C(int n,int m){return jie[n]*ni[m]%mod*ni[n-m]%mod;}
inline void NTT(ll *a,int l,int tag){
for(int i=1;i<l;++i)if(i>rev[i])swap(a[i],a[rev[i]]);
for(int i=1;i<l;i<<=1){
ll wn=power(tag?G:Gi,(mod-1)/(i<<1));
for(int j=0;j<l;j+=(i<<1)){
ll w=1;
for(int k=0;k<i;++k,w=w*wn%mod){
ll x=a[j+k],y=a[i+j+k]*w%mod;
MOD(a[j+k]=x+y);MOD(a[i+j+k]=x-y+mod);
}
}
}
if(!tag){
ll ny=power(l,mod-2);
for(int i=0;i<l;++i)a[i]=a[i]*ny%mod;
}
}
inline void prework(int n){
jie[0]=1;
for(int i=1;i<=n;++i)jie[i]=jie[i-1]*i%mod;
ni[n]=power(jie[n],mod-2);
for(int i=n-1;i>=0;--i)ni[i]=ni[i+1]*(i+1)%mod;
}
int main(){
n=rd();k=rd();
prework(n*k);
for(int i=0;i<k;++i)inv[i]=power(power(n,i),mod-2)*ni[i]%mod;
int maxn=n*(k-1);
dp[0][0]=1;
int l=1,L=0;
while(l<=maxn)l<<=1,L++;
for(int i=1;i<l;++i)rev[i]=rev[i>>1]>>1|((i&1)<<(L-1));
NTT(dp[0],l,1);NTT(inv,l,1);
for(int i=1;i<=n;++i){
for(int j=0;j<l;++j)dp[i][j]=dp[i-1][j]*inv[j]%mod;
}
for(int i=1;i<=n;++i){
NTT(dp[i],l,0);
int x=i*(k-1);
ll nii=1ll*n*power(i,mod-2)%mod,num=1;
for(int j=0;j<=x;++j){
dp[i][j]=dp[i][j]*jie[j]%mod;
num=num*nii%mod;
MOD(g[i]+=dp[i][j]*num%mod);
}
}
for(int i=1;i<=n;++i){
if(i&1)MOD(ans+=C(n,i)*g[i]%mod);
else MOD(ans=ans-C(n,i)*g[i]%mod+mod);
}
cout<<ans;
return 0;
}