树形DP学习笔记
ps: 本文内容与蓝书一致
树的重心
- 概念: 一颗树中的一个节点其最大子树的节点树最小
- 解法:对与每个节点求他儿子的(size) ,上方子树的节点个数为(n-size_u) ,求对于每个节点子树的最大值,找出最小的那个就好了;
(我觉得就不需要code了)
树的直径
- 概念:一颗带权树的最长路径
- 解法:维护一个节点到叶子节点的最大距离(d1[i])和次大距离(d2[i]) ,最大距离就是$max {d1[i]+d2[i] } $
code
#include<iostream>
#include<cstdio>
using namespace std;
const int N=1e4+5;
int n;
struct pp
{
int to,next;
}w[2*N];
int head[N],cnt;
int d1[N],d2[N];
int ans;
void add(int x,int y)
{
cnt++;
w[cnt].next=head[x];
w[cnt].to=y;
head[x]=cnt;
}
void dfs(int x,int fa)
{
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(t!=fa)
{
dfs(t,x);
if(d1[t]+1>d1[x])
{
d2[x]=d1[x];
d1[x]=d1[t]+1;
}
else if(d1[t]+1>d2[x]) d2[x]=d1[t]+1;
}
}
return ;
}
void find_ans(int x,int fa)
{
ans=max(ans,d1[x]+d2[x]);
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(t!=fa) find_ans(t,x);
}
return;
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("diam.in","r",stdin);
freopen("diam.out","w",stdout);
#endif
scanf("%d",&n);
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs(1,0);
find_ans(1,0);
printf("%d",ans);
return 0;
}
例题
P4480 逃学的小孩
- 大概思路:求出树的直径以及其左右端点,再设(d[i])为树上节点(i)到左右端点距离更小的那个,然后求出(max {d[i]}),然后以这个值加上直径就是(ans) ;
code
#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
using namespace std;
const int N=2e5+5;
struct pp
{
int next,to;
ll qu;
}w[N*2];
int head[N],cnt;
int n,m;
bool v[N];
ll d1[N],d2[N],dl[N],dr[N];
int f1[N],f2[N];
int r,l;
ll ans,mans;
void add(int x,int y,int z)
{
w[++cnt].next=head[x];
w[cnt].qu=z;
w[cnt].to=y;
head[x]=cnt;
}
int read()
{
int f=1;
char ch;
while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
int res=ch-'0';
while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
return res*f;
}
void dfs1(int x)
{
if(v[x]) return ;
v[x]=1;
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(!v[t])
{
dfs1(t);
if(d1[t]+w[i].qu>d1[x])
{
f2[x]=f1[x];
f1[x]=f1[t];
d2[x]=d1[x];
d1[x]=d1[t]+w[i].qu;
}
else if(d1[t]+w[i].qu>d2[x]) d2[x]=d1[t]+w[i].qu,f2[x]=f1[t];
}
}
return;
}
void find_ans(int x)
{
if(v[x]) return;
v[x]=1;
if(ans<d1[x]+d2[x])
{
ans=d1[x]+d2[x];
l=f1[x];
r=f2[x];
}
for(int i=head[x];i;i=w[i].next) find_ans(w[i].to);
}
void dfs2(int x)
{
if(v[x]) return;
v[x]=1;
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(!v[t])
{
dl[t]=dl[x]+w[i].qu;
dfs2(t);
}
}
return;
}
void dfs3(int x)
{
if(v[x])return;
v[x]=1;
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(!v[t])
{
dr[t]=dr[x]+w[i].qu;
dfs3(t);
}
}
return;
}
void dfs_ans(int x)
{
if(v[x]) return;
v[x]=1;
mans=max(mans,min(dl[x],dr[x]));
for(int i=head[x];i;i=w[i].next) dfs_ans(w[i].to);
return;
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("Chris.in","r",stdin);
freopen("Chris.out","w",stdout);
#endif
n=read();
m=read();
for(int i=1;i<=m;i++)
{
int x,y,z;
x=read(),y=read(),z=read();
add(x,y,z);
add(y,x,z);
}
for(int i=1;i<=n;i++) f1[i]=i;
dfs1(1);
memset(v,0,sizeof(v));
find_ans(1);
memset(v,0,sizeof(v));
dfs2(l);
memset(v,0,sizeof(v));
dfs3(r);
memset(v,0,sizeof(v));
dfs_ans(1);
printf("%lld",ans+mans);
return 0;
}
树的中心
-
概念:给出一颗带权树,求一个节点,使得此节点到树中其他节点的最远距离最小;
-
解法:如果是一颗没有负边权的树,那直接找到直径的中点就好;
但是这里我们考虑有负边权的情况:
有两种情况:
- 从(u)点向上的最长路径,设为(up[u]);
- 从(u)点向下,即(u)到叶节点的最远距离,设为(d1[u])(次远设为(d2[u]));
(d1[u])和(d2[u])都会求,问题是(up[u])该怎么求?
还是分类讨论,设(u)的父亲为(x),(d1[x])来自于子节点(v);那对于(u):
- 如果(u!=v),那么(up[u]=max{d1[x],up[x]}+dis[x][t]);
- 如果(u==v),那么(up[u]=max{d2[x],up[x]}+dis[x][t]),这也是为什么要维护(d2[x])的原因;
code
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=1e5+5;
struct pp
{
int next,to;
}w[2*N];
int n,k;
int head[N],cnt;
int d1[N],d2[N],pre[N],u[N];
int root,far;
int read()
{
int f=1;
char ch;
while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
int res=ch-'0';
while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
return res*f;
}
void add(int x,int y)
{
cnt++;
w[cnt].next=head[x];
w[cnt].to=y;
head[x]=cnt;
}
bool cmp(int x,int y) {return x>y;}
void dfs1(int x,int fa)
{
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(t!=fa)
{
dfs1(t,x);
if(d1[t]+1>d1[x])
{
pre[x]=t;
d2[x]=d1[x];
d1[x]=d1[t]+1;
}
else if(d1[t]+1>d2[x]) d2[x]=d1[t]+1;
}
}
return;
}
void dfs2(int x,int fa)
{
int minx=min(u[x],d1[x]);
if(far<minx)
{
root=x;
far=minx;
}
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if (t!=fa)
{
if(pre[x]!=t) u[t]=max(d1[x],u[x])+1;
else u[t]=max(d2[x],u[x])+1;
dfs2(t,x);
}
}
return ;
}
int main()
{
n=read(),k=read();
for(int i=1;i<n;i++)
{
int x,y;
x=read(),y=read();
add(x,y);
add(y,x);
}
dfs1(1,0);
dfs2(1,0);
printf("%d",root);
return 0;
}
例题
P5536核心城市
- 思路:显然其中一定会有一个城市为这颗树的中心;那找出这个中心,把这颗无根树变为以它为根的有根树;再求出除根节点以外的每个节点所能到达的最大深度(deepfar[i]),这就是这个节点最远所能到达的距离;然后(sort)一下(deepfar[]),答案就是(deepfar[k+1]+1);
code
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=1e5+5;
struct pp
{
int next,to;
}w[2*N];
int n,k;
int head[N],cnt;
int d1[N],d2[N],pre[N],u[N];
int fardeep[N];
int root,far;
int read()
{
int f=1;
char ch;
while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
int res=ch-'0';
while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
return res*f;
}
void add(int x,int y)
{
cnt++;
w[cnt].next=head[x];
w[cnt].to=y;
head[x]=cnt;
}
bool cmp(int x,int y) {return x>y;}
void dfs1(int x,int fa)
{
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(t!=fa)
{
dfs1(t,x);
if(d1[t]+1>d1[x])
{
pre[x]=t;
d2[x]=d1[x];
d1[x]=d1[t]+1;
}
else if(d1[t]+1>d2[x]) d2[x]=d1[t]+1;
}
}
return;
}
void dfs2(int x,int fa)
{
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if (t!=fa)
{
if(pre[x]!=t) u[t]=max(d1[x],u[x])+1;
else u[t]=max(d2[x],u[x])+1;
dfs2(t,x);
}
}
return ;
}
void dfs3(int x,int fa)
{
int minx=min(u[x],d1[x]);
if(far<minx)
{
root=x;
far=minx;
}
for(int i=head[x];i;i=w[i].next) if(w[i].to!=fa) dfs3(w[i].to,x);
return;
}
void dfs4(int x,int fa)
{
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(t!=fa)
{
dfs4(w[i].to,x);
fardeep[x]=max(fardeep[x],fardeep[t]+1);
}
}
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("XR-3.in","r",stdin);
freopen("XR-3.out","w",stdout);
#endif
n=read(),k=read();
for(int i=1;i<n;i++)
{
int x,y;
x=read(),y=read();
add(x,y);
add(y,x);
}
dfs1(1,0);
dfs2(1,0);
dfs3(1,0);
dfs4(root,0);
sort(fardeep+1,fardeep+1+n,cmp);
printf("%d",fardeep[k+1]+1);
return 0;
}
上面都是有关树的一些经典题型,下面才是今天的主角——树型DP
背包类树型DP
(我觉得把,其实左右子树类树型DP可以归为这一类)
例题
选课
书上的是时间复杂度为(n^3)的算法,这里介绍一个优化,可以讲其降为(n^2);
-
泛化物品优化:具体是什么,请参考2009年国家集训队论文——徐持衡《浅谈几类背包问题》,其中有详细解释;
-
而我对泛化物品优化的感性理解就是:"预留空间"——为在 (u) 到到根节点的路径上(包括u)的点预留空间。
这样就可以在对 (u)DP的时候保证他所依赖的物品预先算进去了;
(dp[u][j])的意思就是在预留(u)及其到根节点的路径上的点的空间后,还剩下(j)的空间的最大价值;
-
没有优化前,DP方程为:
-
没有优化前,DP方程为:
这样对于每个节点都要(n^2)暴力枚举(j)和(k);
经过优化,我们的DP方程就变为了:
这也是再泛化物品优化下,树型背包的基本DP方程;这样我们只需要(O(n))枚举(j)就好了;
ps: 以下代码参考价值不大,建议参考[HAOI2010]软件安装
code
#include<iostream>
#include<algorithm>
#include<queue>
#include<cstdio>
#include<cstring>
using namespace std;
int n,m;
struct edge
{
int next,to;
}e[1000];
int rt,head[1000],tot,val[1000],dp[1000][1000];
void add(int x,int y)
{
e[++tot].next=head[x];
head[x]=tot;
e[tot].to=y;
}
void dfs(int u,int t)
{
if (t<=0) return ;
for (int i=head[u]; i; i=e[i].next)
{
int v = e[i].to;
for (int j=0; j<=t-1; ++j) //为v预留空间
dp[v][j] = dp[u][j];
dfs(v,t-1);//对于v的现有空间
for (int j=1; j<=t; ++j)
dp[u][j] = max(dp[u][j],dp[v][j-1]+val[v]);//背包
}
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
int a;
scanf("%d%d",&a,&val[i]);
if(a)
add(a,i);
if(!a)add(0,i);
}
dfs(0,m);
printf("%d",dp[0][m]);
}
选择类树型DP
基本DP方程:
例题
P2016战略游戏
直接套DP方程就好了;
code
#include<iostream>
#include<cstdio>
using namespace std;
int n;
int dp[1605][2];
struct pp
{
int next,to;
}w[1600<<1];
int head[1600],cnt;
void add(int x,int y)
{
cnt++;
w[cnt].to=y;
w[cnt].next=head[x];
head[x]=cnt;
}
void dfs(int x,int fa)
{
dp[x][1]=1;
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(t==fa) continue;
dfs(t,x);
dp[x][0]+=dp[t][1];
dp[x][1]+=min(dp[t][0],dp[t][1]);
}
return;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
int a,k;
scanf("%d%d",&a,&k);
for(int i=1;i<=k;i++)
{
int b;
scanf("%d",&b);
add(a,b);
add(b,a);
}
}
dfs(0,0);
printf("%d",min(dp[0][1],dp[0][0]));
return 0;
}
普通树型DP
这种树型DP更加灵活,就不像前两种有基本固定的DP方程,所以还是直接来几道例题;(滑稽
例题
LOJ #10157. 皇宫看守
乍一看题,啊哈,模板选择树型DP,开开心心打个代码,恭喜你0分;
仔细一看这道题其实不是什么没有上司的舞会,而是一道覆盖DP题,区别在哪呢?
这道题一条边两端至少要有一个点,可以有两个,而没有上司我舞会是一条边两端至多有一个点,可以没有;
那好,这样的话一个节点u的最少经费就不能像选择DP一样单纯的由儿子选不选的而转移过来,因为他们本来互不冲突,而是必须被覆盖到(这里每个节点的覆盖半径为1),这样对于一个节点u的最少经费就可以由覆盖它的节点转移过来,这样的话就需要考虑三种情况:
首先设(dp[u][0])表示被节点(u)被父亲覆盖且(u)不选,(dp[u][1])表示被自己的子节点覆盖且(u)不选,(dp[u][2])表示被自己覆盖;
所以有状态转移方程:
- 对于(dp[u][0]),因为(u)不选,所以对于(u)的子节点(v),要么被(son(v))所覆盖,要么被(v)自己覆盖:
-
对于(dp[u][1]),要保证(u)必须被一个子节点所覆盖到,还要保证(u)的子节点(v)在不被父亲覆盖的前提下被覆盖到,那显然(dp[u][1]),是由(dp[v][1])和(dp[v][2])转移过来的,但是如何保证(dp[u][1])的转移中一定包含(dp[v][2])呢?
这时候有个巧妙的办法,设个参数:
[d=min{d,dp[v][2]-min{dp[v][1],dp[v][2]}} ](d)的初始值为(0x7fffffff);
这样对于(dp[u][1])就有状态转移方程:
[dp[u][1]=sum min{dp[v][1],dp[v][2]}+d ] -
对于(dp[u][2]),那很显然它可以由子节点任意三种状态转移过来,但是对于(dp[v][0]),它已经加过一遍(a[u]),而对于(dp[u][2]),只能且必须加一遍(a[u]),那怎么办呢?单独特判由(dp[v][0])转移过来的情况,控制(a[u])只加一遍?显然是可以的,但是太麻烦了,那么另外考虑,这里可以看到(dp[v][0])只会往(dp[u][2])上转移,那么可以根据(dp[u][2])需求对(dp[v][0])状态转移方程改一改:
[dp[u][0]=sum min{dp[v][1],dp[v][2]} ](这里的(u)是对于(v)来说的)
感性理解一下就是如果(dp[u][2])不由(dp[v][0])转移过来那要(dp[v][0])也没有什么用,那由(dp[v][0])转移过来,那在(dp[u][2])这加一遍(a[u])就够了,因为(dp[u][2])已经保证了(u)被选,所以不需要(dp[v][0])再保证一遍;
这样对于(dp[u][2]),就有状态转移方程:
[dp[u][2]=sum min{dp[v][1],dp[v][2],dp[v][0]} +a[u] ]
总结下来就有三个状态转移方程:
(所以,显然书上的状态转移方程是错的)
不难发现,修改后的(dp[v][0])一定小于等于(dp[v][1]);所以写代码的时候我顺手把(dp[u][2])的转移方程改成了:
虽然题目早已经解决了,但我还是想再深究一下;这个方程啥意思?
以我的感性理解就是(v)既然已经一定会被它爹(u)覆盖到,那就可以不需要保证(v)一定被它的儿子所覆盖,修改后的(dp[v][0])刚好就是这种情况;
(好了,bb了这么多废话,就一点有用的东西,直接上代码吧)
code
#include <iostream>
#include <cstdio>
using namespace std;
const int N = 1500 + 5;
int dp[N][3];
int v[N], n, root;
struct pp {
int next, to;
} w[N];
int head[N], cnt, du[N];
void add(int x, int y) {
cnt++;
w[cnt].next = head[x];
w[cnt].to = y;
head[x] = cnt;
}
void dfs(int x) {
int d = 0x7fffffff;
for (int i = head[x]; i; i = w[i].next) {
int t = w[i].to;
dfs(t);
dp[x][0] += min(dp[t][1], dp[t][2]);
dp[x][1] += min(dp[t][1], dp[t][2]);
d = min(d, dp[t][2] - min(dp[t][1], dp[t][2]));
dp[x][2] += min(dp[t][2], dp[t][0]);
}
dp[x][1] += d;
dp[x][2] += v[x];
}
int main() {
#ifndef ONLINE_JUDGE
freopen("guard.in", "r", stdin);
freopen("guard.out", "w", stdout);
#endif
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
int x, m;
scanf("%d", &x);
scanf("%d", &v[x]);
scanf("%d", &m);
for (int j = 1; j <= m; j++) {
int y;
scanf("%d", &y);
add(x, y);
du[y]++;
}
}
for (int i = 1; i <= n; i++)
if (!du[i])
root = i;
dfs(root);
printf("%d", min(dp[root][1], dp[root][2]));
return 0;
}
好了,差不多就结束了,虽然写这个一点耗时,但对于我这个蒟蒻来说加深了对于DP的理解,收获也不小,也不算浪费时间了吧(逃);
PS: 2020.10.9 添加了我对泛化物品优化的理解