平衡树练手题目
刚好学平衡树,这道题直接查询前驱和后继,然后比较差的绝对值即可
不要忘记删除
#include <bits/stdc++.h>
#define pushup(u) if(u -> ls -> siz) u -> siz = u -> ls -> siz + u -> rs -> siz, u -> val = u -> rs -> val
#define new_Node(a, b, c, d) (&(*st[cnt++] = Node(a, b, c, d)))
#define merge(a, b) new_Node(a -> siz + b -> siz, b -> val, a, b)
#define ratio 3
using namespace std;
const int N = 100000 * 2 + 5;
const int MOD = 1000000;
typedef long long ll;
ll ans = 0;
int cnt = 0, n, opt, x, last;
template <typename T>
inline void read(T &t) {
t = 0; T m = 1; char ch = getchar();
while(ch < '0' || ch > '9') { if(ch == '-') m = -1; ch = getchar(); }
while(ch >= '0' && ch <= '9') { t = (t << 3) + (t << 1) + (ch & 15); ch = getchar(); }
t *= m;
}
struct Node {
int val, siz;
Node *ls, *rs;
Node () {}
Node (int a, int b, Node *c, Node *d) : siz(a), val(b), ls(c), rs(d) {}
}*root, *null, *st[N], t[N];
void maintain(Node *u) {
if(u -> ls -> siz > u -> rs -> siz * ratio) u -> rs = merge(u -> ls -> rs, u -> rs), u -> ls = u -> ls -> ls;
if(u -> rs -> siz > u -> ls -> siz * ratio) u -> ls = merge(u -> ls, u -> rs -> ls), u -> rs = u -> rs -> rs;
}
void ins(Node *u, int x) {
if(u -> siz == 1) u -> ls = new_Node(1, min(x, u -> val), null, null), u -> rs = new_Node(1, max(u -> val, x), null, null);
else ins(x > u -> ls -> val ? u -> rs : u -> ls, x);
pushup(u), maintain(u);
}
void del(Node *u, int x) {
if(u -> ls -> siz == 1 && x == u -> ls -> val) *u = *u -> rs;
else if(u -> rs -> siz == 1 && x == u -> rs -> val) *u = *u -> ls;
else del(x > u -> ls -> val ? u -> rs : u -> ls, x);
pushup(u), maintain(u);
}
int kth(Node *u, int x) {
if(u -> siz == 1) return u -> val;
return x > u -> ls -> siz ? kth(u -> rs, x - u -> ls -> siz) : kth(u -> ls, x);
}
int rnk(Node *u, int x) {
if(u -> siz == 1) return 1;
return x > u -> ls -> val ? rnk(u -> rs, x) + u -> ls -> siz : rnk(u -> ls, x);
}
int main() {
null = new Node(0, 0, 0, 0);
root = new Node(1, INT_MAX, null, null);
for(int i = 0; i < N; i++) st[i] = &t[i];
read(n);
for(int i = 1; i <= n; i++) {
read(opt), read(x);
if(root -> siz == 1) ins(root, x), last = opt;
else {
if(opt == last) ins(root, x);
else {
int l1 = kth(root, rnk(root, x) - 1), l2 = kth(root, rnk(root, x + 1));
if(abs(x - l1) <= abs(x - l2)) {
ans = (ans + abs(x - l1)) % MOD;
del(root, l1);
} else {
ans = (ans + abs(x - l2)) % MOD;
del(root, l2);
}
}
}
}
printf("%lld
", ans);
return 0;
}