最近学了动态(DP),非常精妙的方法啊
写一篇博客记录一下
例题1
这道题(Luogu1352),相信大家都做过.
可以很快地写出转移方程.
令(f[i][1])表示第i个节点选择的最大快乐度,(f[i][0])表示第i个节点不选的快乐度,(f[u][0]=sum_{vin sonu}max(f[u][0],f[u][1]),f[u][1]=sum_{vin sonu}f[u][0])
现在,增加一个操作,就是修改点权,然后依然查询最大独立集.
于是,难度瞬间变成了黑题.
现在考虑如何维护.
首先,因为要修改,常用的是树剖.
我们可以在前文所述的(f)数组上增加一个(g)数组
然后一条一条重链(dp).
这里我们在处理一条重链的时候,先处理与它相连的所有重链,最后处理它.(g)数组维护所有轻儿子的信息,而(f)数组维护(g)和重儿子的信息(也就是所有儿子).因此可以很快地写出转移方程
(g[u][0]=sum_{vin lightsonu}max(f[u][0],f[u][1]),g[u][1]=sum_{vin lightsonu}f[u][0])
令(v)为(u)的重儿子,则
(f[u][0]=g[u][0]+min(f[v][1],f[v][0]),f[u][1]=f[v][0]+g[u][1])
于是可以支持用线段树修改
一个常见的黑科技是把转移方程写成矩阵的形式,用线段树(树链剖分)维护矩阵乘积即可.
但是这个矩阵不是通常意义下的矩阵.
我们常写的矩阵是这样的
(a[i][j]=sum_{k=1}^na[i][k]*a[k][j])
而现在的矩阵是这样的
(a[i][j]=max{a[i][k]*a[k][j]})
矩阵的所有性质都可行(结合律,交换律,分配律...)
因此我们就做完了这个问题,时间复杂度(O(2^3nlog^2n))
代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#define N (100010)
#define M (200010)
#define inf (0x7f7f7f7f)
#define rg register int
#define Label puts("NAIVE")
typedef long double ld;
typedef long long LL;
typedef unsigned long long ull;
using namespace std;
inline char read(){
static const int IN_LEN=1000000;
static char buf[IN_LEN],*s,*t;
return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x){
static bool iosig;
static char c;
for(iosig=false,c=read();!isdigit(c);c=read()){
if(c=='-')iosig=true;
if(c==-1)return;
}
for(x=0;isdigit(c);c=read())x=((x+(x<<2))<<1)+(c^'0');
if(iosig)x=-x;
}
inline char readchar(){
static char c;
for(c=read();!isalpha(c);c=read())
if(c==-1)return 0;
return c;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN],*ooh=obuf;
inline void print(char c) {
if(ooh==obuf+OUT_LEN)fwrite(obuf,1,OUT_LEN,stdout),ooh=obuf;
*ooh++=c;
}
template<class T>
inline void print(T x){
static int buf[30],cnt;
if(x==0)print('0');
else{
if(x<0)print('-'),x=-x;
for(cnt=0;x;x/=10)buf[++cnt]=x%10+48;
while(cnt)print((char)buf[cnt--]);
}
}
inline void flush(){fwrite(obuf,1,ooh-obuf,stdout);}
struct Matrix{
LL a[2][2];
Matrix(){memset(a,0,sizeof(a));}
Matrix operator *(Matrix x){
Matrix res;
for(int i=0;i<2;i++)
for(int j=0;j<2;j++)
for(int k=0;k<2;k++)
res.a[i][j]=max(res.a[i][j],a[i][k]+x.a[k][j]);
return res;
}
}a[N<<4],val[N],ans;
int n,m,w[N];
int fi[M],ne[M],b[M],E,ind;
int top[N],fa[N],siz[N],son[N],dfn[N],rdfn[N],ed[N];
LL f[N][2];
void add(int x,int y){
ne[++E]=fi[x],fi[x]=E,b[E]=y;
}
void dfs1(int u,int pre){
int maxsiz=-1,ma=0; fa[u]=pre;
for(int i=fi[u];i;i=ne[i]){
int v=b[i];
if(v==pre)continue;
dfs1(v,u);
if(siz[v]>maxsiz)maxsiz=siz[v],ma=v;
siz[u]+=siz[v];
}
son[u]=ma,siz[u]++;
}
void dfs2(int u){
dfn[u]=++ind,rdfn[ind]=u;
if(!son[u]){ed[u]=u;return;}
top[son[u]]=top[u],dfs2(son[u]),ed[u]=ed[son[u]];
for(int i=fi[u];i;i=ne[i]){
int v=b[i];
if(v==son[u]||v==fa[u])continue;
top[v]=v,dfs2(v);
}
}
void dp(int u){
for(int i=fi[u];i;i=ne[i]){
int v=b[i];
if(v==fa[u])continue;
dp(v),f[u][0]+=max(f[v][0],f[v][1]);
f[u][1]+=f[v][0];
}
f[u][1]+=1ll*w[u];
}
void build(int l,int r,int x){
if(l==r){
int u=rdfn[l],g0=0,g1=w[u];
for(int i=fi[u];i;i=ne[i]){
int v=b[i];
if(v==fa[u]||v==son[u])continue;
g0+=max(f[v][0],f[v][1]),g1+=f[v][0];
}
a[x].a[0][0]=a[x].a[0][1]=g0;
a[x].a[1][0]=g1,a[x].a[1][1]=-inf;
val[l]=a[x];
return;
}
int mid=(l+r)>>1;
build(l,mid,x*2),build(mid+1,r,x*2+1);
a[x]=a[x*2]*a[x*2+1];
}
void change(int k,int l,int r,int x){
if(l==r){
a[x]=val[l];
return;
}
int mid=(l+r)>>1;
if(k<=mid)change(k,l,mid,x*2);
else change(k,mid+1,r,x*2+1);
a[x]=a[x*2]*a[x*2+1];
}
Matrix query(int l,int r,int L,int R,int x){
if(l==L&&r==R)return a[x];
int mid=(L+R)>>1;
if(r<=mid)return query(l,r,L,mid,x*2);
else if(l>mid)return query(l,r,mid+1,R,x*2+1);
else return query(l,mid,L,mid,x*2)*query(mid+1,r,mid+1,R,x*2+1);
}
void update(int u,int t){
int pos=dfn[u];
val[pos].a[1][0]+=t-w[u],w[u]=t;
Matrix pre,now;
while(u){
pre=query(dfn[top[u]],dfn[ed[u]],1,n,1);
change(pos,1,n,1);
now=query(dfn[top[u]],dfn[ed[u]],1,n,1);
u=fa[top[u]],pos=dfn[u];
val[pos].a[0][0]+=max(now.a[0][0],now.a[1][0])-max(pre.a[0][0],pre.a[1][0]);
val[pos].a[0][1]=val[pos].a[0][0];
val[pos].a[1][0]+=now.a[0][0]-pre.a[0][0];
}
}
int main(){
read(n),read(m);
for(int i=1;i<=n;i++)read(w[i]);
for(int i=1,x,y;i<n;i++){
read(x),read(y);
add(x,y),add(y,x);
}
dfs1(1,0),top[1]=1,dfs2(1),fa[1]=0;
dp(1),build(1,n,1);
while(m--){
int x,y;
read(x),read(y);
update(x,y);
ans=query(1,dfn[ed[1]],1,n,1);
printf("%lld
",max(ans.a[0][0],ans.a[1][0]));
}
}
例题2
这里是题目
题意
给定一棵树,每个点有点权,求某棵子树内的一个连通块,使它的权值和最大.带修改.
分析
还是动态(dp).令(f[u])表示以(u)为根的子树中,选(u)时的最大权值和.
那么,(f[u]=max(0,w[u]+sum f[v])).令(s[u])表示以u为根的子树中最大权值和.那么,(s[u]=max(s[v],f[u])).答案就是(s[x]),(x)为询问的节点.
由于要支持修改,因此使用常见的套路,将轻重链分开.令(g[u]=f[u]-f[heavyson[u]]),则(f[u]=f[heavyson[u]]+g[u])
废话,这样除了大常数还有什么用
于是我们发现它变成了区间最大子段和的形式,可以用线段树维护这个东西.
现在我们还需要维护(s).考虑对于每个节点建一个堆维护(s).由于要资瓷删除旧版本,因此用两个(priority\_queue)维护就好了.
代码如下
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#include<queue>
#define N (200010)
#define M (N<<1)
#define inf (0x7f7f7f7f)
#define rg register int
typedef long double ld;
typedef long long LL;
typedef unsigned long long ull;
using namespace std;
inline char read(){
static const int IN_LEN=1000000;
static char buf[IN_LEN],*s,*t;
return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x){
static bool iosig;
static char c;
for(iosig=false,c=read();!isdigit(c);c=read()){
if(c=='-')iosig=true;
if(c==-1)return;
}
for(x=0;isdigit(c);c=read())x=((x+(x<<2))<<1)+(c^'0');
if(iosig)x=-x;
}
inline char readchar(){
static char c;
for(c=read();!isalpha(c);c=read())
if(c==-1)return 0;
return c;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN],*ooh=obuf;
inline void print(char c) {
if(ooh==obuf+OUT_LEN)fwrite(obuf,1,OUT_LEN,stdout),ooh=obuf;
*ooh++=c;
}
template<class T>
inline void print(T x){
static int buf[30],cnt;
if(x==0)print('0');
else{
if(x<0)print('-'),x=-x;
for(cnt=0;x;x/=10)buf[++cnt]=x%10+48;
while(cnt)print((char)buf[cnt--]);
}
}
inline void flush(){fwrite(obuf,1,ooh-obuf,stdout);}
struct heap{
priority_queue<LL>p,q;
void push(LL x){p.push(x);}
void erase(LL x){q.push(x);}
LL top(){
while(!q.empty()&&p.top()==q.top())p.pop(),q.pop();
return(p.empty()?0:p.top());
}
}mx[N];
struct seg{
LL lm,rm,mx,sum;
seg(){lm=rm=mx=sum=0;}
seg operator +(seg x){
seg res;
res.sum=sum+x.sum;
res.lm=max(lm,sum+x.lm);
res.rm=max(x.rm,rm+x.sum);
res.mx=max(max(x.mx,mx),x.lm+rm);
return res;
}
};
struct xds{int l,r;seg v;}a[N<<3];
int n,m,fi[N],ne[M],b[M],w[N],E;
int dep[N],dfn[N],rdfn[N],siz[N],son[N],ed[N],top[N],fa[N],ind;
LL g[N],f[N],s[N];
void dfs1(int u,int pre){
siz[u]=1,fa[u]=pre;
for(int i=fi[u];i;i=ne[i]){
int v=b[i];
if(v==pre)continue;
dfs1(v,u),siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u){
dfn[u]=++ind,rdfn[ind]=u;
if(!son[u]){ed[u]=u;return;}
top[son[u]]=top[u],dfs2(son[u]),ed[u]=ed[son[u]];
for(int i=fi[u];i;i=ne[i]){
int v=b[i];
if(v==son[u]||v==fa[u])continue;
top[v]=v,dfs2(v);
}
}
void dp(int u){
for(int i=fi[u];i;i=ne[i]){
int v=b[i];
if(v!=fa[u])dp(v);
if(v!=son[u])g[u]+=f[v],mx[u].push(s[v]);
}
g[u]+=w[u],f[u]=max(g[u]+f[son[u]],(LL)0);
s[u]=max(s[son[u]],max(f[u],mx[u].top()));
}
void build(int l,int r,int x){
a[x].l=l,a[x].r=r;
if(l==r){
int t=rdfn[l];
a[x].v.lm=a[x].v.rm=max(g[t],0ll);
a[x].v.mx=max(g[t],mx[t].top());
a[x].v.sum=g[t]; return;
}
int mid=(l+r)>>1;
build(l,mid,x*2),build(mid+1,r,x*2+1);
a[x].v=a[x*2].v+a[x*2+1].v;
}
void change(int k,int x){
if(a[x].l==a[x].r){
int t=rdfn[a[x].l];
a[x].v.lm=a[x].v.rm=max(g[t],0ll);
a[x].v.mx=max(g[t],mx[t].top());
a[x].v.sum=g[t]; return;
}
int mid=(a[x].l+a[x].r)>>1;
if(k<=mid)change(k,x*2);else change(k,x*2+1);
a[x].v=a[x*2].v+a[x*2+1].v;
}
seg query(int l,int r,int x){
if(a[x].l==l&&a[x].r==r)return a[x].v;
int mid=(a[x].l+a[x].r)>>1;
if(r<=mid)return query(l,r,x*2);
else if(l>mid)return query(l,r,x*2+1);
else return query(l,mid,x*2)+query(mid+1,r,x*2+1);
}
void modify(int u,int st,LL val){
seg pre,now;
while(u){
if(u!=st)mx[u].erase(pre.mx),mx[u].push(now.mx);
pre=query(dfn[top[u]],dfn[ed[u]],1);
g[u]+=val,change(dfn[u],1);
now=query(dfn[top[u]],dfn[ed[u]],1);
val=now.lm-f[top[u]],f[top[u]]=now.lm,u=fa[top[u]];
}
}
void add(int x,int y){ne[++E]=fi[x],fi[x]=E,b[E]=y;}
int main(){
read(n),read(m);
for(int i=1;i<=n;i++)read(w[i]);
for(int i=1,x,y;i<n;i++){
read(x),read(y);
add(x,y),add(y,x);
}
dfs1(1,0),top[1]=1,dfs2(1),dp(1),build(1,n,1);
while(m--){
char ch=readchar(); int x; read(x);
if(ch=='M'){LL val;read(val),modify(x,x,val-w[x]),w[x]=val;}
else print(query(dfn[x],dfn[ed[x]],1).mx),print('
');
}
return flush(),0;
}