( ext{Problem})
( ext{Solution})
把形如 ((a,ka)) 的路径提出来
那么覆盖这些路径的路径为不合法路径
如果能不重不漏的找出这些路径,然后用总路径减去就是答案
为了方便计算,我们限定路径用 (dfn) 序表示 ((x,y)) ,并规定 (x < y)
即树上两点构成的路径 ((x,y)) 满足 (dfn[x] < dfn[y])
然后如何确定那些路径 ((a,b)) 覆盖了最先找出来的路径 ((u,v))
其实很好办,自己画画图就知道了
其中要分两类讨论,记 (end_x) 为子树 (x) 中 (dfn) 序最大的点的 (dfn) 序,即 (end_x = dfn_x + siz_x - 1)
那么
于是我们确定了不合法路径 ((a,b)) 的范围,那怎么去掉重复路径呢?
很妙啊!
因为路径像是平面上的有序数对,于是我们把它弄到平面上,然后发现不合法路径的范围是一个又一个矩阵
那么总数就是矩阵面积的并
扫描线解决即可
( ext{Code})
#include<cstdio>
#include<algorithm>
#define LL long long
#define ls (p << 1)
#define rs (ls | 1)
using namespace std;
const int N = 1e5 + 5;
int n, h[N], m;
struct line{
int x, y0, y1, v;
}l[4000005];
inline bool cmp(line x, line y){return x.x < y.x ? 1 :(x.x == y.x ? x.v < y.v : 0);}
struct edge{int to, nxt;}e[N * 2];
inline void add(int x, int y)
{
static int tot = 0;
e[++tot] = edge{y, h[x]}, h[x] = tot;
}
int dep[N], f[N][20], dfn[N], siz[N];
void dfs(int x)
{
static int dfc = 0;
dfn[x] = ++dfc, siz[x] = 1;
for(int i = 1; i <= 17; i++)
if (f[x][i - 1]) f[x][i] = f[f[x][i - 1]][i - 1];
else break;
for(int i = h[x]; i; i = e[i].nxt)
{
int v = e[i].to;
if (dep[v]) continue;
dep[v] = dep[x] + 1, f[v][0] = x, dfs(v), siz[x] += siz[v];
}
}
int sum[N << 2], tag[N << 2];
inline void pushup(int l, int r, int p)
{
if (tag[p] > 0) sum[p] = r - l + 1;
else if (l == r) sum[p] = 0;
else sum[p] = sum[ls] + sum[rs];
}
void update(int l, int r, int p, int x, int y, int v)
{
if (x > r || y < l) return;
if (x <= l && r <= y)
{
tag[p] += v;
pushup(l, r, p);
return;
}
int mid = (l + r) >> 1;
if (x <= mid) update(l, mid, ls, x, y, v);
if (y > mid) update(mid + 1, r, rs, x, y, v);
pushup(l, r, p);
}
int main()
{
freopen("a.in", "r", stdin), freopen("a.out", "w", stdout);
scanf("%d", &n);
for(int i = 1, x, y; i < n; i++) scanf("%d%d", &x, &y), add(x, y), add(y, x);
dep[1] = 1, dfs(1);
for(int i = 1, x, y, t; i <= n; i++)
for(int j = i + i; j <= n; j += i)
{
x = i, y = j;
if (dfn[x] > dfn[y]) swap(x, y);
if (dfn[x] + siz[x] - 1 >= dfn[y])
{
t = y;
for(int k = 17; k >= 0; k--)
if (f[t][k] && dep[f[t][k]] > dep[x]) t = f[t][k];
if (dfn[t] > 1)
{
l[++m] = line{1, dfn[y], dfn[y] + siz[y] - 1, 1};
l[++m] = line{dfn[t], dfn[y], dfn[y] + siz[y] - 1, -1};
}
if (dfn[t] + siz[t] <= n)
{
l[++m] = line{dfn[y], dfn[t] + siz[t], n, 1};
l[++m] = line{dfn[y] + siz[y], dfn[t] + siz[t], n, -1};
}
}
else{
l[++m] = line{dfn[x], dfn[y], dfn[y] + siz[y] - 1, 1};
l[++m] = line{dfn[x] + siz[x], dfn[y], dfn[y] + siz[y] - 1, -1};
}
}
sort(l + 1, l + m + 1, cmp);
LL ans = 0;
for(int i = 1, j; i <= m; i++)
{
ans += 1LL * sum[1] * (l[i].x - l[i - 1].x);
for(j = i; j <= m && l[j].x == l[i].x; j++) update(1, n, 1, l[j].y0, l[j].y1, l[j].v);
i = j - 1;
}
printf("%lld
", 1LL * n * (n - 1) / 2 - ans);
}