设节点个数大于 $sqrt n$ 的颜色为关键颜色,那么可以证明关键颜色最多有 $sqrt n$ 个.
对于每个关键颜色,暴力预处理出该颜色到查询中另一个颜色的距离和.
对于不是关键颜色的询问,直接建立虚树进行统计即可.
由于不是关键颜色,节点数最多为 $sqrt n$ ,那么时间复杂度是 $O(2 imes nsqrt n)$.
总时间复杂度为 $O(nsqrt n)$,这个就叫做根号分治.
#include <cstdio> #include <algorithm> #include <vector> #include <cmath> #include <map> #define N 100003 #define ll long long #define setIO(s) freopen(s".in", "r" , stdin) , freopen(s".out", "w" , stdout) using namespace std; namespace IO { char *p1,*p2,buf[100000]; #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++) int readint() {int x=0; char c=nc(); while(c<48) c=nc(); while(c>47) x=(((x<<2)+x)<<1)+(c^48),c=nc(); return x;} ll readll() {ll x=0; char c=nc(); while(c<48) c=nc(); while(c>47) x=(((x<<2)+x)<<1)+(c^48),c=nc(); return x;} }; vector <int> G[N], ty[N], node; int n , edges, tim, toop; ll dis[N], depth[N]; int col[N], tax[N], id[N], A[N], size[N], S[N]; int hd[N], nex[N << 1], to[N << 1], top[N], dfn[N], fa[N], dep[N], son[N], siz[N]; ll val[N << 1]; bool cmp(int a, int b) { return dfn[a] < dfn[b]; } inline void addedge(int u, int v, int c) { nex[++ edges] = hd[u], hd[u] = edges, to[edges] = v, val[edges] = 1ll * c; } void dfs1(int u, int ff) { int i, v; fa[u] = ff, dep[u] = dep[ff] + 1, dfn[u] = ++ tim, siz[u] = 1; for(i = hd[u] ; i ; i = nex[i]) { v = to[i]; if(v == ff) continue; depth[v] = depth[u] + 1ll * val[i], dfs1(v, u), siz[u] += siz[v]; if(siz[v] > siz[son[u]]) son[u] = v; } } void dfs2(int u, int tp) { top[u] = tp; if(son[u]) dfs2(son[u], tp); for(int i = hd[u] ; i ; i = nex[i]) { int v = to[i]; if(v == fa[u] || v == son[u]) continue; dfs2(v, v); } } inline int LCA(int x, int y) { while(top[x] ^ top[y]) { dep[top[x]] > dep[top[y]] ? x = fa[top[x]] : y = fa[top[y]]; } return dep[x] < dep[y] ? x : y; } inline ll Dis(int x, int y) { return depth[x] + depth[y] - (depth[LCA(x, y)] << 1); } void solve1(int u, int ff, int cur) { size[u] = (col[u] == cur), dis[u] = 0; for(int i = hd[u] ; i ; i = nex[i]) { int v = to[i]; if(v == ff) continue; solve1(v, u, cur),size[u] += size[v], dis[u] += (dis[v] + 1ll * size[v] * val[i]); } } void solve(int u, int ff, int cur) { for(int i = hd[u] ; i ; i = nex[i]) { int v = to[i]; if(v == ff) continue; dis[v] += (dis[u] - dis[v] - 1ll * size[v] * val[i] + 1ll * (tax[cur] - size[v]) * val[i]); solve(v, u, cur); } } inline void addvir(int u, int v) { G[u].push_back(v); } inline void insert(int x) { if(toop < 2) { S[++ toop] = x; return ; } int lca = LCA(x, S[toop]); if(lca != S[toop]) { while(toop > 1 && dep[S[toop - 1]] >= dep[lca]) addvir(S[toop - 1], S[toop]),-- toop; if(S[toop] != lca) addvir(lca, S[toop]), S[toop] = lca; } S[++ toop] = x; } void pre(int u, int ff, int cur) { size[u] = (col[u] == cur), dis[u] = 0; for(int i = 0; i < G[u].size(); ++ i) { int v = G[u][i]; pre(v, u, cur), size[u] += size[v], dis[u] += dis[v] + 1ll * size[v] * Dis(v, u); } } void work(int u, int ff, int cur) { for(int i = 0; i < G[u].size() ; ++ i) { int v = G[u][i]; dis[v] += (dis[u] - dis[v] - 1ll * size[v] * Dis(u, v) + 1ll * (tax[cur] - size[v]) * Dis(u, v)); work(v, u, cur); } } void clear(int u) { size[u] = dis[u] = 0; for(int i = 0; i < G[u].size(); ++ i) clear(G[u][i]) ; G[u].clear(); } struct Node { int a, b; }ask[N]; vector < int > P[N]; vector < ll > answer[N]; int point[N]; int main() { using namespace IO; // setIO("input"); int i , j, idx = 0, m, Q; n = readint(); m = sqrt(n); for(i = 1; i <= n ; ++ i) col[i] = readint(), ++tax[col[i]], ty[col[i]].push_back(i); for(i = 1; i < n ; ++ i) { int a = readint(), b = readint(), c = readint(); addedge(a, b, c), addedge(b, a, c); } dfs1(1, 0), dfs2(1, 1); for(i = 1; i <= n ; ++ i) if(tax[i] >= m) id[i] = ++idx; Q = readint(); for(i = 1; i <= Q; ++ i) { ask[i].a = readint(), ask[i].b = readint(); if(tax[ask[i].a] < tax[ask[i].b]) swap(ask[i].a, ask[i].b); if(tax[ask[i].a] >= m) P[ask[i].a].push_back(ask[i].b); } for(i = 1; i <= n ; ++ i) { if(tax[i] >= m) { solve1(1, 0, i), solve(1, 0, i); for(j = 0 ; j < P[i].size() ; ++ j) { int cur = P[i][j]; ll re = 0; for(int k = 0; k < ty[cur].size(); ++ k) { re += dis[ty[cur][k]]; } answer[i].push_back(re); } } } for(int cas = 1; cas <= Q; ++ cas) { int a, b; a = ask[cas].a, b = ask[cas].b; if(tax[a] >= m) printf("%lld ", a == b ? answer[a][point[a] ++ ] / 2 : answer[a][point[a] ++ ]); else { int tmp = 0; ll re = 0; for(i = 0; i < ty[a].size(); ++ i) A[++ tmp] = ty[a][i]; for(i = 0; i < ty[b].size(); ++ i) A[++ tmp] = ty[b][i]; sort(A + 1, A + 1 + tmp, cmp); tmp = unique(A + 1, A + 1 + tmp) - (A + 1); toop = 0; if(A[1] != 1) S[++ toop] = 1; for(i = 1 ; i <= tmp ; ++ i) insert(A[i]); while(toop > 1) addvir(S[toop - 1], S[toop]), --toop; pre(1, 0, b), work(1, 0, b); for(i = 0; i < ty[a].size(); ++ i) re += dis[ty[a][i]]; printf("%lld ", a == b ? re / 2 : re); } } return 0; }