1116考试T1
题目大意:
哺噜国里有 n 个城市,有的城市之间有高速公路相连。在最开始时,哺噜国里有 n-1条高速公路,且任意两座城市之间都存在一条由高速公路组成的通路。由于高速公路的维护成本很高,为了减少哺噜国的财政支出,将更多的钱用来哺育小哺噜,秀秀女王决定关闭一些高速公路。但是为了保证哺噜国居民的正常生活,不能关闭太多的高速公路,要保证每个城市可以通过高速公路与至少 k 个城市(包括自己)相连。
在得到了秀秀女王的指令后,交通部长华华决定先进行预调研。华华想知道在满足每个城市都可以与至少 k 个城市相连的前提下,有多少种关闭高速公路的方案(可以一条也不关)。两种方案不同,当且仅当存在一条高速公路在一个方案中被关闭,而在另外一个方案中没有被关闭。
由于方案数可能很大,你只需输出不同方案数对 786433 取模后的结果即可。 (n <= 5000, k <= n)
树形DP.
(f[i][j])表示以(i)为根的子树内, (i)所在的联通块大小为(j)的方案数.
假设当前在以(x)为根的子树内, 现在要把子树(y)的答案合并上, 那么转移一下:
(f[x][i + j] = f[x][i] * f[y][j]) 表示链接((x, y))这条边.
(f[x][i] = f[x][i] * f[y][j] (j >= k)) 表示断开((x, y))这条边.
初值(f[x][1] = 1).
这么枚举总的复杂度是(O(n^ 3))的, 我们需要优化.
我们发现(i, j)不必枚举到(n), (i)只需枚举到(displaystyle sum_{d = 1}^{p - 1} siz[d]), (j)只需枚举到(siz[p]), (p)是当前要合并的子树, 那么就有(p - 1)个已经合并过的子树.
经分析复杂度是(O(n ^ 2))的.
分析(题解原话) : 对于点u,枚举i,j总次数等于以点u为根的子树中选取无序点对,使得它们的lca为u的无序点对数。树上每个点对只会在lca处算一次。
#include <bits/stdc++.h>
using namespace std;
inline long long read() {
long long s = 0, f = 1; char ch;
while(!isdigit(ch = getchar())) (ch == '-') && (f = -f);
for(s = ch ^ 48;isdigit(ch = getchar()); s = (s << 1) + (s << 3) + (ch ^ 48));
return s * f;
}
const int N = 5005, mod = 786433;
int n, K, cnt, ans;
int f[N][N], siz[N], head[N];
struct edge { int to, nxt; } e[N << 1];
void add(int x, int y) {
e[++ cnt].nxt = head[x]; head[x] = cnt; e[cnt].to = y;
}
void get_tree(int x, int fa) {
siz[x] = 1;
for(int i = head[x]; i ; i = e[i].nxt) {
int y = e[i].to; if(y == fa) continue;
get_tree(y, x); siz[x] += siz[y];
}
}
void get_f(int x, int fa) {
f[x][1] = 1;
int lim = 1, h[siz[x] + 1];
for(int i = head[x]; i ; i = e[i].nxt) {
int y = e[i].to; if(y == fa) continue;
get_f(y, x);
for(int j = 1;j <= lim + siz[y]; j++) h[j] = 0;
for(int j = 1;j <= lim; j++)
for(int k = 1;k <= siz[y]; k++) {
h[j + k] += 1ll * f[x][j] * f[y][k] % mod;
if(k >= K) h[j] += 1ll * f[x][j] * f[y][k] % mod;
}
lim += siz[y];
for(int j = 1;j <= lim; j++) f[x][j] = h[j];
}
}
int main() {
n = read(); K = read();
for(int i = 1, x, y;i < n; i++)
x = read(), y = read(), add(x, y), add(y, x);
get_tree(1, 0); get_f(1, 0);
for(int i = K;i <= siz[1]; i++) ans = (ans + f[1][i]) % mod;
printf("%d", ans);
return 0;
}