设 (k) 的答案为 (g(k)),直接计算 (g(k)) 貌似很难,设 (f(k)) 为 (kmidgcd(x,y)) 的 ((x,y),xleq y) 个数。(这里定义 (gcd(x,y)) 为 (x) 到 (y) 最短路径的点权 (gcd))
可以莫比乌斯反演一下,有:
[f(d)=sum_{d|n}g(n)Rightarrow g(d)=sum_{d|n}mu(frac{n}{d})f(n)
]
假如我们已经处理出了 (f),筛出 (mu) 后就能在 (mathcal{O}(nlog n)) 的复杂度内算出 (g)。
如何算 (f) ?注意到 (leq 2 imes 10^5) 的最大约数个数是 (leq 240) 的,可以暴力找出所有约数,并在这些约数为编号的图中加入这条边。
对于每个数 (i) 为编号的图中统计 (f_i),用并查集,每次合并连通块时候 (f_i) 加上跨合并连通块的这条边左右两边点数的乘积。
总复杂度为 (mathcal{O}(nlog n+nsqrt a))。
#include<iostream>
#include<cstdio>
#include<vector>
typedef long long ll;
template <typename T> T Max(T x, T y) { return x > y ? x : y; }
template <typename T> T Min(T x, T y) { return x < y ? x : y; }
template <typename T>
T& read(T& r) {
r = 0; bool w = 0; char ch = getchar();
while(ch < '0' || ch > '9') w = ch == '-' ? 1 : 0, ch = getchar();
while(ch >= '0' && ch <= '9') r = r * 10 + (ch ^ 48), ch = getchar();
return r = w ? -r : r;
}
inline int gcd(int x, int y) { return !y ? x : gcd(y, x % y); }
const int N = 200005;
int n, mx, a[N];
ll f[N];
std::vector<int>vec[N];
struct DSU {
int fa[N], siz[N];
int find(int x) { return fa[x] = fa[x] == x ? x : find(fa[x]); }
void merge(int t, int x, int y) {
int fx = find(x), fy = find(y);
if(fx == fy) return ;
f[t] += 1ll * siz[fx] * siz[fy];
fa[fx] = fy;
siz[fy] += siz[fx];
}
}dsu;
int prime[N], pct, mu[N];
int lu[N], lv[N];
bool vis[N];
void init() {
vis[1] = 1; mu[1] = 1;
for(int i = 2; i <= mx; ++i) {
if(!vis[i]) {
prime[++pct] = i;
mu[i] = -1;
}
for(int j = 1; j <= pct && i * prime[j] <= mx; ++j) {
vis[i * prime[j]] = 1;
if(i % prime[j] == 0) { mu[i * prime[j]] = 0; break; }
mu[i * prime[j]] = -mu[i];
}
}
}
int main() {
read(n);
for(int i = 1; i <= n; ++i) {
read(a[i]);
mx = Max(mx, a[i]);
for(int j = 1; j * j <= a[i]; ++j) {
if(a[i] % j) continue ;
++f[j];
if(j * j != a[i]) ++f[a[i]/j];
}
}
init();
for(int i = 1; i < n; ++i) {
read(lu[i]); read(lv[i]);
int g = gcd(a[lu[i]], a[lv[i]]);
for(int j = 1; j * j <= g; ++j)
if(g % j == 0) {
vec[j].push_back(i);
if(j * j != g) vec[g / j].push_back(i);
}
}
for(int i = 1; i <= mx; ++i) {
for(auto x : vec[i]) {
dsu.fa[lu[x]] = lu[x];
dsu.fa[lv[x]] = lv[x];
dsu.siz[lu[x]] = 1;
dsu.siz[lv[x]] = 1;
}
for(auto x : vec[i])
dsu.merge(i, lu[x], lv[x]);
}
for(int i = 1; i <= mx; ++i) {
ll ans = 0;
for(int j = i; j <= mx; j += i) ans += mu[j / i] * f[j];
if(ans) printf("%d %lld
", i, ans);
}
return 0;
}