Solution
每一列仅当取数超过 (lfloor {frac{k}{2}} floor) 时不合法,所以最多只有一列不合法。考虑容斥,用总方案数 - 不合法方案数。
设 (g_{i,j}) 表示前 i 行选了 j 个数的方案数,(s_i) 表示第 i 行的总和,转移:(g_{i,j}=g_{i-1,j}+g_{i-1,j-1}*s_i)
枚举不合法的列为 (col),设 (f_{i,j,k}) 表示前 (i) 行,当前列选了 (j) 个数,其他列选了 (k) 个数的方案数,分为三种情况:
1.当前行什么都不选 (f_{i,j,k}) += (f_{i-1,j,k})
2.当前行在当前列选 1 个数 (f_{i,j,k}) += (f_{i-1,j-1,k} * a_{i,col})
3.当前行在其他列选 1 个数 (f_{i,j,k}) += (f_{i-1,j,k-1} * (s_i - a_{i,col}))
综上,(f_{i,j,k}=f_{i-1,j,k}+f_{i-1,j-1,k}*a_{i,col}+f_{i-1,j,k-1}*(s_i-a_{i,col}))
这样做的时间复杂度是 (O(mn^3)),可以得到 84 分。
考虑到我们根本不关心 (j, k) 的实际大小,而是只关心它们的相对大小,重新定义状态 (f_{i,j}) 表示前 i 行,当前列和其他列的差值为 (j)。转移方程:
(f_{i,j}=f_{i-1,j}+f_{i-1,j-1}*a_{i,col}+f_{i-1,j+1}*(s_i-a_{i,col}))
为了防止爆数组,(j) 这一维 + n。复杂度 (O(2n^2m))
Code
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define LL long long
using namespace std;
const LL N = 1100, mod = 998244353;
LL n, m, ans = 0, a[666][2333], f[N][N], s[N], g[N][N];
int main()
{
scanf("%lld%lld", &n, &m);
memset(s, 0, sizeof(s));
for(LL i = 1; i <= n; i++)
for(LL j = 1; j <= m; j++)
scanf("%lld", &a[i][j]), s[i] = (s[i] + a[i][j]) % mod;
g[0][0] = 1;
for(LL i = 1; i <= n; i++)
for(LL j = 0; j <= n; j++)
g[i][j] = (g[i - 1][j] + g[i - 1][j - 1] * s[i]) % mod;
for(LL i = 1; i <= n; i++) ans = (ans + g[n][i] + mod) % mod;
for(LL col = 1; col <= m; col++)
{
f[0][n] = 1;
for(LL i = 1; i <= n; i++)
for(LL j = 1; j <= n * 2; j++)
f[i][j] = (f[i - 1][j] + a[i][col] * f[i - 1][j - 1]) % mod,
f[i][j] = (f[i][j] + (s[i] - a[i][col]) * f[i - 1][j + 1]) % mod;
for(LL i = n + 1; i <= 2 * n; i++)
ans = ((ans - f[n][i]) % mod + mod) % mod;
}
printf("%lld", ans % mod);
return 0;
}