题目链接:https://www.luogu.com.cn/problem/P5664
观察题目数据范围,发现前64pts可以用类似状压的思想来做(m<=3)。前84pts可以在O(n^3*m)的时间内完成。100pts需要在O(n^2*m)的时间内做。
总述:
注意总的初始化,初始化要为1,因为后面有乘的操作,最后的时候再将那个多余的1减去。
64pts:
设f[i][j][k][q]表示到第i行,第1列选了j个,第2列选了k个,第3列选了q个。
转移有四种情况:1.不选 2.选的是第一列的 3.选的是第二列的 4.选的是第三列的
则f[i][j][k][q]=f[i-1][j][k][q]+f[i-1][j-1][k][q]*a[i][1]+f[i-1][j][k-1][q]*a[i][2]+f[i-1][j][k][q-1]*a[i][3]
因为根据题意,要求j<=(j+k+q)/2,k<=(j+k+q)/2,q<=(j+k+q)/2,化简为:
j<=k+q,k<=j+q,q<=j+k。只有满足这个,ans+=f[n][j][k][q]。
代码:
1 #include<cstdio> 2 #include<iostream> 3 using namespace std; 4 typedef long long ll; 5 const ll mod=998244353; 6 const int N=50; 7 ll f[N][N][N][N],a[N][N]; 8 ll n,m,ans; 9 int main(){ 10 scanf("%lld%lld",&n,&m); 11 for(int i=1;i<=n;i++) 12 for(int j=1;j<=m;j++) scanf("%lld",&a[i][j]); 13 f[0][0][0][0]=1; 14 for(int i=1;i<=n;i++){ 15 for(int j=0;j<=i;j++) 16 for(int k=0;k<=i;k++) 17 for(int q=0;q<=i;q++){ 18 f[i][j][k][q]=(f[i][j][k][q]+f[i-1][j][k][q])%mod; 19 if(j) f[i][j][k][q]=(f[i][j][k][q]+f[i-1][j-1][k][q]*a[i][1])%mod; 20 if(k) f[i][j][k][q]=(f[i][j][k][q]+f[i-1][j][k-1][q]*a[i][2])%mod; 21 if(q) f[i][j][k][q]=(f[i][j][k][q]+f[i-1][j][k][q-1]*a[i][3])%mod; 22 } 23 } 24 for(int i=0;i<=n/2;i++) 25 for(int j=0;j<=n/2;j++) 26 for(int k=0;k<=n/2;k++){ 27 if(i<=k+j&&j<=i+k&&k<=i+j) ans+=f[n][i][j][k],ans%=mod; 28 } 29 printf("%lld ",ans-1); 30 return 0; 31 }
84pts:
枚举列数j,设f[i][k][q]表示到第i行,第j列一共选了k个,其余所有列一共选了q个。
转移有三种情况:1.不选 2.选的是第j列的 3.选的不是第j列的
则f[i][k][q]=f[i-1][k][q]+f[i-1][k-1][q]*a[i][j]+f[i-1][k][q-1]*(sum[i]-a[i][j])
然后运用容斥原理,用总的方案数减去不符合的(k>q)方案数即为答案。
代码:
1 #include<cstdio> 2 #include<iostream> 3 #include<cstring> 4 using namespace std; 5 typedef long long ll; 6 const ll mod=998244353; 7 const int N=55; 8 ll f[N][N][N],a[N][505],sum[N]; 9 ll n,m,ans=1,res; 10 int main(){ 11 scanf("%lld%lld",&n,&m); 12 for(int i=1;i<=n;i++){ 13 for(int j=1;j<=m;j++) scanf("%lld",&a[i][j]),sum[i]=(sum[i]+a[i][j])%mod; 14 sum[i]%=mod; 15 ans=(ans*(sum[i]+1)%mod)%mod; 16 } 17 for(int j=1;j<=m;j++){ 18 memset(f,0,sizeof(f)); 19 f[0][0][0]=1; 20 for(int i=1;i<=n;i++) 21 for(int k=0;k<=i;k++) 22 for(int q=0;q<=i-k;q++){ 23 f[i][k][q]=(f[i][k][q]+f[i-1][k][q])%mod; 24 if(k) f[i][k][q]=(f[i][k][q]+f[i-1][k-1][q]*a[i][j])%mod; 25 if(q) f[i][k][q]=(f[i][k][q]+f[i-1][k][q-1]*(sum[i]-a[i][j]))%mod; 26 } 27 for(int k=1;k<=n;k++){ 28 for(int q=0;q<=n-k;q++){ 29 if(k>q) res+=f[n][k][q]; 30 } 31 res=(res+mod)%mod; 32 } 33 } 34 printf("%lld ",(ans-res-1+mod)%mod); 35 return 0; 36 }
100pts:
可以发现,在84pts中的做法中,0<=k,q<=n,且k<=q,所以-n<=k-q<=n,然后便可以将上面的压成二维:
f[i][j]表示选到了第i行,j=k-q+n,j∈[0,2n]。
则f[i][j]=f[i-1][j]+f[i-1][j-1]*a[i][j]+f[i-1][j+1]*(sum[i]-a[i][j])。
其中n+1<=j<=2*n是不符合题意的方案数,减掉即为答案。注意初始化,因为整体加了n,所以f[0][n]=1。
代码:
1 #include<cstdio> 2 #include<iostream> 3 #include<cstring> 4 using namespace std; 5 typedef long long ll; 6 const ll mod=998244353; 7 const int N=105; 8 ll f[N][N<<1],a[N][2005],sum[N]; 9 ll n,m,ans=1,res; 10 int main(){ 11 scanf("%lld%lld",&n,&m); 12 for(int i=1;i<=n;i++){ 13 for(int j=1;j<=m;j++) scanf("%lld",&a[i][j]),sum[i]=(sum[i]+a[i][j])%mod; 14 sum[i]%=mod; 15 ans=(ans*(sum[i]+1)%mod)%mod; 16 } 17 for(int j=1;j<=m;j++){ 18 memset(f,0,sizeof(f)); 19 f[0][n]=1; 20 for(int i=1;i<=n;i++) 21 for(int k=n-i;k<=n+i;k++){ 22 f[i][k]=(f[i][k]+f[i-1][k])%mod; 23 if(k) f[i][k]=(f[i][k]+f[i-1][k-1]*a[i][j])%mod; 24 f[i][k]=(f[i][k]+f[i-1][k+1]*(sum[i]-a[i][j])%mod)%mod; 25 } 26 for(int k=n+1;k<=n*2;k++){ 27 res+=f[n][k]; 28 res=(res+mod)%mod; 29 } 30 } 31 printf("%lld ",(ans-res-1+mod)%mod); 32 return 0; 33 }