先让我们探索一下两条非树边以及树边能构成简单环的条件是什么,你会发现将第一条非树边的两个点在树上形成的链记为 (W_1),另一条即为 (W_2),那么当且仅当 (W_1, W_2) 有交时才能满足条件。因为当 (W_1, W_2) 没交时,那么中间一定会经过一些树边不被这两条链覆盖,但因为点只能走一次因此我们就不能回来了,所以这种情况下是不行的;而当 (W_1, W_2) 有交时很容易就可以构造出一条合法的简单环。
于是现在问题转化成统计有多少对链 (W_1, W_2) 满足其有交。这种链交问题我们一般拆成左右两条直上直下的链,这样可以减少讨论。令 (W_1, W_2) 中链顶更高的为 (W_1),不难发现两条链有交当且仅当 (W_1) 会穿过 (W_2) 的链顶与其链上儿子中间组成的这条边,因此对于每条链 (W_i) 我们在其链顶与其在链上的儿子组成的边上权值 (+1),那么和每条链交的链的数量就是其在链上的边上的权值之和,这个我们可以直接倍增计算得出,这部分复杂度 (O(n log n))。于此同时你会发现,如果两条链链顶不同,这样只会计算一次链交,如果两条链顶相同,这样会计算两次,因此我们还需要将这部分多余的减去。
不难发现我们将所有链挂在链顶上,那么链顶相同且有交的链当且仅当是都经过了这个链顶的一个儿子的这样一对链。于是我们把经过每个儿子的链的数量统计出来,令第 (i) 个儿子的数量为 (c_i),则多算的链数应该是:(dbinom{c_i}{2}),减去即可,这部分复杂度 (O(n))。
但是回过头来会发现一个问题,有没有可能两条链在左边边交一次右边边也交一次呢,这样就记重了。事实上是有的,当且仅当两条链的 (LCA) 相同并且两条链都相同地经过 (LCA) 的两个儿子。这部分记重我们只需要将所有链一样地挂在 (LCA) 上每次用 (map) 统计经过两个儿子的链的数量,令其中一种的数量为 (x),则多算的部分应该是:(dbinom{x}{2}),减去即可。这部分复杂度 (O(n log n))。
一些坑点
- 就算原来的链也是直上直下的也会计算重复,也需要减去算重的部分;于此同时需要注意第二次去重时不要统计经过其自身的和一个儿子的重复部分。
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define rep(i, l, r) for(int i = l; i <= r; ++i)
#define dep(i, l, r) for(int i = r; i >= l; --i)
#define Next(i, u) for(int i = h[u]; i; i = e[i].next)
const int N = 200000 + 5;
const int K = 20 + 5;
struct edge{
int v, next;
}e[N << 1];
struct node{
int u, v;
}a[N << 1];
long long ans;
int n, m, u, v, tot, cnt, Lca, h[N], dep[N], tmp[N], f[N][K], dp[N][K];
vector <node> G[N];
map <int, int> M[N];
int read(){
char c; int x = 0, f = 1;
c = getchar();
while(c > '9' || c < '0'){ if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}
void add(int u, int v){
e[++tot].v = v, e[tot].next = h[u], h[u] = tot;
e[++tot].v = u, e[tot].next = h[v], h[v] = tot;
}
void dfs(int u, int fa){
f[u][0] = fa, dep[u] = dep[fa] + 1;
Next(i, u) if(e[i].v != fa) dfs(e[i].v, u);
}
int LCA(int x, int y){
if(dep[x] < dep[y]) swap(x, y);
dep(i, 0, 20) if(dep[f[x][i]] >= dep[y]) x = f[x][i];
if(x == y) return x;
dep(i, 0, 20) if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
int find(int x, int y){
dep(i, 0, 20) if(dep[f[x][i]] > dep[y]) x = f[x][i];
return x;
}
int calc(int x, int y){
int ans = 0;
dep(i, 0, 20) if(dep[f[x][i]] >= dep[y]) ans += dp[x][i], x = f[x][i];
return ans;
}
signed main(){
n = read(), m = read();
rep(i, 1, n - 1) u = read(), v = read(), add(u, v);
dfs(1, 0);
rep(j, 1, 20) rep(i, 1, n) f[i][j] = f[f[i][j - 1]][j - 1];
rep(i, n, m){
u = read(), v = read(), Lca = LCA(u, v); if(dep[u] < dep[v]) swap(u, v);
G[Lca].push_back((node){find(u, Lca), find(v, Lca)});
if(v != Lca) a[++cnt] = (node){u, Lca}, a[++cnt] = (node){v, Lca};
else a[++cnt] = (node){u, v};
}
rep(i, 1, n){
if(!G[i].size()) continue;
for(int j = 0; j < G[i].size(); ++j){
if(G[i][j].u > G[i][j].v) swap(G[i][j].u, G[i][j].v);
if(G[i][j].u != i && G[i][j].v != i) ans -= M[G[i][j].u][G[i][j].v];
++tmp[G[i][j].u], ++tmp[G[i][j].v], ++M[G[i][j].u][G[i][j].v];
}
Next(j, i) ans -= 1ll * tmp[e[j].v] * (tmp[e[j].v] - 1) / 2;
for(int j = 0; j < G[i].size(); ++j) --tmp[G[i][j].u], --tmp[G[i][j].v], --M[G[i][j].u][G[i][j].v];
}
rep(i, 1, cnt) ++dp[find(a[i].u, a[i].v)][0];
rep(j, 1, 20) rep(i, 1, n) dp[i][j] = dp[i][j - 1] + dp[f[i][j - 1]][j - 1];
rep(i, 1, cnt) ans += calc(a[i].u, a[i].v) - 1;
printf("%lld", ans);
return 0;
}
值得一提的时其实在第一次计算时不需要倍增去统计,我们直接利用差分的技巧统计链底到根的和减去链顶到根的和即可。