• 虚树学习笔记


    虚树学习笔记

    [SDOI2011]消耗战

    Link

    题意

    给一棵(n)个点,带边权的树。

    (m)组询问,每组有(k_i)个关键点,你需要切断一些边,使得每个点都到不了根节点,求最小代价。

    (n <= 2.5 cdot 10^5, m <= 5 cdot 10^5,sum k_i <= 5 cdot 10^5)

    Solve 1

    对于每组询问,做一个(dp),设(f[x])表示切断(x)和他的子树所需最小代价,转移分两种

    1. (x)是关键点,答案为(x)到根路径最小值
    2. (x)不是关键点,答案为切断所有儿子的值和第(1)种取(min)

    复杂度(O(nm))

    (can we do better?)

    Solve 2

    要用到所讲的虚树。

    我们发现转移过程中,对于转移有贡献的只有关键点以及他们之间的祖先,于是我们可以简化树的结构。

    把关键点按(dfs)序排序,相邻两个求出(lca)并建边。最后在虚树上做(dp),复杂度(O(n log n + sum k_i log n))

    具体实现用一个栈维护一条树链,排序后一次加入点。

    设当前加入的点(u)

    • 如果(top <=1) ,(stk[++top] = u)
    • (l = lca(u,stk[top])),如果(l == stk[top]),那么(u)应该接在(stk[top])底下,(stk[++top] = u)
    • 否则说明(u)已经是一个新的子树,持续弹栈直到(dfn[stk[top-1]] < dfn[l] <= dfn[stk[top]]),如果(l != stk[top]),把(stk[top])接在(l)后面,(stk[top] = l),最后(stk[++top] = u)
    void insert(int u){
    	if(top <= 1) return stk[++top] = u,void();
    	int l = lca(u,stk[top]);
    	if(l == stk[top]) return stk[++top] = u,void();
    	while(top > 1 && dfn[l] <= dfn[stk[top-1]]){
    		add(stk[top-1],stk[top]); top--;
    	}
    	if(l != stk[top]) add(l,stk[top]),stk[top] = l;
    	stk[++top] = u;
    	return ;
    }
    

    Code

    #include<bits/stdc++.h>
    #define int long long
    #define N 1000015
    #define rep(i,a,n) for (int i=a;i<=n;i++)
    #define per(i,a,n) for (int i=n;i>=a;i--)
    #define inf 0x3f3f3f3f3f3f3f3f
    #define pb push_back
    #define mp make_pair
    #define pii pair<int,int>
    #define fi first
    #define se second
    #define lowbit(i) ((i)&(-i))
    #define VI vector<int>
    #define all(x) x.begin(),x.end()
    using namespace std;
    int n,m,a[N],Min[N],k,dfn[N],clk;
    vector<pii> e[N];
    VI g[N];
    void dfs(int u,int fa){
    	dfn[u] = ++clk;
    	for(auto I:e[u]){
    		int v = I.fi,w = I.se;
    		if(v == fa) continue;
    		Min[v] = min(Min[u],w);
    		dfs(v,u);
    	}
    }
    bool cmp(int u,int v){
    	return dfn[u] < dfn[v];
    }
    namespace LCA{
    	int fa[N][24],dep[N];
    	void Dfs(int u,int f){
    		fa[u][0] = f; dep[u] = dep[f]+1;
    		for(auto I:e[u]){
    			int v = I.fi;
    			if(v == f) continue;
    			Dfs(v,u);
    		}
    	}
    	void init(){
    		rep(j,1,21){
    			rep(i,1,n){
    				fa[i][j] = fa[fa[i][j-1]][j-1];
    			}
    		}
    	}
    	int lca(int u,int v){
    		if(dep[u] < dep[v]) swap(u,v);
    		int t = dep[u] - dep[v];
    		per(i,0,21){
    			if((1<<i)&t) u = fa[u][i];
    		}
    		if(u == v) return u;
    		per(i,0,21){
    			if(fa[u][i] != fa[v][i]) u = fa[u][i],v = fa[v][i];
    		}
    		return fa[u][0];
    	}
    }
    using namespace LCA;
    int stk[N],top;
    void add(int u,int v){
    	//printf("%lld -> %lld
    ",u,v);
    	g[u].pb(v);
    }
    void insert(int u){
    	if(top <= 1) return stk[++top] = u,void();
    	int l = lca(u,stk[top]);
    	if(l == stk[top]) return stk[++top] = u,void();
    	while(top > 1 && dfn[l] <= dfn[stk[top-1]]){
    		add(stk[top-1],stk[top]); top--;
    	}
    	if(l != stk[top]) add(l,stk[top]),stk[top] = l;
    	stk[++top] = u;
    	return ;
    }
    void build(){
    	top = 0;
    	stk[++top] = 1;
    	rep(i,1,k) insert(a[i]);
    	while(top > 1) add(stk[top-1],stk[top]),top--;
    }
    bool gkp[N];
    int dp(int u){
    	int res = 0;
    	if(g[u].size() == 0){
    		//printf("u: %lld val: %lld
    ",u,Min[u]);
    		return Min[u];
    	}
    	for(auto v:g[u]){
    		res += dp(v);
    	}
    	g[u].clear();
    	if(!gkp[u]) return min(res,Min[u]);
    	//printf("u: %lld val: %lld
    ",u,res);
    	return Min[u];
    }
    
    signed main(){
    	//freopen(".in","r",stdin);
    	//freopen(".out","w",stdout);
     	scanf("%lld",&n);
     	memset(Min,0x3f,sizeof Min);
     	rep(i,2,n){
     		int u,v,w; scanf("%lld%lld%lld",&u,&v,&w);
     		e[u].pb(mp(v,w)); e[v].pb(mp(u,w));
     	}
     	dfs(1,0);
     	// rep(i,1,n) printf("%lld ", Min[i]);
     	// printf("
    ");
     	Dfs(1,0); init();
     	// rep(i,1,n){
     	// 	rep(j,i+1,n){
     	// 		printf("(i,j): (%lld,%lld) lca: %lld
    ",i,j,lca(i,j));
     	// 	}
     	// }
     	scanf("%lld",&m);
     	rep(_,1,m){
     		scanf("%lld",&k); rep(i,1,k) scanf("%lld",&a[i]),gkp[a[i]] = 1;
     		sort(a+1,a+k+1,cmp); //puts("sort finished");
     		build(); //puts("build finished");
     		printf("%lld
    ",dp(1));
     		rep(i,1,k) gkp[a[i]] = 0;
     	}
    	return 0;
    }
    
  • 相关阅读:
    Python day43 :pymysql模块/查询,插入,删除操作/SQL注入完全问题/事务/模拟登录注册服务器/视图/函数/存储过程
    docker
    Linux 05
    Linux04
    Linux 03
    Linux 02
    go语言
    go语言
    go语言
    Linux
  • 原文地址:https://www.cnblogs.com/czdzx/p/14317271.html
Copyright © 2020-2023  润新知