链接:https://ac.nowcoder.com/acm/contest/6226/C
来源:牛客网
题目描述
修修去年种下了一棵树,现在它已经有n个结点了。
修修非常擅长数数,他很快就数出了包含每个点的连通点集的数量。
澜澜也想知道答案,但他不会数数,于是他把问题交给了你。
输入描述:
第一行一个整数n (1≤ n ≤ 106),接下来n-1行每行两个整数ai,bi表示一条边 (1≤ ai,bi≤ n)。
输出描述:
输出n行,每行一个非负整数。第i行表示包含第i个点的连通点集的数量对109+7取模的结果。
示例1
输入
6
1 2
1 3
2 4
4 5
4 6
输出
12
15
7
16
9
9
题解
拿到题画了画, 明白了啥是联通点集, 即所有点(包含k点)和边连成一棵树, 即以k点为根的子树
那解样例, 自然发现读于根u的答案是通过子树统计的
ans[u] = 1(只有自己)
for v 子树
ans[u] = ans[u] + ans[v] * ans[u](子树答案和当前根u的答案组合)
= ans[u] * (ans[v] + 1)
这样就只能求根的答案
但是我们发现, 对于子树的根, 就来自父节点的贡献没算!!
这不就是换根dp吗?
当子树的根做主根的时候, 那么原根的答案就为(把这棵子树当作是最后对原根贡献的子树)
//ans[fa] = ans[fa] * (ans[u] + 1) (把这棵子树当作是最后对原根贡献的子树)
//则这棵子树没贡献之前的 ans[fa]为
res = ans[fa] / (ans[u] + 1)
那么, 就可以直接算子树根的答案了
ans[u] = ans[u] + res * ans[u]
= ans[u] * (ans[fa] / (ans[u] + 1) + 1)
= ans[u] * (ans[fa] * power(ans[u] + 1, mod - 2) + 1)
//注意当ans[fa] = mod - 1, 就没有意义,至于为什么要提, 接下来说
这不就可以换根, 两个dfs ac了吗?
呵呵, 这题要取模的, 就存在ans[i] = 0的情况
那么谁让ans[i] == 0了呢?
ans[u] = ans[u] * (ans[v] + 1)
显然是子树 ans[v] == mod - 1, 我们记 cnt[i] 表示以 i为根的树含 ans[j] == mod - 1的数量
(且当子树为零树时, 我们就只标记, 就不算ans[u]了, 就不会出现除数为0的情况)
(且换根时, 由于ans[u] == mod - 1, ans[u]就没对ans[fa]贡献)
对于 cnt[i] > 0 的 ans[i] = 0, 但是我们在换根的时候就要考虑了,
父根cnt[fa] > 0, 则就没贡献, 但是如果 cnt[fa] = 1, 且ans[u] == mod - 1 呢?
没错我们要考虑父根只有一颗零树, 且当前子树就是这棵时的情况
那么就处理完了
具体看代码
代码
#include <bits/stdc++.h>
#define all(n) (n).begin(), (n).end()
#define se second
#define fi first
#define pb push_back
#define mp make_pair
#define sqr(n) (n)*(n)
#define rep(i,a,b) for(int i=a;i<=(b);++i)
#define per(i,a,b) for(int i=a;i>=(b);--i)
#define IO ios::sync_with_stdio(0);cin.tie(0)
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
typedef pair<ll, ll> PLL;
typedef vector<int> VI;
typedef double db;
const int N = 1e6 + 5;
const int mod = 1e9 + 7;
int n, cnt[N];
int h[N], to[N << 1], ne[N << 1], tot;
ll ans[N];
void add(int u, int v) {
ne[++tot] = h[u]; h[u] = tot; to[tot] = v;
}
void calc(int u, ll res) {
if (res == mod - 1) ++cnt[u];
else ans[u] = ans[u] * (res % mod + 1) % mod;
}
void dfs1(int u, int f) {
ans[u] = 1;
for (int i = h[u]; i; i = ne[i]) {
int y = to[i];
if (y == f) continue;
dfs1(y, u);
calc(u, cnt[y] ? 0 : ans[y]);
}
}
ll power(ll a, int b) {
ll res = 1;
for (; b; a = a * a % mod, b >>= 1)
if (b & 1) res = res * a % mod;
return res;
}
void dfs2(int u, int f) {
if (f) {
if (!cnt[u] && ans[u] == mod - 1) {
if (cnt[f] == 1) calc(u, ans[f]);
} else if (cnt[f] == 0)
calc(u, ans[f] * power(cnt[u] ? 1 : ans[u] % mod + 1, mod - 2));
}
for (int i = h[u]; i; i = ne[i]) {
int y = to[i];
if (y == f) continue;
dfs2(y, u);
}
}
int main() {
IO;
cin >> n;
rep (i, 2, n) {
int u, v; cin >> u >> v;
add(u, v); add(v, u);
}
dfs1(1, 0); dfs2(1, 0);
rep (i, 1, n) cout << (cnt[i] ? 0 : ans[i]) << '
';
return 0;
}