• 树链剖分学习笔记(一)


    0. 简介

    树链剖分用于处理树上问题(废话)
    思想是把一棵树分成若干条链,而链上的操作显然比树上好做,这样就降低了处理难度

    有多种方式能将树拆成链,因此树链剖分也有不同的种类
    这篇文章写的是其中一种:轻重链剖分(最常见,有时也被直接称为树链剖分)

    1. 概念

    定义 \(size_x\) 表示以 \(x\) 为根的子树大小(即包含的节点个数)

    \(x\) 的所有儿子中 \(size\) 值最大的那个为 \(x\)重儿子,记作 \(son_x\)
    (如果有多个儿子满足条件,随便选其中一个即可)

    对于每个节点 \(x\) ,将连接 \(x\)\(son_x\) 的边定为重边,其它的定为轻边

    将全部由重边组成的路径称为重链

    这样树就被分成了若干条重链和若干条轻边

    分完之后我们发现了一些有用的性质:

    1. \(y\)\(x\) 的儿子,但不是重儿子,则 \(size_y\le size_x/2.\)
      反证即可,假设 \(size_y>size_x/2\) ,则 \(size_y\) 比其它儿子的 \(size\) 值之和还大,故 \(y\) 是重儿子,与条件矛盾
      所以 \(size_y\le size_x/2\)

    2. 从任意一个节点到根节点的路径上最多有 \(O(\log n)\) 条轻边。
      由性质 1 可知,通过一条轻边往下走,子树大小至少减半,故显然性质成立

    3. 从任意一个节点到根节点的路径上最多有 \(O(\log n)\) 条重链。
      重链之间是用轻边分隔的(废话)
      因此重链的条数和轻边一样也是 \(O(\log n)\) 级别

    题目常会让我们对两点间路径上的所有点执行某操作
    由上面的性质易知,这条路径可被分成不超过 \(O(\log n)\) 条重链和轻边
    我们希望能快速处理重链上的操作

    对整棵树跑一遍深度优先遍历,并且优先遍历重儿子
    则同一条重链对应的 dfs 序必然是连续的一段
    然后重链上的操作就转化成了序列问题,用合适的数据结构维护即可

    下面通过一道模板题讲一下具体怎么使用

    2. 实现

    P3384 【模板】轻重链剖分/树链剖分

    首先预处理出必要的信息

    /**
     * ​ fa[x]: 节点 x 的爹
     * dep[x]: 节点 x 的深度
     *  sz[x]: 节点 x 的子树大小,即上文中的 size
     * son[x]: 节点 x 的重儿子
     * top[x]: 节点 x 所在重链的顶端的节点(深度最小)
     * ord[x]: 节点 x 的 dfs 序
     * rev[n]: dfs 序为 n 的节点,即 rev 是 ord 的逆映射
    **/
    void dfs(int x,int f) { // 第一遍 dfs ,处理 fa,dep,sz,son
       ​fa[x]=f,dep[x]=dep[f]+1,sz[x]=1;
       ​for (int i=last[x]; i; i=E[i].pre) {
           ​int y=E[i].y;
           ​if (y==f) continue;
           ​dfs(y,x);
           ​sz[x]+=sz[y];
           ​if (sz[son[x]]<sz[y]) son[x]=y;
       ​}
    }
    void dfs2(int x); // 第二遍 dfs ,处理 top,ord,rev
    void work(int x,int topx) {
       ​top[x]=topx;
       ​++n2,ord[x]=n2,rev[n2]=x;
       ​dfs2(x);
    }
    void dfs2(int x) {
       ​if (!son[x]) return ;
       ​work(son[x],top[x]); // 优先遍历重儿子
       ​for (int i=last[x]; i; i=E[i].pre)
           ​if (!top[E[i].y]) work(E[i].y,E[i].y);
    }
    

    操作 1,2 是关于两节点 \(x,y\) 之间路径的
    先给出操作 1 的代码实现:

    void solve1(int x,int y,int v) {
        int tx=top[x],ty=top[y];
        while (tx!=ty) {
            if (dep[tx]<dep[ty]) swap(x,y),swap(tx,ty);
            update(ord[tx],ord[x],v,1,n,1); // 修改
            x=fa[tx],tx=top[x]; // 跳到下一条重链底端 
        }
        if (dep[x]>dep[y]) swap(x,y);
        update(ord[x],ord[y],v,1,n,1);
    }
    

    其实就是一直让深度较大的那个点往上跳
    每次跳过一整条重链,并修改这条链上的信息
    两点位于同一条重链上时停止,此时剩下的一小段路径都在这条链上,直接修改即可

    操作 2 也没啥区别,把修改换成查询,最后返回总和即可

    ll solve2(int x,int y) {
        int tx=top[x],ty=top[y]; ll ans=0;
        while (tx!=ty) {
            if (dep[tx]<dep[ty]) swap(x,y),swap(tx,ty);
            ans=(ans+query(ord[tx],ord[x],1,n,1))%P;
            x=fa[tx],tx=top[x];
        }
        if (dep[x]>dep[y]) swap(x,y);
        return (ans+query(ord[x],ord[y],1,n,1))%P;
    }
    

    操作 3,4 和子树相关,所以一次操作涉及的所有节点的 dfs 序是连续的一段
    直接在对应区间上做修改/查询即可

    void solve3(int x,int v) {
        update(ord[x],ord[x]+sz[x]-1,v,1,n,1);
    }
    ll solve4(int x) {
        return query(ord[x],ord[x]+sz[x]-1,1,n,1);
    }
    

    其中 updatequery 的写法取决于你用什么数据结构

    顺便讲一下树状数组怎么做区间修改和区间查询

    对于区间修改,差分一波
    \(c_i=a_i-a_{i-1}\) (规定 \(a_0=0\)
    若要将 \(a_l\sim a_r\) 全部加 \(x\) ,则 \(\Delta c_l=x,\Delta c_{r+1}=-x\)

    对于区间查询,只需考虑如何求前缀和 \(S(x)=\sum\limits_{i=1}^x a_i\)

    \(\begin{aligned} S(x)&=\sum\limits_{i=1}^x a_i\\ &=\sum\limits_{i=1}^x\sum\limits_{j=1}^ic_j\\ &=\sum\limits_{j=1}^x(x+1-j)c_j\\ &=(x+1)\sum\limits_{i=1}^xc_i-\sum\limits_{i=1}^xic_i \end{aligned}\)

    \(d_i=ic_i\) ,则 \(S(x)\) 可以用 \(c\)\(d\) 的前缀和表示
    用树状数组维护 \(c\)\(d\) 即可

    因此本题中树状数组和线段树均可使用,时间复杂度都是 \(O(n+m\log ^2n)\) (跑不满)

    3. 完整代码

    线段树(使用懒标记)版本
    线段树(标记永久化)和树状数组的写得太丑了 也懒得重写 就不放了(

    P3384 SGT ver.
    #include<stdio.h>
    #include<ctype.h>
    #define Tl T[p<<1]
    #define Tr T[p<<1|1]
    typedef long long ll;
    const int N=100010;
    int n,m,root,P,n2,cnt,opt,x,y,v,value[N],last[N];
    int fa[N],dep[N],sz[N],son[N],top[N],ord[N],rev[N];
    struct edge { int y,pre; }E[N<<1];
    inline void swap(int &x,int &y) { int t=x; x=y,y=t; }
    void read(int &x);
    
    struct SGT { int len; ll tag,sum; }T[N<<2];
    void build(int l,int r,int p);
    void pushup(int p);
    void pushdown(int p);
    void update(int x,int y,int v,int l,int r,int p);
    ll query(int x,int y,int l,int r,int p);
    
    void dfs(int x,int f) {
        fa[x]=f,dep[x]=dep[f]+1,sz[x]=1;
        for (int i=last[x]; i; i=E[i].pre) {
            int y=E[i].y;
            if (y==f) continue;
            dfs(y,x);
            sz[x]+=sz[y];
            if (sz[son[x]]<sz[y]) son[x]=y;
        }
    }
    void dfs2(int x);
    void work(int x,int topx) {
        top[x]=topx;
        ++n2,ord[x]=n2,rev[n2]=x;
        dfs2(x);
    }
    void dfs2(int x) {
        if (!son[x]) return ;
        work(son[x],top[x]);
        for (int i=last[x]; i; i=E[i].pre)
            if (!top[E[i].y]) work(E[i].y,E[i].y);
    }
    
    void solve1(int x,int y,int v) {
        int tx=top[x],ty=top[y];
        while (tx!=ty) {
            if (dep[tx]<dep[ty]) swap(x,y),swap(tx,ty);
            update(ord[tx],ord[x],v,1,n,1);
            x=fa[tx],tx=top[x];
        }
        if (dep[x]>dep[y]) swap(x,y);
        update(ord[x],ord[y],v,1,n,1);
    }
    ll solve2(int x,int y) {
        int tx=top[x],ty=top[y]; ll ans=0;
        while (tx!=ty) {
            if (dep[tx]<dep[ty]) swap(x,y),swap(tx,ty);
            ans=(ans+query(ord[tx],ord[x],1,n,1))%P;
            x=fa[tx],tx=top[x];
        }
        if (dep[x]>dep[y]) swap(x,y);
        return (ans+query(ord[x],ord[y],1,n,1))%P;
    }
    void solve3(int x,int v) {
        update(ord[x],ord[x]+sz[x]-1,v,1,n,1);
    }
    ll solve4(int x) {
        return query(ord[x],ord[x]+sz[x]-1,1,n,1);
    }
    
    int main() {
        read(n),read(m),read(root),read(P);
        for (int i=1; i<=n; ++i) read(value[i]);
        for (int i=1; i<n; ++i) {
            read(x),read(y);
            E[++cnt]={y,last[x]},last[x]=cnt;
            E[++cnt]={x,last[y]},last[y]=cnt;
        }
        dfs(root,0);
        work(root,root);
        build(1,n,1);
        while (m--) {
            read(opt),read(x);
            if (opt==1) read(y),read(v),solve1(x,y,v);
            if (opt==2) read(y),printf("%lld\n",solve2(x,y));
            if (opt==3) read(v),solve3(x,v);
            if (opt==4) printf("%lld\n",solve4(x));
        }
        return 0;
    }
    
    void read(int &x) {
        x=0; char ch=getchar();
        while (!isdigit(ch)) ch=getchar();
        while (isdigit(ch)) x=x*10+(ch^48),ch=getchar();
    }
    void build(int l,int r,int p) {
        T[p].len=r-l+1;
        if (l==r) {
            T[p].sum=value[rev[l]];
            return ;
        }
        int mid=(l+r>>1);
        build(l,mid,p<<1);
        build(mid+1,r,p<<1|1);
        pushup(p);
    }
    inline void pushup(int p) {
        T[p].sum=(Tl.sum+Tr.sum)%P;
    }
    inline void pushdown(int p) {
        int t=T[p].tag;
        Tl.tag=(Tl.tag+t)%P;
        Tl.sum=(Tl.sum+t*Tl.len)%P;
        Tr.tag=(Tr.tag+t)%P;
        Tr.sum=(Tr.sum+t*Tr.len)%P;
        T[p].tag=0;
    }
    void update(int x,int y,int v,int l,int r,int p) {
        if (x<=l&&y>=r) {
            T[p].tag=(T[p].tag+v)%P;
            T[p].sum=(T[p].sum+1ll*v*T[p].len)%P;
            return ;
        }
        pushdown(p);
        int mid=(l+r>>1);
        if (x<=mid) update(x,y,v,l,mid,p<<1);
        if (y>mid) update(x,y,v,mid+1,r,p<<1|1);
        pushup(p);
    }
    ll query(int x,int y,int l,int r,int p) {
        if (x<=l&&y>=r) return T[p].sum;
        pushdown(p);
        int mid=(l+r>>1); ll ans=0;
        if (x<=mid) ans=(ans+query(x,y,l,mid,p<<1))%P;
        if (y>mid) ans=(ans+query(x,y,mid+1,r,p<<1|1))%P;
        pushup(p);
        return ans;
    }
    
  • 相关阅读:
    康托展开
    Linux Command Line Basics
    hihoCoder 1401 Registration
    C++ 参考网站
    Linux 下的常用工具
    SQL 命令
    GNU MAKE 笔记
    一道基本的计算几何题
    uva 1451 平均值
    bzoj 1826 缓存交换
  • 原文地址:https://www.cnblogs.com/REKonib/p/15550441.html
Copyright © 2020-2023  润新知