题意:
有n个点的一棵树。其中树上有m条已知的链,每条链有一个权值。从中选出任意个不相交的链使得链的权值和最大。
思路:
树形DP。设dp[i]表示i的子树下的最优权值和,sum[i]表示不考虑i点时子树的最优权值和,即(j是i的儿子),显然dp[i]>=sum[i]。那么问题是考虑i点时dp[i]的值是多少,假设有一条链通过i,且端点a和b都在i的子树里,即LCA(a,b)=i,如果考虑加上这条链的权值,那么a->i, b->i的路上的点v都不能有链经过它们(题目要求链不相交),那么-dp[v],但至少有sum[v],即,其中v是某条链上的点。那么怎么快速求出sigma的值呢,想到树状数组维护前缀和。那么怎么遍历呢,用DFS序遍历,思想和“粮食分配”一样,在L[v]上修改,在R[v]上恢复。
#include <bits/stdc++.h> using namespace std; const int N = 1e5 + 5; const int D = 20; struct Chain { int u, v, w; }; vector<Chain> chains[N]; vector<int> edges[N]; int dp[N], sum[N]; int n, m; int tim; void init() { for (int i=1; i<=n; ++i) { edges[i].clear (); chains[i].clear (); } } int rt[N][D], dep[N]; void init_LCA() { for (int j=1; j<D; ++j) { for (int i=1; i<=n; ++i) { rt[i][j] = rt[i][j-1] ? rt[rt[i][j-1]][j-1] : 0; } } } int LCA(int u, int v) { if (dep[u] < dep[v]) swap (u, v); for (int i=0; i<D; ++i) { if ((dep[u] - dep[v]) >> i & 1) { u = rt[u][i]; } } if (u == v) return u; for (int i=D-1; i>=0; --i) { if (rt[u][i] != rt[v][i]) { u = rt[u][i]; v = rt[v][i]; } } return rt[u][0]; } struct BIT { int C[N]; int n; void init(int n) { this->n = n; memset (C, 0, sizeof (C)); } void updata(int i, int x) { for (; i<=n; i+=i&-i) C[i] += x; } int query(int i) { int ret = 0; for (; i>0; i-=i&-i) ret += C[i]; return ret; } }bsum, bdp; int L[N], R[N]; void DFS(int u, int pa) { L[u] = tim++; dep[u] = dep[pa] + 1; rt[u][0] = pa; for (auto v: edges[u]) { if (v == pa) continue; DFS (v, u); } R[u] = tim; } void DFS(int u) { sum[u] = 0; for (auto v: edges[u]) { if (v == rt[u][0]) continue; DFS (v); sum[u] += dp[v]; } dp[u] = sum[u]; for (auto chain: chains[u]) { int a = chain.u, b = chain.v, c = chain.w; int tmp = bsum.query (L[a]) - bdp.query (L[a]) + bsum.query (L[b]) - bdp.query (L[b]); dp[u] = max (dp[u], sum[u] + tmp + c); } bsum.updata (L[u], sum[u]); bsum.updata (R[u], -sum[u]); bdp.updata (L[u], dp[u]); bdp.updata (R[u], -dp[u]); } void prepare() { dep[0] = 0; tim = 1; DFS (1, 0); init_LCA (); bsum.init (n); bdp.init (n); } int main() { int T; scanf ("%d", &T); while (T--) { init (); scanf ("%d%d", &n, &m); for (int i=1; i<n; ++i) { int u, v; scanf ("%d%d", &u, &v); edges[u].push_back (v); edges[v].push_back (u); } prepare (); for (int i=1; i<=m; ++i) { int u, v, w; scanf ("%d%d%d", &u, &v, &w); int lca = LCA (u, v); chains[lca].push_back ((Chain) {u, v, w}); } DFS (1); printf ("%d ", dp[1]); } return 0; }