问题描述
大中锋的学院要组织学生参观博物馆,要求学生们在博物馆中排成一队进行参观。他的同学可以分为四类:一部分最喜欢唱、一部分最喜欢跳、一部分最喜欢rap,还有一部分最喜欢篮球。如果队列中k,k + 1,k + 2,k + 3位置上的同学依次,最喜欢唱、最喜欢跳、最喜欢rap、最喜欢篮球,那么他们就会聚在一起讨论蔡徐坤。大中锋不希望这种事情发生,因为这会使得队伍显得很乱。大中锋想知道有多少种排队的方法,不会有学生聚在一起讨论蔡徐坤。两个学生队伍被认为是不同的,当且仅当两个队伍中至少有一个位置上的学生的喜好不同。由于合法的队伍可能会有很多种,种类数对998244353取模。
输入格式
输入数据只有一行。每行5个整数,第一个整数n,代表大中锋的学院要组织多少人去参观博物馆。接下来四个整数a、b、c、d,分别代表学生中最喜欢唱的人数、最喜欢跳的人数、最喜欢rap的人数和最喜欢篮球的人数。保证(a+b+c+d ge n)
输出格式
每组数据输出一个整数,代表你可以安排出多少种不同的学生队伍,使得队伍中没有学生聚在一起讨论蔡徐坤。结果对998244353取模。
样例输入
4 4 3 2 1
样例输出
174
数据范围
对于20%的数据,有(n=a=b=c=dle500)
对于100%的数据,有(n le 1000), (a, b, c, d le 500)
解析
我们可以考虑容斥。设 (f_i) 表示至少有 (i) 组学生会讨论蔡徐坤,那么我们只需要确定每组的第一个学生的位置即可,即
对于剩下的 (n-4i) 个位置,我们可以随便排列。我们可以发现,这其实是一个可重排列,但是要满足排列长度等于 (n-4i) 的限制 。设 (f_i) 对应的随便排列的方案数为 (g_i) 结合可重排列的公式,我们有
实际上,(g_i) 是一个卷积的形式。我们可以通过NTT在规定时间内求出 (g_i) 。最后的答案为
其中当 (i=0) 时,得到的是总方案数。
代码
#include <iostream>
#include <cstdio>
#define int long long
#define N 10002
using namespace std;
const int mod=998244353;
const int G=3;
int n,a,b,c,d,i,j,fac[N],inv[N],r[N],A[N],B[N],C[N],D[N],ans;
int read()
{
char c=getchar();
int w=0;
while(c<'0'||c>'9') c=getchar();
while(c<='9'&&c>='0'){
w=w*10+c-'0';
c=getchar();
}
return w;
}
int poww(int a,int b)
{
int ans=1,base=a;
while(b){
if(b&1) ans=ans*base%mod;
base=base*base%mod;
b>>=1;
}
return ans;
}
int cal(int n,int m)
{
return fac[n]*inv[m]%mod*inv[n-m]%mod;
}
void NTT(int *a,int inv,int n)
{
for(int i=0;i<n;i++){
if(i<r[i]) swap(a[i],a[r[i]]);
}
for(int l=2;l<=n;l<<=1){
int mid=l/2;
int cur=poww(G,(mod-1)/l);
if(inv==-1) cur=poww(cur,mod-2);
for(int i=0;i<n;i+=l){
int omg=1;
for(int j=0;j<mid;j++,omg=omg*cur%mod){
int tmp=omg*a[i+j+mid]%mod;
a[i+j+mid]=(a[i+j]-tmp+mod)%mod;
a[i+j]=(a[i+j]+tmp)%mod;
}
}
}
if(inv==-1){
for(int i=0;i<n;i++) a[i]=a[i]*poww(n,mod-2)%mod;
}
}
signed main()
{
n=read();a=read();b=read();c=read();d=read();
for(i=fac[0]=1;i<=n;i++) fac[i]=fac[i-1]*i%mod;
inv[n]=poww(fac[n],mod-2);
for(i=n-1;i>=0;i--) inv[i]=inv[i+1]*(i+1)%mod;
for(i=0;i<=n/4;i++){
if(a<i||b<i||c<i||d<i) continue;
int m=1,lim=0;
while(m<a+b+c+d-4*i) m<<=1,lim++;
for(j=0;j<m;j++) r[j]=(r[j>>1]>>1)|((j&1)<<(lim-1));
for(j=0;j<m;j++) A[j]=(j<=a-i)?inv[j]:0;
for(j=0;j<m;j++) B[j]=(j<=b-i)?inv[j]:0;
for(j=0;j<m;j++) C[j]=(j<=c-i)?inv[j]:0;
for(j=0;j<m;j++) D[j]=(j<=d-i)?inv[j]:0;
NTT(A,1,m);NTT(B,1,m);NTT(C,1,m);NTT(D,1,m);
for(j=0;j<m;j++) A[j]=A[j]*B[j]%mod*C[j]%mod*D[j]%mod;
NTT(A,-1,m);
int tmp=cal(n-3*i,i)*A[n-4*i]%mod*fac[n-4*i]%mod;
if(i%2==0) ans=(ans+tmp)%mod;
else ans=(ans-tmp+mod)%mod;
}
printf("%lld
",ans);
return 0;
}