题目描述
编写一棵二叉排序树,来支持以下 (6) 种操作:
- 插入 (x) 数
- 删除 (x) 数(若有多个相同的数,因只删除一个;如果 (x) 不存在则不需要删除)
- 查询 (x) 数的排名(排名定义为比当前数小的数的个数 (+1) ;如果 (x) 不存在则输出 (-1))
- 查询排名为 (x) 的数(如果 (x) 大于树中元素个数,则输出 (-1))
- 求 (x) 的前驱(前驱定义为小于 (x),且最大的数;如果没有输出 (-1) )
- 求 (x) 的后继(后继定义为大于 (x),且最小的数;如果没有输出 (-1) )
输入格式
第一行为 (n)((1 le n le 10000)),表示操作的个数,下面 (n) 行每行有两个数 ( ext{opt}) 和 (x),( ext{opt}) 表示操作的序号( (1 leq ext{opt} leq 6) )
输出格式
对于操作 (3,4,5,6) 每行输出一个数,表示对应答案
样例输入
10
1 3
1 7
1 15
1 12
3 7
2 7
3 7
4 1
5 8
6 8
样例输出
2
-1
3
3
12
实现代码如下:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 100010;
int lson[maxn], rson[maxn], val[maxn], sz, cnt, tot[maxn];
void Insert(int num) {
if (cnt == 0) { // 如果树为空,则直接插入根节点
cnt ++;
val[++sz] = num;
tot[sz] = 1;
return;
}
// 判断num是否存在
int x = 1;
while (true) {
if (num == val[x]) // 存在,直接返回
return;
else if (num < val[x]) {
if (lson[x]) x = lson[x];
else break;
}
else {
if (rson[x]) x = rson[x];
else break;
}
}
// 插入num
cnt ++;
val[++sz] = num;
tot[sz] = 1;
x = 1;
while (true) {
tot[x] ++;
if (num < val[x]) {
if (lson[x]) x = lson[x];
else {
lson[x] = sz;
break;
}
}
else {
if (rson[x]) x = rson[x];
else {
rson[x] = sz;
break;
}
}
}
}
void Delete(int num) {
if (sz == 0) return;
if (cnt == 1) {
if (val[1] != num) return;
cnt --;
lson[1] = rson[1] = 0;
return;
}
int x = 1, p = 0, y, q;
while (true) {
if (num == val[x]) break;
else if (num < val[x]) {
p = x;
if (!lson[x]) return;
x = lson[x];
}
else {
p = x;
if (!rson[x]) return;
x = rson[x];
}
}
cnt --;
x = 1; p = 0;
while (true) {
tot[x] --;
if (num == val[x]) break;
else if (num < val[x]) {
p = x;
x = lson[x];
}
else {
p = x;
x = rson[x];
}
}
if (!lson[x] && !rson[x]) { // 要删除的x是叶子节点
if (p) {
if (lson[p] == x) lson[p] = 0;
else rson[p] = 0;
}
}
else if (lson[x]) {
y = lson[x], q = x;
while (rson[y]) {
tot[y] --;
q = y;
y = rson[y];
}
if (lson[q] == y) lson[q] = lson[y];
else rson[q] = lson[y];
val[x] = val[y];
}
else {
y = rson[x], q = x;
while (lson[y]) {
tot[y] --;
q = y;
y = lson[y];
}
if (lson[q] == y) lson[q] = rson[y];
else rson[q] = rson[y];
val[x] = val[y];
}
}
int getRank(int num) {
if (cnt == 0) return -1;
// 判断num是否存在
int x = 1;
bool exist = false;
while (true) {
if (num == val[x]) {
exist = true;
break;
}
else if (num < val[x]) {
if (lson[x]) x = lson[x];
else break;
}
else {
if (rson[x]) x = rson[x];
else break;
}
}
if (!exist) return -1;
// 然后从上到下判断
x = 1;
int res = 0;
while (true) {
if (val[x] == num) {
res ++;
if (lson[x]) res += tot[lson[x]];
break;
}
else if (val[x] < num) {
res ++;
if (lson[x]) res += tot[lson[x]];
if (rson[x]) x = rson[x];
else break;
}
else {
if (lson[x]) x = lson[x];
else break;
}
}
return res;
}
int getNumByRank(int rk) {
if (rk > cnt) return -1;
int x = 1;
while (true) {
int left_num = 1;
if (lson[x]) left_num += tot[lson[x]];
if (left_num == rk) return val[x];
else if (left_num > rk) x = lson[x];
else {
rk -= left_num;
x = rson[x];
}
}
}
int getPre(int num) {
int res = -1;
if (cnt == 0) return -1;
int x = 1;
while (true) {
if (val[x] < num) {
res = val[x];
if (rson[x]) x = rson[x];
else break;
}
else {
if (lson[x]) x = lson[x];
else break;
}
}
return res;
}
int getNext(int num) {
int res = -1;
if (cnt == 0) return -1;
int x = 1;
while (true) {
if (val[x] > num) {
res = val[x];
if (lson[x]) x = lson[x];
else break;
}
else {
if (rson[x]) x = rson[x];
else break;
}
}
return res;
}
int n, op, x;
int main() {
cin >> n;
while (n --) {
cin >> op >> x;
if (op == 1) Insert(x);
else if (op == 2) Delete(x);
else if (op == 3) cout << getRank(x) << endl;
else if (op == 4) cout << getNumByRank(x) << endl;
else if (op == 5) cout << getPre(x) << endl;
else if (op == 6) cout << getNext(x) << endl;
}
return 0;
}
使用 set 来实现上述功能的代码:
#include <bits/stdc++.h>
using namespace std;
set<int> st;
int n, op, x;
int main() {
cin >> n;
while (n --) {
cin >> op >> x;
if (op == 1) st.insert(x);
else if (op == 2) {
set<int>::iterator it = st.lower_bound(x);
if (it != st.end() && (*it) == x) st.erase(it);
}
else if (op == 3) {
set<int>::iterator it = st.lower_bound(x);
if (it == st.end() || (*it) != x) cout << -1 << endl;
else cout << distance(st.begin(), it) + 1 << endl;
}
else if (op == 4) {
if (x > st.size()) cout << -1 << endl;
else {
set<int>::iterator it = st.begin();
for (int i = 1; i < x; i ++) it ++;
cout << (*it) << endl;
}
}
else if (op == 5) {
set<int>::iterator it = st.lower_bound(x);
if (it == st.begin()) cout << -1 << endl;
else {
it --;
cout << (*it) << endl;
}
}
else {
set<int>::iterator it = st.upper_bound(x);
if (it == st.end()) cout << -1 << endl;
else cout << (*it) << endl;
}
}
return 0;
}
注意 distance()
函数的时间复杂度是 (O(n)) 的。