非常好的一道题,可以说是树形dp的一道基础题
首先不难发现,:如果我们把有关系的两个点用有向边相连,那么就会形成一个接近树的结构。如果这是一棵完美的树,我们就可以直接在树上打背包了
但是这并不是一棵完美的树,甚至并不是一棵树,因为:
首先,由于题中有n个点,还有n条边,所以有很大的几率出现环!
而且,如果出现了环,那么很有可能整个图并不连通,这样一来根本无法跑dp
所以我们要采取一些策略:
首先,对于出现环的情况,根据题意,此时环中的所有点要么都选,要么都不选,所以我们可以进行tarjan缩点,然后在新图上进行dp
至于整个图不连通的情况,我们可以虚拟一个超级原点向所有入度为0的点连边,这样就可以形成一棵真正的树,这样跑树形dp就可以了。
#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <queue>
#include <stack>
using namespace std;
struct Edge
{
int next;
int to;
}edge[105],e[105];
int head[105];
int h[105];
int cot=1;
int cnt=1;
void init()
{
memset(head,-1,sizeof(head));
memset(h,-1,sizeof(h));
cnt=1;
cot=1;
}
void adde(int l,int r)
{
e[cot].next=h[l];
e[cot].to=r;
h[l]=cot++;
}
void add(int l,int r)
{
edge[cnt].next=head[l];
edge[cnt].to=r;
head[l]=cnt++;
}
int n,m;
int w[105];
int v[105];
int nv[105];
int nw[105];
int posi[105];
int src_cnt=0;
int src_num[105];
int my_stack[105];
int dfn[105];
int low[105];
int dp[105][505];
int tot=0;
int ttop=0;
bool used[105];
int s[105];
void tarjan(int rt)
{
my_stack[++ttop]=rt;
dfn[rt]=low[rt]=++tot;
for(int i=head[rt];i!=-1;i=edge[i].next)
{
int to=edge[i].to;
if(!dfn[to])
{
tarjan(to);
low[rt]=min(low[rt],low[to]);
}else if(!posi[to])
{
low[rt]=min(low[rt],dfn[to]);
}
}
if(dfn[rt]==low[rt])
{
int t=0;
src_cnt++;
while(t!=rt)
{
t=my_stack[ttop--];
posi[t]=src_cnt;
src_num[src_cnt]++;
nv[src_cnt]+=v[t];
nw[src_cnt]+=w[t];
}
}
}
void dfs(int x)
{
s[x]=nw[x];
dp[x][nw[x]]=nv[x];
for(int i=h[x];i!=-1;i=e[i].next)
{
int to=e[i].to;
dfs(to);
for(int j=s[x];j>=nw[x];j--)
{
for(int k=0;k<=s[to];k++)
{
if(j+k>m)
{
break;
}
dp[x][j+k]=max(dp[x][j+k],dp[x][j]+dp[to][k]);
}
}
s[x]+=s[to];
}
}
int main()
{
scanf("%d%d",&n,&m);
init();
for(int i=1;i<=n;i++)
{
scanf("%d",&w[i]);
}
for(int i=1;i<=n;i++)
{
scanf("%d",&v[i]);
}
for(int i=1;i<=n;i++)
{
int f;
scanf("%d",&f);
if(f)
{
add(f,i);
}
}
for(int i=1;i<=n;i++)
{
if(!dfn[i])
{
tarjan(i);
}
}
for(int i=1;i<=n;i++)
{
for(int j=head[i];j!=-1;j=edge[j].next)
{
int to=edge[j].to;
if(posi[to]!=posi[i])
{
adde(posi[i],posi[to]);
used[posi[to]]=1;
}
}
}
for(int i=1;i<=src_cnt;i++)
{
if(!used[i])
{
adde(0,i);
}
}
dfs(0);
int ans=0;
for(int i=0;i<=m;i++)
{
ans=max(ans,dp[0][i]);
}
printf("%d
",ans);
return 0;
}