• 替罪羊树


    替罪羊树是一种依靠重构操作维持平衡的重量平衡树。替罪羊树会在插入、删除操作时,检测净土的节点,若发现失衡,则将以该节点为根的子树重构。

    序言:

    我们知道在一棵平衡的二叉搜索树内进行查询等操作时,时间就可以稳定在(log(n))但是每一次的插入和删除节点,都会使得这棵树不平衡,最会情况就是退化成一条链,显然我们不想要这种树,于是各种维护的方法出现了,有通过旋转的,有拆树在合并的,然而替罪羊树就很优美的,因为一旦发现不平衡的子树,立即拍扁重构,于是替罪羊树的核心是:暴力重建

    正题:

    替罪羊树的每个节点都包含什么:

    (left)(right):记录该节点的左右儿子

    (x):该节点的值

    (tot):有多少个值为(x)的数

    (sz,trsz,whsz)(sz)表示以该节点为根的子树内有多少个节点,(trsz)表示有多少个有效节点,(whsz)表示有多少个数(也就数子树内所有节点的(tot)的和)

    (fa):该点的父亲

    (vis):该点是否有删除标记

    操作一:加点

    先找到一个特殊的节点,如果那个节点的值等于要加的那个点,那么直接让那个节点的(tot+1)即可,如果比那个节点的值要小,就让新加的节点称为它的左儿子,大的话就是右儿子。

    那么如何找到那个"特殊的节点"?假如我们以(x)为关键字去查找,先从根节点开始,假如(x)比根节点的值要小,那我就去它的左儿子哪里,不然去右儿子,直到满足以下两个条件之一:

    • 找到了值为(x)的节点
    • 不能继续往下找

    那么这个所谓的特殊的节点的性质也就很显然,就是与新加入的节点值相同的点,或者新加入的节点的前驱或后继

    找点:

    int find(int x,int now){//now表示当前找到哪个点
        if(x < tr[now].x && tr[now].left) return find(x,tr[now].left);//比当前点的值要小并且有左儿子
        if(x > tr[now].x && tr[now].right) return find(x,tr[now].right);
        return now;
    }
    

    加点:

    void add(int x){
        if(root == 0){//假如当前没有根节点,也就是当前的树是空的,那么直接让他成为根
            build(x,root = New(),0);//新建节点(后面有讲)
            return;
        }
        int p = find(x,root);//找到特殊点
        if(x == tr[p].x{
            tr[p].tot++;
            if(tr[p].vis) tr[p].vis = 0,updata(p,1,0,1);
            else updata(p,0,0,1);
        }
        else if(x < tr[p].x) build(x,tr[p].left = New(),p),updata(p,1,1,1);
        else build(x,tr[p].right = New(),p),updata(p,1,1,1);
        find_rebuild(root,x);
    }
    

    上面用到的几个函数:

    新建节点:

    void build(int x,int y,int fa){//初始化树上编号为y的节点,它的值为x,父亲为fa
        tr[y].left = tree[y].right = 0;tr[y].fa = fa;tree[y].vis = 0;
        tr[y].x = x;tr[y].tot = tr[y].sz = tr[y].trsz = tr[y].whsz = 1;
    }
    

    (updata)函数,更新父亲以及爷爷以及祖先们的(sz,trsz还有whsz)

    void updata(int x,int y,int z,int k){
        if(!x) return;
        tr[x].trsz += y;
        tr[x].sz += z;
        tr[x].whsz += k;
        updata(tr[x].fa,y,z,k);
    }
    

    (New)函数就是个内存池。

    操作二:删点

    删点,严格来说是删掉一个数,假如我们要删掉一个值为(x)的数,那就先找到值为(x)的节点,然后(tot-1)

    难道就这么简单么?当然不是,假如(tot-1)之后变成(0)了怎么办?这意味着这个节点不存在了,然后我们删掉这个节点么?如果把它删了,他的左右儿子怎么办?所以我们不能动他,给它打个标记,标记它被删除了

    代码:

    void del(int x){
        int p = find(x,root);
        tr[p].tot--;
        if(!tr[p].tot) tr[p].vis = true,updata(p,-1,0,-1);
        else updata(p,0,0,-1);
        find_rebuild(root,x);
    }
    

    我们上面的代码提到了一个函数(find)_(rebuild),每一次的加点和删点都有可能使这棵树不平衡,假如有一颗子树不平衡,我们就需要将其重建,所以,(find) _(rebuild)就是用来查找需要重建的子树。

    先说一下怎么重建。

    因为需要重建的子树比订书二叉搜索树,那么这棵子树的中序遍历一定是一个严格上升的序列,于是我们就先中序遍历一下,把树上的有效节点放到一个数组里里面,注意无效节点(就是被打了删除标记的节点)不要。

    然后我们在把数组中的节点重建成一棵极其平衡的完全二叉树(按完全二叉树的方法来建),具体放大就是每一次选取数组中间的节点,让它成为跟,左边的当它左子树,右边的当它右子树,然后再对左右儿子进行相同的操作。

    怎么找需要重建的子树:

    我们每次(add)(del)的数为(x),在将这个数加入到树中或从树中删除之后,加入在树中值为(x)的节点是(y),那我们考虑到其实每一次可能小重构的子树只会是以“根到(y)路径上的节点”为根的子树,那么我们就可以从根往(y)走一次,看看谁要重建就好了。不是(y)往根走的原因:如果根到(y)的路径上只有两个点(a,b),并且(a)(b)的祖先,然后特别巧的是(a,b)都是需要重建的,那么这个时候我们只要重建祖先节点为根的子树,因为重建之后,另一个为根的子树在其内部也重建完了,如果(y)往根走就会出现重建两遍的情况。

    判断替罪羊树是否平衡:

    在替罪羊树中定义了个一个平衡因子(alpha)(alpha)的范围因题而异,一般在(0.5-1.0)之间。判断一棵子树是否平衡的方法:如果(x)的左(右)儿子的节点数量大于以(x)为根的子树的节点数量(*alpha),那么以(x)为根的这棵子树就是不平衡的,这时就将它重建。

    还有一种情况就是,打了删除标记的点多了,效率自然会变慢,所欲如果在一棵子树内有超过30%的几点被打了删除标记,就把这棵树重建。

    (find)_(rebuild)

    void find_rebuild(int now,int x){
        if(1.0 * tr[tr[now].left].sz > 1.0 * tr[now].sz * alpha || 1.0 * tr[tr[now].right].sz > 1.0 * tr[now].sz * alpha|| 1.0 * tr[now].sz - 1.0 * tr[now].trsz >1.0 * tr[now].sz * 0.3){
            rebuild(now);
            return;
        }
        if(tr[now].x != x) find_rebuild(x < tr[now].x ? tr[now].left : tree[now].right,x);
    }
    

    (rebuild)

    void rebuild(int x){//重建以x为根的子树
        tt = 0;
        dfs_rebuild(x);//进行中序遍历并将有效节点压入数组
        if(x == root) root = readd(1,tt,0);//x就是根,那么root就变成重建之后的那棵树的根
        //readd用来把数组里的节点重新建成一棵完全二叉树,并返回这棵树的根
        else{
            updata(tr[x].fa,0,-(tr[x].sz - tree[x].trsz),0);//因为拍扁重建后的树中,被打了删除标记的节点将消失,所以要将祖先们的size进行更改,也就是减去被删去的节点
            if(tr[tr[x].fa].left == x) tr[tr[x].fa].left = readd(1,tt,tr[x].fa);
            else tr[tr[x].fa].right = readd(1,tt,tr[x].fa);
        }
    }
    

    (readd):

    int readd(int l,int r,int fa){
        if(l > r)return 0
        int mid = (l + r) >> 1;//选中间的点作为根
        int id = New();
        tree[id].fa = fa;//更新各项
        tree[id].tot = num[mid].tot;
        tree[id].x = num[mid].x;
        tree[id].left = readd(l,mid - 1,id);
        tree[id].right = readd(mid + 1,r,id);
        tree[id].whsz = tr[tr[id].left].whsz + tr[tr[id].right].whsz + num[mid].tot;
        tree[id].sz = tr[id].trsz = r - l + 1;
        tree[id].vis = 0;
        return id;
    }
    

    中序遍历:

    void dfs_rebuild(int x){
        if(x == 0) return;
        dfs_rebuild(tr[x].left);
        if(!tr[x].vis) num[++tt].x = tr[x].x,num[tt].tot = tree[x].tot;//假如没有删除标记,就只将他的x和tot加进数组,因为其他东西都没有用
        ck[++t] = x;//仓库,存下废弃的节点
        dfs_rebuild(tr[x].right);
    }
    

    然后就是(New)

    int New(){
        if(t > 0) return ck[t--];//假如仓库内有点,就直接用
        else return ++len;//否则再创造一个点
    }
    

    然后我们就可以进行剩下的几个基本操作了

    操作三:查找(x)的排名

    我们只需要像(find)函数一样走一遍行。如果往右儿子走,就让(ans)加上左儿子的数的个数,再加上当前节点的(tot),否则就往左儿子走,走到值为(x)的点结束。

    void findx(int x){
        int now = root;
        int ans = 0;
        while(tr[now].x != x){
            if(x < tr[now].x) now = tr[now].left;
            else ans += tr[tr[now].left].whsz + tr[now].tot,now = tr[now].right;
        }
        ans += tr[tr[now].left].whsz;
        printf("%d
    ",ans + 1);
    }
    

    操作四:查找排名为(x)的数

    类似的,我们先从根走,假如当前节点的左子树的数的个数比(x)要小,那么让(x)减掉左子树的数的个数,然后在看一下当前节点的(tot)是否大于(x),如果大于的话,答案就是这个节点了,否则让(x)减去它的(tot),然后往右儿子去,重复以上操作即可。

    void findrkx(int x){
        int now = root;
        while(1){
            if(x <= tr[tr[now].left].whsz) now = tr[now].left;
            else{
                x -= tr[tr[now].left].whsz;
                if(x <= tr[now].tot){
                    printf("%d
    ",tr[now].x);
                    return;
                }
                x -= tr[now].tot;
                now = tr[now].right;
            }
        }
    }
    

    要注意!这两个函数里用的都是(whsz)

    操作五:查找(x)的前驱

    因为替罪羊树有删除标记这个东西,所以它查找前驱和后继的时候会慢一点。

    具体做法:先找到值为(x)的节点,然后普看看有没有左儿子,如果有就将左子树遍历一遍,顺序是:右儿子->根->左儿子,找到第一个没有被删除的节点就是答案。

    因为被删除的点不超过30%,所以不用担心算法会退化成(O(n))

    void dfs_rml(int x){
        if(tr[x].right != 0) dfs_rml(tr[x].right);
        if(ans) return;
        if(!tr[x].vis){
            printf("%d
    ",tr[x].x);
            ans = 1;
            return;
        }
        if(tr[x].left != 0) dfs_rml(tr[x].left);
    }
    void pre(int now,int x,bool z){
        if(!z){
        	pre(tr[now].fa,x,tr[tr[now].fa].right == now);
        	return;
        }
        if(!tr[now].vis && tr[now].x < x){
        	printf("%d
    ",tr[now].x);
        	return;
        }
        if(tr[now].left){
        	ans = 0;
            dfs_rml(tr[now].left);
            return;
        }
        pre(tr[now].fa,x,tr[tr[now].fa].right == now);
    }
    

    操作六:查找(x)的后继

    跟前驱类似

    void dfs_lmr(int x){
        if(tr[x].left != 0) dfs_lmr(tr[x].left);
        if(ans) return;
        if(!tr[x].vis){
            printf("%d
    ",tre[x].x);
            ans = 1;
            return;
        }
        if(tr[x].right != 0) dfs_lmr(tr[x].right);
    }
    void nxt(int now,int x,bool z){
        if(!z){
        	nxt(tr[now].fa,x,tr[tr[now].fa].right != now);
        	return;
        }
        if(!tr[now].vis && tr[now].x > x){
        	printf("%d
    ",tr[now].x);
        	return;
        }
        if(tr[now].right){
        	ans = 0;
            dfs_lmr(tr[now].right);
            return;
        }
        nxt(tr[now].fa,x,tr[tr[now].fa].right != now);
    }
    

    后记:

    • 关于(alpha)

      (alpha)的值究竟与效率的关系,当的(alpha)值越小,那么替罪羊树就越容易重构,那么树也就越平衡,查询的效率也就越高,自然修改(加点和删点)的效率也就低了。所以,如果查询操作比较多的话,就可以将(alpha)的值设小一点。反之,假如修改操作多,自然(alpha)的值就要大一点了。

      还有,(alpha)不能等于(1) (or) (0.5),假如它等于(0.5),那么当一棵树被重构之后如果因为节点数问题,不能完全重构成一个完全二叉树,那么显然,对于这棵树的根,他的"左子树节点数量 - 右子树节点数量"很可能会等于(1),那么如果往多的那棵子树上加一个节点,那么这棵树又得重构一次,最坏情况时间会变成(n^2)。那么等于1...会有一棵子树的大小大于整棵树的大小咩w?

    • 关于时间复杂度:

      除了重构操作,其他操作的时间复杂度显然都是(log(n))的,那么下面看一下重构的时间复杂度。

      虽然重构一次的时间复杂度是(O(n))的,但是,均摊下来其实只是(O(logn))

      考虑极端情况,每次都把整棵树重构。

      那么我们就需要每次都往根的一棵子树内加点,假设一开始是平衡的,那么左右子树各有50%的节点,那么要使一棵子树内含有超过75%的节点,那么这棵子树就需要在原来的基础上增加(2)倍的节点数。也就是说,当最差情况时,整棵替罪羊树的节点数要翻个倍,才会重构。那么最差情况时也就是在(4,8,16,32……)个节点时才会重构,于是重构的总的时间复杂度也就是(O(nlogn))了,加上一些杂七杂八的重构,也不过就是加上一个很小的常数,可以省略不计。所以,替罪羊树的时间复杂度依然是(O(nlogn))的。

    完整代码

    #define B cout << "BreakPoint" << endl;
    #define O(x) cout << #x << " " << x << endl;
    #define O_(x) cout << #x << " " << x << " ";
    #define Msz(x) cout << "Sizeof " << #x << " " << sizeof(x)/1024/1024 << " MB" << endl;
    #include<cstdio>
    #include<cmath>
    #include<iostream>
    #include<cstring>
    #include<algorithm>
    #include<queue>
    #include<set>
    #define LL long long
    const int inf = 1e9 + 9;
    const int N = 1e7 + 5;
    using namespace std;
    inline int read() {
    	int s = 0,w = 1;
    	char ch = getchar();
    	while(ch < '0' || ch > '9') {
    		if(ch == '-')
    			w = -1;
    		ch = getchar();
    	}
    	while(ch >= '0' && ch <= '9') {
    		s = s * 10 + ch - '0';
    		ch = getchar();
    	}
    	return s * w;
    }
    struct node{
    	int left,right,x,tot,sz,trsz,whsz,fa;
    	bool vis;
    } tr[N];
    struct sl{
    	int x,tot;
    }num[N];
    int len,n,root,ck[N],t;
    double alpha = 0.75;
    void build(int x,int y,int fa){
        tr[y].left = tr[y].right = 0;
    	tr[y].fa = fa,tr[y].vis = false;
        tr[y].x = x,tr[y].tot = tr[y].sz = tr[y].trsz = tr[y].whsz = 1;
    }
    inline int New(){
        if(t > 0) return ck[t--];
        else return ++len;
    }
    void updata(int x,int y,int z,int k){
        if(!x) return;
        tr[x].trsz += y;
        tr[x].sz += z;
        tr[x].whsz += k;
        updata(tr[x].fa,y,z,k);
    }
    int find(int x,int now){
        if(x < tr[now].x && tr[now].left) return find(x,tr[now].left);
        if(x > tr[now].x && tr[now].right) return find(x,tr[now].right);
        return now;
    }
    int tt;
    void dfs_rebuild(int x){
        if(x == 0)return;
        dfs_rebuild(tr[x].left);
        if(!tr[x].vis) num[++tt].x = tr[x].x,num[tt].tot = tr[x].tot;
        ck[++t] = x;
        dfs_rebuild(tr[x].right);
    }
    int readd(int l,int r,int fa){
        if(l > r) return 0;
        int mid = (l+r)>>1;
    	int id = New();
        tr[id].fa = fa;
        tr[id].tot = num[mid].tot;
        tr[id].x = num[mid].x;
        tr[id].left = readd(l,mid-1,id);
        tr[id].right = readd(mid+1,r,id);
        tr[id].whsz = tr[tr[id].left].whsz + tr[tr[id].right].whsz + num[mid].tot;
        tr[id].sz = tr[id].trsz = r - l + 1;
        tr[id].vis = false;
        return id;
    }
    void rebuild(int x){
        tt = 0;
        dfs_rebuild(x);
        if(x == root) root = readd(1,tt,0);
        else{
            updata(tr[x].fa,0,-tr[x].sz + tr[x].trsz,0);
            if(tr[tr[x].fa].left == x) tr[tr[x].fa].left = readd(1,tt,tr[x].fa);
            else tr[tr[x].fa].right = readd(1,tt,tr[x].fa);
        }
    }
    void find_rebuild(int now,int x){
        if(1.0 * tr[tr[now].left].sz > 1.0 * tr[now].sz * alpha || 1.0 * tr[tr[now].right].sz > 1.0 * tr[now].sz * alpha|| 1.0 * tr[now].sz - 1.0 * tr[now].trsz >1.0 * tr[now].sz * 0.3){
            rebuild(now);
            return;
        }
        if(tr[now].x != x) find_rebuild(x < tr[now].x ? tr[now].left : tr[now].right,x);
    }
    void add(int x){
        if(root == 0){
            build(x,root = New(),0);
            return;
        }
        int p = find(x,root);
        if(x == tr[p].x){
            tr[p].tot++;
            if(tr[p].vis) tr[p].vis = 0,updata(p,1,0,1);
            else updata(p,0,0,1);
        }
        else if(x < tr[p].x) build(x,tr[p].left = New(),p),updata(p,1,1,1);
        else build(x,tr[p].right = New(),p),updata(p,1,1,1);
        find_rebuild(root,x);
    }
    void del(int x){
        int p = find(x,root);
        tr[p].tot--;
        if(!tr[p].tot) tr[p].vis = 1,updata(p,-1,0,-1);
        else updata(p,0,0,-1);
        find_rebuild(root,x);
    }
    void findx(int x){
        int now = root;
        int ans = 0;
        while(tr[now].x != x){
            if(x < tr[now].x) now = tr[now].left;
            else ans += tr[tr[now].left].whsz + tr[now].tot,now = tr[now].right;
        }
        ans += tr[tr[now].left].whsz;
        printf("%d
    ",ans + 1);
    }
    void findrkx(int x){
        int now = root;
        while(1){
            if(x <= tr[tr[now].left].whsz) now = tr[now].left;
            else{
                x -= tr[tr[now].left].whsz;
                if(x <= tr[now].tot){
                    printf("%d
    ",tr[now].x);
                    return;
                }
                x -= tr[now].tot;
                now = tr[now].right;
            }
        }
    }
    bool ans;
    void dfs_rml(int x){
        if(tr[x].right != 0) dfs_rml(tr[x].right);
        if(ans) return;
        if(!tr[x].vis){
            printf("%d
    ",tr[x].x);
            ans = 1;
            return;
        }
        if(tr[x].left != 0) dfs_rml(tr[x].left);
    }
    void pre(int now,int x,bool z){
        if(!z){
        	pre(tr[now].fa,x,tr[tr[now].fa].right == now);
        	return;
        }
        if(!tr[now].vis && tr[now].x < x){
        	printf("%d
    ",tr[now].x);
        	return;
        }
        if(tr[now].left){
        	ans = 0;
            dfs_rml(tr[now].left);
            return;
        }
        pre(tr[now].fa,x,tr[tr[now].fa].right == now);
    }
    void dfs_lmr(int x){
        if(tr[x].left != 0) dfs_lmr(tr[x].left);
        if(ans) return;
        if(!tr[x].vis){
            printf("%d
    ",tr[x].x);
            ans = 1;
            return;
        }
        if(tr[x].right != 0) dfs_lmr(tr[x].right);
    }
    void nxt(int now,int x,bool z){
        if(!z){
        	nxt(tr[now].fa,x,tr[tr[now].fa].right != now);
        	return;
        }
        if(!tr[now].vis && tr[now].x > x){
        	printf("%d
    ",tr[now].x);
        	return;
        }
        if(tr[now].right){
        	ans = 0;
            dfs_lmr(tr[now].right);
            return;
        }
        nxt(tr[now].fa,x,tr[tr[now].fa].right != now);
    }
    int main(){
        n = read();
        while(n--){
            int id = read(),x = read();
            if(id == 1) add(x);
            if(id == 2) del(x);
            if(id == 3) findx(x);
            if(id == 4) findrkx(x);
            if(id == 5) pre(find(x,root),x,1);
            if(id == 6) nxt(find(x,root),x,1);
        }
    }
    
  • 相关阅读:
    POJ 3258 (NOIP2015 D2T1跳石头)
    POJ 3122 二分
    POJ 3104 二分
    POJ 1995 快速幂
    409. Longest Palindrome
    389. Find the Difference
    381. Insert Delete GetRandom O(1)
    380. Insert Delete GetRandom O(1)
    355. Design Twitter
    347. Top K Frequent Elements (sort map)
  • 原文地址:https://www.cnblogs.com/excellent-zzy/p/12329177.html
Copyright © 2020-2023  润新知