【题目大意】
给出一棵树,求有多少对点(u,v)满足其路径上不存在两个点a,b满足(a,b)=1
n<=10^5
【题解】
考虑找出所有不符合的点对,共有n*ln(n)对,他们要么是祖先->儿子边,要么是不是。
考虑祖先->儿子边,那么一个点在祖先以上,一个点在儿子以下的点对全部无法访问。
考虑另外一种边,就是LCA不是两个端点的,这就比较好统计了,两个点在这两棵子树的点对无法访问。
考虑用DFS序,这样子树就是连续的一段(祖先以上是连续两段)
然后就是一个二维覆盖问题,用扫描线+线段树即可解决。
复杂度O(nln(n)logn)
注意。。扫描线数组要开到 4 * n * ln(n) 不然。。会奇怪的WA/RE。。。
# include <stdio.h> # include <string.h> # include <iostream> # include <algorithm> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef long double ld; # define RG register # define ST static const int M = 2e5 + 10, N = 1e5 + 10, Max = 8 * M; const int mod = 998244353; int n, head[N], nxt[M], to[M], tot = 0; inline void add(int u, int v) { ++tot; nxt[tot] = head[u]; head[u] = tot; to[tot] = v; } inline void adde(int u, int v) { add(u, v), add(v, u); } int in[N], out[N], DFN = 0; int dep[N], fa[N][19]; inline void dfs(int x, int fat = 0) { in[x] = ++DFN; dep[x] = dep[fat] + 1; fa[x][0] = fat; for (int i=1; i<=18; ++i) fa[x][i] = fa[fa[x][i-1]][i-1]; for (int i=head[x]; i; i=nxt[i]) { if(to[i] == fat) continue; dfs(to[i], x); } out[x] = DFN; } inline int lca(int u, int v) { if(dep[u] < dep[v]) swap(u, v); for (int i=18; ~i; --i) if((dep[u] - dep[v]) & (1<<i)) u = fa[u][i]; if(u == v) return u; for (int i=18; ~i; --i) if(fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i]; return fa[u][0]; } inline int jump(int u, int anc) { for (int i=18; ~i; --i) if(dep[fa[u][i]] > dep[anc]) u = fa[u][i]; return u; } struct pa { int x, yl, yr, d; pa() {} pa(int x, int yl, int yr, int d) : x(x), yl(yl), yr(yr), d(d) {} friend bool operator < (pa a, pa b) { return a.x < b.x; } }p[Max * 4]; int pn = 0; inline void ADD(int xl, int xr, int yl, int yr) { p[++pn] = pa(xl, yl, yr, 1); p[++pn] = pa(xr+1, yl, yr, -1); } inline void doit(int x, int y) { int par = lca(x, y); // if(par == -1) cout << x << ' ' << y << endl; if(dep[x] > dep[y]) swap(x, y); if(x == par) { int pars = jump(y, par); ADD(1, in[pars] - 1, in[y], out[y]); ADD(in[y], out[y], out[pars] + 1, n); } else { if(in[x] > in[y]) swap(x, y); ADD(in[x], out[x], in[y], out[y]); } } struct SMT { int w[Max], tag[Max]; # define ls (x<<1) # define rs (x<<1|1) inline void set() { memset(w, 0, sizeof w); memset(tag, 0, sizeof tag); } inline int gs(int x, int l, int r) { if(tag[x]) return r-l+1; else return w[x]; } inline void edt(int x, int l, int r, int L, int R, int d) { if(L > R) return ; if(L <= l && r <= R) {tag[x] += d; return ;} int mid = l+r>>1; if(L <= mid) edt(ls, l, mid, L, R, d); if(R > mid) edt(rs, mid+1, r, L, R, d); w[x] = gs(ls, l, mid) + gs(rs, mid+1, r); } inline int sum(int x, int l, int r, int L, int R) { if(L > R) return 0; if(tag[x]) return min(R, r) - max(L, l) + 1; if(L <= l && r <= R) return gs(x, l, r); int mid = l+r>>1, ret = 0; if(L <= mid) ret += sum(ls, l, mid, L, R); if(R > mid) ret += sum(rs, mid+1, r, L, R); return ret; } # undef ls # undef rs }T; int main() { // freopen("A.in", "r", stdin); // freopen("A.out", "w", stdout); cin >> n; for (int i=1, u, v; i<n; ++i) { scanf("%d%d", &u, &v); adde(u, v); } dfs(1, 0); for (int i=1; i<=n; ++i) for (int j=i+i; j<=n; j+=i) doit(i, j); sort(p+1, p+pn+1); T.set(); ll ans = (ll)n * (n-1) / 2; for (int i=1, j=1; i<=n; ++i) { while(j<=pn && p[j].x == i) T.edt(1, 1, n, p[j].yl, p[j].yr, p[j].d), ++j; ans -= T.sum(1, 1, n, i+1, n); } cout << ans; return 0; }