题意:给你一颗树,有q次询问,每次询问给你若干个点,这些点可以最多分出m组,每组要满足两个条件:1:每组至少一个点,2:组内的点不能是组内其它点的祖先,问这样的分组能有多少个?
思路:https://blog.csdn.net/BUAA_Alchemist/article/details/86765501
代码:
#include <bits/stdc++.h> #define LL long long #define lowbit(x) (x & (-x)) using namespace std; const LL mod = 1000000007; const int maxn = 100010; vector<int> G[maxn]; vector<int> a; int dfn[maxn], sz[maxn], tot, t; LL dp[maxn][310]; int n; void add(int x, int y) { G[x].push_back(y); G[y].push_back(x); } void dfs(int x, int fa) { dfn[x] = ++tot; sz[x] = 1; for (auto y : G[x]) { if(y == fa) continue; dfs(y, x); sz[x] += sz[y]; } } queue<int> q; int dep[maxn], f[maxn][20]; void bfs() { q.push(1); dep[1] = 1; while(q.size()) { int x = q.front(); q.pop(); for (auto y : G[x]) { if(dep[y]) continue; dep[y] = dep[x] + 1; //dis[y] = dis[x] + 1; f[y][0] = x; for (int j = 1; j <= t; j++) f[y][j] = f[f[y][j - 1]][j - 1]; q.push(y); } } } int lca(int x, int y) { if(dep[x] > dep[y]) swap(x, y); for (int i = t; i >= 0; i--) if(dep[f[y][i]] >= dep[x]) y = f[y][i]; if(x == y) return y; for (int i = t; i >= 0; i--) if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i]; return f[x][0]; } struct BIT { int c[maxn]; int ask(int x) { int ans = 0; for(; x; x -= lowbit(x)) ans += c[x]; return ans; } void add(int x, int y) { for(; x <= n; x += lowbit(x)) c[x] += y; } }; BIT tr; int h[maxn], vis[maxn]; int main() { int u, v, T; scanf("%d%d", &n, &T); t = (int)(log(n) / log(2)) + 1; for (int i = 1; i < n; i++) { scanf("%d%d", &u, &v); add(u, v); } dfs(1, -1); bfs(); int k, m, r, x; LL ans = 0; while(T--) { scanf("%d%d%d",&k, &m, &r); ans = 0; for (int i = 1; i <= k; i++) { scanf("%d", &x); vis[x] = 1; a.push_back(x); tr.add(dfn[x], 1); tr.add(dfn[x] + sz[x], -1); } for (int i = 0; i < k; i++) { int LCA = lca(a[i], r); h[i + 1] = tr.ask(dfn[a[i]]) + tr.ask(dfn[r]) - 2 * tr.ask(dfn[LCA]) + vis[LCA] - 1; } sort(h + 1, h + 1 + k); dp[0][0] = 1; for (int i = 1; i <= k; i++) for (int j = 0; j <= min(i, m); j++) { if(j > 0) dp[i][j] = (LL)((LL)dp[i - 1][j - 1] + ((LL)dp[i - 1][j] * max(0, j - h[i])) % mod) % mod; } for (int i = 1; i <= m; i++) ans = (ans + dp[k][i]) % mod; printf("%lld ", ans); for (int i = 0; i < k; i++) { tr.add(dfn[a[i]], -1); tr.add(dfn[a[i]] + sz[a[i]], 1); vis[a[i]] = 0; } a.clear(); } }