题目描述
有一个n行m列的整数矩阵,其中1到nm之间的每个整数恰好出现一次。如果一个格子比所有相邻格子(相邻是指有公共边或公共顶点)都小,我们说这个格子是局部极小值。
给出所有局部极小值的位置,你的任务是判断有多少个可能的矩阵。
输入
输入第一行包含两个整数n和m(1<=n<=4, 1<=m<=7),即行数和列数。以下n行每行m个字符,其中“X”表示局部极小值,“.”表示非局部极小值。
输出
输出仅一行,为可能的矩阵总数除以12345678的余数。
样例输入
3 2
X.
..
.X
样例输出
60
题解
容斥原理+状压dp
“给出所有局部极小值的位置” 有两层含义:
1.给出的位置是局部最小值;
2.非给出的位置不是局部最小值。
先考虑第一层含义怎么做:
我们把数从小到大填入矩阵中,那么如果一个格子是局部最小值且没有填入数,那么它周围的数都不能填。除此之外的位置均可选择。
因此设 $f[i][j]$ 表示填入前 $i$ 个数,局部最小值的填入情况为 $j$ 的方案数。
那么对于 $f[i][j]$ ,有两种转移:
不填局部最小值的位置,那么 $f[i][j]=f[i-1][j]+可以填的位置数$ 。我们预处理每种局部最小值填入情况下可以填入多少个数 $v[j]$,之后就能算出可以填的位置数 $v[j]-i+1$ 。
填局部最小值的位置,那么枚举填入了第 $k$ 个局部最小值,有 $f[i][j]=f[i-1][j-2^k]$ 。
最终 $f[nm][2^{局部最小值个数}-1]$ 即为答案。
再考虑第二层含义:
考虑容斥,那么讨论其它位置为局部最小值的情况,同样的方法进行dp,乘以容斥系数 $(-1)^{多填的位置数}$ 累计到答案中即可。
注意判断无解的情况。
由于容斥过程时刻要求一个局部最小值的八连通位置不能存在局部最小值,因此状态数是很小的。
时间复杂度 $O(跑得飞快)$
#include <cstdio> #include <cstring> #include <algorithm> #define mod 12345678 using namespace std; const int dx[] = {0 , 1 , 1 , 1 , 0 , -1 , -1 , -1 , 0} , dy[] = {0 , 1 , 0 , -1 , -1 , -1 , 0 , 1 , 1}; char str[10]; int n , m , px[30] , py[30] , p , vis[6][10] , filled[6][10] , v[260] , f[30][260]; int solve() { int i , j , k; memset(v , 0 , sizeof(v)); for(i = 0 ; i < (1 << p) ; i ++ ) { memset(vis , 0 , sizeof(vis)); for(j = 0 ; j < p ; j ++ ) if(!(i & (1 << j))) for(k = 0 ; k < 9 ; k ++ ) vis[px[j] + dx[k]][py[j] + dy[k]] = 1; for(j = 1 ; j <= n ; j ++ ) for(k = 1 ; k <= m ; k ++ ) v[i] += !vis[j][k]; } memset(f , 0 , sizeof(f)); f[0][0] = 1; for(i = 1 ; i <= n * m ; i ++ ) { for(j = 0 ; j < (1 << p) ; j ++ ) { if(v[j] >= i) f[i][j] = f[i - 1][j] * (v[j] - i + 1) % mod; for(k = 0 ; k < p ; k ++ ) if(j & (1 << k)) f[i][j] = (f[i][j] + f[i - 1][j ^ (1 << k)]) % mod; } } return f[n * m][(1 << p) - 1]; } int dfs(int x , int y) { if(y > m) x ++ , y = 1; if(x > n) return solve(); int i , ans = dfs(x , y + 1); for(i = 0 ; i < 9 ; i ++ ) if(filled[x + dx[i]][y + dy[i]]) break; if(i == 9) { px[p] = x , py[p ++ ] = y , filled[x][y] = 1; ans = (ans - dfs(x , y + 1) + mod) % mod; p -- , filled[x][y] = 0; } return ans; } int main() { int i , j; scanf("%d%d" , &n , &m); for(i = 1 ; i <= n ; i ++ ) { scanf("%s" , str + 1); for(j = 1 ; j <= m ; j ++ ) if(str[j] == 'X') px[p] = i , py[p ++ ] = j; } for(i = 0 ; i < p ; i ++ ) { for(j = 0 ; j < i ; j ++ ) { if(abs(px[i] - px[j]) <= 1 && abs(py[i] - py[j]) <= 1) { puts("0"); return 0; } } filled[px[i]][py[i]] = 1; } printf("%d " , dfs(1 , 1)); return 0; }