题目描述
给定一棵 (n) 个结点的有根树 (T),结点从 (1) 开始编号,根结点为 (1) 号结点,每个结点有一个正整数权值 (v_i) 。
设 (x) 号结点的子树内(包含 (x) 自身)的所有结点编号为 (c_1,c_2,dots,c_k),定义 (x) 的价值为:
(val(x)=(v_{c_1}+d(c_1,x)) oplus (v_{c_2}+d(c_2,x)) oplus cdots oplus (v_{c_k}+d(c_k, x)))
其中 (d(x,y)) 表示树上 (x) 号结点与 (y) 号结点间唯一简单路径所包含的边数,(d(x,x) = 0)。⊕ 表示异或运算。
请你求出 $sumlimits_{i=1}^n val(i) $ 的值。
输入格式
第一行一个正整数 (n) 表示树的大小。
第二行 (n) 个正整数表示 (v_i) 。
接下来一行 (n-1) 个正整数,依次表示 (2) 号结点到 (n) 号结点,每个结点的父亲编号。
输出格式
仅一行一个整数表示答案。
题解
很显然的想法是考虑如何让父亲节点利用儿子的(val)信息从而快速计算出父亲的(val)
考场上首先想的是如何处理这个所有点的点权+1后异或和的变化
我们按二进制位来考虑
假设现在有一个数,二进制末三位为(011),那么给它+1之后变成(100),相当于它的二进制第1位,第2位,第3位都取反了
然后很容易发现一个规律,如果一个数字 (a equiv 2^k-1 mod 2^k),那么a+1之后第(k+1)位就会取反
那么我们可以把当前子树内所有点的权值装进(log n)个桶里,(buc[i][j])表示当前权值(即点权+当前深度)( equiv i mod 2^j)的点有多少个
为了方便,我们设(v_i=c_i+d(i,1)),即第(v_i)是自己的权值+自己到根的距离,把这个权值放进上面的桶里面
那么假设我们在计算(val_x),那么子树中所有点现在的点权应该是(v_i-dis(x,1)),我们要对于每个(j),找出有多少个点满足(v_i-dis(x,1)-1 equiv 2^j - 1mod 2^j)
化简一下就是(v_i equiv dis(x,1) mod 2^j)
所以我们对于每个(j)找到(buc[dis(x,1)mod 2^j][j]),那么第(j+1)位实际上要取反这么多次 这个就很好处理了
整理一下思路:首先把(val_x)赋为所有儿子(val)的异或和,然后对于每个(j)找到子树中的所有点(不包括自己)有多少个点满足(v_iequiv dis(x,1)mod 2^j),然后如果是奇数个就把第(j+1)位取反,否则不变
最后把(val_x)异或上(x)的点权
最后关于怎么维护这个桶,用树上启发式合并啊
时间复杂度(O(nlog^2 n)),一个是启发式合并的log,另一个是枚举二进制位的log
不过常数很小,洛谷上最慢点跑了1.3s,相信少爷机上更快
代码体感是挺好写的
考场上桶开小了导致100pts->30pts 心态爆炸
#include <bits/stdc++.h>
#define N 530005
using namespace std;
inline void read(int &num) {
int x = 0, f = 1; char ch = getchar();
for (; ch > '9' || ch < '0'; ch = getchar()) if (ch == '-') f = -1;
for (; ch <= '9' && ch >= '0'; ch = getchar()) x = (x << 3) + (x << 1) + (ch ^ '0');
num = x * f;
}
int n, a[N];
int head[N], pre[N<<1], to[N<<1], sz;
int d[N], siz[N], son[N], fa[N];
int buc[N<<1][22], ans[N]; //开两倍大!!!
int two[22];
long long Ans;
inline void addedge(int u, int v) {
pre[++sz] = head[u]; head[u] = sz; to[sz] = v;
pre[++sz] = head[v]; head[v] = sz; to[sz] = u;
}
void dfs1(int x) {
siz[x] = 1;
for (int i = head[x]; i; i = pre[i]) {
int y = to[i];
if (y == fa[x]) continue;
fa[y] = x; d[y] = d[x] + 1;
dfs1(y);
siz[x] += siz[y];
if (!son[x] || siz[son[x]] < siz[y]) son[x] = y;
}
}
void calc(int x, int tp) {
for (int i = 0; i <= 21; i++) {
buc[a[x]&two[i]][i] += tp;
}
for (int i = head[x]; i; i = pre[i]) {
if (to[i] != fa[x]) calc(to[i], tp);
}
}
void getans(int x) {
for (int i = head[x]; i; i = pre[i]) {
if (to[i] != fa[x]) ans[x] ^= ans[to[i]];
}
for (int i = 0; i <= 21; i++) {
ans[x] ^= ((buc[d[x]&two[i]][i] & 1) << i);
}
ans[x] ^= (a[x] - d[x]);
}
void dfs(int x, int cl) {
for (int i = head[x]; i; i = pre[i]) {
if (to[i] == fa[x] || to[i] == son[x]) continue;
dfs(to[i], 1);
}
if (son[x]) dfs(son[x], 0);
for (int i = head[x]; i; i = pre[i]) {
if (to[i] == fa[x] || to[i] == son[x]) continue;
calc(to[i], 1);
}
getans(x);
if (cl) {
for (int i = head[x]; i; i = pre[i]) {
if (to[i] == fa[x]) continue;
calc(to[i], -1);
}
} else {
for (int i = 0; i <= 21; i++) {
buc[a[x]&two[i]][i]++;
}
}
}
int main() {
read(n);
for (int i = 1; i <= n; i++) read(a[i]);
for (int i = 2, x; i <= n; i++) {
read(x); addedge(i, x);
}
for (int i = 0; i <= 21; i++) two[i] = (1 << i) - 1;
dfs1(1);
for (int i = 1; i <= n; i++) a[i] += d[i];
dfs(1, 0);
for (int i = 1; i <= n; i++) {
Ans += ans[i];
}
printf("%lld
", Ans);
return 0;
}