如果对于每个询问跑一次$dp$,那么$dp[i]$为断开$i$这棵子树的最小花费。
这样的复杂度为$O(n*m)$,过于臃肿。
所以我们要对于每次询问降低这次询问的复杂度。
我们可以发现$m$个关键点,最多有$m-1$个$lca$。
简单证明一下,如果有两个点,会有$1$个$lca$点,如果有三个点,则第三个点会和上一个$lca$产生一个$lca$。
所以以这$2*m-1$个点构建一棵树,在这个树上跑$dp$
虚树的构建推荐一个巨巨的博客
1 #include <bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int maxn = 250010; 5 const ll inf = 2e18 + 10; 6 struct node { 7 int s, e, next; 8 ll w; 9 }edge[maxn * 2]; 10 int head[maxn], len; 11 void init() { 12 memset(head, -1, sizeof(head)); 13 len = 0; 14 } 15 void add(int s, int e, ll w) { 16 edge[len] = { s,e,head[s],w }; 17 head[s] = len++; 18 } 19 int fat[maxn], son[maxn], siz[maxn], top[maxn], tid[maxn], dep[maxn], dfx; 20 ll a[maxn], st[maxn], dp[maxn], Min[maxn], stop; 21 vector<int>mp[maxn]; 22 void dfs1(int x, int fa, int d) { 23 siz[x] = 1, fat[x] = fa; 24 dep[x] = d, son[x] = -1; 25 for (int i = head[x]; i != -1; i = edge[i].next) { 26 int y = edge[i].e; 27 if (y == fa)continue; 28 Min[y] = min(Min[x], edge[i].w); 29 dfs1(y, x, d + 1); 30 siz[x] += siz[y]; 31 if (son[x] == -1 || siz[son[x]] < siz[y]) 32 son[x] = y; 33 } 34 } 35 void dfs2(int x, int c) { 36 top[x] = c; 37 tid[x] = ++dfx; 38 if (son[x] == -1)return; 39 dfs2(son[x], c); 40 for (int i = head[x]; i != -1; i = edge[i].next) { 41 int y = edge[i].e; 42 if (fat[x] == y || y == son[x])continue; 43 dfs2(y, y); 44 } 45 } 46 int LCA(int x, int y) { 47 while (top[x] != top[y]) { 48 if (dep[top[x]] < dep[top[y]]) 49 swap(x, y); 50 x = fat[top[x]]; 51 } 52 if (dep[x] > dep[y])swap(x, y); 53 return x; 54 55 } 56 bool cmp(int x, int y) { 57 return tid[x] < tid[y]; 58 } 59 void dfs3(int x) { 60 if (mp[x].size() == 0) { 61 dp[x] = Min[x]; 62 return; 63 } 64 ll sum = 0; 65 for (int i = 0; i < mp[x].size(); i++) { 66 int y = mp[x][i]; 67 dfs3(y); 68 sum += dp[y]; 69 } 70 mp[x].clear(); 71 dp[x] = min(Min[x], sum); 72 } 73 void insert(int x) { 74 if (stop == 1) { 75 st[++stop] = x; 76 return; 77 } 78 int lca = LCA(x, st[stop]); 79 if (lca == st[stop]) 80 return; 81 while (stop > 1 && tid[st[stop - 1]] >= tid[lca]) 82 mp[st[stop - 1]].push_back(st[stop]), stop--; 83 if (lca != st[stop]) 84 mp[lca].push_back(st[stop]), st[stop] = lca; 85 st[++stop] = x; 86 } 87 int main() { 88 init(); 89 int n, m, t, x, y, z; 90 scanf("%d", &n); 91 for (int i = 1; i < n; i++) { 92 scanf("%d%d%d", &x, &y, &z); 93 add(x, y, z); 94 add(y, x, z); 95 } 96 Min[1] = inf; 97 dfs1(1, 0, 1); 98 dfs2(1, 1); 99 scanf("%d", &m); 100 while (m--) { 101 scanf("%d", &t); 102 for (int i = 1; i <= t; i++) 103 scanf("%d", &a[i]); 104 sort(a + 1, a + 1 + t, cmp); 105 st[stop = 1] = 1; 106 for (int i = 1; i <= t; i++) 107 insert(a[i]); 108 while (stop > 1) 109 mp[st[stop - 1]].push_back(st[stop]), stop--; 110 dfs3(1); 111 printf("%lld ", dp[1]); 112 } 113 }