平衡树
概述
一种数据结构。代码巨长。
其实平衡树的思想挺简单的,代码也不难写。
平衡树满足的性质:
(1、)左儿子权值小于父亲,右儿子权值大于父亲
(2、)左右儿子分别是平衡树
若仅是这样,很容易被毒瘤出题人卡成链,所以我们再人为的(虽然之前的性质也是人为的)给他加上一个性质(k),让这棵树不仅是权值满足上述性质,(k)也满足上述性质。
(k)是随机化出来的。我们可以依据(k)通过旋转改变树高,这样复杂度就变低了。期望时间复杂度是(O(nlogn))的。
例题
题目描述
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
(1.)插入(x)数
(2.)删除(x)数(若有多个相同的数,因只删除一个)
(3.)查询(x)数的排名(排名定义为比当前数小的数的个数(+1)。若有多个相同的数,因输出最小的排名)
(4.)查询排名为(x)的数
(5.)求(x)的前驱(前驱定义为小于(x),且最大的数)
(6.)求(x)的后继(后继定义为大于(x),且最小的数)
输入输出格式
输入格式:
第一行为(n),表示操作的个数,下面(n)行每行有两个数(opt)和(x),(opt)表示操作的序号( (1 leq opt leq 6))
输出格式:
对于操作(3,4,5,6)每行输出一个数,表示对应答案
懒得放样例了= =
定义
struct szh{
int v, k, ls, rs, su, si, f;
//权值 k 左儿子 右儿子 v出现的次数 树的大小 父亲;
szh(){v = k = -inf, ls = rs = -1, su = si = 0, f = -1;}
}tr[100005];
初始化
我们搞一个权值炒鸡大,(k)炒鸡小的炒鸡点来当做根。
void T_begin(){
tr[0].v = inf, tr[0].k = -inf, tr[0].su = tr[0].si = 1;
}
添加
要加入一个数,首先我们要找到他应该位于哪个位置。按照性质,只要左右判断一下就好啦。
void add(int u, int v){ //寻找添加点
Dier &t = treap[u];
if(t.v == -INF){build(u, v); return;} //如果这个节点不存在的话,就新建一个
if(t.v == v){add(u); return;} //如果当前点就是我们要找的点,就给这个点的计数器加一
//以上两个函数下面会说
if(t.v > v){ //若要加入的值大于当前点,说明它应该在它的左子树里
if(t.ls == -1) t.ls = ++cnt, treap[cnt].f = u; //若左子树没有,就新建一个
add(t.ls, v); //递归左子树
}
else{ //若大于它,就在它的右子树里
if(t.rs == -1) t.rs = ++cnt, treap[cnt].f = u; //没有就新建
add(t.rs, v); //递归右子树
}
turn(u), updata(u); //旋转,更新,一会讲
}
新建一个子树非常容易,只要记录一下(v、sum、size、k)就好啦
void build(int u, int v){ //初始化
treap[u].v = v, treap[u].size = treap[u].sum = 1, treap[u].k = rand();
}
若这个点出现多次,就++计数器
void add(int u){ //添加
++treap[u].sum, ++treap[u].size;
}
喜欢压行的可以把这三个函数变成一个
(updata)
更新节点信息的时候,只需要修改(size)
void updata(int u){ //更新节点信息
Dier &t = treap[u];
t.size = t.sum;
if(t.ls != -1) t.size += treap[t.ls].size; //若左子树不为空
if(t.rs != -1) t.size += treap[t.rs].size; //若右子树不为空
}
旋转
旋转是按照我们随机化出来的(k)的大小操作的。每加入一个值,我们就看一下这棵树需不需要旋转。显然,每次我们只需要旋转一次就够了。证明脑补。
每次旋转,分为左旋和右旋。这里的方向与平时我们所说的方向相反,即左旋为将左儿子作为根,原根的左儿子变成新根的右儿子,右旋相反。
void turn(int u){ //旋转
Dier &t = treap[u];
if((t.ls != -1) && (t.k > treap[t.ls].k)) left_turn(u);
else if((t.rs != -1) && (t.k > treap[t.rs].k)) right_turn(u);
}
这是判断该左旋还是右旋
void left_turn(int u){ //向左转
Dier &t = treap[u];
int f = t.f, newt = t.ls;
t.ls = treap[newt].rs; //原根的左儿子连到新根的右儿子上
if(treap[newt].rs != -1) treap[treap[newt].rs].f = u; //更新爸爸
t.f = newt, treap[newt].rs = u; //原根的父亲连到新根上,新根的右儿子是原根
if(treap[f].ls == u) treap[f].ls = newt, treap[newt].f = f; //将原根的爸爸与新根相连
else treap[f].rs = newt, treap[newt].f = f;
updata(u), updata(newt); //更新一下,注意顺序
}
void right_turn(int u){ //向右转
Dier &t = treap[u];
int f = t.f, newt = t.rs;
t.rs = treap[newt].ls; //原根的右儿子即新根的左儿子
if(treap[newt].ls != -1) treap[treap[newt].ls].f = u;
t.f = newt, treap[newt].ls = u; //原根与新根相连
if(treap[f].ls == u) treap[f].ls = newt, treap[newt].f = f; //原根的爸爸与新根相连
else treap[f].rs = newt, treap[newt].f = f;
updata(u), updata(newt); //更新
}
删除
同添加,先找到点,再删除
void del(int u, int v){ //寻找删除点
Dier &t = treap[u];
if(t.v == v){del(u); return;} //若就是当前点,直接删除
if(t.v > v) del(t.ls, v); //递归左儿子
else del(t.rs, v); //递归右儿子
updata(u); //不要忘记更新
}
void del(int u){ //删除点
if(treap[u].sum != 1) --treap[u].size, --treap[u].sum; //如果这个点出现过很多次,计数器--
else end(u); //否则,我们要将他旋转到叶子结点再删除
}
旋转到叶子结点,就按照旋转的规则,转下去就好了
void end(int u){ //将某个点旋转到叶子结点
Dier &t = treap[u];
t.k = INF;
while(t.ls != -1 || t.rs != -1) //只要她还有儿子
if(t.ls != -1) //若有左儿子
if(t.rs != -1) //若也有右儿子
if(treap[t.ls].k < treap[t.rs].k) left_turn(u); //若左儿子的k小,左旋
else right_turn(u); //反正右旋
else left_turn(u); //反之左旋
else right_turn(u); //反之右旋
if(treap[t.f].ls == u) treap[t.f].ls = -1; //删除
else treap[t.f].rs = -1;
for(int i = t.f; i != -1; i = treap[i].f) updata(i); //更新这条路径上的所有点
}
查询
查询(x)的排名
int rak(int u,int k){ //查询x数的排名
Dier &t = treap[u];
if(t.v == k) return treap[t.ls].size + 1; //根据定义
if(t.v > k) return rak(t.ls, k);
return rak(t.rs, k) + treap[t.ls].size + t.sum;
}
查询排名为(k)的数
int ask_rak(int u, int k){ //查询排名为x的数
Dier &t = treap[u];
if(treap[t.ls].size >= k) return ask_rak(t.ls, k);
int s = treap[t.ls].size + t.sum;
if(s >= k) return t.v;
return ask_rak(t.rs, k - s);
}
找第一个大于他的数
int ask_upper(int u,int k){ //找第一个小于他的数
if(u == -1) return -INF;
Dier &t = treap[u];
if(t.v < k) return max(t.v, ask_upper(t.rs, k));
return ask_upper(t.ls, k);
}
找第一个小于他的数
int ask_lower(int u, int k){//找第一个大于他的数
if(u == -1) return INF;
Dier &t = treap[u];
if(t.v > k) return min(t.v, ask_lower(t.ls, k));
return ask_lower(t.rs, k);
}
完整代码
#include <iostream>
#include <cstdlib>
#include <cstdio>
using namespace std;
long long read(){
long long x = 0; int f = 0; char c = getchar();
while(c < '0' || c > '9') f |= c == '-', c = getchar();
while(c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
return f? -x:x;
}
const int inf = 2147483647;
int n, cnt;
struct szh{
int v, k, ls, rs, su, si, f;
szh(){v = k = -inf, ls = rs = -1, su = si = 0, f = -1;}
}tr[100005];
void T_begin(){
tr[0].v = inf, tr[0].k = -inf, tr[0].su = tr[0].si = 1;
}
void u_d(int u){
szh &t = tr[u];
t.si = t.su;
if(t.ls != -1) t.si += tr[t.ls].si;
if(t.rs != -1) t.si += tr[t.rs].si;
}
void l_t(int u){
szh &t = tr[u];
int f = t.f, nt = t.ls;
t.ls = tr[nt].rs;
if(tr[nt].rs != -1) tr[tr[nt].rs].f = u;
t.f = nt, tr[nt].rs = u;
if(tr[f].ls == u) tr[f].ls = nt, tr[nt].f = f;
else tr[f].rs = nt, tr[nt].f = f;
u_d(u), u_d(nt);
}
void r_t(int u){
szh &t = tr[u];
int f = t.f, nt = t.rs;
t.rs = tr[nt].ls;
if(tr[nt].ls != -1) tr[tr[nt].ls].f = u;
t.f = nt, tr[nt].ls = u;
if(tr[f].ls == u) tr[f].ls = nt, tr[nt].f = f;
else tr[f].rs = nt, tr[nt].f = f;
u_d(u), u_d(nt);
}
void turn(int u){
szh &t = tr[u];
if(t.ls != -1 && tr[t.ls].k < t.k) l_t(u);
else if(t.rs != -1 && tr[t.rs].k < t.k) r_t(u);
}
void build(int u, int v){
tr[u].v = v, tr[u].si = tr[u].su = 1, tr[u].k = rand();
}
void add(int u){
++tr[u].su, ++tr[u].si;
}
void add(int u, int v){
szh &t = tr[u];
if(t.v == -inf){build(u, v); return;}
if(t.v == v){add(u); return;}
if(t.v > v){
if(t.ls == -1) t.ls = ++cnt, tr[cnt].f = u;
add(t.ls, v);
}
else{
if(t.rs == -1) t.rs = ++cnt, tr[cnt].f = u;
add(t.rs, v);
}
turn(u), u_d(u);
}
void end(int u){
szh &t = tr[u];
while(t.ls != -1 || t.rs != -1)
if(t.ls != -1)
if(t.rs != -1)
if(tr[t.ls].k < tr[t.rs].k) l_t(u);
else r_t(u);
else l_t(u);
else r_t(u);
if(tr[t.f].ls == u) tr[t.f].ls = -1;
else tr[t.f].rs = -1;
for(int i = t.f; ~i; i = tr[i].f) u_d(i);
}
void del(int u){
szh &t = tr[u];
if(t.su != 1) t.su--, t.si--;
else end(u);
}
void del(int u, int v){
szh &t = tr[u];
if(t.v == v){del(u); return;}
if(t.v > v) del(t.ls, v);
else del(t.rs, v);
u_d(u);
}
int rak(int u, int k){
szh &t = tr[u];
if(t.v == k) return tr[t.ls].si + 1;
if(t.v > k) return rak(t.ls, k);
return rak(t.rs, k) + tr[t.ls].si + t.su;
}
int a_r(int u, int k){
szh &t = tr[u];
if(tr[t.ls].si >= k) return a_r(t.ls, k);
int s = tr[t.ls].si + t.su;
if(s >= k) return t.v;
return a_r(t.rs, k - s);
}
int a_u(int u, int k){
if(u == -1) return -inf;
szh &t = tr[u];
if(t.v >= k) return a_u(t.ls, k);
return max(t.v, a_u(t.rs, k));
}
int a_l(int u, int k){
if(u == -1) return inf;
szh &t = tr[u];
if(t.v > k) return min(t.v, a_l(t.ls, k));
return a_l(t.rs, k);
}
int main(){
n = read();
srand(37022059);
T_begin();
while(n--){
int a, b; a = read(); b = read();
switch(a){
case 1: add(0, b); break;
case 2: del(0, b); break;
case 3: printf("%d
", rak(0, b)); break;
case 4: printf("%d
", a_r(0, b)); break;
case 5: printf("%d
", a_u(0, b)); break;
case 6: printf("%d
", a_l(0, b)); break;
}
}
return 0;
}
(144)行,是我写过的最长的代码。