题目大意是在一个平面直角坐标系中,会给出 N 个线段,线段保证是与坐标轴平行的,且共线的线段不会有交点,现在问这些线段组成的矩形有多
一道老题,初遇时因为不知道怎么处理矩形包含而踌躇。后来发现只要用扫描线的思路考虑每对平行的线段有多少同时穿过它们两个的线段即可算出所有包含情况。
用扫描线的想法:以一个线段为基,得到穿过它的线段终点位置,并在树状数组上标记穿过,再往上处理与它平行的线段,查询树状数组中二者相交区域的区间和,扫描过程中要根据之前的终点位置更新树状数组。然后不断循环以线段为基的过程。这是扫描线的思路,复杂度是O(n^2 logn),很像求线段交点思路只不过扫了多遍。
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 typedef pair<int,int > pii; 5 const int toadd = 5001; 6 const int maxn = 10000+1; 7 vector<pii> V[maxn+4],H[maxn+4]; 8 int Flick[maxn+4]; 9 vector<int> ter[maxn]; 10 11 int lowbit(int x){ 12 return x&(-x); 13 } 14 15 void FlickUpdate(int ind,int C){ 16 while(ind<=10001){ 17 Flick[ind] += C; 18 ind += lowbit(ind); 19 } 20 } 21 22 int FlickQuery(int ind){ 23 int ret = 0; 24 while(ind>0){ 25 ret += Flick[ind]; 26 ind -= lowbit(ind); 27 } 28 return ret; 29 } 30 31 int main(){ 32 int n; 33 scanf("%d",&n); 34 for(int i=0;i<n;++i){ 35 int x1,y1,x2,y2; 36 scanf("%d%d%d%d",&x1,&y1,&x2,&y2); 37 x1 += toadd ,y1 += toadd ,x2 += toadd ,y2 += toadd; 38 if(x1>x2) swap(x1,x2); 39 if(y1>y2) swap(y1,y2); 40 if(x1 == x2) V[x1].emplace_back(make_pair(y1,y2)); 41 else H[y1].emplace_back(make_pair(x1,x2)); 42 } 43 long long ans = 0; 44 for(int i=1;i<=maxn;++i){ 45 for(auto heng : H[i]){ 46 for(int l = heng.first;l<=heng.second;++l){ 47 for(auto zong : V[l]) 48 if(zong.first<=i&&zong.second>i) { 49 FlickUpdate(l,1); 50 ter[zong.second].push_back(l); 51 } 52 } 53 for(int j = i+1;j<=maxn;++j){ 54 for(auto heng2 : H[j]){ 55 int L = max(heng.first,heng2.first); 56 int R = min(heng.second,heng2.second); 57 if(L<=R) { 58 int cnt = FlickQuery(R) - FlickQuery(L-1); 59 // if(cnt) printf("%d %d is %d ",i-toadd,j-toadd,cnt); 60 ans += 1ll*cnt*(cnt-1)/2; 61 } 62 } 63 for(auto x : ter[j]){ 64 FlickUpdate(x,-1); 65 } 66 ter[j].clear(); 67 } 68 69 } 70 } 71 printf("%lld ",ans); 72 return 0; 73 }
还有一种通用的写法,求线段组成的所有四边形个数< O(n^3)>,思路有点类似,只不过不是扫描,而是给所有线段编上号,每一个线段有一个集合记录其与哪些线段相交。考虑一对线段的两个集合,取交集,交集大小即是同时穿过两个线段的线段数目,可用于计算。而这一个题目里面由于是平行垂直关系,所以可以二分(bipartite),只算一组平行线段的集合,表示线段交点和计算两个集合的交集时可以用bitset加速,所以 n ^ 3 里有 n ^ 2 * n / constant 。具体请看 评论区里 mikeweat 的说法,链接。
1 #pragma GCC optimize("O3", "unroll-loops") 2 #pragma GCC target("avx2") 3 4 #include <bits/stdc++.h> 5 using namespace std; 6 7 typedef pair<int , pair<int,int> > segment; 8 vector<segment> V,H; 9 bool getmask(segment i,segment j){ 10 return j.second.first <= i.first && i.first <= j.second.second && 11 i.second.first <= j.first && j.first<= i.second.second; 12 } 13 14 int main(){ 15 int N; 16 scanf("%d",&N); 17 for(int i=0;i<N;++i){ 18 int x1,y1,x2,y2; 19 scanf("%d%d%d%d",&x1,&y1,&x2,&y2); 20 if(x1>x2) swap(x1,x2); 21 if(y1>y2) swap(y1,y2); 22 if(x1==x2) V.emplace_back(make_pair(x1,make_pair(y1,y2))); 23 else H.emplace_back(make_pair(y1,make_pair(x1,x2))); 24 } 25 if(H.size()>V.size()) swap(H,V); 26 vector<bitset<5000> > mask(1*H.size()); 27 for(int i = 0;i<H.size();++i){ 28 for(int j=0;j<V.size();++j){ 29 mask[i][j] = getmask(H[i],V[j]); 30 // if(mask[i][j]) printf("%d and %d ",H[i].first,V[j].first); 31 } 32 } 33 long long ans = 0ll; 34 for(int i=0;i<H.size();++i){ 35 for(int j=i+1;j<H.size();++j){ 36 int cnt = (mask[i]&mask[j]).count(); 37 // printf("%d and %d cnt : %d ",H[i].first,H[j].first,cnt); 38 ans += 1ll*cnt*(cnt-1)/2; 39 } 40 } 41 printf("%lld ",ans); 42 return 0; 43 }
开 O3 和 unroll-loops可以从1300ms 优化到 300ms左右
#pragma GCC optimize("O3", "unroll-loops")
#pragma GCC target("avx2")