求 (sum_{i=0}^{k}inom{m}{i}inom{n-m}{k-i}i^L) ((1leqslant n,mleqslant 2 imes 10^7,1leqslant Lleqslant 2 imes 10^5))
这个式子比较简洁,然后也没啥可推的,所以我们将 (i^L) 展开.
那么原式为 (sum_{i=0}^{k}inom{m}{i}inom{n-m}{k-i}sum_{j=0}^{i}inom{i}{j}S(L,j) imes (j!))
考虑将 (j) 前提,得 (sum_{j=0}^{k}(j!)S(L,j)sum_{i=0}^{k}inom{m}{i}inom{n-m}{k-i}inom{i}{j})
注意:即使 (i<j) 也是无所谓的,因为后面那个组合数可以帮我们抵消掉.
我们发现后面的组合数看起来很眼熟,可以考虑对组合数搞点事情.
(sum_{i=0}^{k}inom{m}{i}inom{n-m}{k-i}inom{i}{j})
(Rightarrow sum_{i=0}^{k}inom{m}{j}inom{m-j}{i-j}inom{n-m}{k-i})
(Rightarrow inom{m}{j}sum_{i=0}^{k}inom{m-j}{i-j}inom{n-m}{k-i})
后面那两个组合数有一个性质:上面的 (n) 之和和下面的 (m) 之和都是定值,所以可以用范德蒙德恒等式
(Rightarrow inom{m}{j}inom{n-j}{k-j})
那么最终答案就是 (sum_{j=0}^{k}(j!)S(L,j)inom{m}{j}inom{n-j}{k-j})
其中斯特林数可以用 (NTT) 预处理,然后枚举一下 (j) 就好了.
#include <bits/stdc++.h>
#define LL long long
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
const int M=2000003;
const int N=20000006;
const int mod=998244353,G=3;
inline int qpow(int x,int y)
{
int tmp=1;
for(;y;y>>=1,x=(LL)x*x%mod) if(y&1) tmp=(LL)tmp*x%mod;
return tmp;
}
inline int INV(int x) { return qpow(x,mod-2); }
inline void NTT(int *a,int len,int flag)
{
int i,j,k,mid;
for(i=k=0;i<len;++i)
{
if(i>k) swap(a[i],a[k]);
for(j=len>>1;(k^=j)<j;j>>=1);
}
for(mid=1;mid<len;mid<<=1)
{
int wn=qpow(G,(mod-1)/(mid<<1));
if(flag==-1) wn=INV(wn);
for(i=0;i<len;i+=(mid<<1))
{
int w=1;
for(j=0;j<mid;++j,w=(LL)w*wn%mod)
{
int x=a[i+j], y=(LL)a[i+j+mid]*w%mod;
a[i+j]=(LL)(x+y)%mod, a[i+j+mid]=(LL)(x-y+mod)%mod;
}
}
}
if(flag==-1)
{
int rev=INV(len);
for(i=0;i<len;++i) a[i]=(LL)a[i]*rev%mod;
}
}
int max_n,max_m,L;
int fac[N],inv[N],f[M],A[M],B[M];
inline int C(int x,int y) { return y>x?0:(LL)fac[x]*inv[y]%mod*inv[x-y]%mod; }
inline void Initialize()
{
int i,j,limit;
inv[0]=fac[0]=1;
for(i=1;i<N;++i) fac[i]=(LL)fac[i-1]*i%mod;
inv[N-1]=INV(fac[N-1]);
for(i=N-2;i>=1;--i) inv[i]=1ll*inv[i+1]*(i+1)%mod;
for(i=0;i<=L;++i)
{
A[i]=inv[i],B[i]=(LL)qpow(i,L)*inv[i]%mod;
if(i&1) A[i]=mod-A[i];
}
for(limit=1;limit<=2*(L+1);limit<<=1);
NTT(A,limit,1),NTT(B,limit,1);
for(i=0;i<limit;++i) A[i]=(LL)A[i]*B[i]%mod;
NTT(A,limit,-1);
for(i=0;i<=L;++i) f[i]=A[i];
}
inline void solve()
{
LL ans=0ll;
int i,j,n,m,k,Lim;
scanf("%d%d%d",&n,&m,&k),Lim=min(min(n,m),L);
for(i=0;i<=Lim;++i) (ans+=(LL)f[i]*fac[i]%mod*C(m,i)%mod*C(n-i,k-i))%=mod;
(ans*=(LL)fac[k]*fac[n-k]%mod*inv[n]%mod)%=mod;
printf("%lld
",ans);
}
int main()
{
// setIO("input");
int i,j,T;
scanf("%d%d%d%d",&max_n,&max_m,&T,&L);
Initialize();
while(T--) solve();
return 0;
}