题目大意
给出一个有根树,(1)为根,若某个节点的儿子全是叶子,你可以将该节点的儿子全部剪掉,这样的操作可以进行多次。定义这棵树的价值为:将树上所有叶子按照(dfs)序排序后,所有叶子点权之和-相邻两叶子路径上点权最大值。现在你要通过剪枝使得这棵树价值最大。
(nleq 100000)
分析
设(f_i)表示(i)作为最后一个叶子时的最大价值。暴力枚举原树(没有剪枝)相邻的两个叶子,显然左链上每个点的(f)都可以转移到右链上,我们暴力处理出这条路径,分类讨论点权最大值在左链还是右链,进行状态转移。可以发现,一条边被枚举的次数最多是(2),一次是从上个子树进入这个子树,一次是从这个子树进入下个子树,所以暴力枚举复杂度其实是(O(n))的,就可以随便做了。
重点!!!!!!
暴力枚举一棵树相邻两叶子的路径复杂度是(O(n))。
Code
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 100007;
inline int read()
{
int x = 0, f = 0;
char c = getchar();
for (; c < '0' || c > '9'; c = getchar()) if (c == '-') f = 1;
for (; c >= '0' && c <= '9'; c = getchar()) x = (x << 1) + (x << 3) + (c ^ '0');
return f ? -x : x;
}
int n, dfn, tot, ans, w[N], p[N], fa[N], dep[N], ord[N], leaf[N], f[N];
int len, arr[N], mx1[N], mx2[N];
vector<int> son[N];
void dfs(int u)
{
ord[++dfn] = u;
int sz = son[u].size();
for (int i = 0; i < sz; i++) fa[son[u][i]] = u, dep[son[u][i]] = dep[u] + 1, dfs(son[u][i]);
}
int getlca(int u, int v)
{
if (dep[u] < dep[v]) swap(u, v);
while (dep[fa[u]] >= dep[v]) u = fa[u];
if (u == v) return u;
while (fa[u] != fa[v]) u = fa[u], v = fa[v];
return fa[u];
}
int main()
{
//freopen("cut.in", "r", stdin);
n = read();
for (int i = 1; i <= n; i++)
{
w[i] = read(), p[i] = read();
for (int j = 1, a; j <= p[i]; j++) a = read(), son[i].push_back(a);
}
dep[1] = 1, dfs(1);
for (int i = 1; i <= n; i++) if (!son[ord[i]].size()) leaf[++tot] = ord[i];
for (int i = 1; i <= n; i++) f[i] = -0x3f3f3f3f;
int x = leaf[1];
while (x != 1) f[x] = w[x], x = fa[x];
for (int i = 2; i <= tot; i++)
{
int a = leaf[i - 1], b = leaf[i], c = getlca(a, b), dist = dep[a] + dep[b] - 2 * dep[c] + 1;
arr[len = 1] = a; while (fa[a] != c) a = fa[a], arr[++len] = a;
arr[++len] = c;
arr[dist--] = b; while (fa[b] != c) b = fa[b], arr[dist--] = b;
mx1[len] = w[c], dist = dep[leaf[i - 1]] + dep[leaf[i]] - 2 * dep[c] + 1;
for (int j = len - 1; j >= 1; j--) mx1[j] = max(mx1[j + 1], w[arr[j]]);
for (int j = len + 1; j <= dist; j++) mx1[j] = max(mx1[j - 1], w[arr[j]]);
mx2[0] = -0x3f3f3f3f; for (int j = 1; j <= len - 1; j++) mx2[j] = max(mx2[j - 1], f[arr[j]] - mx1[j + 1]);
for (int j = len + 1, k = len, maxf = -0x3f3f3f3f; j <= dist; j++)
{
while (k > 1 && mx1[k] <= mx1[j - 1]) k--, maxf = max(maxf, f[arr[k]]);
f[arr[j]] = max(f[arr[j]], maxf - mx1[j - 1] + w[arr[j]]);
if (k > 1) f[arr[j]] = max(f[arr[j]], mx2[k - 1] + w[arr[j]]);
}
}
x = leaf[tot];
while (x != 1) ans = max(ans, f[x]), x = fa[x];
printf("%d
", ans);
return 0;
}