[2020-CCPC Changchun Onsite]-F. Strange Memory(dsu on tree)
题面:
题意:
给定一个含有(mathit n)个节点的数,求下式的值。
[sumlimits_{i=1}^nsumlimits_{j=i+1}^n [a_i oplus a_j = a_{operatorname{lca}(i, j)}] (i oplus j).
]
思路:
观察数据:(1 leq a_i leq 10^6),那么从根节点到叶子节点构成的路径以及它们的子路径不会对答案产生贡献。
因为在同一个链上时,设(dep_u>dep_v),那么(lca(u,v)=v),因为(a_u>0),所以(a_uoplus a_v ot=a_v)。
则可以推出答案值来源于一个节点作为根的子树中根节点的不同儿子子树之间的贡献。
那么问题可以转化为求出每一个子树的lca为子树根的子节点们对答案的贡献总和。
子树问题考虑到使用dsu on tree算法,进行轻重链剖分。
因为答案是点对的异或值总和,考虑到将其二进制拆分,按位选贡献,开桶维护个数。
时间复杂度:(O(n*log^2(n)))
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e5 + 10;
int cnt[maxn * 15][18][2];
vector<int>v[maxn];
int n;
int a[maxn];
ll sum = 0ll;
int son[maxn], siz[maxn];
int Son;
ll base[30];
void dfs(int x, int fa)
{
siz[x] = 1;
for (int i = 0; i < v[x].size(); ++i) {
int to = v[x][i];
if (to == fa) {
continue;
}
dfs(to, x);
siz[x] += siz[to];
if (siz[to] > siz[son[x]]) {
son[x] = to;
}
}
}
void add(int x, int fa, int val, int num)
{
if (val == 0)
for (int i = 0; i < 18; ++i) {
sum += base[i] * cnt[a[x] ^ num][i][!((x >> i) & 1)];
}
for (int i = 0; i < v[x].size(); ++i) {
int to = v[x][i];
if (to == fa || to == Son) {
continue;
}
add(to, x, val, num);
}
if (val != 0)
for (int i = 0; i < 18; ++i) {
cnt[a[x]][i][(x >> i) & 1] += val;
}
}
void dfs2(int x, int fa, int opt)
{
for (int i = 0; i < v[x].size(); ++i) {
int to = v[x][i];
if (to == fa) {
continue;
}
if (to != son[x]) {
dfs2(to, x, 0);
}
}
if (son[x]) {
dfs2(son[x], x, 1);
Son = son[x];
}
for (int i = 0; i < v[x].size(); ++i) {
int to = v[x][i];
if (to == fa || to == Son) {
continue;
}
add(to, x, 0, a[x]);
add(to, x, 1, a[x]);
}
for (int i = 0; i < 18; ++i) {
cnt[a[x]][i][(x >> i) & 1]++;
}
Son = 0;
if (!opt) {
for (int i = 0; i < 18; ++i) {
cnt[a[x]][i][(x >> i) & 1]--;
}
for (int i = 0; i < v[x].size(); ++i) {
int to = v[x][i];
if (to == fa || to == Son) {
continue;
}
add(to, x, -1, a[x]);
}
}
}
int main()
{
base[0] = 1ll;
for (int i = 1; i <= 20; ++i) {
base[i] = base[i - 1] * 2ll;
}
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%d", &a[i]);
}
for (int i = 1; i <= n - 1; ++i) {
int x, y;
scanf("%d %d", &x, &y);
v[x].push_back(y);
v[y].push_back(x);
}
dfs(1, 0);
dfs2(1, 0, 0);
printf("%lld
", sum );
return 0;
}