• 树链剖分算法详解


    学OI也有一段时间了,感觉该搞点东西了。

    于是学习了树()链()剖(pou)分(粪)

    当然,学习这个算法是需要先学习线段树的。不懂的还是再过一段时间吧。


    如果碰到一道题,要对一颗树的两个点中的最短路径、以u为根的子树之类的东西进行修改或者查询,那么大概就是树链剖分的题了。

    树链剖分就是把一颗树的节点按照新的顺序扔到一颗线段树里面,然后保证一条树链上的点在线段树中尽可能连续。

    为什么是尽可能?因为在一棵树中,怎么搞也无法保证对于每一个节点,他的父亲编号都是它的-1,所以是尽可能。那么怎么尽可能呢?

    有很多算法,今天提到的就是树链剖分。我们把一颗树上的所有链分成轻链重链,然后就可以对于每一段连续的重链进行线段树上的修改了。

    而划分轻链和重链的依据是:对于每一个节点u,v是它的儿子,v有一个大小,就是size,代表以v为根的子树的大小。我们选取u最大的儿子为重(zhong)儿子,其余儿子为轻儿子。以连向重儿子的边为重边,剩下的边为轻边。

    然后所有重边连成的链叫做重链,(并不存在轻链)比如下图,红色的链是重链(注意,对于一个叶子节点,如果连向它的是一条轻链,那么他自己就是一条重链)

    这样,我们把一棵树划分成了重链和轻链,我们能保证所有重链都不重不漏的包含了所有的点。

    那么这些重链有什么用?在划分重链的过程中用到的DFS,这个DFS能保证,对于每一条重链,他们的DFS序是连续的!

    这样,我们就可以用线段树(或者其他数据结构)维护了!

     现在,我们把熟练剖分化成两个部分:

    1、把树上的所有点划分重链,然后求出它们的DFS序,以这个顺序扔到线段树里面。

    2、在线段树上进行维护。

    所以,如何实现划分重链?我们需要用两个DFS,第一个DFS找到所有点的重儿子,第二个DFS将所有重儿子连成重链。

    第一个DFS:size是以当前点为根的子树的大小,f是当前点的父亲,son是当前点的重儿子。

    inline void getson(int u,int fa){//获取每个节点的重儿子 
        size[u]=1;
        for(int e=head[u];e;e=nxt[e])
            if(to[e]!=fa){
                depth[to[e]]=depth[u]+1;
                f[to[e]]=u; 
                getson(to[e],u);
                size[u]+=size[to[e]];//记录以每个节点为根的树的大小 
                if(!son[u] || size[son[u]]<size[to[e]])  son[u]=to[e];//判断后将这个点变为重儿子 
            }
        return ;
    }

    第二个DFS:

    inline void getdfn(int u,int t){//连成重链,其中我们可以保证,对于每一条重链,它们的dfn值是连续的。t记录的是当前链的链首 
        top[u]=t;//top记录当前链链首 
        dfn[u]=++cnt;//记录dfn值,也是在线段树中的位置 
        link[cnt]=u;//dfn的逆运算,用于建树时的初始赋值 
        if(!son[u])  return ;//如果当前点没有重儿子,说明是这条重链的结束。 
        getdfn(son[u],t);//继续走这条重链 
        for(int e=head[u];e;e=nxt[e])//这个相当于走每一条轻链 
            if(to[e]!=son[u] && to[e]!=f[u])
                getdfn(to[e],to[e]);//重新开始走每一条重链 
        return ;
    }

    然后,对于线段树的建树,是独立的,我们不用考虑链的关系。(input是输入文件)

    inline void build(int i,int l,int r){//平凡的建树 
        tree[i].l=l,tree[i].r=r;
        if(l==r){
            tree[i].sum=input[link[l]]%mod;//link的作用 
            return ;
        }
        int mid=(l+r)>>1;
        build(i<<1,l,mid);
        build(i<<1|1,mid+1,r);
        tree[i].sum=(tree[i<<1].sum+tree[i<<1|1].sum)%mod;
        return ;
    }

    最后是修改,查询和修改很像,一起说了。

    我们要把u到v路径上所有的点都+k,那么我们就把u,v中深的那个,它到它所在重边的顶端+k。

    然后跳过一条轻边,重复上面的步骤,知道u,v到一条重边上。

    最后把u到v,+k就可以了。

    inline void treeadd(int x,int y,int z){//将题中对树的修改转化成对线段树的修改 
        int tx=top[x],ty=top[y]; 
        while(tx!=ty){//如果两个点不在一条重链上 
            if(depth[tx]<depth[ty])  swap(x,y),swap(tx,ty);//保证x的重链首元素在下方 
            add(1,dfn[tx],dfn[x],z);//从x一直修改到x所在重链的收元素,因为他们在一条重链中,所以在线段树中的位置是连续的。 
            x=f[tx];//走过一条轻链,到上面一个重链的末尾 
            tx=top[x],ty=top[y];//分别更新x、y的重链顶端,准备下一次更新 
        }
        if(depth[x]<depth[y])  swap(x,y);//现在x、y都到了一条重链上了,然后要保证x在下面。 
        add(1,dfn[y],dfn[x],z);//再只用更新他们所在的链就可以了。 
        return ;
    }
    inline int treesum(int x,int y){//将题中对树查询得指令改为对线段树的查询。 
        int ans=0;
        int tx=top[x],ty=top[y];
        while(tx!=ty){//这一段和修改几乎一样,就是把原本对每一个区间的修改,变为了查询,其实都一样。 
            if(depth[tx]<depth[ty])  swap(tx,ty),swap(x,y);
            ans=(ans+query(1,dfn[tx],dfn[x]))%mod;
            x=f[tx];
            tx=top[x],ty=top[ty];
        }
        if(depth[x]<depth[y])  swap(x,y);
        return (ans+query(1,dfn[y],dfn[x]))%mod;
    }

    对于线段树上的维护,和朴素的线段树一样,就不多说了。

    如果题目中说要将以i为根的子树+k,那就直接在线段树上从dfn[i]到dfn[i]+size[i],+k就可以了。

    具体看AC代码:(洛谷模板题)

    #include <iostream>
    #include <cstdio>
    #include <algorithm>
    #include <cstdlib>
    #include <cstring>
    #define in(a) a=read()
    #define REP(i,k,n)  for(int i=k;i<=n;i++)
    #define MAXN 100010
    using namespace std;
    inline int read(){
        int x=0,f=1;
        char ch=getchar();
        for(;!isdigit(ch);ch=getchar())
            if(ch=='-')
                f=-1;
        for(;isdigit(ch);ch=getchar())
            x=x*10+ch-'0';
        return x*f;
    }
    int n,m,r,mod,input[MAXN];
    int total,head[MAXN],to[MAXN<<1],nxt[MAXN<<1];
    int size[MAXN],depth[MAXN],f[MAXN],son[MAXN];
    int cnt,dfn[MAXN],link[MAXN],top[MAXN];
    struct node{
        int l,r,sum,lt;
    }tree[MAXN<<2];
    inline void adl(int a,int b){
        total++;
        to[total]=b;
        nxt[total]=head[a];
        head[a]=total;
        return ;
    }
    inline void getson(int u,int fa){//获取每个节点的重儿子 
        size[u]=1;
        for(int e=head[u];e;e=nxt[e])
            if(to[e]!=fa){
                depth[to[e]]=depth[u]+1;
                f[to[e]]=u; 
                getson(to[e],u);
                size[u]+=size[to[e]];//记录以每个节点为根的树的大小 
                if(!son[u] || size[son[u]]<size[to[e]])  son[u]=to[e];//判断后将这个点变为重儿子 
            }
        return ;
    }
    inline void getdfn(int u,int t){//连成重链,其中我们可以保证,对于每一条重链,它们的dfn值是连续的。t记录的是当前链的链首 
        top[u]=t;//top记录当前链链首 
        dfn[u]=++cnt;//记录dfn值,也是在线段树中的位置 
        link[cnt]=u;//dfn的逆运算,用于建树时的初始赋值 
        if(!son[u])  return ;//如果当前点没有重儿子,说明是这条重链的结束。 
        getdfn(son[u],t);//继续走这条重链 
        for(int e=head[u];e;e=nxt[e])//这个相当于走每一条轻链 
            if(to[e]!=son[u] && to[e]!=f[u])
                getdfn(to[e],to[e]);//重新开始走每一条重链 
        return ;
    }
    inline void build(int i,int l,int r){//平凡的建树 
        tree[i].l=l,tree[i].r=r;
        if(l==r){
            tree[i].sum=input[link[l]]%mod;//link的作用 
            return ;
        }
        int mid=(l+r)>>1;
        build(i<<1,l,mid);
        build(i<<1|1,mid+1,r);
        tree[i].sum=(tree[i<<1].sum+tree[i<<1|1].sum)%mod;
        return ;
    }
    inline void pushdown(int i){//平凡的pushdown 
        if(!tree[i].lt)  return ;
        tree[i<<1].lt+=tree[i].lt;
        tree[i<<1|1].lt+=tree[i].lt;
        int mid=(tree[i].l+tree[i].r)>>1;
        tree[i<<1].sum=(tree[i<<1].sum+(mid-tree[i].l+1)*tree[i].lt)%mod;
        tree[i<<1|1].sum=(tree[i<<1|1].sum+(tree[i].r-mid)*tree[i].lt)%mod;
        tree[i].lt=0;
        return ;
    }
    inline void add(int i,int l,int r,int k){//平凡的区间修改 
        if(tree[i].l>=l && tree[i].r<=r){
            tree[i].sum=(tree[i].sum+(tree[i].r-tree[i].l+1)*k)%mod;
            tree[i].lt+=k;
            return ;
        }
        pushdown(i);
        if(tree[i<<1].r>=l)  add(i<<1,l,r,k);
        if(tree[i<<1|1].l<=r) add(i<<1|1,l,r,k);
        tree[i].sum=(tree[i<<1].sum+tree[i<<1|1].sum)%mod;
        return ;
    }
    inline int query(int i,int l,int r){//平凡的区间查询 
        if(tree[i].l>=l && tree[i].r<=r)  return tree[i].sum;
        int sum=0;
        pushdown(i);
        if(tree[i<<1].r>=l)  sum=(sum+query(i<<1,l,r))%mod;
        if(tree[i<<1|1].l<=r)  sum=(sum+query(i<<1|1,l,r))%mod;
        return sum;
    }
    inline void treeadd(int x,int y,int z){//将题中对树的修改转化成对线段树的修改 
        int tx=top[x],ty=top[y]; 
        while(tx!=ty){//如果两个点不在一条重链上 
            if(depth[tx]<depth[ty])  swap(x,y),swap(tx,ty);//保证x的重链首元素在下方 
            add(1,dfn[tx],dfn[x],z);//从x一直修改到x所在重链的收元素,因为他们在一条重链中,所以在线段树中的位置是连续的。 
            x=f[tx];//走过一条轻链,到上面一个重链的末尾 
            tx=top[x],ty=top[y];//分别更新x、y的重链顶端,准备下一次更新 
        }
        if(depth[x]<depth[y])  swap(x,y);//现在x、y都到了一条重链上了,然后要保证x在下面。 
        add(1,dfn[y],dfn[x],z);//再只用更新他们所在的链就可以了。 
        return ;
    }
    inline int treesum(int x,int y){//将题中对树查询得指令改为对线段树的查询。 
        int ans=0;
        int tx=top[x],ty=top[y];
        while(tx!=ty){//这一段和修改几乎一样,就是把原本对每一个区间的修改,变为了查询,其实都一样。 
            if(depth[tx]<depth[ty])  swap(tx,ty),swap(x,y);
            ans=(ans+query(1,dfn[tx],dfn[x]))%mod;
            x=f[tx];
            tx=top[x],ty=top[ty];
        }
        if(depth[x]<depth[y])  swap(x,y);
        return (ans+query(1,dfn[y],dfn[x]))%mod;
    }
    int main(){
        in(n),in(m),in(r),in(mod);
        REP(i,1,n)  in(input[i]);
        int a,b;
        REP(i,1,n-1)  in(a),in(b),adl(a,b),adl(b,a);
        depth[r];
        getson(r,0);
        getdfn(r,r);
        build(1,1,n);
        int p,x,y,z;
        REP(i,1,m){
            in(p);
            if(p==1)  in(x),in(y),in(z),treeadd(x,y,z);
            if(p==2)  in(x),in(y),printf("%d
    ",treesum(x,y));
            if(p==3)  in(x),in(z),add(1,dfn[x],dfn[x]+size[x]-1,z);//我们会发现,在树链剖分中,i这颗子树里面所有的节点的dfn都是连续的,我们修改u的子树就是将u到u+size-1修改就可以了。 
            if(p==4)  in(x),printf("%d
    ",query(1,dfn[x],dfn[x]+size[x]-1));//查询同上。 
        }
        return 0;
    }
  • 相关阅读:
    nyoj 95 众数问题(set)
    nyoj 93 汉诺塔(三)(stack)
    hdu 1010 Tempter of the Bone
    nyoj 55 懒省事的小明(priority_queue优先队列)
    nyoj 31 5个数求最值
    poj 1256 Anagram
    next_permutation函数
    nyoj 19 擅长排列的小明(深搜,next_permutation)
    nyoj 8 一种排序(用vector,sort,不用set)
    nyoj 5 Binary String Matching(string)
  • 原文地址:https://www.cnblogs.com/jason2003/p/9818242.html
Copyright © 2020-2023  润新知