• 动态dp


    动态 dp

    动态 (dp) 是由猫坤大佬在 WC 2018 提出来的黑科技。

    他主要解决得是带修改的 (dp) 问题,主要思路是由矩阵乘法来维护 (dp) 转移

    我们先来看一道模板题 动态dp

    这道题,我们先来看不带修改的情况,

    转移和状态很容易就能列出来

    (f[i][0/1]) 表示 以 (i) 为根的子树中,且选 (不选) (i) 这个点的最大权值和

    转移方程可以写成:

    (f[i][0] = displaystyle sum_{to in son[i]} max (f[to][0],f[to][1]))

    (f[i][1] = displaystyle sum_{to in son[x]} f[to][0] + w[x])

    当我们修改一个点的权值,那么从他到根节点的路径上的点的 (dp) 值都会受到影响。

    所以,我们就可以只修改这一条链上的信息(用树链剖分来维护)。

    动态 (dp) 的思想就是将 重儿子和轻儿子的贡献分开来考虑。

    (g[i][0] = displaystyle sum _{to in 轻儿子} max(f[to][0],f[to][1])), $g[i][1] = displaystyle sum_{to in 轻儿子} f[to][0] +w[i] $.

    解释一下 (g[i][0]) 表示 选 或不选 (i) 的轻儿子的最大贡献和, (g[i][1]) 表示不选他轻儿子的价值和加上他本身的权值.

    这两个可以在求 (f[x][0/1]) 的时候顺便维护出来

    那么,上面的方程就可以写成。

    (f[i][0] = g[i][0] + max(f[son[i]][0],f[son][i][1]))

    (f[i][1] = g[i][1] + f[son[i]][0]) ((son[x]) 表示 (x) 的重儿子)

    定义广义矩阵乘法:

    矩阵 (c) 为 矩阵 A 和矩阵 B 的乘积

    那么,(c[i][j] = max( a[i][k] + b[k][j]))

    这个其实就是把普通的矩阵乘法的乘号改为加号,加号改为取 (max)

    这种矩阵乘法也满足普通矩阵乘法的性质,不相信的可以跑几组数据试试

    代码实现长这样:

    node operator *(node x,node y)
    {
        node ans;
        for(int i = 0; i <= n; i++)
            for(int j = 0; j <= n; j++)
                for(int k = 0; k <= n; k++)
                    ans.a[i][j] = max(ans.a[i][j],x.a[i][k] + y.a[k][j]);
        return ans;
    }
    

    然后,我们就可以根据这个广义矩阵乘法构造一个矩阵,即,

    [left[ egin{matrix} f[to] [0] \ f[to] [1] \ end{matrix} ight] ag{2} imes left[ egin{matrix} cdots \ cdots \ end{matrix} ight] = left[ egin{matrix} f[i][0] \ f[i][1] \ end{matrix} ight] ]

    其中 第二个矩阵是我们要 确定的转移矩阵。

    第二个矩阵构造出来长这样:

    [left[ egin{matrix} f[to] [0] \ f[to] [1] \ end{matrix} ight] ag{2} imes left[ egin{matrix} g[i][0] & g[i][1] \ g[i][0] & -infin \ end{matrix} ight] = left[ egin{matrix} f[i][0] \ f[i][1] \ end{matrix} ight] ]

    但,我们发现一个问题,我们矩阵中没有这个 (f) 值,这,我们一开始的矩阵就没法转移啊。

    其实,对于每一个重链的末尾节点,他的转移矩阵就是他的 (f) 值,

    也就是说重链的末端节点保存了 f的准确值 -by treaker

    这样,我们就可以由这个矩阵来推出这一条重链上每个点的信息。

    我们的暴力做法就是一直跳他的父亲就是一直跳他的父亲节点,再把矩阵乘起来。

    但,这样的复杂度还是不能接受。我们可以把这个矩阵放到线段树上来维护。

    线段树的每个叶子节点存这个点的转移矩阵,每个节点表示左右儿子矩阵的乘积,

    这样就可以用 O(log n) 的时间得到我们想要的信息。

    在结合树剖套线段树的做法,就能做到维护这棵树的信息。

    注意:矩阵乘法要右乘,因为我们是从 dfn序大的地方跳到 dfn序小的地方,在线段树上对应的是从右往左。

    我们上面的转移矩阵就需要重构一下,变成

    [left[ egin{matrix} g[i][0] & g[i][0] \ g[i][1] & - infin \ end{matrix} ight] ag{2} imes left[ egin{matrix} f[son[x]][0]\ f[son[x]][1]\ end{matrix} ight] = left[ egin{matrix} f[i][0] \ f[i][1] \ end{matrix} ight] ]

    修改操作,上文我们提到了 ‘当我们修改一个点的权值,那么从他到根节点的路径上的点的 (dp) 值都会受到影响’。

    我们修改的就是从这个点到根节点的路径的矩阵,我们把这条链琛出来,发现他是重链和轻边交替在一起的。

    我们仔细观察一下我们构造的矩阵,发现里面维护的是轻儿子的信息,不会涉及到重儿子。

    对于重链,我们不用修改,但对于重链链顶 $top[x] $和轻边 $fa[top[x]] $交替的地方,我们需要单点修改。

    因为此时链顶 (top[x]) 属于 $fa[top[x]] $轻儿子,他的修改会对他父亲的转移矩阵造成影响。

    我们关键要算出他改变之后对他父亲转移矩阵的影响。

    他的 (f) 值我们可以由下面的推出来,然后可以根据 (g) 数组的定义算出他改变之后对他父亲的影响。

    就这样在更新,每次跳重链,直到跳到根节点,我们的修改操作就大工告成了。

    注意一下,我们不能修改之后在统计他的 轻儿子的值,那样可能不对,我们就只能通过增量法来修改。

    记录一下原来状态的矩阵,在记录修改之后的矩阵,两者之差就是对他父亲转移矩阵的贡献。

    每次修改操作的时间复杂度为 O((log^2 n))

    具体代码长这样

    void modify(int x,int val)
    {
        base[dfn[x]].a[1][0] += val - w[x];//增量法统计他修改的贡献
        w[x] = val;
        while(x)
        {
            node Old = get_node(top[x]);//记录一下链顶修改之前的矩阵
            chenge(1,1,n,dfn[x]);//修改当前这个节点的转移矩阵
            node New = get_node(top[x]);//得到链顶修改之后的转移矩阵
            int fx = dfn[fa[top[x]]];
            base[fx].a[0][0] += max(New.a[0][0],New.a[1][0]) - max(Old.a[0][0],Old.a[1][0]);//算他修改对他父亲转移矩阵的影响
            base[fx].a[0][1] += max(New.a[0][0],New.a[1][0]) - max(Old.a[0][0],Old.a[1][0]);
            base[fx].a[1][0] += New.a[0][0] - Old.a[0][0];
            x = fa[top[x]];//跳链修改
        }
    }
    

    总代码:

    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cstring>
    using namespace std;
    const int inf = 233333333;
    const int N = 1e5+10;
    int n,m,tot,x,y,num,val,u,v;
    int head[N],top[N],dep[N],siz[N],fa[N],son[N],ord[N],end[N],f[N][2],g[N][2],dfn[N],w[N];
    struct bian
    {
        int to,net;
    }e[N<<1];
    struct node
    {
        int a[2][2];
        node() {memset(a,-127/3,sizeof(a));}
    }tr[N<<2],base[N];
    inline int read()
    {
        int s = 0,w = 1; char ch = getchar();
        while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
        while(ch >= '0' && ch <= '9'){s =s * 10+ch - '0'; ch = getchar();}
        return s * w;
    }
    node operator *(node x,node y)//广义矩阵乘法
    {
        node ans;
        for(int i = 0; i <= 1; i++)
        {
            for(int j = 0; j <= 1; j++)
            {
                for(int k = 0; k <= 1; k++)
                {
                    ans.a[i][j] = max(ans.a[i][j],x.a[i][k] + y.a[k][j]);
                }
            }
        }
        return ans;
    }
    void add(int x,int y)
    {
        e[++tot].to = y;
        e[tot].net = head[x];
        head[x] = tot;
    }
    void get_tree(int x)//树剖预处理
    {
        dep[x] = dep[fa[x]] + 1; siz[x] = 1;
        for(int i = head[x]; i; i = e[i].net)
        {
            int to = e[i].to;
            if(to == fa[x]) continue;
            fa[to] = x;
            get_tree(to);
            siz[x] += siz[to];
            if(siz[son[x]] < siz[to] || son[x] == -1) son[x] = to;
        }
    }
    void dfs(int x,int topp)
    {
        top[x] = topp; dfn[x] = ++num; ord[num] = x;//END 记录这条重链的链尾,ord记录当前 当前这个节点的树上编号
        if(son[x] == -1)
        {
            end[topp] = num;
            return;
        }
        dfs(son[x],topp);
        for(int i = head[x]; i; i = e[i].net)
        {
            int to = e[i].to;
            if(to == fa[x] || to == son[x]) continue;
            dfs(to,to);
        }
    }
    void dp(int x,int fa)//预处理为没修改之前的 f 值与 g 值
    {
        g[x][1] = f[x][1] = w[x];
        for(int i = head[x]; i; i = e[i].net)
        {
            int to = e[i].to;
            if(to == fa) continue;
            dp(to,x);
            f[x][0] += max(f[to][0], f[to][1]);
            f[x][1] += f[to][0];
            if(to != son[x])
            {
                g[x][0] += max(f[to][0], f[to][1]);
                g[x][1] += f[to][0];
            }
        }
    }
    void up(int o)
    {
        tr[o] = tr[o<<1] * tr[o<<1|1];
    }
    void build(int o,int L,int R)//线段树建树
    {
        if(L == R)
        {
        	int tmp = ord[L];
            tr[o].a[0][0] = tr[o].a[0][1] = g[tmp][0];//构造转移矩阵
            tr[o].a[1][0] = g[tmp][1]; 
            base[L] = tr[o];
            return;
        }
        int mid = (L + R)>>1;
        build(o<<1,L,mid);
        build(o<<1|1,mid+1,R);
        up(o);
    }
    void chenge(int o,int l,int r,int x)//单点修改
    {
        if(l == r)
        {
            tr[o] = base[l];
            return;
        }
        int mid = (l + r)>>1;
        if(x <= mid) chenge(o<<1,l,mid,x);
        if(x > mid) chenge(o<<1|1,mid+1,r,x);
        up(o);
    }
    node query(int o,int l,int r,int L,int R)//区间查询
    {
        if(L <= l && R >= r) return tr[o];
        int mid = (l + r)>>1;
        if(R <= mid) return query(o<<1,l,mid,L,R);
        if(L > mid) return query(o<<1|1,mid+1,r,L,R);
        return query(o<<1,l,mid,L,R) * query(o<<1|1,mid+1,r,L,R);
    }
    node get_node(int x)//得到链顶的 f 值
    {
        return query(1,1,n,dfn[x],end[top[x]]);
    }
    void modify(int x,int val)
    {
        base[dfn[x]].a[1][0] += val - w[x];//增量法统计他修改的贡献
        w[x] = val;
        while(x)
        {
            node Old = get_node(top[x]);//记录一下链顶修改之前的矩阵
            chenge(1,1,n,dfn[x]);//修改当前这个节点的转移矩阵
            node New = get_node(top[x]);//得到链顶修改之后的转移矩阵
            int fx = dfn[fa[top[x]]];
            base[fx].a[0][0] += max(New.a[0][0],New.a[1][0]) - max(Old.a[0][0],Old.a[1][0]);//算他修改对他父亲转移矩阵的影响
            base[fx].a[0][1] += max(New.a[0][0],New.a[1][0]) - max(Old.a[0][0],Old.a[1][0]);
            base[fx].a[1][0] += New.a[0][0] - Old.a[0][0];
            x = fa[top[x]];//跳链修改
        }
    }
    int main()
    {
        n = read(); m = read();
        for(int i = 1; i <= n; i++)
        {
            w[i] = read();
            son[i] = -1;
        }
        for(int i = 1; i <= n-1; i++)
        {
        	u = read(); v = read();
        	add(u,v); add(v,u);
        }
        get_tree(1); dfs(1,1); dp(1,0); build(1,1,n);//预处理
        for(int i = 1; i <= m; i++)
        {
            x = read(); val = read();
            modify(x,val);//修改操作
            node ans = get_node(1);//得到新答案
            printf("%d
    ",max(ans.a[0][0],ans.a[1][0]));
        }
        return 0;
    }
    
  • 相关阅读:
    判断输入的年份是闰年还是平年!!!
    键盘接收数,接收运算符号进行运算!!
    eclipse项目上如何传到码云上!新手,简单易懂,希望对你有所帮助。
    jquery dialog弹出框简单写法和一些属性的应用,写的不好,大佬勿喷!谢谢!
    新手冒泡排序,随机生成十个数。
    新手java九九乘法表
    Lambda表达式
    如何删除gitee仓库的文件
    Collection方法
    java冒泡排序的几种写法
  • 原文地址:https://www.cnblogs.com/genshy/p/13606010.html
Copyright © 2020-2023  润新知