题目概述
一个$n imes m$的整点集。其中$q$个点被m被设置为不能访问。
问这个点集中含有多少个不同的正方形,满足不包含任何一个不能访问的点。
对于$50\%$的数据满足$1 leq n,m leq 10^4, 1 leq q leq 10^3$
对于另外$50\%$的数据满足$1 leq n,m leq 2 imes 10^5, 1 leq q leq 200$
Solution
我们规定行递增的方向为$x$的正方向,列递增的方向为$y$的正方向。
设$f[i][j]$表示以$(i,j)$为左上角的正方形最大边长。
则$sumlimits_{i=1}^{n} sumlimits_{j=1}^{m} f[i][j]$就是所求。
我们可以枚举对角线,然后再枚举每个标记点,分别计算各个标记点对这条对角线的影响。
不妨设当前直线是$y = x + b$。
- 若点$(i,j)$在直线下方(即顺时针方向),其会限制到点$(j-b,j)$左上方的所有点。
- 若点$(i,j)$在直线上,其会限制到这个点左上方的所有点。
- 若点$(i,j)$在直线上方(即逆时针方向),其会限制到点$(i,i+b)$左上方的所有点
注意到,对于每一条对角线,由于标记点个数是$q$个,那么限制的记录点最多也是$q$个。
两个记录点间的转移可以使用一些数学计算来加速做到$O(1)$,但是需要注意非常多的细节。
人傻常数大,后面用一个set,复杂度到了$O(nq log_2 q)$
# pragma GCC optimize(3) # include<bits/stdc++.h> # define int long long using namespace std; const int N=2e5+10; int n,m,q; namespace fast_IO{ const int IN_LEN = 10000000, OUT_LEN = 10000000; char ibuf[IN_LEN], obuf[OUT_LEN], *ih = ibuf + IN_LEN, *oh = obuf, *lastin = ibuf + IN_LEN, *lastout = obuf + OUT_LEN - 1; inline char getchar_(){return (ih == lastin) && (lastin = (ih = ibuf) + fread(ibuf, 1, IN_LEN, stdin), ih == lastin) ? EOF : *ih++;} inline void putchar_(const char x){if(oh == lastout) fwrite(obuf, 1, oh - obuf, stdout), oh = obuf; *oh ++= x;} inline void flush(){fwrite(obuf, 1, oh - obuf, stdout);} int read(){ int x = 0; int zf = 1; char ch = ' '; while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar_(); if (ch == '-') zf = -1, ch = getchar_(); while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar_(); return x * zf; } void write(int x){ if (x < 0) putchar_('-'), x = -x; if (x > 9) write(x / 10); putchar_(x % 10 + '0'); } } set< pair<int,int> >st; struct point { int x,y;}r[N]; int cross(point a,point b) { return a.x*b.y-a.y*b.x; } int direct (point a,point b,point c) { point ba={a.x-b.x,a.y-b.y}; point bc={c.x-b.x,c.y-b.y}; int ret=cross(bc,ba); if (ret==0) return 0; else if (ret<0) return -1; else if (ret>0) return 1; } pair<int,int>bag[N],p[N]; int work(int b) { int cnt=0; int Lx=max(-b,0ll),Rx=min(n,m-b); int Ly=Lx+b,Ry=Rx+b; for (int i=1;i<=q;i++) { int rec=direct(r[i],(point){-b,0},(point){0,b}); if (b<0) rec=-rec; if (b==0) rec=direct(r[i],(point){-1,-1},(point){0,0}); if (rec==-1) { if (r[i].y-b>=Lx&&r[i].y-b<=Rx&&r[i].y>=Ly&&r[i].y<=Ry) bag[++cnt]=make_pair(r[i].y-b,r[i].x-r[i].y+b-1); } else if (rec==0) { if (r[i].x>=Lx&&r[i].x<=Rx&&r[i].y>=Ly&&r[i].y<=Ry) bag[++cnt]=make_pair(r[i].x,0); } else if (rec==1) { if (r[i].x>=Lx&&r[i].x<=Rx&&r[i].x+b>=Ly&&r[i].x+b<=Ry) bag[++cnt]=make_pair(r[i].x,r[i].y-r[i].x-b-1); } } sort(bag+1,bag+1+cnt); int tot=0; for (int i=1,j;i<=cnt;i=j) { j=i; int ret=bag[i].second; while (bag[j].first==bag[i].first&&j<=cnt) j++; p[++tot]=make_pair(bag[i].first,ret); } if (tot==0) { int tmp=min(n-Lx,m-Ly); return (tmp+1)*(tmp)/2; } int ans=(min(n-(p[tot].first+1),m-(p[tot].first+1+b)))*(min(n-(p[tot].first+1),m-(p[tot].first+1+b))+1)/2; p[tot].second=min(p[tot].second,min(n-p[tot].first,m-p[tot].first-b)); int last=-1; for (int i=tot;i>=1;i--) { if (last==-1) { ans+=p[i].second; last=i; continue;} int num=p[last].first-p[i].first-1; int val=p[last].second+1; if (st.find(make_pair(p[last].first,p[last].first+b))!=st.end()) val--; ans+=num*(val+val+num-1)/2; if (st.find(make_pair(p[last].first,p[last].first+b))!=st.end()) p[i].second=min(p[i].second,p[last].second+num); else p[i].second=min(p[i].second,p[last].second+num+1); ans+=p[i].second; last=i; } int num=p[1].first-Lx; int val=p[1].second+1; if (st.find(make_pair(p[1].first,p[1].first+b))!=st.end()) val--; ans+=num*(val+val+num-1)/2; return ans; } using namespace fast_IO; signed main() { n=read(),m=read(),q=read(); n--; m--; for (int i=1;i<=q;i++) { r[i].x=read();r[i].y=read(); r[i].x--; r[i].y--; st.insert(make_pair(r[i].x,r[i].y)); } int ans=0; for (int i=-n;i<=m;i++) ans+=work(i); write(ans); flush(); return 0; }