这个题有点卡常数。。我的常数比较大所以是吸着氧气跑过去的。。。
题意:计算对于序列中每个位置(p),([1,p-1])区间内比它大的数的个数,和([p + 1, N])区间内比它小的数的个数和,要求支持修改操作,带修主席树可以解决。
通过主席树来维护权值状态和比某个数大/小的数的个数,用树状数组来支持修改和维护一个主席树的前缀和(主席树前缀和具有可减性)。时间空间(O(Nlog^2N)),(1000ms+128MB)的限制对本算法略为苛刻,但卡常可过。
#include <bits/stdc++.h>
using namespace std;
const int N = 100010;
#define mid ((l + r) >> 1)
#define lowbit(x) (x & -x)
int n, m, tot, rt[N], arr[N], pos[N];
struct Segment_Node {
int sz, ls, rs;
}t[N << 8];
void modify (int &_rt, int l, int r, int w, int del) {
if (_rt == 0) _rt = ++tot;
t[_rt].sz += del;
if (l != r) {
if (w <= mid) {
modify (t[_rt].ls, l, mid, w, del);
} else {
modify (t[_rt].rs, mid + 1, r, w, del);
}
}
}
namespace rev {
int a[N];
inline void add (int pos, int val) {
while (pos <= n) {
a[pos] += val;
pos += lowbit (pos);
}
}
inline int get_sum (int pos) {
register int res = 0;
while (pos) {
res += a[pos];
pos -= lowbit (pos);
}
return res;
}
long long get_rev () {
long long res = 0;
for (int i = n; i >= 1;--i) {
res += get_sum (arr[i]);
add (arr[i], 1);
}
return res;
}
}
int _query (int _rt, int l, int r, int nl, int nr) {
if (nl <= l && r <= nr) return t[_rt].sz;
register int res = 0;
if (nl <= mid) res += _query (t[_rt].ls, l, mid, nl, nr);
if (mid < nr) res += _query (t[_rt].rs, mid + 1, r, nl, nr);
return res;
}
inline int query (int l, int r, int w, int type) {
l = l - 1;
//求序列里面[l, r]内有多少数大于w (type = 1)
//求序列里面[l, r]内有多少数小于w (type = 2)
// printf ("l = %d, r = %d, w = %d, type = %d
", l, r, w, type);
register int i, res = 0;
for (i = l; i != 0; i -= lowbit (i)) {
if (type == 1) res -= _query (rt[i], 0, n + 1, w + 1, n);
if (type == 2) res -= _query (rt[i], 0, n + 1, 1, w - 1);
}
// printf ("res = %d
", res);
for (i = r; i != 0; i -= lowbit (i)) {
if (type == 1) res += _query (rt[i], 0, n + 1, w + 1, n);
if (type == 2) res += _query (rt[i], 0, n + 1, 1, w - 1);
}
// printf ("l = %d, r = %d, w = %d, type = %d, res = %d
", l, r, w, type, res);
return res;
}
inline int read () {
int s = 0, w = 1, ch = getchar ();
while ('9' < ch || ch < '0') {
ch = getchar ();
}
while ('0' <= ch && ch <= '9') {
s = s * 10 + ch - '0';
ch = getchar ();
}
return s * w;
}
int main () {
n = read (), m = read ();
register int i, j, w;
for (i = 1; i <= n; ++i) {
arr[i] = read ();
pos[arr[i]] = i;
for (j = i; j <= n; j += lowbit (j)) {
modify (rt[j], 0, n + 1, arr[i], +1);
}
}
long long ans = rev :: get_rev ();
for (i = 1, w = 0; i <= m; ++i) {
printf ("%lld
", ans);
w = read ();
ans -= query (1, pos[w] - 1, w, 1);
ans -= query (pos[w] + 1, n, w, 2);
for (j = pos[w]; j <= n; j += lowbit (j)) {
modify (rt[j], 0, n + 1, w, -1);
}
}
}