随便树上莫队搞一搞就好啦。
#include<bits/stdc++.h> #define LL long long #define LD long double #define ull unsigned long long #define fi first #define se second #define mk make_pair #define PLL pair<LL, LL> #define PLI pair<LL, int> #define PII pair<int, int> #define SZ(x) ((int)x.size()) #define ALL(x) (x).begin(), (x).end() #define fio ios::sync_with_stdio(false); cin.tie(0); using namespace std; const int N = 2e5 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9 + 7; const double eps = 1e-8; const double PI = acos(-1); template<class T, class S> inline void add(T& a, S b) {a += b; if(a >= mod) a -= mod;} template<class T, class S> inline void sub(T& a, S b) {a -= b; if(a < 0) a += mod;} template<class T, class S> inline bool chkmax(T& a, S b) {return a < b ? a = b, true : false;} template<class T, class S> inline bool chkmin(T& a, S b) {return a > b ? a = b, true : false;} const int B = 500; int n, q, depth[N], f[N], pa[N][20], gender[N]; int in[N], out[N], id[N], idx, op[N]; LL ans[N]; bool flag[N]; int cnt[2][N]; vector<int> oo; vector<int> G[N]; struct Qus { int L, R, lca, id; bool operator < (const Qus& rhs) const { if(L / B == rhs.L / B) return R < rhs.R; return L < rhs.L; } } qus[N]; int l, r; LL ret; inline void update(int x) { ret += op[id[x]] * cnt[gender[id[x]] ^ 1][f[id[x]]]; cnt[gender[id[x]]][f[id[x]]] += op[id[x]]; op[id[x]] = -op[id[x]]; } void dfs(int u, int fa) { depth[u] = depth[fa] + 1; pa[u][0] = fa; for(int i = 1; i < 20; i++) pa[u][i] = pa[pa[u][i - 1]][i - 1]; id[++idx] = u; in[u] = idx; for(auto& v : G[u]) { if(v == fa) continue; dfs(v, u); } id[++idx] = u; out[u] = idx; } int getLca(int u, int v) { if(depth[u] < depth[v]) swap(u, v); for(int i = 19; ~i; i--) if((depth[u] - depth[v]) >> i & 1) u = pa[u][i]; if(u == v) return u; for(int i = 19; ~i; i--) if(pa[u][i] != pa[v][i]) u = pa[u][i], v = pa[v][i]; return pa[u][0]; } int main() { scanf("%d", &n); for(int i = 1; i <= n; i++) scanf("%d", &gender[i]), op[i] = 1; for(int i = 1; i <= n; i++) { scanf("%d", &f[i]); oo.push_back(f[i]); } for(int i = 1; i < n; i++) { int a, b; scanf("%d%d", &a, &b); G[a].push_back(b); G[b].push_back(a); } sort(ALL(oo)); oo.erase(unique(ALL(oo)), oo.end()); for(int i = 1; i <= n; i++) f[i] = lower_bound(ALL(oo), f[i]) - oo.begin(); dfs(1, 0); scanf("%d", &q); for(int i = 1; i <= q; i++) { int a, b, lca; scanf("%d%d", &a, &b); lca = getLca(a, b); if(lca == a || lca == b) { if(a == lca) qus[i] = Qus{in[a], in[b], a, i}; else qus[i] = Qus{in[b], in[a], b, i}; } else { if(in[a] < in[b]) qus[i] = Qus{out[a], in[b], lca, i}; else qus[i] = Qus{out[b], in[a], lca, i}; } } l = 1, r = 0, ret = 0; sort(qus + 1, qus + 1 + q); for(int o = 1; o <= q; o++) { int L = qus[o].L, R = qus[o].R, lca = qus[o].lca, who = qus[o].id; while(r < R) update(++r); while(l > L) update(--l); while(r > R) update(r--); while(l < L) update(l++); if(lca != id[L] && lca != id[R]) ans[who] = ret + cnt[gender[lca] ^ 1][f[lca]]; else ans[who] = ret; } for(int i = 1; i <= q; i++) printf("%lld ", ans[i]); return 0; } /* */