• 树形DP 学习笔记


    树形DP学习笔记

    ps: 本文内容与蓝书一致

    树的重心

    • 概念: 一颗树中的一个节点其最大子树的节点树最小
    • 解法:对与每个节点求他儿子的(size) ,上方子树的节点个数为(n-size_u) ,求对于每个节点子树的最大值,找出最小的那个就好了;

    (我觉得就不需要code了)


    树的直径

    • 概念:一颗带权树的最长路径
    • 解法:维护一个节点到叶子节点的最大距离(d1[i])和次大距离(d2[i]) ,最大距离就是$max {d1[i]+d2[i] } $

    code

    #include<iostream>
    #include<cstdio>
    using namespace std;
    const int N=1e4+5;
    int n;
    struct pp
    {
        int to,next;
    }w[2*N];
    int head[N],cnt;
    int d1[N],d2[N];
    int ans;
    void add(int x,int y)
    {
        cnt++;
        w[cnt].next=head[x];
        w[cnt].to=y;
        head[x]=cnt;
    }
    void dfs(int x,int fa)
    {
        for(int i=head[x];i;i=w[i].next)
        {
            int t=w[i].to;
            if(t!=fa)
            {
                dfs(t,x);
                if(d1[t]+1>d1[x])
                {
                    d2[x]=d1[x];
                    d1[x]=d1[t]+1;
                }
                else if(d1[t]+1>d2[x]) d2[x]=d1[t]+1;
            }
        }
        return ;
    }
    void find_ans(int x,int fa)
    {
        ans=max(ans,d1[x]+d2[x]);
        for(int i=head[x];i;i=w[i].next)
        {
            int t=w[i].to;
            if(t!=fa) find_ans(t,x);
        }
        return;
    }
    int main()
    {
    #ifndef ONLINE_JUDGE
        freopen("diam.in","r",stdin);
        freopen("diam.out","w",stdout);
    #endif
        scanf("%d",&n);
        for(int i=1;i<n;i++)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            add(x,y);
            add(y,x);
        }
        dfs(1,0);
        find_ans(1,0);
        printf("%d",ans);
        return 0;
    }
    

    例题

    P4480 逃学的小孩

    • 大概思路:求出树的直径以及其左右端点,再设(d[i])为树上节点(i)到左右端点距离更小的那个,然后求出(max {d[i]}),然后以这个值加上直径就是(ans)

    code

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #define ll long long
    using namespace std;
    const int N=2e5+5;
    struct pp
    {
        int next,to;
        ll qu;
    }w[N*2];
    int head[N],cnt;
    int n,m;
    bool v[N];
    ll d1[N],d2[N],dl[N],dr[N];
    int f1[N],f2[N];
    int r,l;
    ll ans,mans;
    void add(int x,int y,int z)
    {
        w[++cnt].next=head[x];
        w[cnt].qu=z;
        w[cnt].to=y;
        head[x]=cnt;
    }
    int read()
    {
        int f=1;
        char ch;
        while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
        int res=ch-'0';
        while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
        return res*f;
    }
    void dfs1(int x)
    {
        if(v[x]) return ;
        v[x]=1;
        for(int i=head[x];i;i=w[i].next)
        {
            int t=w[i].to;
            if(!v[t])
            {
                dfs1(t);
                if(d1[t]+w[i].qu>d1[x])
                {
                    f2[x]=f1[x];
                    f1[x]=f1[t];
                    d2[x]=d1[x];
                    d1[x]=d1[t]+w[i].qu;
                }
                else if(d1[t]+w[i].qu>d2[x]) d2[x]=d1[t]+w[i].qu,f2[x]=f1[t];
            }
            
        }
        return;
    }
    void find_ans(int x)
    {
        if(v[x]) return;
        v[x]=1;
        if(ans<d1[x]+d2[x])
        {
            ans=d1[x]+d2[x];
            l=f1[x];
            r=f2[x];
        }
        for(int i=head[x];i;i=w[i].next) find_ans(w[i].to);
    }
    void dfs2(int x)
    {
        if(v[x]) return;
        v[x]=1;
        for(int i=head[x];i;i=w[i].next)
        {
            int t=w[i].to;
            if(!v[t])
            {
                dl[t]=dl[x]+w[i].qu;
                dfs2(t);
            }
        }
        return;
    }
    void dfs3(int x)
    {
        if(v[x])return;
        v[x]=1;
        
        for(int i=head[x];i;i=w[i].next)
        {
            int t=w[i].to;
            if(!v[t])
            {
                dr[t]=dr[x]+w[i].qu;
                dfs3(t);
            }
        }
        return;
    }
    void dfs_ans(int x)
    {
        if(v[x]) return;
        v[x]=1;
        mans=max(mans,min(dl[x],dr[x]));
        for(int i=head[x];i;i=w[i].next) dfs_ans(w[i].to);
        return;
    }
    int main()
    {
    #ifndef ONLINE_JUDGE
        freopen("Chris.in","r",stdin);
        freopen("Chris.out","w",stdout);
    #endif
        n=read();
        m=read();
        for(int i=1;i<=m;i++)
        {
            int x,y,z;
            x=read(),y=read(),z=read();
            add(x,y,z);
            add(y,x,z);
        }
        for(int i=1;i<=n;i++) f1[i]=i;
        dfs1(1);
        memset(v,0,sizeof(v));
        find_ans(1);
        memset(v,0,sizeof(v));
        dfs2(l);
        memset(v,0,sizeof(v));
        dfs3(r);
        memset(v,0,sizeof(v));
        dfs_ans(1);
        printf("%lld",ans+mans);
        return 0;
    }
    

    树的中心

    • 概念:给出一颗带权树,求一个节点,使得此节点到树中其他节点的最远距离最小;

    • 解法:如果是一颗没有负边权的树,那直接找到直径的中点就好;

      但是这里我们考虑有负边权的情况:

      有两种情况:

      1. (u)点向上的最长路径,设为(up[u]);
      2. (u)点向下,即(u)到叶节点的最远距离,设为(d1[u])(次远设为(d2[u]));

      (d1[u])(d2[u])都会求,问题是(up[u])该怎么求?

      还是分类讨论,设(u)的父亲为(x),(d1[x])来自于子节点(v);那对于(u):

      1. 如果(u!=v),那么(up[u]=max{d1[x],up[x]}+dis[x][t]);
      2. 如果(u==v),那么(up[u]=max{d2[x],up[x]}+dis[x][t]),这也是为什么要维护(d2[x])的原因;

    code

    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    using namespace std;
    const int N=1e5+5;
    struct pp
    {
        int next,to;
    }w[2*N];
    int n,k;
    int head[N],cnt;
    int d1[N],d2[N],pre[N],u[N];
    int root,far;
    int read()
    {
        int f=1;
        char ch;
        while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
        int res=ch-'0';
        while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
        return res*f;
    }
    void add(int x,int y)
    {
        cnt++;
        w[cnt].next=head[x];
        w[cnt].to=y;
        head[x]=cnt;
    }
    bool cmp(int x,int y) {return x>y;}
    void dfs1(int x,int fa)
    {
        for(int i=head[x];i;i=w[i].next)
        {
            int t=w[i].to;
            if(t!=fa)
            {
                dfs1(t,x);
                if(d1[t]+1>d1[x])
                {
                    pre[x]=t;
                    d2[x]=d1[x];
                    d1[x]=d1[t]+1;
                }
                else if(d1[t]+1>d2[x]) d2[x]=d1[t]+1;
            }
        }
        return;
    }
    void dfs2(int x,int fa)
    {
        int minx=min(u[x],d1[x]);
        if(far<minx)
        {
            root=x;
            far=minx;
        }
        for(int i=head[x];i;i=w[i].next)
        {
            int t=w[i].to;
            if (t!=fa)
            {
                if(pre[x]!=t) u[t]=max(d1[x],u[x])+1;
                else u[t]=max(d2[x],u[x])+1;
                dfs2(t,x);
            }
        }
        return ;
    }
    int main()
    {
        n=read(),k=read();
        for(int i=1;i<n;i++)
        {
            int x,y;
            x=read(),y=read();
            add(x,y);
            add(y,x);
        }
        dfs1(1,0);
        dfs2(1,0);
        printf("%d",root);
        return 0;
    }
    

    例题

    P5536核心城市

    • 思路:显然其中一定会有一个城市为这颗树的中心;那找出这个中心,把这颗无根树变为以它为根的有根树;再求出除根节点以外的每个节点所能到达的最大深度(deepfar[i]),这就是这个节点最远所能到达的距离;然后(sort)一下(deepfar[]),答案就是(deepfar[k+1]+1);

    code

    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    using namespace std;
    const int N=1e5+5;
    struct pp
    {
        int next,to;
    }w[2*N];
    int n,k;
    int head[N],cnt;
    int d1[N],d2[N],pre[N],u[N];
    int fardeep[N];
    int root,far;
    int read()
    {
        int f=1;
        char ch;
        while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
        int res=ch-'0';
        while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
        return res*f;
    }
    void add(int x,int y)
    {
        cnt++;
        w[cnt].next=head[x];
        w[cnt].to=y;
        head[x]=cnt;
    }
    bool cmp(int x,int y) {return x>y;}
    void dfs1(int x,int fa)
    {
        for(int i=head[x];i;i=w[i].next)
        {
            int t=w[i].to;
            if(t!=fa)
            {
                dfs1(t,x);
                if(d1[t]+1>d1[x])
                {
                    pre[x]=t;
                    d2[x]=d1[x];
                    d1[x]=d1[t]+1;
                }
                else if(d1[t]+1>d2[x]) d2[x]=d1[t]+1;
            }
        }
        return;
    }
    void dfs2(int x,int fa)
    {
        for(int i=head[x];i;i=w[i].next)
        {
            int t=w[i].to;
            if (t!=fa)
            {
                if(pre[x]!=t) u[t]=max(d1[x],u[x])+1;
                else u[t]=max(d2[x],u[x])+1;
                dfs2(t,x);
            }
        }
        return ;
    }
    void dfs3(int x,int fa)
    {
        int minx=min(u[x],d1[x]);
        if(far<minx)
        {
            root=x;
            far=minx;
        }
        for(int i=head[x];i;i=w[i].next) if(w[i].to!=fa) dfs3(w[i].to,x);
        return;
    }
    void dfs4(int x,int fa)
    {
        for(int i=head[x];i;i=w[i].next)
        {
            int t=w[i].to;
            if(t!=fa)
            {
                dfs4(w[i].to,x);
                fardeep[x]=max(fardeep[x],fardeep[t]+1);
            }
        }
    }
    int main()
    {
    #ifndef ONLINE_JUDGE
        freopen("XR-3.in","r",stdin);
        freopen("XR-3.out","w",stdout);
    #endif
        n=read(),k=read();
        for(int i=1;i<n;i++)
        {
            int x,y;
            x=read(),y=read();
            add(x,y);
            add(y,x);
        }
        dfs1(1,0);
        dfs2(1,0);
        dfs3(1,0);
        dfs4(root,0);
        sort(fardeep+1,fardeep+1+n,cmp);
        printf("%d",fardeep[k+1]+1);
        return 0;
    }
    

    上面都是有关树的一些经典题型,下面才是今天的主角——树型DP


    背包类树型DP

    (我觉得把,其实左右子树类树型DP可以归为这一类)

    例题

    选课

    书上的是时间复杂度为(n^3)的算法,这里介绍一个优化,可以讲其降为(n^2);

    • 泛化物品优化:具体是什么,请参考2009年国家集训队论文——徐持衡《浅谈几类背包问题》,其中有详细解释;

    • 而我对泛化物品优化的感性理解就是:"预留空间"——为在 (u) 到到根节点的路径上(包括u)的点预留空间。

      这样就可以在对 (u)DP的时候保证他所依赖的物品预先算进去了

      (dp[u][j])的意思就是在预留(u)及其到根节点的路径上的点的空间后,还剩下(j)的空间的最大价值;

    • 没有优化前,DP方程为:

    • 没有优化前,DP方程为:

    [dp[u][j]=max{dp[u][j],dp[u][j-k]+dp[v][k]} ]

    这样对于每个节点都要(n^2)暴力枚举(j)(k);

    经过优化,我们的DP方程就变为了:

    [egin{cases} dp[v][j]=dp[u][j](dfs前)\ dp[u][j]=max{dp[u][j],dp[v][j-w[v]]+val[v]}(回溯时) end{cases} ]

    这也是再泛化物品优化下,树型背包的基本DP方程;这样我们只需要(O(n))枚举(j)就好了;


    ps: 以下代码参考价值不大,建议参考[HAOI2010]软件安装

    code

    #include<iostream>
    #include<algorithm>
    #include<queue>
    #include<cstdio>
    #include<cstring>
    using namespace std;
    
    int n,m;
    struct edge
    {
        int next,to;
    }e[1000];
    int rt,head[1000],tot,val[1000],dp[1000][1000];
    void add(int x,int y)
    {
        e[++tot].next=head[x];
        head[x]=tot;
        e[tot].to=y;
    }
    void dfs(int u,int t)
    {
        if (t<=0) return ;
        for (int i=head[u]; i; i=e[i].next)
        {
            int v = e[i].to;
            for (int j=0; j<=t-1; ++j) //为v预留空间
                dp[v][j] = dp[u][j];
            dfs(v,t-1);//对于v的现有空间
            for (int j=1; j<=t; ++j) 
                dp[u][j] = max(dp[u][j],dp[v][j-1]+val[v]);//背包
        }
    }
    int main()
    {
        scanf("%d%d",&n,&m);
        for(int i=1;i<=n;i++)
        {
            int a;
            scanf("%d%d",&a,&val[i]);
            if(a)
              add(a,i);
            if(!a)add(0,i);
        }
        dfs(0,m);
        printf("%d",dp[0][m]);
    }
    

    选择类树型DP

    基本DP方程:

    [vin{son(u)} egin{cases} f[u][0]=sum f[v][1] \ f[u][1]=min{f[v][1],f[v][0]}+1 end{cases} ]

    例题

    P2016战略游戏

    直接套DP方程就好了;

    code

    #include<iostream>
    #include<cstdio>
    using namespace std;
    int n;
    int dp[1605][2];
    struct pp
    {
    	int next,to;
    }w[1600<<1];
    int head[1600],cnt;
    void add(int x,int y)
    {
    	cnt++;
    	w[cnt].to=y;
    	w[cnt].next=head[x];
    	head[x]=cnt;
    }
    void dfs(int x,int fa)
    {
    	dp[x][1]=1;
    	for(int i=head[x];i;i=w[i].next)
    	{
    		int t=w[i].to;
    		if(t==fa) continue;
    		dfs(t,x);
    		dp[x][0]+=dp[t][1];
    		dp[x][1]+=min(dp[t][0],dp[t][1]);
    	}
    	return;
    }
    int main()
    {
    	scanf("%d",&n);
    	for(int i=1;i<=n;i++)
    	{
    		int a,k;
    		scanf("%d%d",&a,&k);
    		for(int i=1;i<=k;i++)
    		{
    			int b;
    			scanf("%d",&b);
    			add(a,b);
    			add(b,a);
    		}
    	}
    	dfs(0,0);
    	printf("%d",min(dp[0][1],dp[0][0]));
    	return 0;
    }
    

    普通树型DP

    这种树型DP更加灵活,就不像前两种有基本固定的DP方程,所以还是直接来几道例题;(滑稽

    例题

    LOJ #10157. 皇宫看守

    乍一看题,啊哈,模板选择树型DP,开开心心打个代码,恭喜你0分;

    仔细一看这道题其实不是什么没有上司的舞会,而是一道覆盖DP题,区别在哪呢?

    这道题一条边两端至少要有一个点,可以有两个,而没有上司我舞会是一条边两端至多有一个点,可以没有;

    那好,这样的话一个节点u的最少经费就不能像选择DP一样单纯的由儿子选不选的而转移过来,因为他们本来互不冲突,而是必须被覆盖到(这里每个节点的覆盖半径为1),这样对于一个节点u的最少经费就可以由覆盖它的节点转移过来,这样的话就需要考虑三种情况:

    首先设(dp[u][0])表示被节点(u)被父亲覆盖且(u)不选,(dp[u][1])表示被自己的子节点覆盖且(u)不选,(dp[u][2])表示被自己覆盖;

    所以有状态转移方程:

    • 对于(dp[u][0]),因为(u)不选,所以对于(u)的子节点(v),要么被(son(v))所覆盖,要么被(v)自己覆盖:

    [dp[u][0]=sum min{dp[v][1],dp[v][2]} +a[f[u]]; ]

    • 对于(dp[u][1]),要保证(u)必须被一个子节点所覆盖到,还要保证(u)的子节点(v)在不被父亲覆盖的前提下被覆盖到,那显然(dp[u][1]),是由(dp[v][1])(dp[v][2])转移过来的,但是如何保证(dp[u][1])的转移中一定包含(dp[v][2])呢?

      这时候有个巧妙的办法,设个参数:

      [d=min{d,dp[v][2]-min{dp[v][1],dp[v][2]}} ]

      (d)的初始值为(0x7fffffff);

      这样对于(dp[u][1])就有状态转移方程:

      [dp[u][1]=sum min{dp[v][1],dp[v][2]}+d ]

    • 对于(dp[u][2]),那很显然它可以由子节点任意三种状态转移过来,但是对于(dp[v][0]),它已经加过一遍(a[u]),而对于(dp[u][2]),只能且必须加一遍(a[u]),那怎么办呢?单独特判由(dp[v][0])转移过来的情况,控制(a[u])只加一遍?显然是可以的,但是太麻烦了,那么另外考虑,这里可以看到(dp[v][0])只会往(dp[u][2])上转移,那么可以根据(dp[u][2])需求对(dp[v][0])状态转移方程改一改:

      [dp[u][0]=sum min{dp[v][1],dp[v][2]} ]

      (这里的(u)是对于(v)来说的)

      感性理解一下就是如果(dp[u][2])不由(dp[v][0])转移过来那要(dp[v][0])也没有什么用,那由(dp[v][0])转移过来,那在(dp[u][2])这加一遍(a[u])就够了,因为(dp[u][2])已经保证了(u)被选,所以不需要(dp[v][0])再保证一遍;

      这样对于(dp[u][2]),就有状态转移方程:

      [dp[u][2]=sum min{dp[v][1],dp[v][2],dp[v][0]} +a[u] ]

    总结下来就有三个状态转移方程:

    [egin{cases} dp[u][0]=sum min{dp[v][1],dp[v][2]};\ dp[u][1]=sum min{dp[v][1],dp[v][2]}+d ;(d=min{d,dp[v][2]-min{dp[v][1],dp[v][2]}})\ dp[u][2]=sum min{dp[v][1],dp[v][2],dp[v][0]} +a[u] end{cases} ]

    (所以,显然书上的状态转移方程是错的)

    不难发现,修改后的(dp[v][0])一定小于等于(dp[v][1]);所以写代码的时候我顺手把(dp[u][2])的转移方程改成了:

    [dp[u][2]=sum min{dp[v][2],dp[v][0]} +a[u] ]

    虽然题目早已经解决了,但我还是想再深究一下;这个方程啥意思?

    以我的感性理解就是(v)既然已经一定会被它爹(u)覆盖到,那就可以不需要保证(v)一定被它的儿子所覆盖,修改后的(dp[v][0])刚好就是这种情况;

    (好了,bb了这么多废话,就一点有用的东西,直接上代码吧)

    code

    #include <iostream>
    #include <cstdio>
    using namespace std;
    const int N = 1500 + 5;
    int dp[N][3];
    int v[N], n, root;
    struct pp {
        int next, to;
    } w[N];
    int head[N], cnt, du[N];
    void add(int x, int y) {
        cnt++;
        w[cnt].next = head[x];
        w[cnt].to = y;
        head[x] = cnt;
    }
    void dfs(int x) {
        int d = 0x7fffffff;
        for (int i = head[x]; i; i = w[i].next) {
            int t = w[i].to;
            dfs(t);
            dp[x][0] += min(dp[t][1], dp[t][2]);
            dp[x][1] += min(dp[t][1], dp[t][2]);
            d = min(d, dp[t][2] - min(dp[t][1], dp[t][2]));
            dp[x][2] += min(dp[t][2], dp[t][0]);
        }
        dp[x][1] += d;
        dp[x][2] += v[x];
    }
    int main() {
    #ifndef ONLINE_JUDGE
        freopen("guard.in", "r", stdin);
        freopen("guard.out", "w", stdout);
    #endif
        scanf("%d", &n);
        for (int i = 1; i <= n; i++) {
            int x, m;
            scanf("%d", &x);
            scanf("%d", &v[x]);
            scanf("%d", &m);
            for (int j = 1; j <= m; j++) {
                int y;
                scanf("%d", &y);
                add(x, y);
                du[y]++;
            }
        }
        for (int i = 1; i <= n; i++)
            if (!du[i])
                root = i;
        dfs(root);
        printf("%d", min(dp[root][1], dp[root][2]));
        return 0;
    }
    

    好了,差不多就结束了,虽然写这个一点耗时,但对于我这个蒟蒻来说加深了对于DP的理解,收获也不小,也不算浪费时间了吧(逃);


    PS: 2020.10.9 添加了我对泛化物品优化的理解

  • 相关阅读:
    XTU 1250 Super Fast Fourier Transform
    XTU 1249 Rolling Variance
    XTU 1243 2016
    CodeForces 710A King Moves
    CodeForces 710B Optimal Point on a Line
    HDU 4472 Count
    并查集
    高精度四件套
    Kruskal最小生成树
    [蓝桥杯]翻硬币(贪心的详解附带题目测试链接)
  • 原文地址:https://www.cnblogs.com/Wednesday-zfz/p/12209729.html
Copyright © 2020-2023  润新知