线段树合并
应用范围:将子树的信息合并给父亲节点,并且权值线段树的下标值域和节点数相近。
CF600E Lomsat gelral
题意:一棵树有n个结点,每个结点都是一种颜色,每个颜色有一个编号,求树中每个子树的最多的颜色编号的和。
(1 <= n<=1e5)
解法:线段树合并,这个东西的时空复杂度都很玄学,姑且认为时间为(O(nlogn)),空间为(常数( imes log_n imes n)),常数一般为(4-8)。
#include <cstdio>
#include <algorithm>
using namespace std;
#define maxn 100100
#define ll long long
int n;
int fir[maxn], nxt[maxn * 2], vv[maxn * 2];
int tot = 0;
void add(int u, int v)
{
nxt[++tot] = fir[u];
fir[u] = tot;
vv[tot] = v;
}
int cnt = 0;
int root[maxn], col[maxn];
int lz[maxn * 17 * 2], rz[maxn * 17 * 2], sum[maxn * 17 * 2];
ll ans[maxn * 17 * 2];
void pushup(int a)
{
if(sum[lz[a]] < sum[rz[a]])
{
sum[a] = sum[rz[a]];
ans[a] = ans[rz[a]];
}
if(sum[lz[a]] > sum[rz[a]])
{
sum[a] = sum[lz[a]];
ans[a] = ans[lz[a]];
}
if(sum[lz[a]] == sum[rz[a]])
{
sum[a] = sum[lz[a]];
ans[a] = ans[lz[a]] + ans[rz[a]];
}
return;
}
int merge(int a, int b, int l, int r)
{
if(a == 0) return b;
if(b == 0) return a;
if(l == r)
{
sum[a] += sum[b];
ans[a] = l;
return a;
}
int mid = (l + r) >> 1;
lz[a] = merge(lz[a], lz[b], l, mid);
rz[a] = merge(rz[a], rz[b], mid + 1, r);
pushup(a);
return a;
}
void update(int &a, int l, int r, int v)
{
if(!a) a = ++cnt;
int mid = (l + r) >> 1;
if(l == r)
{
sum[a] += 1;
ans[a] = l;
return;
}
if(mid >= v) update(lz[a], l, mid, v);
if(mid < v) update(rz[a], mid + 1, r, v);
pushup(a);
return;
}
void dfs(int u, int fa)
{
for(int i = fir[u]; i; i = nxt[i])
{
int v = vv[i];
if(v == fa) continue;
dfs(v, u);
merge(root[u], root[v], 1, 100000);
}
update(root[u], 1, 100000, col[u]);
ans[u] = ans[root[u]];
}
int main()
{
scanf("%d", &n); cnt = n;
for(int i = 1; i <= n; i++)
{
scanf("%d", &col[i]); root[i] = i;
}
for(int i = 1; i < n; i++)
{
int u, v;
scanf("%d%d", &u, &v);
add(u, v); add(v, u);
}
dfs(1, 0);
for(int i = 1; i <= n; i++) printf("%lld ", ans[i]);
return 0;
}
雨天的尾巴
注意(ans)要在(dfs)时计算,不然当前节点的(root)可能被父亲节点继承,然后就炸了。
#include <cstdio>
#include <algorithm>
using namespace std;
#define maxn 100100
#define ll long long
int n;
int fir[maxn], nxt[maxn * 2], vv[maxn * 2];
int tot = 0;
void add(int u, int v)
{
nxt[++tot] = fir[u];
fir[u] = tot;
vv[tot] = v;
}
int cnt = 0;
int root[maxn], col[maxn];
int lz[maxn * 17 * 2], rz[maxn * 17 * 2], sum[maxn * 17 * 2];
ll ans[maxn * 17 * 2];
void pushup(int a)
{
if(sum[lz[a]] < sum[rz[a]])
{
sum[a] = sum[rz[a]];
ans[a] = ans[rz[a]];
}
if(sum[lz[a]] > sum[rz[a]])
{
sum[a] = sum[lz[a]];
ans[a] = ans[lz[a]];
}
if(sum[lz[a]] == sum[rz[a]])
{
sum[a] = sum[lz[a]];
ans[a] = ans[lz[a]] + ans[rz[a]];
}
return;
}
int merge(int a, int b, int l, int r)
{
if(a == 0) return b;
if(b == 0) return a;
if(l == r)
{
sum[a] += sum[b];
ans[a] = l;
return a;
}
int mid = (l + r) >> 1;
lz[a] = merge(lz[a], lz[b], l, mid);
rz[a] = merge(rz[a], rz[b], mid + 1, r);
pushup(a);
return a;
}
void update(int &a, int l, int r, int v)
{
if(!a) a = ++cnt;
int mid = (l + r) >> 1;
if(l == r)
{
sum[a] += 1;
ans[a] = l;
return;
}
if(mid >= v) update(lz[a], l, mid, v);
if(mid < v) update(rz[a], mid + 1, r, v);
pushup(a);
return;
}
void dfs(int u, int fa)
{
for(int i = fir[u]; i; i = nxt[i])
{
int v = vv[i];
if(v == fa) continue;
dfs(v, u);
merge(root[u], root[v], 1, 100000);
}
update(root[u], 1, 100000, col[u]);
ans[u] = ans[root[u]];
}
int main()
{
scanf("%d", &n); cnt = n;
for(int i = 1; i <= n; i++)
{
scanf("%d", &col[i]); root[i] = i;
}
for(int i = 1; i < n; i++)
{
int u, v;
scanf("%d%d", &u, &v);
add(u, v); add(v, u);
}
dfs(1, 0);
for(int i = 1; i <= n; i++) printf("%lld ", ans[i]);
return 0;
}