• 动态DP学习笔记


    动态DP学习笔记

    约定:

    1. 若无特别说明,数组下标从1开始.
    2. 若无特别说明,矩阵的下标从0开始(这是为了和代码一致,而代码这样实现是为了节省内存)
    3. 在公式中,矩阵用粗斜体表示,如(m{A})
    4. 若无特别说明,对于有根树上的结点(x),(child(x))表示它儿子的集合,(son(x))表示它的重儿子
    5. 对于实数的乘法,一律用(+,cdot)表示,有字母时乘号可能省略。

    可能需要的前置知识:

    1. 线段树
    2. DFS序,轻重链剖分
    3. LCT
    4. 矩阵乘法

    问题引入

    动态DP是指在动态规划问题中修改参量,并询问修改后的DP值。如果每次修改完之后都要朴素的重新DP一遍,时间复杂度很高,无法接受。

    动态DP有很多解法。对于一些特殊动态DP问题,可以用离线处理,虚树,倍增等方法转化成普通的静态DP问题。这些方法虽然灵活方便且常数较小,但是可扩展性差。而本文讨论的是一种较为通用的解法,即把状态转移方程写成矩阵的形式,进而把修改操作转化为修改矩阵,查询操作转化为区间矩阵乘积,然后用数据结构维护矩阵

    序列上的动态DP

    给出一个可能有负数序列(a_i),支持单点修改,区间查询最大子段和。即给出([L,R]),求$$max_{L leq l leq r leq R} sum_{i=l}^{r}a_i$$

    线段树维护包含左端点,包含右端点和区间内的答案显然可做

    如果查询整个序列的子段和,容易写出DP方程。设(f_i)表示以(i)结尾的最大子段和,显然有(f_i=max(f_{i-1}+a_i,a_i)).但是这个方程无法快速修改和维护。

    这时,我们想到了矩阵优化DP。

    重新定义矩阵乘法

    我们知道,一般的矩阵乘法可以写成$$c_{i,j}=sum_{k=1}^m a_{i,k} cdot b_{k,j}$$的形式。
    我们知道,很多次加法可以变成乘法,那么什么运算可以变成加法?实际上是(max)(min)运算。它们有类似的性质,比如:
    (max(a,b)+c=max(a+c,b+c))
    $(a+b)cdot c=acdot c+b cdot c $.

    那么我们可以重新定义矩阵乘法:

    [c_{i,j}=max_{k=1}^m (a_{i,k} + b_{k,j}) ]

    很(wo)容(bu)易(hui)证明这样的矩阵乘法满足结合律,不满足交换律

    另外,容易发现在(max,+)矩阵乘法中,(-infty)充当了0的作用,0充当了1的作用,而单位矩阵是对角线为0,其他元素为(-infty)的矩阵。在(min,+)矩阵乘法中,则是(+infty)

    转化成数据结构问题

    根据矩阵乘法的定义,我们可以写出如下的式子:

    [egin{bmatrix} f_i \ 0end{bmatrix} =egin{bmatrix}a_i a_i\ -infty 0end{bmatrix}egin{bmatrix} f_{i-1} \ 0end{bmatrix} ]

    [egin{bmatrix} f_R \ 0end{bmatrix} =egin{bmatrix}a_{R} a_{R}\ -infty 0end{bmatrix}egin{bmatrix}a_{R-1} a_{R-1}\ -infty 0end{bmatrix}cdotsegin{bmatrix}a_{L+1} a_{L+1}\ -infty 0end{bmatrix}egin{bmatrix} f_{L} \ 0end{bmatrix} ]

    那么修改的时候就修改某个点对应的矩阵,查询的时候就是查询区间乘积。由于(f_L)的初始值为(a_L),我们可以直接查询([L,R])的区间矩阵乘积(m{S}),然后输出(max(m{S_{0,0}},m{S_{1,0}}))即可。

    树上的动态DP

    实际上,几乎没有在序列上的动态DP题目(如果有,一般也能用线段树等数据结构直接解决),大部分的动态DP题目都出现在树上。而树上的动态DP往往与链剖分紧密结合。

    接下来我们讨论一个经典模型:树上最大权独立集。

    LuoguP4719:给出一个(n)个点的树,每个点权值为(v_i),每次单点修改,查询整棵树的最大权独立集

    (f_{x,0},f_{x,1})分别表示(x)子树中,不选择(x)的最大权独立集大小,和选择(x)的最大权独立集大小。显然有转移:

    (f_{x,0}=sum_{y in child(x)} max(f_{y,0},f_{y,1}))
    (f_{x,1}=sum_{y in child(x)} f_{y,0})

    这个方程无法写成矩阵形式,考虑优化。

    树链剖分+线段树解法

    为了快速维护DP值,我们可以分轻重儿子维护DP值。

    (g_{x,0}=sum_{y in child(x)-{son(x)}} max(f_{y,0},f_{y,1}))

    (g_{x,1}=a_x+sum_{y in child(x)-{son(x)}} f_{y,0})

    g维护了所有轻儿子的DP贡献,那么有:

    (f_{x,0}=max(f_{son(x),0},f_{son(x),1})+g_{x,0})
    (f_{x,1}=f_{son(x),0}+g_{x,1})

    写成矩阵的形式(注意这里是max,+矩阵乘法)

    [egin{bmatrix}f_{x,0} \ f_{x,1} end{bmatrix}=egin{bmatrix}g_{x,0} g_{x,0} \ g_{x,1} -infin end{bmatrix} egin{bmatrix}f_{son(x),0} \ f_{son(x),1} end{bmatrix} ]

    不妨将转移矩阵记为(m{M_x}=egin{bmatrix}g_{x,0} g_{x,0} \ g_{x,1} -infin end{bmatrix})
    因为(x,son(x),son(son(x)))构成的是一条重链,所以对于重链上的点,可以类似序列动态DP的方法,用查询区间(m{M_x})乘积的方法求出它们的DP值,用线段树在DFS序上维护。

    但是还要考虑对轻链的影响.
    ddp.png

    如图所示,从每个点的(g)的统计范围可以看出,每次修改只会影响(x)到根的路径上,每条重链的底端的(g)值.((x)所在一条除外)。比如修改5会影响到2.于是我们先修改(g_{x,1})(即(m{{M_x}_{1,0}})),把它加上(a_{x})的增加量,因为它的定义里包含(a_x).

    然后沿着重链往上跳:
    对于每条重链的链顶,我们要减去原来的(f)对链顶父亲的影响。还要求出这条重链上新的DP值,显然只需要考虑重链底端的影响,它轻儿子已经修改完了,我们根据新的(g)重新赋值(m{M_x})矩阵,然后在线段树上单点修改。接着跳到链顶,在线段树上查询f值,然后更新链顶的父亲。

    ll get_f(int x,int k) {
    	//f[x]需要从x所在重链底端推上来,变成区间矩阵乘法
    	return T.query(dfn[x],dfn[btm[x]],1).a[k][0];//btm[x]表示x所在重链的底端
    }
    void change(int x,int v) {
    	g[x][1]+=v-val[x];//先修改x
    	val[x]=v;
    	while(x) {
            //对于重链底端,根据新的g重新赋值M[x]
    		mat[x].a[0][0]=g[x][0];
    		mat[x].a[0][1]=g[x][0];
    		mat[x].a[1][0]=g[x][1];
    		mat[x].a[1][1]=-INF;
    		T.update(dfn[x],mat[x],1);//单点修改
    		x=top[x];//对于链顶,要更新它父亲的g
    		g[fa[x]][0]-=max(f[x][0],f[x][1]);//减去旧的f
    		g[fa[x]][1]-=f[x][0];
    		f[x][0]=get_f(x,0);
    		f[x][1]=get_f(x,1);
    		g[fa[x]][0]+=max(f[x][0],f[x][1]);//加上新的f
    		g[fa[x]][1]+=f[x][0];
    		x=fa[x];//跳到上一条重链
    	}
    }
    

    在更新的实现中,也可以不用(f)(g)数组,可以直接修改(m{M})里的对应位置,但是要注意先减去对父亲的影响再修改。还要注意区分(m{M})和线段树节点里的矩阵,(m{M})里存的是(g)值,而线段树节点里存的实际上是(f)值。

    该做法的时间复杂度是(O(nlog^2 n))(矩阵乘法的复杂度看作常数).

    完整代码:

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #define INF 0x3f3f3f3f
    #define maxn 200000
    using namespace std;
    typedef long long ll;
    int n,m;
    struct edge {
    	int from;
    	int to;
    	int next;
    } E[maxn*2+5];
    int head[maxn+5];
    int esz=1;
    void add_edge(int u,int v) {
    	esz++;
    	E[esz].from=u;
    	E[esz].to=v;
    	E[esz].next=head[u];
    	head[u]=esz;
    }
    int fa[maxn+5],son[maxn+5],sz[maxn+5],top[maxn+5],btm[maxn+5]/*所在重链最底端*/,dfn[maxn+5],hash_dfn[maxn+5];
    void dfs1(int x,int f) {
    	sz[x]=1;
    	fa[x]=f;
    	for(int i=head[x]; i; i=E[i].next) {
    		int y=E[i].to;
    		if(y!=f) {
    			dfs1(y,x);
    			sz[x]+=sz[y];
    			if(sz[y]>sz[son[x]]) son[x]=y;
    		}
    	}
    }
    int tim=0;
    void dfs2(int x,int t) {
    	top[x]=t;
    	dfn[x]=++tim;
    	hash_dfn[dfn[x]]=x;
    	if(son[x]) {
    		dfs2(son[x],t);
    		btm[x]=btm[son[x]];//维护重链最底端节点
    	} else btm[x]=x;
    	for(int i=head[x]; i; i=E[i].next) {
    		int y=E[i].to;
    		if(y!=fa[x]&&y!=son[x]) {
    			dfs2(y,y);
    		}
    	}
    }
    
    
    struct matrix {
    	ll a[2][2];
    	inline void set(int x) {
    		for(int i=0; i<2; i++) {
    			for(int j=0; j<2; j++) a[i][j]=x;
    		}
    	}
    	friend matrix operator * (matrix p,matrix q) {
    		matrix ans;
    		ans.set(-INF);
    		for(int i=0; i<2; i++) {
    			for(int j=0; j<2; j++) {
    				for(int k=0; k<2; k++) {
    					ans.a[i][j]=max(ans.a[i][j],p.a[i][k]+q.a[k][j]);
    				}
    			}
    		}
    		return ans;
    	}
    } mat[maxn+5];
    ll val[maxn+5];
    ll f[maxn+5][2],g[maxn+5][2];
    void dfs3(int x) {
    	f[x][0]=0;
    	f[x][1]=val[x];
    	for(int i=head[x]; i; i=E[i].next) {
    		int y=E[i].to;
    		if(y!=fa[x]) {
    			dfs3(y);
    			f[x][0]+=max(f[y][0],f[y][1]);
    			f[x][1]+=f[y][0];
    		}
    	}
    	g[x][0]=0,g[x][1]=val[x];
    	for(int i=head[x]; i; i=E[i].next) {
    		int y=E[i].to;
    		if(y!=fa[x]&&y!=son[x]) {
    			g[x][0]+=max(f[y][0],f[y][1]);
    			g[x][1]+=f[y][0];
    		}
    	}
    	mat[x].a[0][0]=g[x][0];
    	mat[x].a[0][1]=g[x][0];
    	mat[x].a[1][0]=g[x][1];
    	mat[x].a[1][1]=-INF;
    }
    
    struct segment_tree {
    	struct node {
    		int l;
    		int r;
    		matrix v;
    	} tree[maxn*4+5];
    	void push_up(int pos) {
    		tree[pos].v=tree[pos<<1].v*tree[pos<<1|1].v;
    	}
    	void build(int l,int r,int pos) {
    		tree[pos].l=l;
    		tree[pos].r=r;
    		if(l==r) {
    			tree[pos].v=mat[hash_dfn[l]];
    			return;
    		}
    		int mid=(l+r)>>1;
    		build(l,mid,pos<<1);
    		build(mid+1,r,pos<<1|1);
    		push_up(pos);
    	}
    	void update(int upos,matrix &uval,int pos) {
    		if(tree[pos].l==tree[pos].r) {
    			tree[pos].v=uval;
    			return;
    		}
    		int mid=(tree[pos].l+tree[pos].r)>>1;
    		if(upos<=mid) update(upos,uval,pos<<1);
    		else update(upos,uval,pos<<1|1);
    		push_up(pos);
    	}
    	matrix query(int L,int R,int pos) {
    		if(L<=tree[pos].l&&R>=tree[pos].r) return tree[pos].v;
    		int mid=(tree[pos].l+tree[pos].r)>>1;
    		matrix ans;
    		ans.a[0][0]=ans.a[1][1]=0;
    		ans.a[0][1]=ans.a[1][0]=-INF;
    		if(L<=mid) ans=ans*query(L,R,pos<<1);
    		if(R>mid) ans=ans*query(L,R,pos<<1|1);
    		return ans;
    	}
    } T;
    ll get_f(int x,int k) {
    	//f[x]需要从x所在重链底端推上来,变成区间矩阵乘法
    	return T.query(dfn[x],dfn[btm[x]],1).a[k][0];
    }
    void change(int x,int v) {
    	g[x][1]+=v-val[x];
    	val[x]=v;
    	while(x) {
    		mat[x].a[0][0]=g[x][0];
    		mat[x].a[0][1]=g[x][0];
    		mat[x].a[1][0]=g[x][1];
    		mat[x].a[1][1]=-INF;
    		T.update(dfn[x],mat[x],1);
    		x=top[x];
    		g[fa[x]][0]-=max(f[x][0],f[x][1]);
    		g[fa[x]][1]-=f[x][0];
    		f[x][0]=get_f(x,0);
    		f[x][1]=get_f(x,1);
    		g[fa[x]][0]+=max(f[x][0],f[x][1]);
    		g[fa[x]][1]+=f[x][0];
    		x=fa[x];
    	}
    }
    
    int main() {
    	int u,v;
    	scanf("%d %d",&n,&m);
    	for(int i=1; i<=n; i++) scanf("%lld",&val[i]);
    	for(int i=1; i<n; i++) {
    		scanf("%d %d",&u,&v);
    		add_edge(u,v);
    		add_edge(v,u);
    	}
    	dfs1(1,0);
    	dfs2(1,1);
    	dfs3(1);
    	T.build(1,n,1);
    	for(int i=1; i<=m; i++) {
    		scanf("%d %d",&u,&v);
    		change(u,v);
    		printf("%lld
    ",max(get_f(1,0),get_f(1,1)));
    	}
    }
    

    LCT解法

    既然轻重链剖分可做,那么LCT也可做,只需要在Splay节点里维护子树矩阵乘积。把实链看成重链,虚链看成轻链即可。(g)维护的就是所有虚儿子的信息。初始的时候所有边都是虚边,fa指向原树上的父亲。

    查询的时候直接splay(1),然后输出矩阵信息即可。
    修改的时候要先accesssplay.实际上的修改操作在access中完成。考虑access的过程

    void access(int x){
        for(int y=0;x;y=x,x=fa(x)){
            splay(x);
            rson(x)=y;
            push_up(x);
        }
    }
    

    rson(x)=y,实际上就是原来的rson(x)变成了轻儿子,y变成了重儿子。因此(g)要加上(f_{rson(x)}),去掉(f_y)。这个过程和LCT维护子树信息是类似的。

    void access(int x) {
        //这里和树剖向上跳重链更新是类似的
        for(int y=0; x; y=x,x=fa(x)) {
            splay(x);
            //原来的rson(x)由实变虚
            if(rson(x)){
                mat[x].a[0][0]+=max(tree[rson(x)].v.a[0][0],tree[rson(x)].v.a[1][0]);
                mat[x].a[1][0]+=tree[rson(x)].v.a[0][0];
                //这里也可以不用f和g,直接写对应矩阵里的值 
            } 
            rson(x)=y;
            if(rson(x)){
                mat[x].a[0][0]-=max(tree[rson(x)].v.a[0][0],tree[rson(x)].v.a[1][0]);
                mat[x].a[1][0]-=tree[rson(x)].v.a[0][0];
            }
            mat[x].a[0][1]=mat[x].a[0][0];
            push_up(x);
        }
    }
    

    时间复杂度是(O(nlog n)).由于树的形态不变,不需要make_root操作,也就不需要翻转标记和push_down.因此动态DP中的LCT的常数并没有那么大,很多时候跑的比树剖快。

    完整代码:

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #define INF 0x3f3f3f3f
    #define maxn 200000
    using namespace std;
    typedef long long ll;
    int n,m;
    struct edge {
    	int from;
    	int to;
    	int next;
    } E[maxn*2+5];
    int head[maxn+5];
    int esz=1;
    void add_edge(int u,int v) {
    	esz++;
    	E[esz].from=u;
    	E[esz].to=v;
    	E[esz].next=head[u];
    	head[u]=esz;
    }
    
    struct matrix {
    	ll a[2][2];
    	matrix(){
    		a[0][0]=a[0][1]=a[1][0]=a[1][1]=-INF;
    	}
    	inline void set(int x) {
    		for(int i=0; i<2; i++) {
    			for(int j=0; j<2; j++) a[i][j]=x;
    		}
    	}
    	friend matrix operator * (matrix p,matrix q) {
    		matrix ans;
    		ans.set(-INF);
    		for(int i=0; i<2; i++) {
    			for(int j=0; j<2; j++) {
    				for(int k=0; k<2; k++) {
    					ans.a[i][j]=max(ans.a[i][j],p.a[i][k]+q.a[k][j]);
    				}
    			}
    		}
    		return ans;
    	}
    } mat[maxn+5];
    ll val[maxn+5];
    ll f[maxn+5][2],g[maxn+5][2];
    
    struct LCT {
    #define lson(x) (tree[x].ch[0])
    #define rson(x) (tree[x].ch[1])
    #define fa(x) (tree[x].fa)
    	struct node {
    		int ch[2];
    		int fa;
    		matrix v;
    	} tree[maxn+5];
    	inline bool is_root(int x) { //注意合并顺序
    		return !(lson(fa(x))==x||rson(fa(x))==x);
    	}
    	inline int check(int x) {
    		return rson(fa(x))==x;
    	}
    	void push_up(int x) {
    		tree[x].v=mat[x];
    		if(lson(x)) tree[x].v=tree[lson(x)].v*tree[x].v;
    		if(rson(x)) tree[x].v=tree[x].v*tree[rson(x)].v;
    	}
    	void rotate(int x) {
    		int y=tree[x].fa,z=tree[y].fa,k=check(x),w=tree[x].ch[k^1];
    		tree[y].ch[k]=w;
    		tree[w].fa=y;
    		if(!is_root(y)) tree[z].ch[check(y)]=x;
    		tree[x].fa=z;
    		tree[x].ch[k^1]=y;
    		tree[y].fa=x;
    		push_up(y);
    		push_up(x);
    	}
    	void splay(int x) {
    		while(!is_root(x)) {
    			int y=fa(x);
    			if(!is_root(y)) {
    				if(check(x)==check(y)) rotate(y);
    				else rotate(x);
    			}
    			rotate(x);
    		}
    	}
    	void access(int x) {
    		//access的时候可能由实变虚,或由虚变实,因此要更新f,g,方法类似LCT维护虚子树信息
    		//这里和树剖向上跳重链更新是类似的
    		for(int y=0; x; y=x,x=fa(x)) {
    			splay(x);
    			//原来的rson(x)由实变虚
    			if(rson(x)){
    				mat[x].a[0][0]+=max(tree[rson(x)].v.a[0][0],tree[rson(x)].v.a[1][0]);//这里也可以不用f和g,直接写对应矩阵里的值 
    				mat[x].a[1][0]+=tree[rson(x)].v.a[0][0];
    			} 
    			rson(x)=y;
    			if(rson(x)){
    				mat[x].a[0][0]-=max(tree[rson(x)].v.a[0][0],tree[rson(x)].v.a[1][0]);
    				mat[x].a[1][0]-=tree[rson(x)].v.a[0][0];
    			}
    			mat[x].a[0][1]=mat[x].a[0][0];
    			push_up(x);
    		}
    	}
    	void change(int x,int v) {
    		access(x);
    		splay(x);
    		mat[x].a[1][0]+=v-val[x];
    		push_up(x);
    		val[x]=v;
    	}
    	ll query(int x) {
    		splay(1);//查询前记得splay到根 
    		return max(tree[1].v.a[0][0],tree[1].v.a[1][0]);
    	}
    } T;
    
    void dfs(int x,int fa) {
    	f[x][0]=0;
    	f[x][1]=val[x];
    	for(int i=head[x]; i; i=E[i].next) {
    		int y=E[i].to;
    		if(y!=fa) {
    			dfs(y,x);
    			f[x][0]+=max(f[y][0],f[y][1]);
    			f[x][1]+=f[y][0];
    		}
    	}
    	mat[x].a[0][0]=mat[x].a[0][1]=f[x][0];//一开始全是轻边,f=g 
    	mat[x].a[1][0]=f[x][1];
    	mat[x].a[1][1]=-INF;
    	T.tree[x].v=mat[x];//初始化LCT
    	T.tree[x].fa=fa; //记得初始化fa
    }
    int main() {
    	int u,v;
    	scanf("%d %d",&n,&m);
    	for(int i=1; i<=n; i++) scanf("%lld",&val[i]);
    	for(int i=1; i<n; i++) {
    		scanf("%d %d",&u,&v);
    		add_edge(u,v);
    		add_edge(v,u);
    	}
    	dfs(1,0);
    	for(int i=1; i<=m; i++) {
    		scanf("%d %d",&u,&v);
    		T.change(u,v);
    		printf("%lld
    ",T.query(1));
    	}
    }
    

    全局平衡二叉树解法

    我们在前面提到,动态DP的树剖+线段树解法和LCT解法的常数都不是很优秀。全局平衡二叉树解法很小,且实现简洁。

    全局平衡二叉树,实际上结合了LCT和树剖的特点。它的结构类似一棵静态的LCT,但修改方法又类似树剖。

    在LCT中,一开始所有边都是虚边,那么初次access的复杂度可能就是(O(n))了,虽然均摊的总复杂度是正确的,但常数不是很优秀。因此我们不妨用树剖的思路,初始时就对轻重边进行划分之后不再改变,并且划分方法要保证暴力向上跳的复杂度尽量小。

    我们知道,在LCT里每棵Splay维护的是原树中的一条链。在全局平衡二叉树中,每棵平衡二叉树(BST)维护的是一条重链的全部节点,BST之间用fa指针链接,BST根的fa指向这条重链顶端的父亲所在重链,那么暴力沿着fa指针往上跳,对于每个根节点进行(g)的修改即可。为了让树高为(O(log n))级别以保证跳重链的复杂度,我们要找到重链按轻子树大小的带权重心,把它作为根,然后递归向下对两边重链建BST,分别设为左右儿子。

    //stk[l,r]里存储当前重链的全部节点
    //sumsz存储轻子树大小的前缀和
    int get_bst(int l,int r) {
        if(l>r) return 0;
        int mid=lower_bound(sumsz+l,sumsz+r+1,(sumsz[l-1]+sumsz[r])/2)-sumsz;//求带权重心
        int x=stk[mid];
        lson(x)=get_bst(l,mid-1);
        rson(x)=get_bst(mid+1,r);//递归建树,这样的二叉树是平衡的
        if(lson(x)) fa(lson(x))=x;//类似LCT,初始化fa和son 
        if(rson(x)) fa(rson(x))=x;
        push_up(x);
        return x;
    }
    

    建树的过程同样可以通过DFS实现,我们先递归重儿子,得到整条重链,然后调用get_bst()对它建出BST.

    int build(int x,int f) {
        int rt=0;
        stk[++tot]=x;
        sumsz[tot]+=lsz[x];
        if(son[x]) { //继续dfs重链
            sumsz[tot+1]+=sumsz[tot];
            rt=build(son[x],x);
        } else { //到了重链底部,可以建二叉树了
            rt=get_bst(1,tot);
            for(int i=1; i<=tot; i++) sumsz[i]=0;
            tot=0;
            return rt;
        }
        for(int i=head[x]; i; i=E[i].next) {
            int y=E[i].to;
            if(y!=f&&y!=son[x]) fa(build(y,x))=x;//对于轻链,递归下去建树,再用fa把它们连起来
        }
        return rt;
    }
    

    修改操作很简单,直接沿着fa暴力跳,只有到了每棵BST的根时才需要修改。这里相当于树剖时跳到重链顶端的操作。

    void update(int x) {
        while(x) { //这一部分和树剖跳重链类似
            int f=fa(x);
            if(f&&is_root(x)) {//只有到了BST根的时候,说明已经处理完了整条重链,跳轻链到fa(x)更新上一条重链 
                //删掉原来的f的影响
                mat[f][0][0]-=max(tree[x].v[0][0],tree[x].v[1][0]);
                mat[f][0][1]-=max(tree[x].v[0][0],tree[x].v[1][0]);
                mat[f][1][0]-=tree[x].v[0][0];
            }
            push_up(x);
            if(f&&is_root(x)) {
                //更新现在的f的影响
                mat[f][0][0]+=max(tree[x].v[0][0],tree[x].v[1][0]);
                mat[f][0][1]+=max(tree[x].v[0][0],tree[x].v[1][0]);
                mat[f][1][0]+=tree[x].v[0][0];
            }
            x=fa(x);
        }
    }
    

    查询操作直接输出根节点的矩阵值即可。

    时间复杂度(O(nlog n)),且常数很小。

    完整代码:

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #define INF 0x3f3f3f3f
    #define maxn 200000
    using namespace std;
    typedef long long ll;
    int n,m;
    struct edge {
    	int from;
    	int to;
    	int next;
    } E[maxn*2+5];
    int head[maxn+5];
    int esz=1;
    void add_edge(int u,int v) {
    	esz++;
    	E[esz].from=u;
    	E[esz].to=v;
    	E[esz].next=head[u];
    	head[u]=esz;
    }
    
    struct matrix {
    	ll a[2][2];
    	matrix() {
    		a[0][0]=a[0][1]=a[1][0]=a[1][1]=-INF;
    	}
    	inline void set(int x) {
    		for(int i=0; i<2; i++) {
    			for(int j=0; j<2; j++) a[i][j]=x;
    		}
    	}
    	friend matrix operator * (matrix p,matrix q) {
    		matrix ans;
    		ans.set(-INF);
    		for(int i=0; i<2; i++) {
    			for(int j=0; j<2; j++) {
    				for(int k=0; k<2; k++) {
    					ans.a[i][j]=max(ans.a[i][j],p.a[i][k]+q.a[k][j]);
    				}
    			}
    		}
    		return ans;
    	}
    	ll* operator [](int i) {
    		return a[i];
    	}
    } mat[maxn+5];
    ll val[maxn+5];
    ll f[maxn+5][2],g[maxn+5][2];
    int sz[maxn+5],lsz[maxn+5],son[maxn+5];
    
    void dfs1(int x,int fa) {
    	sz[x]=lsz[x]=1;
    	f[x][0]=0;
    	f[x][1]=val[x];
    	for(int i=head[x]; i; i=E[i].next) {
    		int y=E[i].to;
    		if(y!=fa) {
    			dfs1(y,x);
    			f[x][0]+=max(f[y][0],f[y][1]);
    			f[x][1]+=f[y][0];
    			sz[x]+=sz[y];
    			if(sz[son[x]]<sz[y]) son[x]=y;
    		}
    	}
    	g[x][0]=0,g[x][1]=val[x];
    	for(int i=head[x]; i; i=E[i].next) {
    		int y=E[i].to;
    		if(y!=fa&&y!=son[x]) {
    			g[x][0]+=max(f[y][0],f[y][1]);
    			g[x][1]+=f[y][0];
    			lsz[x]+=sz[y];
    		}
    	}
    	mat[x].a[0][0]=g[x][0];
    	mat[x].a[0][1]=g[x][0];
    	mat[x].a[1][0]=g[x][1];
    	mat[x].a[1][1]=-INF;
    }
    
    struct BST {
    #define fa(x) (tree[x].fa)
    #define lson(x) (tree[x].ch[0])
    #define rson(x) (tree[x].ch[1])
    	int root;
    	int tot;
    	int stk[maxn+5];//存储当前重链
    	int sumsz[maxn+5];//存储重链上点的lsz之和
    	struct node {
    		int fa;//全局平衡二叉树上的父亲
    		int ch[2];
    		matrix v;
    	} tree[maxn+5];
    	inline bool is_root(int x) { //注意合并顺序
    		return !(lson(fa(x))==x||rson(fa(x))==x);
    	}
    	void push_up(int x) {//很多函数和LCT是一样的
    		tree[x].v=mat[x];
    		if(lson(x)) tree[x].v=tree[lson(x)].v*tree[x].v;
    		if(rson(x)) tree[x].v=tree[x].v*tree[rson(x)].v;
    	}
    
    	int get_bst(int l,int r) {
    		if(l>r) return 0;
    		int mid=lower_bound(sumsz+l,sumsz+r+1,(sumsz[l-1]+sumsz[r])/2)-sumsz;//求带权重心
    		int x=stk[mid];
    		lson(x)=get_bst(l,mid-1);
    		rson(x)=get_bst(mid+1,r);//递归建树,这样的二叉树是平衡的
    		if(lson(x)) fa(lson(x))=x;//类似LCT,初始化fa和son 
    		if(rson(x)) fa(rson(x))=x;
    		push_up(x);
    		return x;
    	}
    	int build(int x,int f) {
    		int rt=0;
    		stk[++tot]=x;
    		sumsz[tot]+=lsz[x];
    		if(son[x]) { //继续dfs重链
    			sumsz[tot+1]+=sumsz[tot];
    			rt=build(son[x],x);
    		} else { //到了重链底部,可以建二叉树了
    			rt=get_bst(1,tot);
    			for(int i=1; i<=tot; i++) sumsz[i]=0;
    			tot=0;
    			return rt;
    		}
    		for(int i=head[x]; i; i=E[i].next) {
    			int y=E[i].to;
    			if(y!=f&&y!=son[x]) fa(build(y,x))=x;//对于轻链,递归下去建树,再用fa把它们连起来
    		}
    		return rt;
    	}
    	void update(int x) {
    		while(x) { //这一部分和树剖跳重链类似
    			int f=fa(x);
    			if(f&&is_root(x)) {//只有到了BST根的时候,说明已经处理完了整条重链,跳轻链到fa(x)更新上一条重链 
    				mat[f][0][0]-=max(tree[x].v[0][0],tree[x].v[1][0]);
    				mat[f][0][1]-=max(tree[x].v[0][0],tree[x].v[1][0]);
    				mat[f][1][0]-=tree[x].v[0][0];
    			}
    			push_up(x);
    			if(f&&is_root(x)) {
    				mat[f][0][0]+=max(tree[x].v[0][0],tree[x].v[1][0]);
    				mat[f][0][1]+=max(tree[x].v[0][0],tree[x].v[1][0]);
    				mat[f][1][0]+=tree[x].v[0][0];
    			}
    			x=fa(x);
    		}
    	}
    	void ini(){
    		dfs1(1,0);
    		root=build(1,0);
    	}
    	void change(int x,int v) {
    		mat[x][1][0]+=v-val[x];
    		val[x]=v;
    		update(x);
    	}
    	ll query(){
    		return max(tree[root].v[0][0],tree[root].v[1][0]);
    	}
    } T;
    
    
    int main() {
    	int u,v;
    	scanf("%d %d",&n,&m);
    	for(int i=1; i<=n; i++) scanf("%lld",&val[i]);
    	for(int i=1; i<n; i++) {
    		scanf("%d %d",&u,&v);
    		add_edge(u,v);
    		add_edge(v,u);
    	}
    	T.ini();
    //	T.debug();
    	for(int i=1; i<=m; i++) {
    		scanf("%d %d",&u,&v);
    		T.change(u,v);
    		printf("%lld
    ",T.query());
    	}
    }
    
    

    例题

    [NOIP2018]保卫王国

    给出一棵(n)个点树,有m组询问,每次询问给出两个点,规定他们必须选或必须不选。求树的最小权覆盖集。(n,m leq 10^5)

    此题有倍增+树形dp的做法,常数非常优秀,但思路比较难想到。

    显然最小权覆盖集=总点权和-最大权独立集
    看到最大权独立集,可以直接套上面的模板

    考虑如何处理询问。由于我们要权值最小,如果必须选某个点,就把它的点权修改为(-infty),如果必须不选,就修改为(+infty).代码实现上就把它修改成大于所有点权值之和的数即可,如(10^{10}).然后用板子求最大权独立集,再用总和减去。注意当我们把点权修改为(-infty)时,最小权覆盖集会包含(-infty),这时算出的和并不是真正答案,还要加上(v_x-(-infty)),其中(v_x)是被强制选的值。

    因为树剖和LCT两种实现动态DP的方式常数过大,没有O2的情况下会TLE,而众所周知NOIP是没有O2优化的。因此这里只给出全局平衡二叉树写法的代码。

    为了节约篇幅,代码见这里

    [LuoguP4426][AHOI2018]毒瘤

    给出一个(n)个点(m)条边的无向图,求独立集个数。
    (n leq 10^5,n-1 leq m leq n+10)

    注意到(|m-n|)很小,我们可以暴力枚举这些非树边((u,v))的状态,按两边选和不选有(0,0)(0,1)(1,0)三种。其实可以合并为2种:

    1. (u)强制不选,(v)可任意选
    2. (u)强制选,(v)强制不选

    那么直接暴力枚举每条边的状态,然后在树上修改,做动态DP即可。

    (f_{x,0},f_{x,1})分别表示(x)不选/选,(x)子树中的独立集个数,那么:
    (f_{x,0}=1+prod_{y in child(x)} (f_{y,0}+f_{y,1}))
    (f_{x,1}=1+prod_{y in child(x)} f_{y,0})

    最终答案为(f_{x,0}+f_{x,1})

    (g_{x,0}=1+prod_{y in child(x)-{son(x)}} (f_{y,0}+f_{y,1}))

    (g_{x,1}=1+prod_{y in child(x)-{son(x)}} f_{y,0})

    g维护了所有轻儿子的DP贡献,那么有:

    (f_{x,0}=(f_{son(x),0}+f_{son(x),1})cdot g_{x,0})
    (f_{x,1}=f_{son(x),0} cdot g_{x,1})

    写成矩阵的形式(注意这里是+,(cdot)矩阵乘法)

    [egin{bmatrix}f_{x,0} \ f_{x,1} end{bmatrix}=egin{bmatrix}g_{x,0} g_{x,0} \ g_{x,1} 0 end{bmatrix} egin{bmatrix}f_{son(x),0} \ f_{son(x),1} end{bmatrix} ]

    (m{M_x}=egin{bmatrix}g_{x,0} g_{x,0} \ g_{x,1} 0 end{bmatrix})。为了处理强制选和不选的情况,我们还需要对每个节点定义一个矩阵(m{C_x}),求区间矩阵积的时候把乘(m{M_x})变成乘(m{C_xM_x})

    注意到(egin{bmatrix} 0 0 \ 0 1end{bmatrix}egin{bmatrix}f_{x,0} \ f_{x,1} end{bmatrix}=egin{bmatrix}0 \ f_{x,1} end{bmatrix}),于是使得(f_{x,0}=0),那么(m{C_x}=egin{bmatrix} 0 0 \ 0 1end{bmatrix})就表示强制选(x).同理(m{C_x}=egin{bmatrix} 1 0 \ 0 0end{bmatrix})就表示强制不选(x),(m{C_x}=egin{bmatrix} 1 0 \ 0 1end{bmatrix})就表示选和不选(x)均可。于是枚举的时候单点修改即可。

    但是还有一个问题,在动态DP的过程中,我们需要把儿子的影响从父亲中消除,也就是说要做除法。但是万一(f_y=0),就会出现除0的问题。于是我们可以对于每个(f)(g)值,记录它们被乘进去了几个0,做除法的时候0的个数会减少。如果减到了0,就变成了它们的真实值。具体实现可以定义一个新的类,重载它的*,/运算符

    struct mynum { //为了消除下方g对上方g的影响,要支持撤回乘0操作
    	ll val;
    	int cnt;//记录被乘上去的0个数
    	mynum() {
    		val=cnt=0;
    	}
    	mynum(ll _val) {
    		if(_val==0) val=cnt=1;
    		else val=_val,cnt=0;
    	}
    	friend mynum operator * (mynum p,mynum q) {
    		mynum ans;
    		ans.val=p.val*q.val%mod;//把0的val设为1,这样乘的时候val就不变
    		ans.cnt=p.cnt+q.cnt;
    		return ans;
    	}
    	friend mynum operator / (mynum p,mynum q) {
    		mynum ans;
    		ans.val=p.val*inv(q.val)%mod;
    		ans.cnt=p.cnt-q.cnt;
    		return ans;
    	}
    	ll value() {
    		if(cnt==0) return val;
    		else return 0;
    	}
    };
    

    用LCT实现,复杂度(O(n+m+2^{m-n}log n)),常数还可以。

    为了节约篇幅,代码见这里

    总结

    我们介绍了动态DP的通用解法:把DP方程写成矩阵形式,然后用矩阵乘法维护信息。然后又把它应用到树上,和轻重链剖分,LCT与全局平衡二叉树结合。值得注意的是,动态DP有常数较大和代码量较大的缺点,对于许多题目,实际上可以不用这种通用解法解决,而是有更灵活的解法,需要结合实际问题分析。

  • 相关阅读:
    HDU 1010 Tempter of the Bone(DFS剪枝)
    HDU 1013 Digital Roots(九余数定理)
    HDU 2680 Choose the best route(反向建图最短路)
    HDU 1596 find the safest road(最短路)
    HDU 2072 单词数
    HDU 3790 最短路径问题 (dijkstra)
    HDU 1018 Big Number
    HDU 1042 N!
    NYOJ 117 求逆序数 (树状数组)
    20.QT文本文件读写
  • 原文地址:https://www.cnblogs.com/birchtree/p/12681885.html
Copyright © 2020-2023  润新知