题意
给出一个长度为 n 的全排列,以及 m 个删除操作,每个操作给出要删除的数字。请输出每个删除操作前序列的逆序对数量。
思路
我们可以求出整个序列的逆序对数量,然后对于每个删除操作,找出这次删除操作会减少多少个逆序对。
当前删除的数字为 \(a_i\) ,其位置为 \(pos_i\) ,删除时间为 \(time_i\)。
删除数字 \(a_i\) 以后,会减少两种逆序对。
- \(time_j > time_i\),\(pos_j<pos_i\),\(a_j>a_i\)
- \(time_j>time_i\),\(pos_j>pos_i\),\(a_j<a_i\)
这就是两种情况的三维偏序。
使用 CDQ分治 算出删除每个数字会减少的逆序对,然后按照删除时间从小到大排序。
对于每个被删除的数字,输出 总逆序对 - 减少逆序对的前缀和
代码
#include <algorithm>
#include <iostream>
#include <map>
#include <math.h>
#include <queue>
#include <set>
#include <stack>
#include <stdio.h>
#include <string.h>
#include <string>
#include <vector>
#define emplace_back push_back
#define pb push_back
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int mod = 1e9 + 7;
const int seed = 12289;
const double eps = 1e-6;
const int inf = 0x3f3f3f3f;
const int N = 2e5 + 10;
struct note {
int pos, val, time, ans;
} arr[N];
bool cmp1(note x, note y)
{
return x.time > y.time;
}
bool cmp2(note x, note y)
{
return x.pos < y.pos;
}
bool cmp3(note x, note y)
{
return x.pos > y.pos;
}
bool cmp4(note x, note y)
{
return x.time < y.time;
}
struct treearray {
int tree[N], k;
int lowbit(int x)
{
return x & (-x);
}
void update(int pos, int val)
{
for (; pos <= k; pos += lowbit(pos)) {
tree[pos] += val;
}
}
int sum(int pos)
{
int ans = 0;
for (; pos; pos -= lowbit(pos)) {
ans += tree[pos];
}
return ans;
}
} tr; //tr
void cdq(int l, int r)
{
if (l == r)
return;
int mid = (l + r) / 2;
cdq(l, mid), cdq(mid + 1, r);
sort(arr + l, arr + mid + 1, cmp2);
sort(arr + mid + 1, arr + r + 1, cmp2);
int i = l, j = mid + 1;
for (; j <= r; j++) {
while (arr[i].pos < arr[j].pos && i <= mid) {
tr.update(arr[i].val, 1);
i++;
}
arr[j].ans += tr.sum(tr.k) - tr.sum(arr[j].val);
}
for (j = l; j < i; j++) {
tr.update(arr[j].val, -1);
}
sort(arr + l, arr + mid + 1, cmp3);
sort(arr + mid + 1, arr + r + 1, cmp3);
i = l, j = mid + 1;
for (; j <= r; j++) {
while (arr[i].pos > arr[j].pos && i <= mid) {
tr.update(arr[i].val, 1);
i++;
}
arr[j].ans += tr.sum(arr[j].val);
}
for (j = l; j < i; j++) {
tr.update(arr[j].val, -1);
}
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
tr.k = n;
ll rel = 0;
for (int i = 1; i <= n; i++) {
int x;
scanf("%d", &x);
rel += tr.sum(n) - tr.sum(x);
tr.update(x, 1);
arr[i].time = inf;
arr[x].val = x;
arr[x].pos = i;
}
for (int i = 1; i <= n; i++) {
tr.update(i, -1);
}
for (int i = 1; i <= m; i++) {
int x;
scanf("%d", &x);
arr[x].time = i;
}
sort(arr + 1, arr + 1 + n, cmp1);
cdq(1, n);
sort(arr + 1, arr + 1 + n, cmp4);
printf("%lld\n", rel);
ll now = 0;
for (int i = 1; i <= n; i++) {
now += arr[i].ans;
if (arr[i].time < m) {
printf("%lld\n", rel - now);
}
}
return 0;
}