• [学习笔记]树套树


    引言

    树套树,顾名思义,就是要将两种或多种树形数据结构结合起来,解决一些单独无法解决的问题。

    如果说要解决区间上的问题,如最大值,区间修改等,肯定会想到线段树

    但是线段树不能查询第k大,不能查询一个数在区间的排名,自然也不能查询前驱和后继。

    平衡树可以解决查询排名、前驱、后继等问题,但其不能限定区间。

    文艺平衡树中有操作可以把区间锁定在一个结点的子树,问题是只能通过翻转左右子树,来实现区间翻转。

    既然单独无法解决这个问题,那就将两种树形数据结构结合起来。

    原理

    很多人都对树套树望而生畏,包括我。。。

    以前只知道通过树套树可以解决的问题,但没有敲过

    经常听到队友说几个线段树,再用一个主席树维护什么什么的,但其实原理不难,只要懂树套树中的这两种树。

    举个例子:

    如图这是个线段树

    假设这个序列是: 5 2 3 4 5 7 8 9 3 1 (随便写的)

    现在我要查2-7区间中第5个数即5在这个区间排第几小,

    [3-7]区间即:[3],[4-5],[6-7]

    第几小即计算有多少个比它小,然后加一

    [3]:1个

    [4-5]:1个

    [6-7]:0个

    所以他是第3小的。

    将每个子区间得到的答案求和利用的是线段树

    而中间每个区间查询有多少比它小利用的是平衡树Splay

    线段树的每个结点建立一个Splay

    有人会怀疑空间复杂度不够,如果把Splay封装,每个Splay都是(N)的大小必然不够,我们不需要事先开辟那么多空间来建Splay

    代码中是:(开局一个root,然后记录每个线段树结点的root就行了)

    void build(int p,int l,int r){
        t[p].l = l,t[p].r = r;
        //线段树每个结点建立一个splay
        sp.ins(t[p].rt,-inf);
        sp.ins(t[p].rt,inf);
        for(int i = l;i <= r;++i){
            sp.ins(t[p].rt,arr[i]);
        }
        if(l == r){ t[p].mx = t[p].mn = arr[l];return; }
        int mid = (l+r) >> 1;
        build(p<<1,l,mid);
        build(p<<1|1,mid + 1,r);
        pushUp(p);
    }
    

    这里有个问题,root会发生变化,所以线段树结点中定义的root并不是一成不变的,这需要用到引用,即传地址

    还有就是要插入两个无穷大结点,来解决不存在的情况。

    应用-模板题

    1. 查询k在区间内的排名
    2. 查询区间内排名为k的值
    3. 修改某一位值上的数值
    4. 查询k在区间内的前驱(前驱定义为严格小于x,且最大的数,若不存在输出-2147483647)
    5. 查询k在区间内的后继(后继定义为严格大于x,且最小的数,若不存在输出2147483647)

    先复制上封装好的Splay

    struct Splay {
         int get(int x) {return s[s[x].fa].ch[1] == x;}
    
         void Clear(int x) {
             s[x].fa = s[x].ch[0] = s[x].ch[1] = s[x].sz = s[x].val =0;
         }
    
         void maintain(int x){
             s[x].sz = s[s[x].ch[0]].sz + s[s[x].ch[1]].sz + s[x].cnt;
         }
    
         void Rorate(int x){
             int y = s[x].fa, z = s[y].fa, chk = get(x);
    
             s[y].ch[chk] = s[x].ch[chk ^ 1];
             s[s[x].ch[chk ^ 1]].fa = y;
    
             s[y].fa = x;
             s[x].ch[chk ^ 1] = y;
    
             s[x].fa =z;
             if(z) s[z].ch[s[z].ch[1] == y] = x;
    
             maintain(y);
             maintain(x);
         }
    
         void splay(int &rt,int x,int y){
             for(int f = s[x].fa;f != y;Rorate(x),f=s[x].fa){
                 if(s[f].fa != y) Rorate(get(x) == get(f) ? f : x);
             }
             if(y==0) rt = x;
         }
    
         void ins(int &root ,int val){
             if(!root) {
                 root = ++tot;
                 s[root].val = val;
                 s[root].cnt++;
                 maintain(root);
                 return ;
             }
    
             int f = 0, x = root;
             while(true){
                 if(s[x].val == val){
                     s[x].cnt ++;
                     maintain(x);
                     maintain(f);
                     splay(root,x,0);
                     return;
                 }
    
                 f = x;
                 x = s[x].ch[s[x].val < val];
                 if(!x) {
                     s[++tot].val = val;
                     s[tot].cnt = 1;
                     s[tot].fa = f;
                     s[f].ch[s[f].val < val] = tot;
                     maintain(tot);
                     maintain(f);
                     splay(root,tot,0);
                     return ;
                 }
             }
         }
    
        inline int Find(int &rt,int k) {
            int res = 0,now = rt;
            while(true) {
                if(k<s[now].val) {
                    now = s[now].ch[0];
                }else {
                    //否则加上右子树的个数
                    res += s[s[now].ch[0]].sz;
                    //中序遍历,如果找到这个节点返回res+1
                    if(k == s[now].val) {
                        splay(rt,now,0);
                        return res + 1;
                    }
                    res += s[now].cnt;
                    now = s[now].ch[1];
                }
            }
        }
    
         int getPre(int rt){
             int now = s[rt].ch[0];
             while (s[now].ch[1]) now = s[now].ch[1];
             return now;
         }
    
         int getNxt(int rt){
             int now = s[rt].ch[1];
             while (s[now].ch[0]) now = s[now].ch[0];
             return now;
         }
         
        inline void del(int &rt,int k){
           Find(rt,k);//先让该点成为根节点
            if(s[rt].cnt > 1) {//如果大于1,不需要删除节点
                s[rt].cnt--;
                maintain(rt);
                return;
            }
            //如果只有一个点
            if(!s[rt].ch[0] && !s[rt].ch[1]){
                Clear(rt);
                rt = 0;
                return;
            }
            //没有左儿子,让右儿子成为根节点
            if(!s[rt].ch[0]){
                int tmp = rt;
                rt = s[rt].ch[1];
                s[rt].fa=0;
                Clear(tmp);
                return;
            }
            //没有右儿子,让左儿子成为根节点
            if(!s[rt].ch[1]){
                int tmp = rt;
                rt = s[rt].ch[0];
                s[rt].fa = 0;
                Clear(tmp);
                return;
            }
            //有左右儿子,让前驱成为根节点
            int x = getPre(rt) , now = rt;
            splay(rt,x,0);
            s[s[now].ch[1]].fa = x;
            s[x].ch[1] = s[now].ch[1];
            Clear(now);
            maintain(rt);
        }
    }sp;
    
    • 问题1之前提到了,就是在Splay中插入这个结点,然后返回这个结点的左儿子的Size就行,记得减去无穷大的那个点。
    int query_order(int p,int l,int r,int val){
        //查询顺序,就是查有多少个比他小
        if(l <= t[p].l && t[p].r <= r){
           sp.ins(t[p].rt,val);
           int res = s[s[t[p].rt].ch[0]].sz-1;
           sp.del(t[p].rt,val);
           return res;
        }
        int mid = (t[p].l + t[p].r) >> 1,res = 0;
        if(l <= mid) res += query_order(p << 1,l,r,val);
        if(mid < r) res += query_order(p << 1|1,l,r,val);
        return res;
    }
    
    • 问题2求排名k的值,这需要用到二分,二分check函数就是问题1的query_order,在区间权值范围内二分,权值越大排名越大,就是在单调递增区间中查询小于k的数的最大值(因为有一个无穷小结点,所以不能小于等于)。二分模板也很明显:
    int query_number(int L,int R,int val){
        int l = 1,r = getMax(1,L,R) ,mid,tmp;
        while(l < r){
            mid = (l + r + 1)>>1;
            tmp = query_order(1,L,R,mid);
            if(tmp < val){
                l = mid;
            }else{
                r = mid - 1;
            }
        }
        return l;
    }
    
    • 问题3是修改,这个不难,这个点所在的所有线段树结点都要删除该点在Splay树上的结点,然后加入新值。
    void modify(int p,int pos,int val){
         sp.del(t[p].rt,arr[pos]);
         sp.ins(t[p].rt,val);
         if(t[p].l == t[p].r){
             t[p].mx = val;
             t[p].mn = val;
             arr[pos] = val;
             return;
         }
         int mid = (t[p].l + t[p].r) >> 1;
         if(pos <= mid) modify(p << 1,pos,val);
         if(pos > mid) modify(p << 1 | 1,pos,val);
         pushUp(p);
    }
    
    • 问题4,查询前驱,即查询每个线段树区间最大的比该数小的数,最后取个最大值。5同理
    int query_Pre(int p,int l,int r,int val){
        if(l <= t[p].l && r >= t[p].r){
            sp.ins(t[p].rt,val);
            int res = s[sp.getPre(t[p].rt)].val;
            sp.del(t[p].rt,val);
            return res;
        }
        int res = -inf,mid = (t[p].l + t[p].r) >> 1;
        if(l <= mid)  res = max(res,query_Pre(p << 1,l,r,val));
        if(r > mid)  res = max(res,query_Pre(p << 1|1,l,r,val));
        return res;
    }
    
    int query_Nxt(int p,int l,int r,int val){
        if(l <= t[p].l && r >= t[p].r){
            sp.ins(t[p].rt,val);
            int res = s[sp.getNxt(t[p].rt)].val;
            sp.del(t[p].rt,val);
            return res;
        }
        int res = inf,mid = (t[p].l + t[p].r) >> 1;
        if(l <= mid)  res = min(res,query_Nxt(p << 1,l,r,val));
        if(r > mid)  res = min(res,query_Nxt(p << 1|1,l,r,val));
        return res;
    }
    
    • 中途为了优化二分(也没什么用),还加了线段树查询最大值和最小值的
    int getMax(int p,int l,int r){
        if(l <= t[p].l && t[p].r <=r) return t[p].mx;
        int mid = (t[p].l + t[p].r) >> 1,res = -inf;
        if(l <= mid) res = max(res,getMax(p << 1,l,r));
        if(mid < r) res = max(res,getMax(p << 1 | 1,l,r));
        return res;
    }
    
    int getMin(int p,int l,int r){
        if(l <= t[p].l && t[p].r <= r) return t[p].mx;
        int mid = (t[p].l + t[p].r) >> 1 ,res = inf;
        if(l <= mid) res = min(res,getMin(p << 1,l,r));
        if(mid < r) res = min(res,getMin(p << 1|1,l,r));
        return res;
    }
    

    完整代码

    #pragma GCC optimize(2)
    #pragma GCC optimize(3,"Ofast","inline")
    #include<bits/stdc++.h>
    using namespace std;
    #define ll long long
    
    const int N = 1e7+7;
    const int inf = 2147483647;
    int tot;//节点个数
    struct node {
        int fa;//父亲节点
        int ch[2];//子节点
        int val;//权值
        int sz;//子树大小
        int cnt;
    }s[N];
    struct Tree{
        int rt,l,r,mx,mn;
    }t[N];
    int arr[N];
    struct Splay {
         int get(int x) {return s[s[x].fa].ch[1] == x;}
    
         void Clear(int x) {
             s[x].fa = s[x].ch[0] = s[x].ch[1] = s[x].sz = s[x].val =0;
         }
    
         void maintain(int x){
             s[x].sz = s[s[x].ch[0]].sz + s[s[x].ch[1]].sz + s[x].cnt;
         }
    
         void Rorate(int x){
             int y = s[x].fa, z = s[y].fa, chk = get(x);
    
             s[y].ch[chk] = s[x].ch[chk ^ 1];
             s[s[x].ch[chk ^ 1]].fa = y;
    
             s[y].fa = x;
             s[x].ch[chk ^ 1] = y;
    
             s[x].fa =z;
             if(z) s[z].ch[s[z].ch[1] == y] = x;
    
             maintain(y);
             maintain(x);
         }
    
         void splay(int &rt,int x,int y){
             for(int f = s[x].fa;f != y;Rorate(x),f=s[x].fa){
                 if(s[f].fa != y) Rorate(get(x) == get(f) ? f : x);
             }
             if(y==0) rt = x;
         }
    
         void ins(int &root ,int val){
             if(!root) {
                 root = ++tot;
                 s[root].val = val;
                 s[root].cnt++;
                 maintain(root);
                 return ;
             }
    
             int f = 0, x = root;
             while(true){
                 if(s[x].val == val){
                     s[x].cnt ++;
                     maintain(x);
                     maintain(f);
                     splay(root,x,0);
                     return;
                 }
    
                 f = x;
                 x = s[x].ch[s[x].val < val];
                 if(!x) {
                     s[++tot].val = val;
                     s[tot].cnt = 1;
                     s[tot].fa = f;
                     s[f].ch[s[f].val < val] = tot;
                     maintain(tot);
                     maintain(f);
                     splay(root,tot,0);
                     return ;
                 }
             }
         }
    
        inline int Find(int &rt,int k) {
            int res = 0,now = rt;
            while(true) {
                if(k<s[now].val) {
                    now = s[now].ch[0];
                }else {
                    //否则加上右子树的个数
                    res += s[s[now].ch[0]].sz;
                    //中序遍历,如果找到这个节点返回res+1
                    if(k == s[now].val) {
                        splay(rt,now,0);
                        return res + 1;
                    }
                    res += s[now].cnt;
                    now = s[now].ch[1];
                }
            }
        }
    
         int getPre(int rt){
             int now = s[rt].ch[0];
             while (s[now].ch[1]) now = s[now].ch[1];
             return now;
         }
    
         int getNxt(int rt){
             int now = s[rt].ch[1];
             while (s[now].ch[0]) now = s[now].ch[0];
             return now;
         }
    
        inline void del(int &rt,int k){
           Find(rt,k);//先让该点成为根节点
            if(s[rt].cnt > 1) {//如果大于1,不需要删除节点
                s[rt].cnt--;
                maintain(rt);
                return;
            }
            //如果只有一个点
            if(!s[rt].ch[0] && !s[rt].ch[1]){
                Clear(rt);
                rt = 0;
                return;
            }
            //没有左儿子,让右儿子成为根节点
            if(!s[rt].ch[0]){
                int tmp = rt;
                rt = s[rt].ch[1];
                s[rt].fa=0;
                Clear(tmp);
                return;
            }
            //没有右儿子,让左儿子成为根节点
            if(!s[rt].ch[1]){
                int tmp = rt;
                rt = s[rt].ch[0];
                s[rt].fa = 0;
                Clear(tmp);
                return;
            }
            //有左右儿子,让前驱成为根节点
            int x = getPre(rt) , now = rt;
            splay(rt,x,0);
            s[s[now].ch[1]].fa = x;
            s[x].ch[1] = s[now].ch[1];
            Clear(now);
            maintain(rt);
        }
    }sp;
    
    void pushUp(int x){
        t[x].mx = max(t[x<<1].mx,t[x<<1|1].mx);
        t[x].mn = min(t[x<<1].mn,t[x<<1|1].mn);
    }
    
    void build(int p,int l,int r){
        t[p].l = l,t[p].r = r;
        //线段树每个结点建立一个splay
        sp.ins(t[p].rt,-inf);
        sp.ins(t[p].rt,inf);
        for(int i = l;i <= r;++i){
            sp.ins(t[p].rt,arr[i]);
        }
        if(l == r){ t[p].mx = t[p].mn = arr[l];return; }
        int mid = (l+r) >> 1;
        build(p<<1,l,mid);
        build(p<<1|1,mid + 1,r);
        pushUp(p);
    }
    
    int getMax(int p,int l,int r){
        if(l <= t[p].l && t[p].r <=r) return t[p].mx;
        int mid = (t[p].l + t[p].r) >> 1,res = -inf;
        if(l <= mid) res = max(res,getMax(p << 1,l,r));
        if(mid < r) res = max(res,getMax(p << 1 | 1,l,r));
        return res;
    }
    
    int getMin(int p,int l,int r){
        if(l <= t[p].l && t[p].r <= r) return t[p].mx;
        int mid = (t[p].l + t[p].r) >> 1 ,res = inf;
        if(l <= mid) res = min(res,getMin(p << 1,l,r));
        if(mid < r) res = min(res,getMin(p << 1|1,l,r));
        return res;
    }
    
    int query_order(int p,int l,int r,int val){
        //查询顺序,就是查有多少个比他小
        if(l <= t[p].l && t[p].r <= r){
           sp.ins(t[p].rt,val);
           int res = s[s[t[p].rt].ch[0]].sz-1;
           sp.del(t[p].rt,val);
           return res;
        }
        int mid = (t[p].l + t[p].r) >> 1,res = 0;
        if(l <= mid) res += query_order(p << 1,l,r,val);
        if(mid < r) res += query_order(p << 1|1,l,r,val);
        return res;
    }
    
    void modify(int p,int pos,int val){
         sp.del(t[p].rt,arr[pos]);
         sp.ins(t[p].rt,val);
         if(t[p].l == t[p].r){
             t[p].mx = val;
             t[p].mn = val;
             arr[pos] = val;
             return;
         }
         int mid = (t[p].l + t[p].r) >> 1;
         if(pos <= mid) modify(p << 1,pos,val);
         if(pos > mid) modify(p << 1 | 1,pos,val);
         pushUp(p);
    }
    
    int query_Pre(int p,int l,int r,int val){
        if(l <= t[p].l && r >= t[p].r){
            sp.ins(t[p].rt,val);
            int res = s[sp.getPre(t[p].rt)].val;
            sp.del(t[p].rt,val);
            return res;
        }
        int res = -inf,mid = (t[p].l + t[p].r) >> 1;
        if(l <= mid)  res = max(res,query_Pre(p << 1,l,r,val));
        if(r > mid)  res = max(res,query_Pre(p << 1|1,l,r,val));
        return res;
    }
    
    int query_Nxt(int p,int l,int r,int val){
        if(l <= t[p].l && r >= t[p].r){
            sp.ins(t[p].rt,val);
            int res = s[sp.getNxt(t[p].rt)].val;
            sp.del(t[p].rt,val);
            return res;
        }
        int res = inf,mid = (t[p].l + t[p].r) >> 1;
        if(l <= mid)  res = min(res,query_Nxt(p << 1,l,r,val));
        if(r > mid)  res = min(res,query_Nxt(p << 1|1,l,r,val));
        return res;
    }
    
    int query_number(int L,int R,int val){
        int l = 1,r = getMax(1,L,R) ,mid,tmp;
        while(l < r){
            mid = (l + r + 1)>>1;
            tmp = query_order(1,L,R,mid);
            if(tmp < val){
                l = mid;
            }else{
                r = mid - 1;
            }
        }
        return l;
    }
    
    int main(){
        int n,q,op,l,r,pos;
        scanf("%d%d",&n,&q);
        for(int i=1;i<=n;++i) scanf("%d",&arr[i]);
        build(1,1,n);
        while(q--){
            scanf("%d",&op);
            if(op == 1){
                scanf("%d%d%d",&l,&r,&pos);
                printf("%d
    ",query_order(1,l,r,pos)+1);
            }else if(op == 2){
                scanf("%d%d%d",&l,&r,&pos);
                printf("%d
    ",query_number(l,r,pos));
            }else if(op == 3){
                scanf("%d%d",&l,&pos);
                modify(1,l,pos);
            }else if(op == 4){
                scanf("%d%d%d",&l,&r,&pos);
                printf("%d
    ",query_Pre(1,l,r,pos));
            }else if(op == 5){
                scanf("%d%d%d",&l,&r,&pos);
                printf("%d
    ",query_Nxt(1,l,r,pos));
            }
        }
        return 0;
    }
    
    

    代码不加O2优化会超时,如果要优化的话,可以加个输入输出挂。

    后记

    博客两周年快乐。

    这是第一篇博客https://www.cnblogs.com/smallocean/p/8525932.html:2018.3.7

    发现自己留下的东西都可以当作时间胶囊,等未来某天翻看的时候,仿佛能看到那个时候的自己。

  • 相关阅读:
    什么是https?
    Gojs
    GoJs 01讲解
    你真的了解WebSocket吗?
    django channels
    序列化及反序列化
    全角转半角
    Thread Culture
    设置输入法
    token的认证使用
  • 原文地址:https://www.cnblogs.com/smallocean/p/12436130.html
Copyright © 2020-2023  润新知