Once there was a rooted tree. The tree contained n nodes, which were numbered 1,…,n. The node numbered 1 was the root of the tree. Besides, every node i was assigned a number a**i. Your were surprised to find that there were several pairs of nodes (i,j) satisfying a**i⊕a**j=alca(i,j),
where ⊕ denotes the bitwise XOR operation, and lca(i,j) is the lowest common ancestor of i and j, or formally, the lowest (i.e. deepest) node that has both i and j as descendants.
Unfortunately, you cannot remember all such pairs, and only remember the sum of i⊕j for all different pairs of nodes (i,j) satisfying the above property. Note that (i,j) and (j,i) are considered the same here. In other words, you will only be able to recall n∑i=1n∑j=i+1a**i⊕a**j=alca(i,j).
You are assumed to calculate it now in order to memorize it better in the future.
首先暴力做法就是双重遍历点求LCA然后判断,复杂度显然爆炸,考虑如何优化。既然枚举点不可,那么可以考虑枚举LCA,求LCA为根的子树对于答案的贡献,加起来就是最终的答案了。这样就自然想到了dsu on tree。由于异或的性质,可知要满足的条件\(a[x]\ xor\ a[y] = a[LCA]\)等价于\(a[y]=a[LCA]\ xor\ a[x]\),那么可以维护一个vector数组vector<int> vec[N],vec[i]存储值为i的点的下标。那么在遍历LCA的时候先暴力搜轻儿子,再搜重儿子且保留信息(存在vec数组),再依次搜轻儿子,对于每个轻儿子为根的子树先更新答案再更新vec数组(这样保证产生贡献的两个点的最近公共祖先一定是枚举到的LCA)。但这样的话每次需要扫一遍vector,复杂度也是无法接受的。注意到异或运算的话每一位都是独立的,可以考虑分别计算每一位的贡献然后加起来。设cnt[i, j, k]表示值为i,下标第j位为k的数的个数。只有两个下标第j位不相同,这一位才对答案有贡献,因此获取k的时候需要用x的第j位和1异或。
#include <iostream>
#include <cstring>
#define ll long long
const int N = 1e5 + 5;
using namespace std;
int n, a[N], head[N], ver[2 * N], Next[2 * N], tot = 0, sz[N], son[N];
long long ans = 0;
int hson = 0;
int cnt[N * 15][21][2];//第一维一定要开够
void add(int x, int y) {
ver[++tot] = y, Next[tot] = head[x], head[x] = tot;
}
void dfs1(int x, int pre) {
sz[x] = 1;
int mxsz = -1;
for(int i = head[x]; i; i = Next[i]) {
int y = ver[i];
if(y == pre) continue;
dfs1(y, x);
sz[x] += sz[y];
if(sz[y] > mxsz) {
son[x] = y;
mxsz = sz[y];
}
}
}
void update(int x, int pre, int lca) {
int tmp = a[x] ^ a[lca];
for(int i = 0; i <= 20; i++) {
ans += (1ll << i) * (cnt[tmp][i][!((x >> i) & 1)]);
}
for(int i = head[x]; i; i = Next[i]) {
int y = ver[i];
if(y == pre || y == hson) continue;
update(y, x, lca);
}
}
void modify(int x, int pre, int v) {
for(int i = 0; i <= 20; i++) {
cnt[a[x]][i][(x >> i) & 1] += v;//x的信息必须在这里添加才不会产生影响
}
for(int i = head[x]; i; i = Next[i]) {
int y = ver[i];
if(y == pre || y == hson) continue;
modify(y, x, v);
}
}
void dfs2(int x, int pre, bool keep) {
for(int i = head[x]; i; i = Next[i]) {
int y = ver[i];
if(y == pre || y == son[x]) continue;
dfs2(y, x, 0);
}
if(son[x]) {
dfs2(son[x], x, 1);
hson = son[x];
}
for(int i = head[x]; i; i = Next[i]) {//统计以a[x]为lca的答案
int y = ver[i];
if(y == pre || y == son[x]) continue;//因为重儿子已经add过一遍了,不能再添加了
//必须先计算再添加 否则可能出现v在u到假设的lca的链上,这样假设的lca就不是真正的lca了
update(y, x, x);
modify(y, x, 1);
}
for(int i = 0; i <= 20; i++) {
cnt[a[x]][i][(x >> i) & 1]++;//x的信息必须在这里添加才不会产生影响
}
hson = 0;
if(!keep) {
modify(x, pre, -1);
}
}
signed main() {
cin >> n;
memset(cnt, 0, sizeof(cnt));
for(int i = 1; i <= n; i++) {
cin >> a[i];
}
for(int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
add(u, v);
add(v, u);
}
dfs1(1, 0);
dfs2(1, 0, 1);
cout << ans;
return 0;
}