首先按照套路,我们按位考虑,那么就变成统计 (0) 和 (1) 两种情况的组合的数目,最后乘2的某次幂。
那么记 (dp_{i,j,0/1}) 表示经过以 (i) 为根的子树的每一个节点,在第 (j) 位上产生了多少个 (0) 或 (1),然后 dfs 一遍,用儿子的答案更新父亲的答案,最后累加就好了。注意要考虑长度为 (0) 的路径。
#include <bits/stdc++.h>
#define reg register
#define ll long long
#define int long long
#define ull unsigned long long
#define db double
#define pi pair<int, int>
#define pl pair<ll, ll>
#define vi vector<int>
#define vl vector<ll>
#define vpi vector<pi>
#define vpl vector<pl>
#define pb push_back
#define er erase
#define SZ(x) (int) x.size()
#define lb lower_bound
#define ub upper_bound
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
#define mkp make_pair
#define ms(data_name) memset(data_name, 0, sizeof(data_name))
#define msn(data_name, num) memset(data_name, num, sizeof(data_name))
#define For(i, j) for(reg int (i) = 1; (i) <= (j); ++(i))
#define For0(i, j) for(reg int (i) = 0; (i) < (j); ++(i))
#define Forx(i, j, k) for(reg int (i) = (j); (i) <= (k); ++(i))
#define Forstep(i , j, k, st) for(reg int (i) = (j); (i) <= (k); (i) += (st))
#define fOR(i, j) for(reg int (i) = (j); (i) >= 1; (i)--)
#define fOR0(i, j) for(reg int (i) = (j) - 1; (i) >= 0; (i)--)
#define fORx(i, j, k) for(reg int (i) = (k); (i) >= (j); (i)--)
#define tour(i, u) for(reg int (i) = head[(u)]; (i) != -1; (i) = nxt[(i)])
using namespace std;
char ch, B[1 << 20], *S = B, *T = B;
#define getc() (S == T && (T = (S = B) + fread(B, 1, 1 << 20, stdin), S == T) ? 0 : *S++)
#define isd(c) (c >= '0' && c <= '9')
int rdint() {
int aa, bb;
while(ch = getc(), !isd(ch) && ch != '-');
ch == '-' ? aa = bb = 0 : (aa = ch - '0', bb = 1);
while(ch = getc(), isd(ch))
aa = aa * 10 + ch - '0';
return bb ? aa : -aa;
}
ll rdll() {
ll aa, bb;
while(ch = getc(), !isd(ch) && ch != '-');
ch == '-' ? aa = bb = 0 : (aa = ch - '0', bb = 1);
while(ch = getc(), isd(ch))
aa = aa * 10 + ch - '0';
return bb ? aa : -aa;
}
const int mod = 998244353;
// const int mod = 1e9 + 7;
struct mod_t {
static int norm(int x) {
return x + (x >> 31 & mod);
}
int x;
mod_t() { }
mod_t(int v) : x(v) { }
// mod_t(ll v) : x(v) { }
mod_t(char v) : x(v) { }
mod_t operator +(const mod_t &rhs) const {
return norm(x + rhs.x - mod);
}
mod_t operator -(const mod_t &rhs) const {
return norm(x - rhs.x);
}
mod_t operator *(const mod_t &rhs) const {
return (ll) x * rhs.x % mod;
}
};
const int MAXN = 2e5 + 10;
int n, a[MAXN], dp[MAXN][20][2];
ll ans = 0;
int E, head[MAXN], nxt[MAXN << 1], pnt[MAXN << 1];
inline void clear() {
E = 0;
msn(head, -1);
}
inline void addedge(int x, int y) {
nxt[E] = head[x];
pnt[E] = y;
head[x] = E++;
}
inline void dfs(int u, int f) {
For0(i, 20)
if(a[u] & (1 << i))
dp[u][i][1] = 1;
else
dp[u][i][0] = 1;
tour(i, u) {
int v = pnt[i];
if(v != f) {
dfs(v, u);
For0(j, 20) {
int tmp = (a[u] >> j) & 1;
ans += (dp[u][j][1] * dp[v][j][0] + dp[u][j][0] * dp[v][j][1]) << j;
dp[u][j][tmp ^ 0] += dp[v][j][0];
dp[u][j][tmp ^ 1] += dp[v][j][1];
}
}
}
}
inline void work() {
n = rdint();
For(i, n) {
a[i] = rdint();
ans += a[i];
}
clear();
Forx(i, 2, n) {
int u = rdint(), v = rdint();
addedge(u, v);
addedge(v, u);
}
dfs(1, -1);
printf("%lld
", ans);
}
signed main() {
// freopen("input.txt", "r", stdin);
work();
return 0;
}