概述:
参考神犇yyb的博客
问题:如何做到(O(nlogn)-O(1))复杂度求解(k)次祖先?
常规倍增是(O(nlogn)-O(logn))的,重链剖分是(O(nlogn)-O(logn))的,欧拉序st表能在(O(nlogn)-O(1))复杂度内求两点LCA,但并不能查出k次祖先是谁
长链剖分
方法和树剖十分类似,代码也几乎相同,但我们每次不是挑子树最大的儿子作为重链,而是挑最大深度最大的儿子作为重链
长链剖分有如下性质:
1.所有重链长度之和是(O(n))级别
显然每个点最多在一条重链内
2.如果x和k次祖先y不在同一重链内,那么y点长链的链长(所在重链末尾节点到它的距离),一定大于等于k
如果小于k,那么x-y这条链更长,与长链剖分的前提——挑最大深度的儿子相悖
继续考虑怎么利用性质
这个做法需要分类讨论
在常规的重链剖分中,如果k级祖先和它在同一重链内(用深度判断(dep[top_{x}]-dep_[x]ge k)),我们可以在(O(1))时间找到k级祖先(维护重链剖分序,同一重链上的点一定连续)
把这个想法拓展到长链剖分,我们去掉了x与k级祖先在同一重链上的情况
现在x和k级祖先不在同一重链上
有一个想法:我们找到x点的(r)级祖先,使得(r>k/2),我们能够(O(1))时间内求出x点的(r)级祖先z。然后考虑z的(k-r)级祖先,用上面的方法提到的check一下。如果不彳亍,说明z和y不在同一链内,且z的链头T深度比y大
由长链剖分性质1可知,重链长度之和一定是(O(n))级别,我们对于每个链头,暴力处理出跳([1,链长])长度时的祖先!!容易发现这个预处理复杂度是(O(n))的
而我们找到的(r>k/2),利用上面预处理出的数组就可以(O(1))找到y了
还剩一个问题,这个(r)级祖先怎么选,才能(O(1))找到呢?倍增就行了!我们令r一定是2的幂次,对于询问k,我们取k的最高位(highbit(k))即可
总结一下:每个点倍增预处理(O(nlogn)),长链剖分(O(n)),链头的处理(O(n)),每次询问(O(1))
几道题
给你一棵树,对于每个点x,子树内所有点到它都有一个距离,询问出现次数最多的距离,输出这个距离(点数相同时输出最小的距离)
首先有个非常裸的重链剖分dsu做法,每次先处理轻子树,然后把轻子树桶信息清空,再进入重子树,保留桶信息,再遍历一遍轻子树把信息丢进桶,然后处理答案。我们只有添加or清空操作,用桶维护,记录最大值并更新。时间(O(nlogn))
#include<bits/stdc++.h>
using namespace std;
#define r(x) read(x)
#define ll long long
#define it set<string>::iterator
const int N1=1e6+7;
template <typename _T> void read(_T &ret)
{
ret=0; _T fh=1; char c=getchar();
while(c<'0'||c>'9'){ if(c=='-') fh=-1; c=getchar(); }
while(c>='0'&&c<='9'){ ret=ret*10+c-'0'; c=getchar(); }
ret=ret*fh;
}
struct EDGE{
int to[N1*2],nxt[N1*2],head[N1],cte;
void ae(int u,int v)
{ cte++; to[cte]=v, nxt[cte]=head[u]; head[u]=cte; }
}e;
int n,ma;
int sz[N1],son[N1],dep[N1],bar[N1],ans[N1];
void dfs0(int u,int dad)
{
sz[u]=1;
for(int j=e.head[u];j;j=e.nxt[j]){
int v=e.to[j]; if(v==dad) continue;
dep[v]=dep[u]+1;
dfs0(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]]) son[u]=v;
}
}
void push(int x,int w)
{
bar[dep[x]]+=w;
if(bar[dep[x]]>bar[ma]) ma=dep[x];
else if(bar[dep[x]]==bar[ma]&&dep[x]<ma) ma=dep[x];
}
void inbar(int u,int dad,int w)
{
push(u,w);
for(int j=e.head[u];j;j=e.nxt[j]){
int v=e.to[j]; if(v==dad) continue;
inbar(v,u,w);
}
}
void dfs1(int u,int dad)
{
for(int j=e.head[u];j;j=e.nxt[j]){
int v=e.to[j];
if(v==dad||v==son[u]) continue;
dfs1(v,u);
inbar(v,u,-1);
}
ma=0;
if(son[u]){
dfs1(son[u],u);
}
for(int j=e.head[u];j;j=e.nxt[j]){
int v=e.to[j];
if(v==dad||v==son[u]) continue;
inbar(v,u,1);
}
push(u,1);
ans[u]=ma-dep[u];
}
int main(){
// freopen("1.in","r",stdin);
// freopen(".out","w",stdout);
read(n);
int x,y;
for(int i=1;i<n;i++) read(x), read(y), e.ae(x,y), e.ae(y,x);
dep[1]=1; dfs0(1,-1);
dfs1(1,-1);
// for(int i=1;i<=n;i++) printf("%d
",son[i]);
for(int i=1;i<=n;i++) printf("%d
",ans[i]);
return 0;
}
回到长链剖分,我们考虑在长链剖分序上DP,同一条链上的点一定连续。
考虑在长链剖分序上DP,我们记(f(x,j))表示距离x点距离为j的点个数。每个点在继承重儿子信息时,指针移位即可(它们一定连续)。然后暴力合并轻儿子记录的信息。因为每条长链只会在链头被暴力合并一次,总时间复杂度(O(n))
#include<bits/stdc++.h>
using namespace std;
#define r(x) read(x)
#define ll long long
#define it set<string>::iterator
const int N1=1e6+7;
template <typename _T> void read(_T &ret)
{
ret=0; _T fh=1; char c=getchar();
while(c<'0'||c>'9'){ if(c=='-') fh=-1; c=getchar(); }
while(c>='0'&&c<='9'){ ret=ret*10+c-'0'; c=getchar(); }
ret=ret*fh;
}
struct EDGE{
int to[N1*2],nxt[N1*2],head[N1],cte;
void ae(int u,int v)
{ cte++; to[cte]=v, nxt[cte]=head[u]; head[u]=cte; }
}e;
int n;
int len[N1],son[N1],*f[N1],ans[N1],*tot;
void dfs0(int u,int dad)
{
for(int j=e.head[u];j;j=e.nxt[j]){
int v=e.to[j]; if(v==dad) continue;
dfs0(v,u);
if(len[v]>len[son[u]]) son[u]=v;
}
len[u]=len[son[u]]+1;
}
int dfs1(int u,int dad)
{
int ma=0;
if(son[u]) f[son[u]]=tot++, ma=dfs1(son[u],u)+1;
for(int j=e.head[u];j;j=e.nxt[j]){
int v=e.to[j]; if(v==dad||v==son[u]) continue;
f[v]=tot++;
v=dfs1(v,u);
}
for(int j=e.head[u];j;j=e.nxt[j]){
int v=e.to[j]; if(v==dad||v==son[u]) continue;
for(int i=0;i<len[v];i++){
f[u][i+1]+=f[v][i];
if(f[u][i+1]>f[u][ma]) ma=i+1;
else if(f[u][i+1]==f[u][ma]&&i+1<ma) ma=i+1;
}
}
f[u][0]=1;
if(f[u][0]>f[u][ma]) ma=0;
else if(f[u][0]==f[u][ma]&&0<ma) ma=0;
// ans[u]=ma;
ans[u]=ma;
return ma;
}
int main(){
// freopen("1.in","r",stdin);
// freopen(".out","w",stdout);
read(n);
int x,y;
for(int i=1;i<n;i++) read(x), read(y), e.ae(x,y), e.ae(y,x);
dfs0(1,-1);
tot=(int*)malloc(N1*sizeof(int));
f[1]=tot++;
dfs1(1,-1);
// for(int i=1;i<=n;i++) printf("%d
",son[i]);
for(int i=1;i<=n;i++) printf("%d
",ans[i]);
return 0;
}