虚树,顾名思义,就是假的树.
在树形dp中有很大的优化作用.
虚树主要针对于树中关键点的询问.我们仅仅对关键点及其lca建一棵树.这样只要保证sigmak在时间复杂度内即可.
以下是建树的模板
q=read(); for(int i=1;i<=q;++i) { num=read(); for(int j=1;j<=num;++j) b[j]=read(),vis[b[j]]=true;//标记关键点. sort(b+1,b+num+1,cmp);//按照dfn排序 stak[top=1]=b[1];//强行加入第一个点. for(int j=2;j<=num;++j) { int now=b[j]; int lc=lca(now,stak[top]); while(1) { if(deep[lc]>=deep[stak[top-1]])//如果lca为top,或top-1,或在两者之间 { if(lc!=stak[top])//不等于top { add2(lc,stak[top]);//先连边 if(lc!=stak[top-1]) stak[top]=lc;//如果在两者之间,去掉top,加入lca else --top;//否则为top-1,直接去掉top即可. } break; } else {add2(stak[top-1],stak[top]);top--;}//lca在top-1之上,top-1向top连边,去掉top1 } stak[++top]=now;//最后把now加入栈中. } while(--top) add2(stak[top],stak[top+1]);//最后将最右链加入加入虚树 dfs(stak[1]);//从最上面的点开始dfs
这里用栈维护了虚树的最右链,dfs中记得将虚树的信息清空即可.
我觉得最难得不是虚树的建立,毕竟这就是一个模板,而是建立虚树后的dp转移...头大...
[SDOI2011]消耗战
这个题要求所有的关键点都不能到达1号点的最小代价.
看到sigma(ki)<=500000,就知道要用到虚树(要养成好习惯).
我们先考虑从普通的dp入手,再探索虚树上应该如何dp.
我们设f[i]表示以i为根的子树内的关键点都不与1联通的最小代价.
考虑当前x的状态如何转移.
首先如果x是关键点,那f[x]只能等于v(fa[x],x).也就是必须切断x的父亲与x的联系。这样x及其子树都不可能与1联通.
倘若x不是关键点,那f[x]=min(sum[x],v(fa[x],x)).sum[x]=sigmaf[y].(y=x.son)
好了,这样普通的dp就只能达到这种地步了.
如果我们把这种dp放到虚树上会是什么样呢?由于我们将许多没用的点都抽离出去了,所以如果一个点是关键带你的话,我们无法做到查询
v(fa[x],x)的值.那我们思考当想要将x的关键点拦截的话,付出他的最小代价究竟是什么,是点x到1的最小的边权.
那我们在之前的dfs中预处理出来这个东西.之后按照上面的转移即可.
#include<bits/stdc++.h> #define ll long long #define min(a,b) a<b?a:b using namespace std; const int N=500500; int link1[N],tot1,link2[N],tot2,n,deep[N],f[N][25],q; int b[N],num,dfn[N],stak[N],top; ll minv[N]; bool vis[N]; struct edge{int y,next;ll v;}a1[N<<1],a2[N<<1]; inline int read() { int x=0,ff=1; char ch=getchar(); while(!isdigit(ch)) {if(ch=='-') ff=-1;ch=getchar();} while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48);ch=getchar();} return x*ff; } inline void add1(int x,int y,int v) { a1[++tot1].y=y; a1[tot1].v=v; a1[tot1].next=link1[x]; link1[x]=tot1; } inline void add2(int x,int y) { a2[++tot2].y=y; a2[tot2].next=link2[x]; link2[x]=tot2; } inline void dfs1(int x,int fa) { dfn[x]=++num; for(int i=link1[x];i;i=a1[i].next) { int y=a1[i].y; if(y==fa) continue; deep[y]=deep[x]+1; f[y][0]=x; for(int j=1;j<=20;++j) f[y][j]=f[f[y][j-1]][j-1]; minv[y]=min(minv[x],a1[i].v); dfs1(y,x); } } inline int lca(int a,int b) { if(deep[a]>deep[b]) swap(a,b); for(int i=20;i>=0;--i) if(deep[f[b][i]]>=deep[a]) b=f[b][i]; if(a==b) return a; for(int i=20;i>=0;--i) if(f[a][i]!=f[b][i]) a=f[a][i],b=f[b][i]; return f[a][0]; } inline bool cmp(int a,int b) {return dfn[a]<dfn[b];} inline ll dfs2(int x) { ll sum=0,dp; for(int i=link2[x];i;i=a2[i].next) { int y=a2[i].y; sum+=dfs2(y); } if(vis[x]) dp=minv[x]; else dp=min(minv[x],sum); if(vis[x]) vis[x]=false; link2[x]=0; return dp; } int main() { // freopen("1.in","r",stdin); n=read(); for(int i=1;i<n;++i) { int x=read(),y=read(),v=read(); add1(x,y,v);add1(y,x,v); } minv[1]=1e18; dfs1(1,0);q=read(); while(q--) { num=read(); for(int i=1;i<=num;++i) { b[i]=read(); vis[b[i]]=true; } sort(b+1,b+num+1,cmp); stak[top=1]=b[1]; for(int i=2;i<=num;++i) { int now=b[i]; int lc=lca(now,stak[top]); while(1) { if(deep[lc]>=deep[stak[top-1]]) { if(lc!=stak[top]) { add2(lc,stak[top]); if(lc!=stak[top-1]) stak[top]=lc; else top--; } break; } else {add2(stak[top-1],stak[top]);top--;} } stak[++top]=now; } while(--top) add2(stak[top],stak[top+1]); cout<<dfs2(stak[1])<<endl; tot2=0; } return 0; }
[HEOI2014]大工程
这种题真的一搞一上午啊,还是我太菜了.....
我们看到k的范围自然就想到了虚树.
那就让我们先考虑普通的dp:
第一问,是所有关键点两两匹配的总长度之和.二三问分别是最长和最小长度.
第一问直接统计每条边的贡献,第二三问用求直径的思想。
我们设sum[x],mx[x],mn[x],size[x]分别表示以x为根的树中,所有关键点到x的路径和,最大值,最小值,和个数.
对于ans1,我们考虑当前处理到y这个儿子.
ans1+=sum[x]*size[y]+(sum[y]+dis(x,y)*size[y])*size[x].这个意思就是之前的子树中每条边都出来与y中的子树匹配.
mx,与mn就不加述说了.
我之前一直在思考如果是关键点的话,怎么特殊处理.因为我们的做法其实枚举了每一个lca,将两端拼接起来的.
可是观察上面的转移,如果我们将关键点的size[x]初始化为1,那size[x]里就为累计一下(sum[y]+dis(x,y)的代价,其实就等同于x与所有关键点的匹配.
在普通树里,dis(x,y)是1,而在虚树里dis(x,y)是deep[y]-deep[x]。之后将其转移即可.
#include<bits/stdc++.h> #define ll long long using namespace std; const int N=1000010; int n,q,link1[N],tot1,link2[N],tot2,deep[N],f[N][25],b[N],num,dfn[N]; int stak[N],top; ll ans1,ans2,ans3,sum[N],mx[N],mn[N],size[N]; bool vis[N]; struct edge{int y,next;}a1[N<<1],a2[N<<1]; inline int read() { int x=0,ff=1; char ch=getchar(); while(!isdigit(ch)) {if(ch=='-') ff=-1;ch=getchar();} while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48);ch=getchar();} return x*ff; } inline bool cmp(int x,int y){return dfn[x]<dfn[y];} inline void add1(int x,int y) { a1[++tot1].y=y; a1[tot1].next=link1[x]; link1[x]=tot1; } inline void add2(int x,int y) { a2[++tot2].y=y; a2[tot2].next=link2[x]; link2[x]=tot2; } inline void dfs1(int x) { dfn[x]=++num; for(int i=link1[x];i;i=a1[i].next) { int y=a1[i].y; if(y==f[x][0]) continue; deep[y]=deep[x]+1; f[y][0]=x; for(int j=1;j<=20;++j) f[y][j]=f[f[y][j-1]][j-1]; dfs1(y); } } inline int lca(int a,int b) { if(deep[a]>=deep[b]) swap(a,b); for(int i=20;i>=0;--i) if(deep[f[b][i]]>=deep[a]) b=f[b][i]; if(a==b) return a; for(int i=20;i>=0;--i) if(f[a][i]!=f[b][i]) a=f[a][i],b=f[b][i]; return f[a][0]; } inline void dfs2(int x) { sum[x]=0;mx[x]=0;mn[x]=(vis[x]?0:1e18);size[x]=(vis[x]?1:0); for(int i=link2[x];i;i=a2[i].next) { int y=a2[i].y; dfs2(y); ll dis=deep[y]-deep[x]; ans1+=(sum[y]+dis*size[y])*size[x]+sum[x]*size[y]; ans2=max(ans2,mx[x]+mx[y]+dis); ans3=min(ans3,mn[x]+mn[y]+dis); sum[x]+=sum[y]+dis*size[y]; mx[x]=max(mx[x],mx[y]+dis); mn[x]=min(mn[x],mn[y]+dis); size[x]+=size[y]; } if(vis[x]) vis[x]=false; link2[x]=0; } int main() { freopen("1.in","r",stdin); n=read(); for(int i=1;i<n;++i) { int x=read(),y=read(); add1(x,y);add1(y,x); } deep[1]=1;dfs1(1); q=read(); for(int i=1;i<=q;++i) { num=read(); for(int j=1;j<=num;++j) b[j]=read(),vis[b[j]]=true; sort(b+1,b+num+1,cmp); stak[top=1]=b[1]; for(int j=2;j<=num;++j) { int now=b[j]; int lc=lca(now,stak[top]); while(1) { if(deep[lc]>=deep[stak[top-1]]) { if(lc!=stak[top]) { add2(lc,stak[top]); if(lc!=stak[top-1]) stak[top]=lc; else --top; } break; } else {add2(stak[top-1],stak[top]);top--;} } stak[++top]=now; } while(--top) add2(stak[top],stak[top+1]); ans1=0;ans2=0;ans3=1e18; dfs2(stak[1]); printf("%lld %lld %lld ",ans1,ans3,ans2); tot2=0; } return 0; }