「点分治」学习笔记
什么是点分治
拆开来说,“点”就是树上的每一个点,“分治”就是类似 (CDQ) 的分治,用来解决和处理带权无根树(无向树,即每个点都可以当根)上路径统计问题。
基本思路
对于一个树上的路径,无非分成两种情况:
- 经过根节点的一条路径
- 在一个子树里的路径
这样我们就可以利用分治的思想,先处理一颗大树,只统计经过根节点的路径,再处理它的每个子树。
STEP1
- 寻找树的重心
(树的重心:其最大子树大小比其他任何一个节点当根的最大子树大小都小)
为什么要找树的重心呢,用一个例子来说明:
如果树是一条长度为 (5) 的链状,我们用最上面的 (1) 结点来处理,复杂度就是 (5)。
但是,我们如果用其重心 ((3)) 来处理,复杂度就是 (2) 了。
所以求树的重心保证了我们的时间复杂度。
void FindRoot (int u, int fa) { // 寻找树的重心
size[u] = 1, maxson[u] = 0; // 子树大小和,最大子树大小
for (register int i = head[u]; i; i = e[i].next) {
register int v = e[i].to;
if (v == fa || vis[v]) continue; // vis数组很重要,我们后面再说
FindRoot (v, u);
size[u] += size[v]; // 统计子树大小
maxson[u] = max (maxson[u], size[v]); // 求出最大
}
maxson[u] = max (maxson[u], Tsize - size[u]); // 别忘了它上面的子树(它的爸爸那边的)
if (maxson[u] < maxson[root]) root = u; // 更新根节点
}
STEP2
- 计算并统计其子树节点到它的路径
这里就要用到我们的分治思想了。
首先我们先 (DFS) 一遍统计一下每个根节点到下面节点的路径长度。
void DFS (int u, int fa, int d) { // d为当前节点u到root的距离
w[++ cnt] = d; // 记录到数组里面,一会儿分治用
for (register int i = head[u]; i; i = e[i].next) {
register int v = e[i].to;
if (v == fa || vis[v]) continue;
DFS (v, u, d + e[i].w); // 计算距离
}
}
- 分治
我们将上面记录的距离数组,从小到大排序一边。
定义一个左指针 (l) ,一个右指针 (r), 如果说 (w[l] + w[r] > k) 将右指针左移, 如果 (w[l] + w[r] <= k) 那么 (r) 左边的所有路径都可以满足条件,将答案加上 (r - l + 1) 即可。
我们现在将每个根节点到自己的为 (0) 路径也算到答案里面了,所以我们最后需要减去一个 (n) 。
同时我们也将同一个子树里面结点的路径也加上了,所以我们在总的点分治里面,再减去其子树所特殊统计到的答案。
inline int Calc (int u, int d) {
register int sum = 0;
cnt = 0;
DFS (u, 0, d); // 统计路径
sort (w + 1, w + cnt + 1); // 从小到大排序
register int r = cnt;
for (register int l = 1; l <= cnt; l ++) {
while (w[l] + w[r] > k && r >= 1) r --; // 左移右指针
if (l > r) break;
sum += r - l + 1;
}
return sum;
}
STEP3
- 最后就是点分治的核心了
先加上到当前根节点的答案,便历其子节点,将其子节点基础距离为 (e[i].w)的答案减去。
别忘了在处理每颗子树的时候,也要更新一边重心。
void CDQ (int u) {
ans += Calc (u, 0);
vis[u] = 1;
for (register int i = head[u]; i; i = e[i].next) {
register int v = e[i].to;
if (vis[v]) continue;
ans -= Calc (v, e[i].w); // 减去冗余的答案
root = 0;
Tsize = size[v]; // 更新当前树的大小和根节点
FindRoot (v, u);
CDQ (root); // 处理每个子树
}
}
时间复杂度
好像差不多是 (O(nlog^2_n)), 证明略
模板题
代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define DEBUG puts ("emmmm")
using namespace std;
const int maxn = 1e5 + 50, INF = 0x3f3f3f3f;
inline int read () {
register int x = 0, w = 1;
char ch = getchar();
for (; ch < '0' || ch > '9'; ch = getchar()) if (ch == '-') w = -1;
for (; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
return x * w;
}
int n, k, root, Tsize, ans, cnt;
int size[maxn], maxson[maxn], w[maxn];
bool vis[maxn];
struct Edge {
int to, next, w;
}e[maxn << 1];
int tot, head[maxn];
void Add (int u, int v, int w) {
e[++tot].to = v;
e[tot].w = w;
e[tot].next = head[u];
head[u] = tot;
}
void FindRoot (int u, int fa) { // 寻找树的重心
size[u] = 1, maxson[u] = 0; // 子树大小和,最大子树大小
for (register int i = head[u]; i; i = e[i].next) {
register int v = e[i].to;
if (v == fa || vis[v]) continue; // vis数组很重要,我们后面再说
FindRoot (v, u);
size[u] += size[v]; // 统计子树大小
maxson[u] = max (maxson[u], size[v]); // 求出最大
}
maxson[u] = max (maxson[u], Tsize - size[u]); // 别忘了它上面的子树(它的爸爸那边的)
if (maxson[u] < maxson[root]) root = u; // 更新根节点
}
void DFS (int u, int fa, int d) { // d为当前节点u到root的距离
w[++ cnt] = d; // 记录到数组里面,一会儿分治用
for (register int i = head[u]; i; i = e[i].next) {
register int v = e[i].to;
if (v == fa || vis[v]) continue;
DFS (v, u, d + e[i].w); // 计算距离
}
}
inline int Calc (int u, int d) {
register int sum = 0;
cnt = 0;
DFS (u, 0, d); // 统计路径
sort (w + 1, w + cnt + 1); // 从小到大排序
register int r = cnt;
for (register int l = 1; l <= cnt; l ++) {
while (w[l] + w[r] > k && r >= 1) r --; // 左移右指针
if (l > r) break;
sum += r - l + 1;
}
return sum;
}
void CDQ (int u) {
ans += Calc (u, 0);
vis[u] = 1;
for (register int i = head[u]; i; i = e[i].next) {
register int v = e[i].to;
if (vis[v]) continue;
ans -= Calc (v, e[i].w); // 减去冗余的答案
root = 0;
Tsize = size[v]; // 更新当前树的大小和根节点
FindRoot (v, u);
CDQ (root); // 处理每个子树
}
}
int main () {
//freopen ("tree.in", "r", stdin);
//freopen ("tree.out", "w", stdout);
n = read();
for (register int i = 1; i <= n - 1; i ++) {
register int u = read(), v = read(), w = read();
Add (u, v, w), Add (v, u, w);
}
k = read();
maxson[root = 0] = INF, Tsize = n;
FindRoot (1, 0);
CDQ (root);
printf ("%d
", ans - n);
return 0;
}