题目大意:给一棵$n(nleqslant10^5)$个点的树,有$q(qleqslant10^5)$次询问,每次询问给出$k,m,r$表示把以下$k$个点分成不超过$m$组,使得在以$r$为根的情况下,组内的任意两个结点不存在祖先关系。$sum kleqslant10^5,mleqslant300$
题解:针对一次询问,可以想到一个$DP$,$f_{i,j}$表示处理到第$i$个点(假设为$u$),$u$的祖先都已经处理完,这$i$个点放到$j$组的方案数。$f_{i,j}=f_{i-1,j-1}+f_{i-1,j-father_u}$,$father_u$表示这$k$个点中$u$的祖先的个数。
现在考虑如何处理$DP$顺序,发现可以深度由浅到深处理,$dfs$序也是可以的。而后再考虑如何处理$father$,我想到了几种写法:
- 用树状数组,给子树加,然后查询需要的点,这种写法因为换根需要分类讨论,复杂度$O(klog_2n)$。
- 树剖后单点加,询问需要的点到根的和(这种写法在写之前被我以为需要分类讨论),复杂度$O(klog_2^2n)$。
- 建虚树,直接$dfs$,复杂度$O(klog_2k)$。
我写的时候不想分类讨论,就选择了第$3$种,在我交的时候成功跑到$rank$倒一。写题解的时候突然发现这种写法似乎复杂度最优秀?自带大常数
卡点:树剖求$LCA$时$top$更新错
C++ Code:
#include <algorithm> #include <cstdio> #include <iostream> #define maxn 100010 const int mod = 1e9 + 7; inline void reduce(int &x) { x += x >> 31 & mod; } int head[maxn], cnt; struct Edge { int to, nxt; } e[maxn << 1]; inline void addedge(int a, int b) { e[++cnt] = (Edge) { b, head[a] }; head[a] = cnt; e[++cnt] = (Edge) { a, head[b] }; head[b] = cnt; } namespace Tree { int sz[maxn], dep[maxn], fa[maxn]; int dfn[maxn], out[maxn], top[maxn], idx; #define f top int find(int x) { return x == f[x] ? x : (f[x] = find(f[x])); } #undef f void dfs(int u) { int son = 0; top[u] = u; dfn[u] = ++idx, sz[u] = 1; for (int i = head[u], v; i; i = e[i].nxt) { v = e[i].to; if (v != fa[u]) { dep[v] = dep[u] + 1, fa[v] = u; dfs(v), sz[u] += sz[v]; if (sz[v] > sz[son]) son = v; } } out[u] = idx; if (son) top[son] = u; } inline int LCA(int x, int y) { while (top[x] != top[y]) { if (dep[top[x]] > dep[top[y]]) x = fa[top[x]]; else y = fa[top[y]]; } return dep[x] < dep[y] ? x : y; } } using Tree::dfn; using Tree::out; int n, q, k, m, r; int Mark[maxn], f[305], idx; inline bool cmp(int x, int y) { return dfn[x] < dfn[y]; } void dfs(int u, int fa = 0, int num = 0) { if (Mark[u]) { ++idx; for (int i = std::min(m, idx); ~i; --i) { if (i > num) f[i] = (static_cast<long long> (i - num) * f[i] + f[i - 1]) % mod; else f[i] = 0; } ++num; } for (int i = head[u], v; i; i = e[i].nxt) { v = e[i].to; if (v != fa) dfs(v, u, num); } head[u] = Mark[u] = 0; } void solve() { static int List[maxn << 1], tot, S[maxn], top; std::cin >> k >> m >> r; for (int i = 0; i < k; ++i) { std::cin >> List[i]; Mark[List[i]] = 1; } tot = k; if (!Mark[r]) List[tot++] = r; std::sort(List, List + tot, cmp); for (int i = tot - 1; i; --i) List[tot++] = Tree::LCA(List[i], List[i - 1]); tot = (std::sort(List, List + tot, cmp), std::unique(List, List + tot) - List); top = 0; for (int I = 0, i = *List; I < tot; i = List[++I]) { while (top && out[S[top]] < dfn[i]) --top; if (top) addedge(S[top], i); S[++top] = i; } f[idx = 0] = 1, dfs(r); int ans = 0; for (int i = 1; i <= m; ++i) reduce(ans += f[i] - mod); std::cout << ans << ' '; cnt = 0; __builtin_memset(f, 0, m + 1 << 2); } int main() { std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout.tie(0); std::cin >> n >> q; for (int i = 1, a, b; i < n; ++i) { std::cin >> a >> b; addedge(a, b); } Tree::dfs(1); for (int i = 1; i <= n; ++i) Tree::find(i); __builtin_memset(head, 0, n + 1 << 2), cnt = 0; while (q --> 0) solve(); return 0; }