CF1034C Region Separation
题目大意
给定一棵 (n) 个点的树。每个点有权值,记为 (a_{1dots n})。
你想砍树。你可以砍任意轮,每轮你选择一些边(至少一条)断开,需要满足每轮结束后每个连通块的权值和是相等的。
求有多少种砍树方案。两种方案不同,当且仅当轮数不同或者某一轮砍的边不同。答案对 (10^9 + 7) 取模。
数据范围:(1leq nleq 10^6),(1leq a_ileq 10^9)。
本题题解
取任意点为根(不妨取 (1)),把树变成有根树。设 (S) 为所有节点的点权和,(s_i) 表示以 (i) 为根的子树内的点权和。
考虑只进行一轮划分,有多少种方案。设把树划分为了 (k) 块,那么存在方案的一个必要条件是:(k) 是 (S) 的约数。此时每块内的点权和是 (frac{S}{k})。
我们从叶子开始,类似树形 DP 的过程,每当当前连通块的和达到 (frac{S}{k}),就把当前连通块割下来。具体来说,对所有 (i = 2dots n),如果 (s_{i} mod frac{S}{k} = 0),就切断 (i) 和 (mathrm{fa}(i)) 之间的边。容易发现,如果有解,我们这样会求出唯一的一组解。如果无解,按此方法最终得到的连通块数会小于 (k)。
通过如上讨论,我们已经可以回答【只进行一轮划分,把树划分为 (k) 块的方案数】,设为 (f(k)),则:
也就是说,(f(k)) 只能等于 (0) 或 (1),并且它为 (1) 当且仅当:(k) 是 (S) 的约数,且恰有 (k) 个 (i) 满足 (s_i mod frac{S}{k} = 0)。
暴力求出 (f(1)dots f(n)) 的时间复杂度是 (mathcal{O}(n^2)) 的,我们想要更快。
考虑 (s_i mod frac{S}{k} = 0) 这个要求。转化一下:
也就是说,这个要求等价于 (k) 是 (frac{S}{gcd(S, s_i)}) 的倍数。
那么对每个 (s_i),设 (v = frac{S}{gcd(S, s_i)}),它可以对 $k = v, 2v, 3v, dots $ 产生贡献,另外注意 (k leq n)。于是我们用桶存一下,然后累加一遍,就能在调和级数的时间复杂度内,求出所有 (f(k)) 了。时间复杂度 (mathcal{O}(nlog n))。
然后考虑不止划分一轮的问题。设 (mathrm{dp}(k)) 表示划分了若干轮,得到 (k) 个连通块的方案数。根据划分方案的唯一性,它上一轮划分出的连通块数,能且仅能是 (k) 的约数。于是可以写出转移:
答案就是 (sum_{k = 1}^{n} mathrm{dp}(k))。这个 DP 的时间复杂度也是调和级数: (mathcal{O}(nlog n)) 的。
时间复杂度 (mathcal{O}(n(log n + log a))),其中 (log a) 来自求 (gcd)。
参考代码
片段:
const int MAXN = 1e6;
const int MOD = 1e9 + 7;
int n, a[MAXN + 5], fa[MAXN + 5];
ll s[MAXN + 5], S;
int f[MAXN + 5], dp[MAXN + 5];
ll gcd(ll x, ll y) { return (!y) ? x : gcd(y, x % y); }
int main() {
cin >> n;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
s[i] = a[i];
S += a[i];
}
for (int i = 2; i <= n; ++i) {
cin >> fa[i];
}
for (int i = n; i >= 2; --i) {
s[fa[i]] += s[i];
}
assert(s[1] == S);
for (int i = 1; i <= n; ++i) {
ll v = S / gcd(S, s[i]);
if (v <= n) {
f[v]++;
}
}
for (int i = n; i >= 1; --i) {
for (int j = i + i; j <= n; j += i) {
f[j] += f[i];
}
}
for (int i = 1; i <= n; ++i) {
f[i] = (f[i] == i); // 恰好能划分为 i 块
}
dp[1] = 1;
int ans = 0;
for (int i = 1; i <= n; ++i) {
if (!f[i]) {
dp[i] = 0;
continue;
}
for (int j = i + i; j <= n; j += i) {
dp[j] = (dp[j] + dp[i]) % MOD;
}
ans = (ans + dp[i]) % MOD;
}
cout << ans << endl;
return 0;
}