首先可以了解一下启发式合并,这个可以看我之前的博客。虽然两者关系不大
该算法英文名为(dsu on tree),最先以成型的算法出现是在(Codeforces)的这篇博客上。
树上启发式合并可以在(O(nlogn))的时间复杂度内离线解决很多无修改子树询问。
先由一个例子引入:树上每个点有一种颜色,询问子树颜色个数。
在线算法我们可以用(dfs)序(+)主席树。
离线算法呢?
我们用(vis_i)表示子树内(i)是否出现,(cnt_i)表示颜色个数。这个东西是支持(O(1))修改的。
先考虑暴力,对每一子树(dfs)一遍统计答案。时间复杂度(O(n^2))。
(dfs)序(+)序列莫队,但复杂度是(O(nsqrt{n}))的。
树上启发式合并怎么做?
我们发现答案可以从儿子节点获取,但不能直接获取,这样空间复杂度是(O(n^2))的。
预处理重儿子(即子树节点最多的儿子)
先递归处理非重儿子的答案,并且不获取非重儿子的答案,即清空(vis)数组。
然后处理重儿子的答案,并且获取重儿子的答案。
最后再次递归计算非重儿子的答案,并且暴力合并得到该点的答案。
该算法的复杂度是什么?前面说了是(O(nlogn))的。
我们需要证明一个引理。
根节点出发的任意路径上轻边(不连向重儿子的边)条数(leq logn)。
证明考虑每次到非重儿子子树大小减少一半以上,最多减(logn)次。
统计一个点的答案是,重儿子的子树内点的遍历次数是不需计入该点的(那些点自己本身也要遍历一次)。
考虑每个点被遍历的次数,即为到根的轻边数,复杂度为(O(logn))。
总复杂度为(O(nlogn))。
一道例题:CF600E
和上面那题做法差不多,就当模板题做啦。
#include<cstdio>
#include<vector>
#define rep(i, a, b) for (register int i=(a); i<=(b); ++i)
#define per(i, a, b) for (register int i=(a); i>=(b); --i)
using namespace std;
const int N=100005;
int size[N], son[N], cnt[N], col[N], skip[N], Max;
long long ans[N], sum;
vector<int> G[N];
inline int read()
{
int x=0,f=1;char ch=getchar();
for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
for (;ch>='0'&&ch<='9';ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
return x*f;
}
void dfs(int u, int fa)
{
size[u]=1;
for (int v: G[u]) if (v^fa)
{
dfs(v, u); size[u]+=size[v];
if (!son[u] || size[v]>size[son[u]]) son[u]=v;
}
}
void modify(int u, int fa, int k)
{
cnt[col[u]]+=k;
if (~k && cnt[col[u]]>=Max)
{
if (cnt[col[u]]>Max) sum=0, Max=cnt[col[u]];
sum+=col[u];
}
for (int v: G[u]) if (v^fa && !skip[v]) modify(v, u, k);
}
void solve(int u, int fa, bool flag)
{
for (int v: G[u]) if (v^fa && v^son[u]) solve(v, u, 0);
if (son[u]) solve(son[u], u, 1), skip[son[u]]=1;
modify(u, fa, 1); ans[u]=sum;
if (son[u]) skip[son[u]]=0;
if (!flag) modify(u, fa, -1), Max=sum=0;
}
int main()
{
int n=read();
rep(i, 1, n) col[i]=read();
rep(i, 1, n-1)
{
int u=read(), v=read();
G[u].push_back(v); G[v].push_back(u);
}
dfs(1, 0); solve(1, 0, 0);
rep(i, 1, n) printf("%lld ", ans[i]);
return 0;
}
记(cnt_i)为(i)点子树内每个字母奇偶性的二进制状态。
只有(cnt_i=0/2^k)时合法,这个用(lowbit)检验即可。
然后就是树上启发式合并模板啦。
#include<cstdio>
#include<vector>
#define rep(i, a, b) for (register int i=(a); i<=(b); ++i)
#define per(i, a, b) for (register int i=(a); i>=(b); --i)
using namespace std;
const int N=500005;
int size[N], son[N], cnt[N], skip[N], dep[N], ans[N];
vector<pair<int, int> >q[N];
vector<int> G[N];
char s[N];
inline int read()
{
int x=0,f=1;char ch=getchar();
for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
for (;ch>='0'&&ch<='9';ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
return x*f;
}
void dfs(int u, int fa)
{
dep[u]=dep[fa]+1; size[u]=1;
for (int v: G[u])
{
dfs(v, u); size[u]+=size[v];
if (!son[u] || size[son[u]]<size[v]) son[u]=v;
}
}
bool check(int x){return !(x&(x-1));}
void modify(int u)
{
cnt[dep[u]]^=1<<(s[u]-'a');
for (int v: G[u]) if (!skip[v]) modify(v);
}
void solve(int u, int flag)
{
for (int v: G[u]) if (v^son[u]) solve(v, 0);
if (son[u]) solve(son[u], 1), skip[son[u]]=1;
modify(u); skip[son[u]]=0;
for (auto i: q[u]) ans[i.second]=check(cnt[i.first]);
if (!flag) modify(u);
}
int main()
{
int n=read(), m=read();
rep(i, 2, n) G[read()].push_back(i);
scanf("%s", s+1);
rep(i, 1, m) {int v=read(), h=read(); q[v].push_back(make_pair(h, i));}
dfs(1, 0); solve(1, 0);
rep(i, 1, m) puts(ans[i]?"Yes":"No");
return 0;
}
算法发明人出的题。据说坑了很多人
还是记一个上题的(cnt_i)一样的东西,不过记录的是到根的路径。
然后开一个桶(f_i)记录(cnt)为(i)的最大深度,然后按照点分治的思路统计答案。
然后统计答案的时候就需要用到(dsu on tree)了。
#include<cstdio>
#include<vector>
#define rep(i, a, b) for (register int i=(a); i<=(b); ++i)
#define per(i, a, b) for (register int i=(a); i>=(b); --i)
using namespace std;
inline void chkmax(int &x, int y){x<y?(x=y):0;}
const int N=500005;
vector<pair<int, int> > G[N];
int size[N], in[N], out[N], id[N], dep[N], son[N], tot;
int Xor[N], f[1<<22], ans[N];
inline int read()
{
int x=0,f=1;char ch=getchar();
for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
for (;ch>='0'&&ch<='9';ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
return x*f;
}
#define v i.first
#define w i.second
void dfs(int u)
{
size[u]=1; id[in[u]=++tot]=u;
for (auto i: G[u])
{
dep[v]=dep[u]+1; Xor[v]=Xor[u]^w;
dfs(v); size[u]+=size[v];
if (size[v]>size[son[u]]) son[u]=v;
}
out[u]=tot;
}
void solve(int u, int flag)
{
for (auto i: G[u]) if (v^son[u]) solve(v, 0), chkmax(ans[u], ans[v]);
if (son[u]) solve(son[u], 1), chkmax(ans[u], ans[son[u]]);
if (f[Xor[u]]) chkmax(ans[u], f[Xor[u]]-dep[u]);
rep(i, 0, 21) if (f[Xor[u]^(1<<i)])
chkmax(ans[u], f[Xor[u]^(1<<i)]-dep[u]);
chkmax(f[Xor[u]], dep[u]);
for (auto i: G[u]) if (v^son[u])
{
rep(j, in[v], out[v])
{
if (f[Xor[id[j]]])
chkmax(ans[u], f[Xor[id[j]]]+dep[id[j]]-(dep[u]<<1));
rep(k, 0, 21) if (f[Xor[id[j]]^(1<<k)])
chkmax(ans[u], f[Xor[id[j]]^(1<<k)]+dep[id[j]]-(dep[u]<<1));
}
rep(j, in[v], out[v]) chkmax(f[Xor[id[j]]], dep[id[j]]);
}
if (!flag) rep(i, in[u], out[u]) f[Xor[id[i]]]=0;
}
#undef v
#undef w
int main()
{
int n=read();
rep(i, 2, n)
{
int p=read(); char c=getchar();
G[p].push_back(make_pair(i, 1<<(c-'a')));
}
dep[1]=1; dfs(1); solve(1, 0);
rep(i, 1, (1<<22)-1) if (f[i]) printf("%d %d
", i, f[i]);
rep(i, 1, n) printf("%d ", ans[i]);
return 0;
}