动态 dp
动态 (dp) 是由猫坤大佬在 WC 2018 提出来的黑科技。
他主要解决得是带修改的 (dp) 问题,主要思路是由矩阵乘法来维护 (dp) 转移
我们先来看一道模板题 动态dp
这道题,我们先来看不带修改的情况,
转移和状态很容易就能列出来
设 (f[i][0/1]) 表示 以 (i) 为根的子树中,且选 (不选) (i) 这个点的最大权值和
转移方程可以写成:
(f[i][0] = displaystyle sum_{to in son[i]} max (f[to][0],f[to][1]))
(f[i][1] = displaystyle sum_{to in son[x]} f[to][0] + w[x])
当我们修改一个点的权值,那么从他到根节点的路径上的点的 (dp) 值都会受到影响。
所以,我们就可以只修改这一条链上的信息(用树链剖分来维护)。
动态 (dp) 的思想就是将 重儿子和轻儿子的贡献分开来考虑。
设 (g[i][0] = displaystyle sum _{to in 轻儿子} max(f[to][0],f[to][1])), $g[i][1] = displaystyle sum_{to in 轻儿子} f[to][0] +w[i] $.
解释一下 (g[i][0]) 表示 选 或不选 (i) 的轻儿子的最大贡献和, (g[i][1]) 表示不选他轻儿子的价值和加上他本身的权值.
这两个可以在求 (f[x][0/1]) 的时候顺便维护出来
那么,上面的方程就可以写成。
(f[i][0] = g[i][0] + max(f[son[i]][0],f[son][i][1]))
(f[i][1] = g[i][1] + f[son[i]][0]) ((son[x]) 表示 (x) 的重儿子)
定义广义矩阵乘法:
矩阵 (c) 为 矩阵 A 和矩阵 B 的乘积
那么,(c[i][j] = max( a[i][k] + b[k][j]))
这个其实就是把普通的矩阵乘法的乘号改为加号,加号改为取 (max)
这种矩阵乘法也满足普通矩阵乘法的性质,不相信的可以跑几组数据试试
代码实现长这样:
node operator *(node x,node y)
{
node ans;
for(int i = 0; i <= n; i++)
for(int j = 0; j <= n; j++)
for(int k = 0; k <= n; k++)
ans.a[i][j] = max(ans.a[i][j],x.a[i][k] + y.a[k][j]);
return ans;
}
然后,我们就可以根据这个广义矩阵乘法构造一个矩阵,即,
其中 第二个矩阵是我们要 确定的转移矩阵。
第二个矩阵构造出来长这样:
但,我们发现一个问题,我们矩阵中没有这个 (f) 值,这,我们一开始的矩阵就没法转移啊。
其实,对于每一个重链的末尾节点,他的转移矩阵就是他的 (f) 值,
也就是说重链的末端节点保存了 f的准确值 -by treaker
这样,我们就可以由这个矩阵来推出这一条重链上每个点的信息。
我们的暴力做法就是一直跳他的父亲就是一直跳他的父亲节点,再把矩阵乘起来。
但,这样的复杂度还是不能接受。我们可以把这个矩阵放到线段树上来维护。
线段树的每个叶子节点存这个点的转移矩阵,每个节点表示左右儿子矩阵的乘积,
这样就可以用 O(log n) 的时间得到我们想要的信息。
在结合树剖套线段树的做法,就能做到维护这棵树的信息。
注意:矩阵乘法要右乘,因为我们是从 dfn序大的地方跳到 dfn序小的地方,在线段树上对应的是从右往左。
我们上面的转移矩阵就需要重构一下,变成
修改操作,上文我们提到了 ‘当我们修改一个点的权值,那么从他到根节点的路径上的点的 (dp) 值都会受到影响’。
我们修改的就是从这个点到根节点的路径的矩阵,我们把这条链琛出来,发现他是重链和轻边交替在一起的。
我们仔细观察一下我们构造的矩阵,发现里面维护的是轻儿子的信息,不会涉及到重儿子。
对于重链,我们不用修改,但对于重链链顶 $top[x] $和轻边 $fa[top[x]] $交替的地方,我们需要单点修改。
因为此时链顶 (top[x]) 属于 $fa[top[x]] $轻儿子,他的修改会对他父亲的转移矩阵造成影响。
我们关键要算出他改变之后对他父亲转移矩阵的影响。
他的 (f) 值我们可以由下面的推出来,然后可以根据 (g) 数组的定义算出他改变之后对他父亲的影响。
就这样在更新,每次跳重链,直到跳到根节点,我们的修改操作就大工告成了。
注意一下,我们不能修改之后在统计他的 轻儿子的值,那样可能不对,我们就只能通过增量法来修改。
记录一下原来状态的矩阵,在记录修改之后的矩阵,两者之差就是对他父亲转移矩阵的贡献。
每次修改操作的时间复杂度为 O((log^2 n))
具体代码长这样
void modify(int x,int val)
{
base[dfn[x]].a[1][0] += val - w[x];//增量法统计他修改的贡献
w[x] = val;
while(x)
{
node Old = get_node(top[x]);//记录一下链顶修改之前的矩阵
chenge(1,1,n,dfn[x]);//修改当前这个节点的转移矩阵
node New = get_node(top[x]);//得到链顶修改之后的转移矩阵
int fx = dfn[fa[top[x]]];
base[fx].a[0][0] += max(New.a[0][0],New.a[1][0]) - max(Old.a[0][0],Old.a[1][0]);//算他修改对他父亲转移矩阵的影响
base[fx].a[0][1] += max(New.a[0][0],New.a[1][0]) - max(Old.a[0][0],Old.a[1][0]);
base[fx].a[1][0] += New.a[0][0] - Old.a[0][0];
x = fa[top[x]];//跳链修改
}
}
总代码:
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int inf = 233333333;
const int N = 1e5+10;
int n,m,tot,x,y,num,val,u,v;
int head[N],top[N],dep[N],siz[N],fa[N],son[N],ord[N],end[N],f[N][2],g[N][2],dfn[N],w[N];
struct bian
{
int to,net;
}e[N<<1];
struct node
{
int a[2][2];
node() {memset(a,-127/3,sizeof(a));}
}tr[N<<2],base[N];
inline int read()
{
int s = 0,w = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
while(ch >= '0' && ch <= '9'){s =s * 10+ch - '0'; ch = getchar();}
return s * w;
}
node operator *(node x,node y)//广义矩阵乘法
{
node ans;
for(int i = 0; i <= 1; i++)
{
for(int j = 0; j <= 1; j++)
{
for(int k = 0; k <= 1; k++)
{
ans.a[i][j] = max(ans.a[i][j],x.a[i][k] + y.a[k][j]);
}
}
}
return ans;
}
void add(int x,int y)
{
e[++tot].to = y;
e[tot].net = head[x];
head[x] = tot;
}
void get_tree(int x)//树剖预处理
{
dep[x] = dep[fa[x]] + 1; siz[x] = 1;
for(int i = head[x]; i; i = e[i].net)
{
int to = e[i].to;
if(to == fa[x]) continue;
fa[to] = x;
get_tree(to);
siz[x] += siz[to];
if(siz[son[x]] < siz[to] || son[x] == -1) son[x] = to;
}
}
void dfs(int x,int topp)
{
top[x] = topp; dfn[x] = ++num; ord[num] = x;//END 记录这条重链的链尾,ord记录当前 当前这个节点的树上编号
if(son[x] == -1)
{
end[topp] = num;
return;
}
dfs(son[x],topp);
for(int i = head[x]; i; i = e[i].net)
{
int to = e[i].to;
if(to == fa[x] || to == son[x]) continue;
dfs(to,to);
}
}
void dp(int x,int fa)//预处理为没修改之前的 f 值与 g 值
{
g[x][1] = f[x][1] = w[x];
for(int i = head[x]; i; i = e[i].net)
{
int to = e[i].to;
if(to == fa) continue;
dp(to,x);
f[x][0] += max(f[to][0], f[to][1]);
f[x][1] += f[to][0];
if(to != son[x])
{
g[x][0] += max(f[to][0], f[to][1]);
g[x][1] += f[to][0];
}
}
}
void up(int o)
{
tr[o] = tr[o<<1] * tr[o<<1|1];
}
void build(int o,int L,int R)//线段树建树
{
if(L == R)
{
int tmp = ord[L];
tr[o].a[0][0] = tr[o].a[0][1] = g[tmp][0];//构造转移矩阵
tr[o].a[1][0] = g[tmp][1];
base[L] = tr[o];
return;
}
int mid = (L + R)>>1;
build(o<<1,L,mid);
build(o<<1|1,mid+1,R);
up(o);
}
void chenge(int o,int l,int r,int x)//单点修改
{
if(l == r)
{
tr[o] = base[l];
return;
}
int mid = (l + r)>>1;
if(x <= mid) chenge(o<<1,l,mid,x);
if(x > mid) chenge(o<<1|1,mid+1,r,x);
up(o);
}
node query(int o,int l,int r,int L,int R)//区间查询
{
if(L <= l && R >= r) return tr[o];
int mid = (l + r)>>1;
if(R <= mid) return query(o<<1,l,mid,L,R);
if(L > mid) return query(o<<1|1,mid+1,r,L,R);
return query(o<<1,l,mid,L,R) * query(o<<1|1,mid+1,r,L,R);
}
node get_node(int x)//得到链顶的 f 值
{
return query(1,1,n,dfn[x],end[top[x]]);
}
void modify(int x,int val)
{
base[dfn[x]].a[1][0] += val - w[x];//增量法统计他修改的贡献
w[x] = val;
while(x)
{
node Old = get_node(top[x]);//记录一下链顶修改之前的矩阵
chenge(1,1,n,dfn[x]);//修改当前这个节点的转移矩阵
node New = get_node(top[x]);//得到链顶修改之后的转移矩阵
int fx = dfn[fa[top[x]]];
base[fx].a[0][0] += max(New.a[0][0],New.a[1][0]) - max(Old.a[0][0],Old.a[1][0]);//算他修改对他父亲转移矩阵的影响
base[fx].a[0][1] += max(New.a[0][0],New.a[1][0]) - max(Old.a[0][0],Old.a[1][0]);
base[fx].a[1][0] += New.a[0][0] - Old.a[0][0];
x = fa[top[x]];//跳链修改
}
}
int main()
{
n = read(); m = read();
for(int i = 1; i <= n; i++)
{
w[i] = read();
son[i] = -1;
}
for(int i = 1; i <= n-1; i++)
{
u = read(); v = read();
add(u,v); add(v,u);
}
get_tree(1); dfs(1,1); dp(1,0); build(1,1,n);//预处理
for(int i = 1; i <= m; i++)
{
x = read(); val = read();
modify(x,val);//修改操作
node ans = get_node(1);//得到新答案
printf("%d
",max(ans.a[0][0],ans.a[1][0]));
}
return 0;
}