个人感觉挺有意思的,然而被颜神D无聊惹(~ ̄▽ ̄)~
这题我们可以首先试图去统计以每一个点作为 w 点所能对答案造成的贡献是多少。不难发现,当且仅当 u 和 v 都在 w 所在边双的一侧的时候不能构成一个合法的三元组,因为它们要到达 w 均需经过一条共同的割边。那么因为原图是一棵树,所以我们连接两个点的时候就是在把这两个点所在的边双一直到根所在的边双都合并为一个。
考虑如何在合并答案的时候计算出答案的变化。若我们合并的是 S,T 这两个集合,我们可以先减去由 S 和 T 中的点作为 w 点时对答案造成的贡献。1.u 和 v 均为 w 所在边双中的点,这个直接用边双大小统计就可以了;2.一个在边双外部,一个在边双内部。这个也可以直接用边双大小进行统计。
比较难想到的是如何统计两个点都在边双外部的情况(在边双的两侧)。这个直接统计并不是很方便,但是不难发现如果统计在点双外部且在两侧的情况是很多的,而在点双外部且在同一侧的情况则单一很多。全部的选择就是点双外部的点钟随便选两个,我们可以把在同一侧的情况减去得到合法的解。维护数组 w[u] 表示 u 联通块中的一个点所能匹配到的同一侧的两个点有多少种方案。非法的情况即为 w[u] * s[u] (u 联通块的大小)。合并的时候 w 数组怎么合并呢?令 u 为 v 的父亲,则 w[u] + w[v] 这样统计的话会把 v 所在的子树内的点对 & v 点外部(父亲子树)的点对统计两次。减去就好啦。
#include <bits/stdc++.h> using namespace std; #define maxn 1000000 #define int long long int n, ans, s[maxn], w[maxn], dep[maxn]; int size[maxn], fa[maxn], f[maxn]; int read() { int x = 0, k = 1; char c; c = getchar(); while(c < '0' || c > '9') { if(c == '-') k = -1; c = getchar(); } while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar(); return x * k; } struct edge { int cnp, to[maxn], last[maxn], head[maxn]; edge() { cnp = 2; } void add(int u, int v) { to[cnp] = v, last[cnp] = head[u], head[u] = cnp ++; to[cnp] = u, last[cnp] = head[v], head[v] = cnp ++; } }E1; void dfs(int u) { size[u] = 1; dep[u] = dep[fa[u]] + 1; for(int i = E1.head[u]; i; i = E1.last[i]) { int v = E1.to[i]; if(v == fa[u]) continue; fa[v] = u; dfs(v); size[u] += size[v]; w[u] += size[v] * size[v]; } w[u] += (n - size[u]) * (n - size[u]); ans -= w[u]; } int Cal(int u) { return max(s[u] * (s[u] - 1) * (s[u] - 2), 0LL); } int find(int x) { return f[x] == x ? x : f[x] = find(f[x]); } void merge(int u, int v) { ans -= (n - s[u]) * (n - s[u]) * s[u] - w[u] * s[u]; ans -= (n - s[v]) * (n - s[v]) * s[v] - w[v] * s[v]; ans -= (n - s[u]) * s[u] * (s[u] - 1) * 2; ans -= (n - s[v]) * s[v] * (s[v] - 1) * 2; ans -= Cal(u) + Cal(v); f[v] = u, s[u] += s[v]; w[u] += w[v] - size[v] * size[v] - (n - size[v]) * (n - size[v]); ans += (n - s[u]) * (n - s[u]) * s[u] - w[u] * s[u] + Cal(u); ans += (n - s[u]) * s[u] * (s[u] - 1) * 2; } signed main() { n = read(); for(int i = 1; i < n; i ++) { int u = read(), v = read(); E1.add(u, v); } ans = n * (n - 1) * (n - 1); dfs(1); for(int i = 1; i <= n; i ++) f[i] = i, s[i] = 1; int q = read(); printf("%lld ", ans); for(int i = 1; i <= q; i ++) { int u = read(), v = read(); u = find(u), v = find(v); while(u != v) { if(dep[u] < dep[v]) swap(u, v); int fu = find(fa[u]); merge(fu, u); u = fu; } printf("%lld ", ans); } return 0; }