• 树链剖分[模板]


    传送门

    树链剖分

    树链剖分就是把一颗树分成很多条链,然后把链上的数据进行瞎搞操作(本题是用线段树区间修改)

    一步一步慢慢讲:

    1. 从根节点开始对整颗树进行一次遍历

    求出每个节点子树的大小,父节点,深度和重儿子

    重儿子指 儿子子树大小最大 的儿子节点

    (做这些都是为了后面瞎搞)

    2. 再来一次遍历..

    这次求出每个节点 所在重链的顶端,在dfs序中的编号(要让重链的节点编号连续)以及这个编号的节点的值(节点是原树的节点)

    重链是指一段连续的重儿子连在一起形成的链,还要注意对于非重儿子节点,它本身是重链的起点

    (做这些还是为了后面瞎搞)

    3.瞎搞

    本题用是线段树搞..

    那就建立线段树

    把节点在dfs序中的编号从 1 ~ n 建立线段树

    注意了,第2次遍历时重链的编号是连续的,所以可以用线段树直接存重链的信息

    题目要区间加,区间求和,子树加,子树求和

    区间操作就用线段树对重链慢慢搞就好了

    重要的是对子树操作

    还是注意第2次遍历时是深度优先遍历

    所以可以发现,

    对于一颗子树,整颗子树的编号刚好是从根节点的编号 ~ 根节点编号+子树的大小-1(子树的大小还包括本身)

    所以根本不用管每个节点具体的编号,直接用线段树瞎搞就OK了

    具体怎么搞还是看代码吧:

    (讲得还是很清楚的.....吧)

    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cmath>
    #include<cstring>
    #include<vector>
    using namespace std;
    int n,m,roo,mo;//n为节点数,m为操作数,roo为根节点,mo为取模数
    int va[100005];//节点的值
    vector <int> v[100005];//v[x][i]的值表示点x到点v[x][i]有一条边
    int siz[100005],fa[100005],dep[100005],son[100005];
    //siz[x]表示以x为根节点(包括本身)的子树的节点数
    //fa[x]表示点x的父节点
    //dep[x]表示x的深度(根节点深度为1)
    //son[x]表示点x的重儿子的编号
    //第一遍dfs确定siz,fa,dep和son
    void dfs1(int x,int f,int deep)//x为节点编号,f为父节点,deep为深度
    {
        siz[x]=1; fa[x]=f; dep[x]=deep;
        int len=v[x].size(),masiz=0;//masiz为当前点x最大的的子树大小
        for(int i=0;i<len;i++)
        {
            int u=v[x][i];
            if(u==f) continue;//如果为父节点就跳过,否则u就是儿子节点
            dfs1(u,x,deep+1);//向下深搜
            siz[x]+=siz[u];//siz[x]等于当前x点的子树大小再加上儿子节点u的大小
            if(siz[u]>masiz)//如果儿子节点u的子树大小大于之前儿子节点子树大小的最大值
            {
                masiz=siz[u];//更新masiz
                son[x]=u;//更新重儿子
            }
        }
    }
    int top[100005],id[100005],val[100005],cnt;
    //top[x]表示x所在的重链的顶端
    //id[x]表示节点x在dfs序中的编号(线段树需要用到)
    //val[x]表示点x在线段树中的值
    //第2遍dfs确定top,id和val
    void dfs2(int x,int topp)//x为节点编号,topp表示x所在的重链的顶端
    {
        id[x]=++cnt;
        top[x]=topp;
        val[cnt]=va[x];
        if(son[x]==0) return;//如果没有儿子直接返回
        dfs2(son[x],topp);//优先向重儿子dfs,保证重链编号连续
        int len=v[x].size();
        for(int i=0;i<len;i++)
        {
            int u=v[x][i];
            if(u==son[x]||u==fa[x]) continue;//如果节点u是父节点或重儿子则跳过(重儿子搜过了)
            dfs2(u,u);//此时u是轻儿子,轻儿子所在的重链以自己为开端
        }
    }
    //线段树
    int t[400005],laz[400005];//t是线段树的节点,laz为懒标记
      //建树
    inline void build(int o,int l,int r)//o是线段树的节点编号,l和r分别为线段树区间的左右端点
    {
        if(l==r)//如果区间为一个点
        {
            t[o]=val[l];//更新节点值
            return;//下面没有节点了,返回
        }
        //否则
        int mid=l+r>>1;
         //建树
        build(o*2,l,mid);
        build(o*2+1,mid+1,r);
        //更新节点值
        t[o]=(t[o*2]+t[o*2+1])%mo;
    }
      //下传懒标记
    inline void push(int o,int l,int r)//o,l,r同上
    {
        int mid=l+r>>1;
         //更新儿子的值
        t[o*2]=(t[o*2]+laz[o]*(mid-l+1))%mo;
        t[o*2+1]=(t[o*2+1]+laz[o]*(r-mid))%mo;
         //更新儿子的懒标记的值
        laz[o*2]=(laz[o*2]+laz[o])%mo;
        laz[o*2+1]=(laz[o*2+1]+laz[o])%mo;
        laz[o]=0;//原节点的懒标记已下传,清零
    }
      //更新区间值
    inline void change(int o,int l,int r,int ql,int qr,int x)
    {
        //o,l,r同上,ql和qr为要更新的左右端点,x为要增加的值
        if(l>qr||r<ql) return;//如果当前区间与要更新的区间无关,返回
        if(l>=ql&&r<=qr)//如果当前区间在要更新的区间内部,更新
        {
            t[o]=(t[o]+x*(r-l+1))%mo;//更新节点值
            laz[o]+=x;//更新懒标记
            return;//当前区间在要更新的区间内部,不需要向下更新
        }
         //当前区间不完全在更新的区间内
        int mid=l+r>>1;
        push(o,l,r);//下传懒标记并更新儿子的值
          //尝试更新左右儿子节点
        change(o*2,l,mid,ql,qr,x);
        change(o*2+1,mid+1,r,ql,qr,x);
        t[o]=(t[o*2]+t[o*2+1])%mo;//更新节点值
    }
      //查询区间值
    inline int query(int o,int l,int r,int ql,int qr)
    {
        //o,l,r同上,ql和qr为要查询的区间
        if(l>qr||r<ql) return 0;//如果当前区间与要查询的区间无关,返回0
        if(l>=ql&&r<=qr) return t[o];//如果当前区间在要查询的区间内部
        //不需要继续向下,直接返回当前区间的值
         //当前区间不完全在更新的区间内
        int mid=l+r>>1;
        push(o,l,r);//下传懒标记并更新儿子的值
        int res=(query(o*2,l,mid,ql,qr)+query(o*2+1,mid+1,r,ql,qr))%mo;
        t[o]=(t[o*2]+t[o*2+1])%mo;//更新节点值(之前的push函数可能会更新儿子的值)
        return res;
    }
    //将树从x到y结点最短路径上所有节点的值都加上z
    inline void ins1(int x,int y,int z)//x,y,z意义如上
    {
        while(top[x]!=top[y])//若x和y不在同一条重链上
        {
            if(dep[top[x]]<dep[top[y]]) swap(x,y);
            //首先保证点x所在的重链顶端的深度较大才能保证最终x和y一定能“走”到同一条重链
            change(1,1,n,id[top[x]],id[x],z);
            //更新点x所在的重链的值
            //(因为第2遍dfs后一条重链的编号是连续的所以可以直接将整个区间更新)
            x=fa[top[x]];//这条链更新完了就向下一条链更新
        }
        //此时点x和点y已处于同一条重链
        if(dep[x]>dep[y]) swap(x,y);//保证点x的深度更小
        change(1,1,n,id[x],id[y],z);//剩下只要更新id[x]到id[y]的区间就好了
    }
    //将以x为根节点的子树内所有节点值都加上z
    inline void ins2(int x,int z)//x,z意义同上
    {
        //在第2遍dfs时已经使点x的后代节点的编号(id[])分别从id[x]+1到di[x]+siz[x]-1
        //减1是因为siz[x]还包括自己
        change(1,1,n,id[x],id[x]+siz[x]-1,z);//直接更新x的子树(包括自己)
    }
    //求树从x到y结点最短路径上所有节点的值之和
    inline int q1(int x,int y)//x,y意义如上
    {
        int res=0;//res储存结果
         //像更新操作一样,也是用线段树查询一整个区间
        while(top[x]!=top[y])
        {
            if(dep[top[x]]<dep[top[y]]) swap(x,y);
            res=(res+query(1,1,n,id[top[x]],id[x]))%mo;
            x=fa[top[x]];
        }
        if(dep[x]>dep[y]) swap(x,y);
        res=(res+query(1,1,n,id[x],id[y]))%mo;
        return res;
    }
    //求以x为根节点的子树内所有节点值之和
    inline int q2(int x)
    {
        //像更新操作一样,也是用线段树查询一整个区间
        return query(1,1,n,id[x],id[x]+siz[x]-1);
    }
    //终于来到了主程序...
    int main()
    {
        int a,b,c;
        cin>>n>>m>>roo>>mo;
        for(int i=1;i<=n;i++)
            scanf("%d",&va[i]),va[i]%=mo;
        for(int i=1;i<n;i++)
        {
            scanf("%d%d",&a,&b);
            v[a].push_back(b);
            v[b].push_back(a);
        }
        dfs1(roo,0,1);//第1遍深搜
        dfs2(roo,roo);//第2遍深搜
        build(1,1,n);//建树
        int k;
        while(m--)
        {
            scanf("%d",&k);
            if(k==1)
            {
                scanf("%d%d%d",&a,&b,&c);
                c%=mo;
                ins1(a,b,c);
            }
            if(k==2)
            {
                scanf("%d%d",&a,&b);
                printf("%d
    ",q1(a,b));
            }
            if(k==3)
            {
                scanf("%d%d",&a,&b);
                b%=mo;
                ins2(a,b);
            }
            if(k==4)
            {
                scanf("%d",&a);
                printf("%d
    ",q2(a));
            }
        }
        return 0;
    }

    感觉好像注释有点多了....

  • 相关阅读:
    Java多线程
    SpringCloud
    Java 多线程
    MySQL
    MySQL
    SpringCloud
    SpringCloud
    SpringCloud
    SpringBoot
    Spring MVC
  • 原文地址:https://www.cnblogs.com/LLTYYC/p/9526741.html
Copyright © 2020-2023  润新知