题目大意
有四类东西,每类东西分别有(a,b,c,d)个。
从中选出(n)个排成一排,问有多少种不同的合法排法。
一种方法合法当且仅当该方法中不存在连续的四个东西,它们的种类依次是第一、二、三、四种。称这四个东西是不合法的段。
两种方法不同当且仅当存在一个位置,两种方法中该位置上的东西种类不同。
(nleq 1000;a,b,c,dleq 500;a+b+c+dgeq n;)
题解
考虑计算总方案数-不合法方案数。设选(x)个东西随便排的不同的方案数为(g(x))。
计算至少有(k)个不合法段的方案数(f(k))时,选(n-4 imes k)个东西出来随便排,再向其中插入(k)个不合法段,总共有(C_{n-3 imes k}^k imes g(n-4 imes k))种不同的方案。
发现这样计算会有重复,恰有(k_0(k_0geq k))个不合法段的方案被算了(C_{k_0}^{k})次。
所以需要容斥:(答案=sumlimits_{k=0}(-1)^k imes C_{n-3 imes k}^k imes g(n-4 imes k))。
考虑计算(g(n-4 imes k))。选(x)个东西随便排相当于有(x)个空位,每次选不超过一个数的若干个空位,选四次,四次选的空位恰有(x)个。
这个有两种计算方法:
(1)可以看成(g(n-4 imes k)=sumlimits_{a_0+b_0+c_0+d_0=x}frac{(n-4 imes k)!}{a_0! imes b_0! imes c_0! imes d_0!})
可以用生成函数求:四个多项式,最高项分别是(a-k,b-k,c-k,d-k),(i)次项的系数为(frac{1}{i!}),将它们相乘,((n-4 imes k))次项即为所求。
时间复杂度(Theta(n^2 imes log_2space n))
(2)观察(g(n-4 imes k))的意义可以发现:其实相当于选三次空位,因为第四次只能选前三次剩下的。设(n'=n-4 imes k,a'=a-k),(b',c',d')同理。
则(g(n')=sumlimits_{x_0=max(0,n'-b'-c'-d')}^{min(n',a')} C_{n'}^{x_0} imes sumlimits_{x_1=max(0,n'-c'-d')}^{min(n'-a',b')}C_{n'-a'}^{x_1} imes sumlimits_{x_2=max(0,n-d')}^{min(n'-a'-b',c')}C_{n'-a'-b'}^{x_2})。
求其中的(sumlimits_{x_2=max(0,n-d')}^{min(n'-a'-b',c')}C_{n'-a'-b'}^{x_2})可以求(C)的前缀和,但是前两个都不行。
考虑把它们分开:枚举(x_3)表示前两类东西的总数,这样第一、二类和第三类就可以分开求了。(g(n')=(sumlimits_{x_3=max(0,n'-c'-d')}^{min(n',a'+b')} C_{n'}^{x_3}) imes (sumlimits_{x_0=max(0,x_3-b')}^{min(x_3,a')}C_{x_3}^{x_1}) imes (sumlimits_{x_2=max(0,n-x_3-d')}^{min(n'-x_3,c')}C_{n'-x_3}^{x_2}))。
时间复杂度(Theta(n^2))。
代码
(1)生成函数
#include<algorithm>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<iomanip>
#include<iostream>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<vector>
#define rep(i,x,y) for(register int i=(x);i<=(y);++i)
#define dwn(i,x,y) for(register int i=(x);i>=(y);--i)
#define view(u,k) for(int k=fir[u];~k;k=nxt[k])
#define maxn 4007
#define LL long long
using namespace std;
int read()
{
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)&&ch!='-')ch=getchar();
if(ch=='-')f=-1,ch=getchar();
while(isdigit(ch))x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
return x*f;
}
void write(int x)
{
if(x==0){putchar('0'),putchar('
');return;}
int f=0;char ch[20];
if(x<0)putchar('-'),x=-x;
while(x)ch[++f]=x%10+'0',x/=10;
while(f)putchar(ch[f--]);
putchar('
');
return;
}
int n,a,b,c,d,ans,nown,nowlen,r[maxn],pa[maxn],pb[maxn],pc[maxn],pd[maxn],pw[maxn],ny[maxn];
const int mod=998244353;
int mo(int x){return x>=mod?x-mod:x;}
int mul(int x,int y){int res=1;while(y){if(y&1)res=(LL)res*x%mod;x=(LL)x*x%mod,y>>=1;}return res;}
void dnt(int *u,int fh)
{
rep(i,0,nown-1)if(i<r[i])swap(u[i],u[r[i]]);
for(int i=1;i<nown;i<<=1)
{
int wn=mul(3,(mod-1)/(i<<1));
if(fh==-1)wn=mul(wn,mod-2);
for(int j=0;j<nown;j+=(i<<1))
{
int w=1,x,y;
rep(k,0,i-1)
{
x=u[j+k],y=(LL)w*u[i+j+k]%mod,u[j+k]=mo(x+y),u[i+j+k]=mo(x-y+mod),w=(LL)w*wn%mod;
}
}
}
if(fh==-1)
{
int invn=mul(nown,mod-2);
rep(i,0,nown-1)u[i]=(LL)u[i]*invn%mod;
}
}
int main()
{
n=read(),a=read(),b=read(),c=read(),d=read();
pw[0]=1;rep(i,1,n)pw[i]=(LL)pw[i-1]*i%mod;
ny[n]=mul(pw[n],mod-2);dwn(i,n-1,0)ny[i]=(LL)ny[i+1]*(i+1)%mod;
int li=min(n/4,min(min(a,b),min(c,d)));
rep(i,0,li)
{
int a0=a-i,b0=b-i,c0=c-i,d0=d-i,len=a0+b0+c0+d0+4;
for(nown=1,nowlen=0;nown<len;nown<<=1,nowlen++);
rep(i,0,nown-1)r[i]=((i&1)<<(nowlen-1))|(r[i>>1]>>1);
rep(i,0,a0)pa[i]=ny[i];rep(i,a0+1,nown-1)pa[i]=0;dnt(pa,1);
rep(i,0,b0)pb[i]=ny[i];rep(i,b0+1,nown-1)pb[i]=0;dnt(pb,1);
rep(i,0,c0)pc[i]=ny[i];rep(i,c0+1,nown-1)pc[i]=0;dnt(pc,1);
rep(i,0,d0)pd[i]=ny[i];rep(i,d0+1,nown-1)pd[i]=0;dnt(pd,1);
rep(i,0,nown)pa[i]=(LL)pa[i]*pb[i]%mod*pc[i]%mod*pd[i]%mod;
dnt(pa,-1);
int ad=(LL)pw[n-3*i]*ny[i]%mod*pa[n-4*i]%mod;
if(i&1)ans=mo(ans-ad+mod);
else ans=mo(ans+ad);
}
write(ans);
return 0;
}
(2)
#include<algorithm>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<iomanip>
#include<iostream>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<vector>
#define rep(i,x,y) for(register int i=(x);i<=(y);++i)
#define dwn(i,x,y) for(register int i=(x);i>=(y);--i)
#define view(u,k) for(int k=fir[u];~k;k=nxt[k])
#define maxn 1007
#define LL long long
using namespace std;
int read()
{
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)&&ch!='-')ch=getchar();
if(ch=='-')f=-1,ch=getchar();
while(isdigit(ch))x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
return x*f;
}
void write(int x)
{
if(x==0){putchar('0'),putchar('
');return;}
int f=0;char ch[20];
if(x<0)putchar('-'),x=-x;
while(x)ch[++f]=x%10+'0',x/=10;
while(f)putchar(ch[f--]);
putchar('
');
return;
}
int n,a,b,c,d,C[maxn][maxn],sum[maxn][maxn],ans;
const int mod=998244353;
int mo(int x){return x>=mod?x-mod:x;}
LL mul(int x,int y){int res=1;while(y){if(y&1)res=(LL)res*x%mod;x=(LL)x*x%mod,y>>=1;}return res;}
int getc(int l,int r,int id)
{
if(l)return mo(sum[id][r]-sum[id][l-1]+mod);
else return sum[id][r];
}
int main()
{
n=read(),a=read(),b=read(),c=read(),d=read();
C[0][0]=sum[0][0]=1;
rep(i,1,n){rep(j,0,i)
{
C[i][j]=mo(C[i-1][j]+C[i-1][j-1]);
if(j)sum[i][j]=mo(sum[i][j-1]+C[i][j]);
else sum[i][j]=C[i][j];
}}
int li=min(n/4,min(min(a,b),min(c,d)));
rep(i,0,li)
{
int n0=n-4*i,a0=a-i,b0=b-i,c0=c-i,d0=d-i,dn=max(0,n0-c0-d0),up=min(n0,a0+b0),tmp=0;
rep(j,dn,up)
{
int ad=(LL)C[n0][j]*getc(max(0,j-b0),min(a0,j),j)%mod*getc(max(0,n0-j-d0),min(c0,n0-j),n0-j)%mod;
ad=(LL)ad*C[n0+i][i]%mod;
if(i&1)ans=mo(ans-ad+mod),tmp=mo(tmp-ad+mod);
else ans=mo(ans+ad),tmp=mo(tmp+ad);
}
}
write(ans);
return 0;
}