题意
给定(n,k,D),求
[sum_{a_ige0,sum_{i=1}^n a_i=D}frac{D!}{prod_{i=1}^n(a_i+k)!}
]
(1le nle 50,0le kle 50,0le Dle 10^8)
题解
由于(e^x=sum_{k=0}^{infty}frac{x^k}{k!}),所以有
[sum_{a_ige0,sum_{i=1}^n a_i=D}frac{1}{prod_{i=1}^na_i!}=[x^D](e^x)^n=[x^D]e^{nx}
]
设(b_i=a_i+k),则原问题化为求
[sum_{b_ige k,sum_{i=1}^n b_i=D+nk}frac{D!}{prod_{i=1}^nb_i!}
]
考虑(b_ige k)的限制,只需将次数(<k)的项去掉,所以
[sum_{b_ige k,sum_{i=1}^n b_i=D+nk}frac{1}{prod_{i=1}^nb_i!}=[x^{D+nk}](e^x-sum_{i=0}^{k-1}frac{x^i}{i!})^n
]
由于(n,k)很小,上式可以直接暴力计算,记(A(x)=-sum_{i=0}^{k-1}frac{x^i}{i!}),预处理(A(x),A^2(x),...,A^n(x)),再根据((e^x+A(x))^n=sum_{i=0}^{n}inom{n}{i}A^i(x)e^{(n-i)x})即可算出答案。计算过程中出现的(frac{1}{(D+nk-j)!},jin[0,A.len])无法快速计算,但由于整体乘上了(D!),而(frac{D!}{(D+nk-j)!})可以快速计算。最终复杂度(O(n^2k^2))。
#include <bits/stdc++.h>
#define pb(x) emplace_back(x)
using namespace std;
const int N=2510;
using ll=long long;
const ll M=998244353;
int n;
ll s[N],sv[N],D,D2;
inline void MOD(ll&x){if(x>=M)x%=M;}
ll pm(ll x,ll b){ll res=1;while(b){if(b&1)res=res*x%M;x=x*x%M;b>>=1;}return res;}
ll inv(ll x){return pm(x,M-2);}
struct pol{
int len;
ll a[N];
ll& operator[](size_t x){return a[x];}
}ps[52];
void mul(pol& a,pol& b,pol& c){
c.len=a.len+b.len;
for(int k=0;k<=c.len;k++){
c[k]=0;
for(int i=0;i<=k;i++){
c[k]+=a[i]*b[k-i]%M;
}
MOD(c[k]);
}
}
ll C(ll n,ll m){return s[n]*sv[n-m]%M*sv[m]%M;}
//求(x!)/(y!)
ll cal1(ll x,ll y){
if(x==y)return 1;
if(x>y){
ll res=1;
for(ll i=y+1;i<=x;i++)res=res*i%M;
return res;
}
else{
ll res=1;
for(ll i=x+1;i<=y;i++){res=res*i%M;}
return inv(res);
}
}
ll cal2(ll x,ll y){
if(y==0)return 1;
if(x==0)return 0;
ll res=pm(x,y);
res*=cal1(D,y);
return res%M;
}
void f1(){
int k;
scanf("%d%d%lld",&n,&k,&D);
D2=D+n*k;
s[0]=1;sv[0]=1;
for(int i=1;i<=50;i++){s[i]=s[i-1]*i%M;}
sv[50]=inv(s[50]);
for(int i=49;i>=1;i--){sv[i]=sv[i+1]*(i+1)%M;}
ps[0][0]=1;ps[0].len=0;
ps[1].len=k-1;
for(int i=0;i<k;i++){ps[1][i]=M-sv[i];}
for(int i=2;i<=n;i++){mul(ps[i-1],ps[1],ps[i]);}
ll ans=0,tmp=0;
for(int i=0;i<=n;i++){
tmp=0;
for(int j=0,l=min<ll>(ps[i].len,D2);j<=l;j++){
tmp+=cal2(n-i,D2-j)*ps[i][j]%M;
if(tmp>=M)tmp-=M;
}
ans+=tmp*C(n,i)%M;
if(ans>=M)ans-=M;
}
printf("%lld",ans);
}
int main(){
f1();
return 0;
}