「CTS2019 | CTSC2019」随机立方体
传送门
题解
首先看到这个恰好(k)个可以很轻松的想到容斥,这个东西就是一个二项式反演.
接着我们思考如何求恰好(i)个合法的个数?设(dp_i)为所求,
[f_i=prod_{j=0}^{i-1}(n-j)(m-j)(l-j)
\
g_i=n imes m imes l-(n-i) imes (m-i) imes (l-i)
\
h_i=h_{i-1} imes frac{(g_i-1)!}{g_{i-1}!}
\
egin{align}
dp_i&=inom{n imes m imes l}{g_i} imes f_i imes h_i imes (n imes m imes l-g_i)!
\
&=(n imes m imes l)! imes f_i imes prod_{j=0}^{i}frac{1}{g_j!}prod_{j=1}^i(g_j-1)
\
&=(n imes m imes l)! imes f_i imes prod_{j=1}^{i}frac{1}{g_j}
end{align}
]
具体解释就是(f_i)表示恰好(i)个的放置方案数,(g_i)表示可以放置的位置的数量,(h_i)表示极大数中数放置的可能方案数.
这个求的是方案数,把((n imes m imes l)!)去掉就是概率了,然后直接对(dp_i)二项式反演即可.注意最后那个分数可以线性求逆元求.
代码
#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include<math.h>
#include<algorithm>
#include<queue>
#include<set>
#include<map>
#include<iostream>
using namespace std;
#define ll long long
#define REP(a,b,c) for(int a=b;a<=c;a++)
#define re register
#define file(a) freopen(a".in","r",stdin);freopen(a".out","w",stdout)
typedef pair<int,int> pii;
#define mp make_pair
inline int gi()
{
int f=1,sum=0;char ch=getchar();
while(ch>'9' || ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0' && ch<='9'){sum=(sum<<3)+(sum<<1)+ch-'0';ch=getchar();}
return f*sum;
}
const int N=5000010,Mod=998244353;
int qpow(int a,int b){int ret=1;while(b){if(b&1)ret=1ll*ret*a%Mod;b>>=1;a=1ll*a*a%Mod;}return ret;}
int n,m,l,k,f[N],g[N],ifac[N],M,dp[N],inv[N],fac[N];
void init()
{
fac[0]=fac[1]=ifac[0]=ifac[1]=inv[0]=inv[1]=1;
for(int i=2;i<=5000000;i++)
fac[i]=1ll*fac[i-1]*i%Mod,
inv[i]=1ll*(Mod-Mod/i)*inv[Mod%i]%Mod,
ifac[i]=1ll*ifac[i-1]*inv[i]%Mod;
}
int C(int n,int m){return 1ll*fac[n]*ifac[m]%Mod*ifac[n-m]%Mod;}
int main()
{
int T=gi();init();
while(T--)
{
n=gi();m=gi();l=gi();k=gi();f[0]=ifac[0]=1;
M=min(n,min(m,l));
if(M<k){puts("0");continue;}
for(int i=0;i<M;i++)f[i+1]=1ll*f[i]*(n-i)%Mod*(m-i)%Mod*(l-i)%Mod;
for(int i=1;i<=M;i++)g[i]=(1ll*n*m%Mod*l%Mod-1ll*(n-i)*(m-i)%Mod*(l-i)%Mod+Mod)%Mod;
int now=1;
for(int i=1;i<=M;i++)now=1ll*now*g[i]%Mod;now=qpow(now,Mod-2);
for(int i=M;i;i--){int pre=g[i];g[i]=now;now=1ll*now*pre%Mod;}
for(int i=1;i<=M;i++)dp[i]=1ll*f[i]*g[i]%Mod;
int ans=0;
for(int i=k,f=1;i<=M;i++,f=Mod-f)
ans=(ans+1ll*f*C(i,k)%Mod*dp[i]%Mod)%Mod;
printf("%d
",ans);
}
return 0;
}