分块,暴力。
将序列分成$sqrt(n)$块,每块$sqrt(n)$个元素,每块内排序。
每次操作要计算这个区间中比$a[p1]$大的有几个,小的有几个,比$a[p2]$大的有几个,小的有几个,端点的块内暴力找,中间的块内二分找。
交换完数字之后,可以直接重新$sort$排个序。
总体时间复杂度$O(m*log(sqrt(n))*sqrt(n))$。
#include <cstdio> #include <cmath> #include <cstring> #include <map> #include <algorithm> using namespace std; int n,m; int b[200010];//每一个位置属于哪一块 int L[200010];//每一块最左端 int R[200010];//每一块最右端 int x[200010];//原数组 int y[200010];//分块后排序的数组 int main() { scanf("%d%d",&n,&m); int sz = (int)sqrt(1.0*n); for(int i=0;i<=200000;i++) { L[i] = 200001; R[i] = 0; } for(int i=1;i<=n;i++) { b[i] = i/sz; L[b[i]] = min(L[b[i]],i); R[b[i]] = max(R[b[i]],i); } for(int i=1;i<=n;i++) x[i] = i, y[i] = i; long long ans = 0; while(m--) { int p1,p2; scanf("%d%d",&p1,&p2); if(p1>p2) swap(p1,p2); if(p1==p2) { printf("%lld ",ans); continue; } //属于同一块 if(b[p1]==b[p2]) { for(int i=p1+1;i<=p2-1;i++) { if(x[i]<x[p1]) ans--; if(x[i]>x[p1]) ans++; if(x[i]<x[p2]) ans++; if(x[i]>x[p2]) ans--; } if(x[p1]<x[p2]) ans++; else ans--; swap(x[p1],x[p2]); } //属于相邻两块 else if(b[p1]+1 == b[p2]) { for(int i=p1+1;i<=p2-1;i++) { if(x[i]<x[p1]) ans--; if(x[i]>x[p1]) ans++; if(x[i]<x[p2]) ans++; if(x[i]>x[p2]) ans--; } if(x[p1]<x[p2]) ans++; else ans--; int t1 = x[p1], t2 = x[p2]; swap(x[p1],x[p2]); int py1,py2; for(int i=L[b[p1]];i<=R[b[p1]];i++) if(y[i] == t1) py1 = i; for(int i=L[b[p2]];i<=R[b[p2]];i++) if(y[i] == t2) py2 = i; swap(y[py1],y[py2]); sort(y+L[b[p1]],y+R[b[p1]]+1); sort(y+L[b[p2]],y+R[b[p2]]+1); } else { for(int i=p1+1;i<=R[b[p1]];i++) { if(x[i]<x[p1]) ans--; if(x[i]>x[p1]) ans++; if(x[i]<x[p2]) ans++; if(x[i]>x[p2]) ans--; } for(int i=b[p1]+1;i<=b[p2]-1;i++) { int ll = L[i], rr = R[i], pos=-1; while(ll<=rr) { int mid = (ll+rr)/2; if(y[mid]>x[p2]) pos = mid, rr = mid-1; else ll = mid+1; } if(pos == -1) { ans = ans - 0; ans = ans + R[i] - L[i] + 1; } else { ans = ans - (R[i] - pos + 1); ans = ans + (pos - L[i]); } ll = L[i], rr = R[i], pos=-1; while(ll<=rr) { int mid = (ll+rr)/2; if(y[mid]>x[p1]) pos = mid, rr = mid-1; else ll = mid+1; } if(pos == -1) { ans = ans + 0; ans = ans - (R[i] - L[i] + 1); } else { ans = ans + (R[i] - pos + 1); ans = ans - (pos - L[i]); } } for(int i=L[b[p2]];i<=p2-1;i++) { if(x[i]<x[p1]) ans--; if(x[i]>x[p1]) ans++; if(x[i]<x[p2]) ans++; if(x[i]>x[p2]) ans--; } if(x[p1]<x[p2]) ans++; else ans--; int t1 = x[p1], t2 = x[p2]; swap(x[p1],x[p2]); int py1,py2; for(int i=L[b[p1]];i<=R[b[p1]];i++) if(y[i] == t1) py1 = i; for(int i=L[b[p2]];i<=R[b[p2]];i++) if(y[i] == t2) py2 = i; swap(y[py1],y[py2]); sort(y+L[b[p1]],y+R[b[p1]]+1); sort(y+L[b[p2]],y+R[b[p2]]+1); } printf("%lld ",ans); } return 0; }