简介
虚树可以解决一些关于树上一部分节点的问题. 对于一棵树 (T) 的一个子集 (S), 可以在 (O(|S| log |S|)) 的时间复杂度内求出 (S) 的虚树.
虚树包括根节点, 所有询问点和所有询问点之间的 (lca).
代码
//store the tree
struct tg{
struct te{int t,pr,v;}edge[nsz*2];
int hd[nsz],pe=1;
void adde(int f,int t,int v){edge[++pe]=(te){t,hd[f],v};hd[f]=pe;}
void adddb(int f,int t,int v){adde(f,t,v);adde(t,f,v);}
#define forg(g,p,i,v) for(int i=g.hd[p],v=g.edge[i].t;i;i=g.edge[i].pr,v=g.edge[i].t)
void clear(){//only g2
pe=1;
// rep(i,1,pu)hd[usedp[i]]=0;
}
}g1,g2;
//get lca (with sparse table)
namespace nlca{
void dfs(int u,int fa){
eul[++pe]=u,vis[u]=pe,d[u]=d[fa]+1;
forg(g1,u,i,v){
if(v==fa)continue;
dfs(v,u);
eul[++pe]=u;
}
}
int dmin(int a,int b){return d[a]<=d[b]?a:b;}
void rmq(){
rep(i,1,pe)stt[i][0]=eul[i];
rep(j,1,l2n[pe]){
rep(i,1,pe+1-(1<<j)){
stt[i][j]=dmin(stt[i][j-1],stt[i+(1<<(j-1))][j-1]);
}
}
}
int stqu(int a,int b){
int l=l2n[b-a+1];
return dmin(stt[a][l],stt[b-(1<<l)+1][l]);
}
void eulinit(){
int l=0;
rep(i,1,n*3){
if(i==(1<<(l+1)))++l;
l2n[i]=l;
}
dfs(1,0);
rmq();
}
int lca(int a,int b){
int x=vis[a],y=vis[b];
if(x>y)swap(x,y);
return stqu(x,y);
}
}
// 求虚树
// line[1...k]: 用到的点
bool cmp(int a,int b){return vis[a]<vis[b];}
int stk[nsz],top=0;
void build(){
g2.clear(),top=0;
sort(line+1,line+k+1,cmp);
top=0,stk[++top]=1;
rep(i,1,k){
int l=nlca::lca(line[i],stk[top]);
while(top>1&&vis[stk[top-1]]>=vis[l])g2.adde(stk[top-1],stk[top],1),--top;
if(l!=stk[top])g2.adde(l,stk[top],1),stk[top]=l;
stk[++top]=line[i];
}
while(top>1)g2.adde(stk[top-1],stk[top],1),--top;
}
//dfs 过程
void sol(int p){
forg(g2,p,i,v){
sol(v);
//do something...
}
g2.hd[p]=0; //清空虚树
}
例题: BZOJ2286 [Sdoi2011]消耗战
建立虚树之后dp即可.
注意输入的节点必须断掉, 但lca节点可断可不断. 可以标记输入的节点.
#include<cstdio>
#include<iostream>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<set>
#include<map>
using namespace std;
#define rep(i,l,r) for(register int i=(l);i<=(r);++i)
#define repdo(i,l,r) for(register int i=(l);i>=(r);--i)
#define il inline
typedef double db;
typedef long long ll;
//---------------------------------------
const int nsz=250050;
const ll ninf=1e18;
int n,m,k,line[nsz];
int used[nsz];
struct tg{
struct te{int t,pr,v;}edge[nsz*2];
int hd[nsz],pe=1;
void adde(int f,int t,int v){edge[++pe]=(te){t,hd[f],v};hd[f]=pe;}
void adddb(int f,int t,int v){adde(f,t,v);adde(t,f,v);}
#define forg(g,p,i,v) for(int i=g.hd[p],v=g.edge[i].t;i;i=g.edge[i].pr,v=g.edge[i].t)
void clear(){//only g2
pe=1;
// rep(i,1,pu)hd[usedp[i]]=0;
}
}g1,g2;
ll mind[nsz]{0,ninf};
int l2n[nsz*3+50];
int eul[nsz*3],pe=0,vis[nsz],d[nsz];
int stt[nsz*3][21];
namespace nlca{
void dfs(int u,int fa){
eul[++pe]=u,vis[u]=pe,d[u]=d[fa]+1;
forg(g1,u,i,v){
if(v==fa)continue;
mind[v]=min(mind[u],(ll)g1.edge[i].v);
dfs(v,u);
eul[++pe]=u;
}
}
int dmin(int a,int b){return d[a]<=d[b]?a:b;}
void rmq(){
rep(i,1,pe)stt[i][0]=eul[i];
rep(j,1,l2n[pe]){
rep(i,1,pe+1-(1<<j)){
stt[i][j]=dmin(stt[i][j-1],stt[i+(1<<(j-1))][j-1]);
}
}
}
int stqu(int a,int b){
int l=l2n[b-a+1];
return dmin(stt[a][l],stt[b-(1<<l)+1][l]);
}
void eulinit(){
int l=0;
rep(i,1,n*3){
if(i==(1<<(l+1)))++l;
l2n[i]=l;
}
dfs(1,0);
rmq();
}
int lca(int a,int b){
int x=vis[a],y=vis[b];
if(x>y)swap(x,y);
return stqu(x,y);
}
}
bool cmp(int a,int b){return vis[a]<vis[b];}
int stk[nsz],top=0;
void build(){
g2.clear(),top=0;
sort(line+1,line+k+1,cmp);
top=0,stk[++top]=1;
rep(i,1,k){
int l=nlca::lca(line[i],stk[top]);
while(top>1&&vis[stk[top-1]]>=vis[l])g2.adde(stk[top-1],stk[top],1),--top;
if(l!=stk[top])g2.adde(l,stk[top],1),stk[top]=l;
stk[++top]=line[i];
}
while(top>1)g2.adde(stk[top-1],stk[top],1),--top;
}
ll dp[nsz];
void sol(int p){
dp[p]=0;
forg(g2,p,i,v){
sol(v);
dp[p]+=(used[v]?mind[v]:min(mind[v],dp[v]));
}
g2.hd[p]=0;
}
int main(){
ios::sync_with_stdio(0),cin.tie(0);
cin>>n;
int a,b,c;
rep(i,2,n){
cin>>a>>b>>c;
g1.adddb(a,b,c);
}
nlca::eulinit();
cin>>m;
rep(i,1,m){
cin>>k;
rep(j,1,k)cin>>line[j],used[line[j]]=1;
pe=1;
build();
sol(1);
cout<<dp[1]<<'
';
rep(j,1,k)used[line[j]]=0;
}
return 0;
}