回到二维哈希的问题,不难想到我们对于一个((1,1)-(i,j))的矩阵,先把每一行的哈希值压成一个数,这个过程和一维是完全一样的,就是对原数组的进制转换,然后对于每一列将压缩好的信息直接再做一次(hash),注意这一次不用管原数组了,注意两个压缩过程尽量选用不同的(base)比如(87)和(312)。
如何(O(1))求一个矩阵的哈希值?考虑二维前缀和,先对行消去影响,再对列消去影响,再对被重复消去的部分容斥
即
((x,y)-(i,j))的矩阵:(hash[i][j]-hash[i][y-1]*base1^{j-y}-hash[x-1][j]*base2^{i-x}+hash[x-1][y-1]*base1^{j-y}*base2^{i-x})
以上大功告成
再回到本题,我们判断一个正方形是否对称,只需要预处理按纵轴和按横轴翻转的两个矩阵(相当于镜像,将它翻转),在这三个矩阵中找到同一个正方形的位置做(hash),如果三个(hash)值都一样说明怎么翻转这个正方形都完全重合,也就一定是对称的了。
万事俱备只欠二分,发现正方形的中点可能落在数字上也可能落在四个数字之间的空格上,对于这两种情况分别处理。
扫到每一个这样的点就二分正方形边长的一半,然后哈希(check),找到最大的边长(x),易知以该点为中心的正方形边长(<=mid)也一定合法,因此对答案的贡献是(ans+=x)。
再总结一下流程
1.预处理两个方向翻转的二维数组,以及对这三个数组做二维(hash)
2.对于每一个点二分扩展,分中点在数字上和中点在空格上考虑,找到最大边长(ans+=x)
3.最后(ans+=m*n),因为每个小正方形也是答案
总复杂度(O(nmlogn))
注:本人此题开(unsigned long long)自然溢出不取模
更多细节看代码
(Code)
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#define re register
#define int unsigned long long
#define base1 87
#define base2 312
const int maxn=1010;
using namespace std;
inline int read()
{
int x=0,f=1; char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int n,m,ans;
int a[maxn][maxn],b[maxn][maxn],fac1[maxn],fac2[maxn];
int bb[maxn][maxn],le[maxn][maxn],bbb[maxn][maxn],up[maxn][maxn];
void copys()
{
for(re int i=1;i<=n;++i) //复制左右颠倒的矩阵
for(re int j=1;j<=m;++j)
bb[i][m-j+1]=b[i][j];
for(re int i=1;i<=n;++i)//复制上下颠倒的矩阵
for(re int j=1;j<=m;++j)
bbb[n-i+1][j]=b[i][j];
}
void hasha()
{
for(re int i=1;i<=n;++i)
for(re int j=1;j<=m;++j)
a[i][j]=a[i][j-1]*base1+b[i][j];
for(re int i=1;i<=n;++i)
for(re int j=1;j<=m;++j)
a[i][j]+=a[i-1][j]*base2;
for(re int i=1;i<=n;++i)
for(re int j=1;j<=m;++j)
le[i][j]=le[i][j-1]*base1+bb[i][j];//注意这里合并信息
for(re int i=1;i<=n;++i)
for(re int j=1;j<=m;++j)
le[i][j]+=le[i-1][j]*base2;//这里直接合并hash不用管bb了
for(re int i=1;i<=n;++i)
for(re int j=1;j<=m;++j)
up[i][j]=up[i][j-1]*base1+bbb[i][j];
for(re int i=1;i<=n;++i)
for(re int j=1;j<=m;++j)
up[i][j]+=up[i-1][j]*base2;
}
inline int check(int x,int y,int len)
{
int yy=y;
//注意unsigned没有负数,所以不能写x-len<0
if(x>n||y>m||x<len||y<len) return 0;
int ans1=a[x][y]-a[x-len][y]*fac2[len]-a[x][y-len]*fac1[len]+a[x-len][y-len]*fac1[len]*fac2[len];
y=m-(y-len);//y-len是到最下角,翻转过去再被m-就是右下角
//不明白为什么可以手推一下
int ans2=le[x][y]-le[x-len][y]*fac2[len]-le[x][y-len]*fac1[len]+le[x-len][y-len]*fac2[len]*fac1[len];
y=yy;//这里一定要复原,因为我们处理的是只有上下颠倒
x=n-(x-len);
int ans3=up[x][y]-up[x-len][y]*fac2[len]-up[x][y-len]*fac1[len]+up[x-len][y-len]*fac1[len]*fac2[len];
if(ans1==ans2&&ans2==ans3) return 1;
return 0;
}
void work()
{
for(re int i=1;i<n;++i)
for(re int j=1;j<m;++j)
{
int l=0,r=n+1,sums=0;
while(l<=r)
{
int mid=(l+r)>>1;
if(check(i+mid,j+mid,mid+mid))
{
if(mid) sums=mid;//这是对于中心在空格上的情况
l=mid+1;
}
else r=mid-1;
}
ans+=sums;
}
for(re int i=1;i<n;++i)
for(re int j=1;j<m;++j)
{
int l=0,r=n+1,sums=0;
while(l<=r)
{
int mid=(l+r)>>1;
if(check(i+mid,j+mid,mid+mid+1))
{
if(mid) sums=mid;
l=mid+1;
}
else r=mid-1;
}
ans+=sums;
}
ans+=n*m;//别忘了加上
return;
}
signed main()
{
n=read(),m=read();
fac1[0]=fac2[0]=1;//预处理
for(re int i=1;i<=n;++i) fac1[i]=fac1[i-1]*base1;
for(re int i=1;i<=m;++i) fac2[i]=fac2[i-1]*base2;
for(re int i=1;i<=n;++i)
for(re int j=1;j<=m;++j)
b[i][j]=read();
copys();
hasha();
work();
printf("%u
",ans);
return 0;
}