题面
题解
缩点+树形dp
依赖关系可以看作有向边
因为有环,先缩点
缩点后,有可能图不联通.
我们可以新建一个结点连接每个联通块.
然后就是树形dp了
Code
#include<bits/stdc++.h>
#define LL long long
#define RG register
using namespace std;
template<class T> inline void read(T &x) {
x = 0; RG char c = getchar(); bool f = 0;
while (c != '-' && (c < '0' || c > '9')) c = getchar(); if (c == '-') c = getchar(), f = 1;
while (c >= '0' && c <= '9') x = x*10+c-48, c = getchar();
x = f ? -x : x;
return ;
}
template<class T> inline void write(T x) {
if (!x) {putchar(48);return ;}
if (x < 0) x = -x, putchar('-');
int len = -1, z[20]; while (x > 0) z[++len] = x%10, x /= 10;
for (RG int i = len; i >= 0; i--) putchar(z[i]+48);return ;
}
int n, m;
const int N = 110, M = 510;
int ww[N], vv[N], x[N], y[N], tot;
struct node {
int to, next;
}g[N<<1];
int last[N], gl;
void add(int x, int y) {
g[++gl] = (node) {y, last[x]};
last[x] = gl;
}
int dfn[N], low[N], cnt, w[N], v[N], js, color[N];
stack<int> s;
bool vis[N];
void tarjan(int x) {
vis[x] = 1; s.push(x);
dfn[x] = low[x] = ++cnt;
for (int i = last[x]; i; i = g[i].next) {
int v = g[i].to;
if (!dfn[v]) tarjan(v), low[x] = min(low[x], low[v]);
else if (vis[v]) low[x] = min(low[x], dfn[v]);
}
if (dfn[x] == low[x]) {
js++;
while (vis[x]) {
w[js] += ww[s.top()];
v[js] += vv[s.top()];
vis[s.top()] = 0;
color[s.top()] = js;
s.pop();
}
}
return ;
}
int f[N][M];
void dfs(int u, int fa) {
for (int i = last[u]; i; i = g[i].next) {
int v = g[i].to;
if (v == fa) continue;
dfs(v, u);
}
if (w[u] > m) return ;
f[u][w[u]] = v[u];
for (int i = last[u]; i; i = g[i].next) {
int v = g[i].to;
if (v == fa) continue;
for (int j = m; j > w[u]; j--)
for (int k = 0; k <= j-w[u]; k++) //必须选u点,所以是j-w[u]
f[u][j] = max(f[u][j], f[u][j-k]+f[v][k]);
}
return ;
}
int main() {
read(n); read(m);
for (int i = 1; i <= n; i++) read(ww[i]);
for (int i = 1; i <= n; i++) read(vv[i]);
for (int i = 1; i <= n; i++) {
int z; read(z);
if (!z) continue;
x[++tot] = z; y[tot] = i;
add(z, i);
}
for (int i = 1; i <= n; i++)
if (!dfn[i])
tarjan(i);
memset(last, 0, sizeof(last));
memset(vis, 0, sizeof(vis));
gl = 0;
for (int i = 1; i <= tot; i++)
if (color[x[i]] != color[y[i]]) {
add(color[x[i]], color[y[i]]), vis[color[y[i]]] = 1;
//printf("%d %d
", color[x[i]], color[y[i]]);
}
for (int i = 1; i <= js; i++)
if (!vis[i]) {
add(js+1, i);
}
dfs(js+1, 0);
int ans = 0;
for (int i = 0; i <= m; i++)
ans = max(ans, f[js+1][i]);
printf("%d
", ans);
return 0;
}