用途
竞赛中一些树上问题会涉及到从树中取出一些点进行询问。若数据保证总点数(M)小于等于一个数例如(2*10^5)之类的时候,我们可以通过构造一颗虚树来减少点的数量,从而保证总复杂度符合要求。
构造
虚树的构造利用了一个栈。现有两种复杂度相同但常数不同的构造方法。实际运行中时间差别肉眼可见。
在一棵树上建虚树仅需将各点及其两两的lca连接起来即可。
对于一棵树,显然叶子节点的个数大于非叶子节点的个数(两个点的树除外)。因此,如果我们对(k)个点建一颗虚树,其中两两的lca总共数量小于(k)。这是构造虚树的复杂度保证,虚树中最多(2k-1)个点,所以虚树中点的空间复杂度为(O(n))。
方法一
每次将所有关键点(即要询问中的(k)个点)加入一个数组中,按(dfn)序进行排序,再对相邻两个求(lca),加在数组后面,再进行一次排序并去重,用栈维护一条链。
int sta[maxn<<2],top;
void build() {
sort(a+1,a+all+1,cmp);
int tmp=all;
for (int i=1;i<all;++i) a[++tmp]=lca(a[i],a[i+1]);
all=tmp;
sort(a+1,a+all+1,cmp);
top=0;
for (int i=1;i<=all;++i) {
if (!top) sta[++top]=a[i]; else {
while (top && lca(a[i],sta[top])!=sta[top]) --top;
if (top) ans+=dist(sta[top],a[i]);
sta[++top]=a[i];
}
}
}
方法二
注意到方法一中进行了两次排序。其实还有一种方法只需对数组进行一次排序,减小了算法的常数。
同样按(dfn)序排序数组,我们用栈维护一条链。每次求当前栈顶(top)和当前节点(now)的(lca),记为(l)。若(now)在当前栈顶的下方,即(lca=) 栈顶元素,那么直接在栈上加入当前点。若(dep[l]<dep[top]),那么我们在栈中不断弹出点,直到当前栈顶元素深度小于等于(l)的深度,再向栈中插入(l)和(now)。原因是,我们按照(dfn)序排好后,若是第一种情况,说明子树还在继续处理,否则子树已经处理完,可以弹出,
处理另一边。最后把栈清空即可。
每次弹出元素时都要把当前栈顶和第二个连边。
int sta[maxn<<2],top;
void build() {
sort(a+1,a+all+1,cmp);
sta[top=1]=1;
for (int i=1;i<=all;++i) {
int now=a[i],l=lca(now,sta[top]);
if (l==sta[top]) {
sta[++top]=now;
continue;
}
while (dep[l]<dep[sta[top]]) {
if (dep[l]>=dep[sta[top-1]]) {
int last=sta[top--];
ans+=dist(last,l);
if (l!=sta[top]) sta[++top]=l;
sta[++top]=now;
break;
}
ans+=dist(sta[top],sta[top-1]),--top;
}
}
while (top) ans+=dist(sta[top],sta[top-1]),--top;
}
例题
bzoj2286 消耗战
题目大意
给到一棵树,每次询问其中一些点,求切断它们到根节点路径的最小代价。切断一条边的代价为其边权。总询问节点数小于(5*10^5).
做法
看到总询问节点数小于(5*10^5),我们想到构建虚树求解。虚树的边权为原树上那一条链的最小边权,树形dp,取切断当前节点所有出边和切断当前节点的一条入边的最小值。还是比较简单的题目。
代码
#include<cstdio>
#include<cctype>
#include<algorithm>
#include<cstring>
#define F(x) for (giant i=h[x],v=e[i].v,w=e[i].w;i;i=e[i].next,v=e[i].v,w=e[i].w)
using namespace std;
typedef long long giant;
giant read() {
giant x=0,f=1;
char c=getchar();
for (;!isdigit(c);c=getchar()) if (c=='-') f=-1;
for (;isdigit(c);c=getchar()) x=x*10+c-'0';
return x*f;
}
const giant maxn=2.5e5+10;
const giant maxj=21;
const giant inf=1e9+10;
giant d[maxn],dep[maxn],f[maxn][maxj],g[maxn][maxj],c[maxn],first[maxn],dfn=0;
giant lca(giant x,giant y) {
if (dep[x]<dep[y]) swap(x,y);
for (giant j=maxj-1;j>=0;--j) if (dep[f[x][j]]>=dep[y]) x=f[x][j];
if (x==y) return x;
for (giant j=maxj-1;j>=0;--j) if (f[x][j]!=f[y][j]) x=f[x][j],y=f[y][j];
return f[x][0];
}
giant mw(giant x,giant y) {
giant ret=inf;
if (dep[x]<dep[y]) swap(x,y);
for (giant j=maxj-1;j>=0;--j) if (dep[f[x][j]]>=dep[y]) ret=min(ret,g[x][j]),x=f[x][j];
return ret;
}
bool spe[maxn];
giant r[maxn],up[maxn],val[maxn];
struct edge {
giant v,w,next;
};
struct graph {
edge e[maxn<<2];
giant h[maxn],tot;
graph ():tot(0) {}
void add(giant u,giant v,giant w=0) {
e[++tot]=(edge){v,w,h[u]};
h[u]=tot;
}
void addt(giant u,giant v,giant w=0) {
add(u,v,w),add(v,u,w);
}
void dfs(giant x,giant fa) {
f[x][0]=fa,dep[x]=dep[fa]+1,first[x]=++dfn;
F(x) if (v!=fa) d[v]=d[x]+w,val[v]=min(val[x],w),g[v][0]=w,dfs(v,x);
}
giant dp(giant x,giant fa) {
giant tmp=0;
if (spe[x]) return val[x];
F(x) if (v!=fa) {
tmp+=dp(v,x);
}
return x==1?tmp:min(val[x],tmp);
}
} a,b;
giant sta[maxn],top;
bool cmp(giant x,giant y) {
return first[x]<first[y];
}
void build(giant all) {
sort(c+1,c+all+1,cmp);
for (giant i=1;i<all;++i) b.h[lca(c[i],c[i+1])]=0;
sta[top=1]=1;
for (giant i=1;i<=all;++i) {
giant now=c[i],l=lca(now,sta[top]);
if (l==sta[top]) {
sta[++top]=now;
continue;
}
while (dep[l]<dep[sta[top]]) {
if (dep[l]>=dep[sta[top-1]]) {
giant last=sta[top--];
b.addt(last,l);
if (l!=sta[top]) sta[++top]=l;
sta[++top]=now;
break;
}
b.addt(sta[top],sta[top-1]);
--top;
}
}
while (top>1) {
b.addt(sta[top],sta[top-1]);
--top;
}
}
int main() {
#ifndef ONLINE_JUDGE
freopen("test.in","r",stdin);
#endif
giant n=read();
for (giant i=1;i<=n;++i) val[i]=inf;
for (giant i=1;i<n;++i) {
giant u=read(),v=read(),w=read();
a.addt(u,v,w);
}
a.dfs(1,0);
val[1]=0;
for (giant j=1;j<maxj;++j) for (giant i=1;i<=n;++i) f[i][j]=f[f[i][j-1]][j-1],g[i][j]=min(g[i][j-1],g[f[i][j-1]][j-1]);
giant m=read();
while (m--) {
giant all=read();
b.tot=b.h[1]=0;
for (giant i=1;i<=all;++i) b.h[c[i]=read()]=0,spe[c[i]]=true;;
build(all);
up[1]=inf;
r[1]=0;
printf("%lld
",b.dp(1,0));
for (giant i=1;i<=all;++i) spe[c[i]]=false;
}
}
bzoj3572 世界树
题目大意
给出一棵树,每次给出一些关键点。每个点被离它最近的关键点管辖。问每个关键点管辖多少个节点,其中关键点管辖自己。总询问点数小于等于(3*10^5).
做法
建立虚树。对于虚树上的节点,要么是关键点,要么是关键点的(lca),因此先求出虚树上点的归属,再在虚边上分割原树上的点。我分了四种情况讨论。所以我发现对于我多讨论几种情况比直接写一条简便的式子不容易错。代码有点长。
代码
#include<cstdio>
#include<cctype>
#include<algorithm>
#define F(x) for (int i=h[x],v=e[i].v,w=e[i].w;i;i=e[i].next,v=e[i].v,w=e[i].w)
using namespace std;
int read() {
int x=0,f=1;
char c=getchar();
for (;!isdigit(c);c=getchar()) if (c=='-') f=-1;
for (;isdigit(c);c=getchar()) x=x*10+c-'0';
return x*f;
}
const int maxn=3e5+10;
const int maxj=21;
const int inf=1e9+10;
int f[maxn][maxj],c[maxn],cd[maxn],dep[maxn],size[maxn],first[maxn],dfn=0,theroot;
int all,val[maxn],ans[maxn],near[maxn],d[maxn],n;
bool spe[maxn];
struct edge {
int v,w,next;
};
int jump(int x,int y) {
for (int j=0;j<maxj;++j) if (y&(1<<j)) x=f[x][j];
return x;
}
int todep(int x,int y) {
return jump(x,dep[x]-y);
}
int lca(int x,int y) {
if (dep[x]<dep[y]) swap(x,y);
for (int j=maxj-1;j>=0;--j) if (dep[f[x][j]]>=dep[y]) x=f[x][j];
if (x==y) return x;
for (int j=maxj-1;j>=0;--j) if (f[x][j]!=f[y][j]) x=f[x][j],y=f[y][j];
return f[x][0];
}
int dist(int x,int y) {
return dep[x]+dep[y]-2*dep[lca(x,y)]+1;
}
struct graph {
edge e[maxn<<1];
int h[maxn],tot;
graph ():tot(0) {}
inline void add(int u,int v,int w=0) {
e[++tot]=(edge){v,w,h[u]};
h[u]=tot;
}
void dfs(int x,int fa) {
f[x][0]=fa;
size[x]=1;
dep[x]=dep[fa]+1;
first[x]=++dfn;
F(x) if (v!=fa) {
dfs(v,x);
size[x]+=size[v];
}
}
void init(int x,int fa) {
if (spe[x]) {
d[x]=0;
near[x]=x;
} else {
d[x]=inf;
near[x]=0;
}
val[x]=size[x];
F(x) if (v!=fa) init(v,x);
}
void dfirst(int x,int fa,int up) {
F(x) if (v!=fa) dfirst(v,x,w);
if (fa) {
if (d[fa]>d[x]+up) {
d[fa]=d[x]+up;
near[fa]=near[x];
} else if (d[fa]==d[x]+up) {
near[fa]=min(near[fa],near[x]);
}
}
}
void dsecond(int x,int fa,int up) {
if (fa) {
if (d[x]>d[fa]+up) {
d[x]=d[fa]+up;
near[x]=near[fa];
} else if (d[x]==d[fa]+up) {
near[x]=min(near[x],near[fa]);
}
}
F(x) if (v!=fa) dsecond(v,x,w);
}
void dans(int x,int fa,int up) {
if (!fa) {
ans[near[x]]+=n-size[x];
} else {
int dis=dep[x]-dep[fa];
int down=jump(x,dis-1);
val[fa]-=size[down];
if (near[fa]==near[x]) {
ans[near[x]]+=size[down]-size[x];
} else {
int tl=todep(x,dep[near[fa]]);
if (tl==near[fa]) { // on a single line
int tmp=dep[near[x]]+dep[near[fa]];
if (tmp&1) {
int tod=(tmp+1)/2;
int now=todep(x,tod);
ans[near[x]]+=size[now]-size[x];
ans[near[fa]]+=size[down]-size[now];
} else {
int tod=tmp/2;
if (near[fa]<near[x]) ++tod;
int now=todep(x,tod);
ans[near[x]]+=size[now]-size[x];
ans[near[fa]]+=size[down]-size[now];
}
} else { // zag
int tmp=dist(near[x],near[fa]);
if (tmp%2==0) {
int tod=dep[near[x]]-tmp/2+1;
int now=todep(x,tod);
ans[near[x]]+=size[now]-size[x];
ans[near[fa]]+=size[down]-size[now];
} else {
int tod=dep[near[x]]-tmp/2;
if (near[fa]<near[x]) ++tod;
int now=todep(x,tod);
ans[near[x]]+=size[now]-size[x];
ans[near[fa]]+=size[down]-size[now];
}
}
}
}
F(x) if (v!=fa) dans(v,x,w);
}
void dend(int x,int fa) {
ans[near[x]]+=val[x];
F(x) if (v!=fa) dend(v,x);
}
void solve() {
init(theroot,0);
dfirst(theroot,0,0);
dsecond(theroot,0,0);
dans(theroot,0,0);
dend(theroot,0);
}
} a,b;
bool cmp(int x,int y) {
return first[x]<first[y];
}
int sta[maxn],top;
void build() {
sort(c+1,c+all+1,cmp);
theroot=c[1];
for (int i=1;i<all;++i) {
int tmp=lca(c[i],c[i+1]);
b.h[tmp]=0;
if (first[tmp]<first[theroot]) theroot=tmp;
}
sta[top=1]=theroot;
int st=theroot==c[1]?2:1;
for (int i=st;i<=all;++i) {
int now=c[i],l=lca(now,sta[top]);
if (l==sta[top]) {
sta[++top]=now;
continue;
}
while (dep[l]<dep[sta[top]]) {
if (dep[l]>=dep[sta[top-1]]) {
int last=sta[top--];
b.add(l,last,dep[last]-dep[l]);
if (l!=sta[top]) sta[++top]=l;
sta[++top]=now;
break;
}
b.add(sta[top-1],sta[top],dep[sta[top]]-dep[sta[top-1]]);
--top;
}
}
while (top>1) {
b.add(sta[top-1],sta[top],dep[sta[top]]-dep[sta[top-1]]);
--top;
}
}
int main() {
#ifndef ONLINE_JUDGE
freopen("test.in","r",stdin);
freopen("wd.out","w",stdout);
#endif
n=read();
for (int i=1;i<n;++i) {
int u=read(),v=read();
a.add(u,v),a.add(v,u);
}
a.dfs(1,0);
for (int j=1;j<maxj;++j) for (int i=1;i<=n;++i) f[i][j]=f[f[i][j-1]][j-1];
int m=read();
while (m--) {
all=read();
b.tot=0;
for (int i=1;i<=all;++i) {
c[i]=read();
spe[c[i]]=true;
cd[i]=c[i];
b.h[c[i]]=0;
ans[c[i]]=0;
}
build();
b.solve();
for (int i=1;i<=all;++i) printf("%d ",ans[cd[i]]);
puts("");
for (int i=1;i<=all;++i) spe[c[i]]=false;
}
}
bzoj3611 大工程
题目大意
给出一棵树,每次给出一些关键点,将它们两两连起来。求所有的路径和,最长路径和最短路径。数据保证总询问点数小于等于(2*10^6).
做法
建出虚树后,问题就转化成了在树上求点对距离和、最远和最近点对距离。进行树形dp,其中有利用一种巧妙的转移。用(toit[i])表示(i)的子树到(i)的距离,(mx[i])表示(i)的子树中到(i)最远的关键点距离,(mi[i])表示(i)子树中到(i)最近的关键点距离,若自己是关键点则为0.
代码
#include<cstdio>
#include<cctype>
#include<algorithm>
#define F(x) for (giant i=h[x],v=e[i].v,w=e[i].w;i;i=e[i].next,v=e[i].v,w=e[i].w)
using namespace std;
typedef long long giant;
giant read() {
giant x=0,f=1;
char c=getchar();
for (;!isdigit(c);c=getchar()) if (c=='-') f=-1;
for (;isdigit(c);c=getchar()) x=x*10+c-'0';
return x*f;
}
const giant maxn=1.5e6+10;
const giant maxj=21;
const giant inf=1e15+10;
giant size[maxn],toit[maxn],mx[maxn],mi[maxn],dep[maxn],f[maxn][maxj];
giant thesum,themin,themax,first[maxn],dfn=0,c[maxn];
bool spe[maxn];
giant lca(giant x,giant y) {
if (dep[x]<dep[y]) swap(x,y);
for (giant j=maxj-1;j>=0;--j) if (dep[f[x][j]]>=dep[y]) x=f[x][j];
if (x==y) return x;
for (giant j=maxj-1;j>=0;--j) if (f[x][j]!=f[y][j]) x=f[x][j],y=f[y][j];
return f[x][0];
}
struct edge {
giant v,w,next;
};
struct graph {
edge e[maxn<<1];
giant h[maxn],tot;
graph ():tot(0) {}
void _add(giant u,giant v,giant w=0) {
e[++tot]=(edge){v,w,h[u]};
h[u]=tot;
}
void add(giant u,giant v,giant w=0) {
_add(u,v,w),_add(v,u,w);
}
void dfs(giant x,giant fa) {
f[x][0]=fa,dep[x]=dep[fa]+1,first[x]=++dfn;
F(x) if (v!=fa) dfs(v,x);
}
void dp(giant x,giant fa) { // see here
size[x]=spe[x];
toit[x]=0;
mx[x]=spe[x]?0:-inf;
mi[x]=spe[x]?0:inf;
F(x) if (v!=fa && v!=x) {
dp(v,x);
giant tmp=toit[v]+w*size[v];
thesum+=toit[x]*size[v]+tmp*size[x];
size[x]+=size[v];
toit[x]+=tmp;
themax=max(themax,mx[x]+mx[v]+w);
themin=min(themin,mi[x]+mi[v]+w);
mx[x]=max(mx[x],mx[v]+w);
mi[x]=min(mi[x],mi[v]+w);
}
}
} a,b;
bool cmp(giant x,giant y) {
return first[x]<first[y];
}
giant sta[maxn],top;
giant dist(giant x,giant y) {
return abs(dep[x]-dep[y]);
}
void build(giant all) {
sort(c+1,c+all+1,cmp);
for (giant i=1;i<all;++i) {
int tmp=lca(c[i],c[i+1]);
b.h[tmp]=size[tmp]=toit[tmp]=mi[tmp]=mx[tmp]=0;
}
sta[top=1]=1;
for (giant i=1;i<=all;++i) {
giant now=c[i],l=lca(now,sta[top]);
if (l==sta[top]) {
sta[++top]=now;
continue;
}
while (dep[l]<dep[sta[top]]) {
if (dep[l]>=dep[sta[top-1]]) {
giant last=sta[top--];
b.add(last,l,dist(last,l));
if (l!=sta[top]) sta[++top]=l;
sta[++top]=now;
break;
}
b.add(sta[top],sta[top-1],dist(sta[top],sta[top-1])),--top;
}
}
while (top>1) {
b.add(sta[top],sta[top-1],dist(sta[top],sta[top-1]));
--top;
}
}
int main() {
#ifndef ONLINE_JUDGE
freopen("test.in","r",stdin);
freopen("bd.out","w",stdout);
#endif
giant n=read();
for (giant i=1;i<n;++i) {
giant u=read(),v=read();
a.add(u,v);
}
a.dfs(1,0);
for (giant j=1;j<maxj;++j) for (giant i=1;i<=n;++i) f[i][j]=f[f[i][j-1]][j-1];
giant m=read();
while (m--) {
giant all=read();
b.h[0]=b.h[1]=b.tot=0;
for (giant i=1;i<=all;++i) spe[c[i]=read()]=true,b.h[c[i]]=mx[c[i]]=mi[c[i]]=size[c[i]]=toit[c[i]]=0;
build(all);
thesum=0,themin=inf,themax=-inf;
b.dp(1,0);
printf("%lld %lld %lld
",thesum,themin,themax);
for (giant i=1;i<=all;++i) spe[c[i]]=false;
}
}
bzoj3991 寻宝游戏
题目大意
给出一棵树,和一个空集,向集合中不断插入或删除树上的点,动态求集合中任意一个点出发走遍集合回到此点的路径和。
做法
其实这题跟虚树没什么太大关系,不过也可以说是动态维护一颗虚树。题意即求所有点按(dfn)序排序,两两之间求距离,再加上从最后走到第一个点的距离。用(set)维护(dfn)序即可。依然是分了一大堆情况讨论,因为(set)有各种跳出去的可能。记得开long long.
代码
#include<cstdio>
#include<cctype>
#include<algorithm>
#include<set>
#define F(x) for (giant i=h[x],v=e[i].v,w=e[i].w;i;i=e[i].next,v=e[i].v,w=e[i].w)
using namespace std;
typedef long long giant;
giant read() {
giant x=0,f=1;
char c=getchar();
for (;!isdigit(c);c=getchar()) if (c=='-') f=-1;
for (;isdigit(c);c=getchar()) x=x*10+c-'0';
return x*f;
}
const giant maxn=1e5+10;
const giant maxj=20;
struct edge {
giant v,w,next;
} e[maxn<<1];
giant h[maxn],tot=0;
giant first[maxn],dfn=0,f[maxn][maxj],d[maxn],dep[maxn];
void add(giant u,giant v,giant w) {
e[++tot]=(edge){v,w,h[u]};
h[u]=tot;
}
void dfs(giant x,giant fa) {
f[x][0]=fa;
first[x]=++dfn;
dep[x]=dep[fa]+1;
F(x) if (v!=fa) d[v]=d[x]+w,dfs(v,x);
}
bool status[maxn];
giant ans,cnt;
typedef pair<giant,giant> pai;
typedef set<pai> Set;
typedef Set::iterator itt;
Set s;
giant lca(giant x,giant y) {
if (dep[x]<dep[y]) swap(x,y);
for (giant j=maxj-1;j>=0;--j) if (dep[f[x][j]]>=dep[y]) x=f[x][j];
if (x==y) return x;
for (giant j=maxj-1;j>=0;--j) if (f[x][j]!=f[y][j]) x=f[x][j],y=f[y][j];
return f[x][0];
}
giant dist(giant x,giant y) {
return d[x]+d[y]-2*d[lca(x,y)];
}
void Insert(giant x) {
pai k(first[x],x);
if (!cnt) s.insert(k),ans=0; else
if (cnt==1) {
itt it=s.begin();
ans=2*dist(it->second,x);
s.insert(k);
} else {
itt bfe=s.end(),bff=s.begin();
--bfe;
s.insert(k);
itt it=s.find(k);
itt afe=s.end();
--afe;
if (it==s.begin()) {
ans-=dist(bff->second,bfe->second);
ans+=dist(it->second,bff->second);
ans+=dist(it->second,bfe->second);
} else if (it==afe) {
ans+=dist(bfe->second,it->second);
ans+=dist(bff->second,it->second);
ans-=dist(bff->second,bfe->second);
} else {
itt pre=it,suc=it;
--pre;
++suc;
ans-=dist(pre->second,suc->second);
ans+=dist(pre->second,it->second);
ans+=dist(it->second,suc->second);
}
}
++cnt;
}
void Erase(giant x) {
pai k(first[x],x);
if (cnt==1) s.clear(); else
if (cnt==2) ans=0; else {
itt it=s.find(k),ed=s.end(),fis=s.begin();
--ed;
if (it==s.begin()) {
++fis;
ans-=dist(it->second,fis->second);
ans-=dist(ed->second,it->second);
ans+=dist(fis->second,ed->second);
} else if (it==ed) {
--ed;
ans-=dist(it->second,fis->second);
ans-=dist(ed->second,it->second);
ans+=dist(fis->second,ed->second);
} else {
itt pre=it,suc=it;
--pre;
++suc;
ans-=dist(pre->second,it->second);
ans-=dist(it->second,suc->second);
ans+=dist(pre->second,suc->second);
}
}
s.erase(k);
--cnt;
}
void print(Set s) {
for (itt it=s.begin();it!=s.end();++it) printf("{%lld, %lld} ",it->first,it->second);
puts("");
}
void print(itt it) {
printf("first = %lld , second = %lld
",it->first,it->second);
}
int main() {
#ifndef ONLINE_JUDGE
freopen("test.in","r",stdin);
freopen("sc.out","w",stdout);
#endif
giant n=read(),m=read();
for (giant i=1;i<n;++i) {
giant u=read(),v=read(),w=read();
add(u,v,w),add(v,u,w);
}
dfs(1,0);
for (giant j=1;j<maxj;++j) for (giant i=1;i<=n;++i) f[i][j]=f[f[i][j-1]][j-1];
ans=cnt=0;
while (m--) {
giant x=read();
if (status[x]) Erase(x);
else Insert(x);
status[x]^=true;
printf("%lld
",ans);
}
}
总结
虚树中把树上点抽离出来建新树,并在新树上解决问题,以保证复杂度的思想其实很容易想,代码也较简单,在题目中有保证总询问点数的树上问题中十分有用。