背景
题意
给定一个数列,要求单点修改、区间查询最大子段和。
解法
线段树单点修改、区间查询模板。维护每个区间和 (sum) 、区间最大子段和 (val) 、区间从左端点开始向右的最大子段和 (lmax) 、区间从右端点开始向左的最大子段和 (rmax) 。
(trick)
查询时先加入一个节点,后来查询到的节点直接与该节点比较更新即可。
细节
(1.) 建树和修改时各个信息均要赋值。(好傻逼的错误啊)
(2.) 查询中比较更新时注意各信息顺序,修改时要用到别的信息值的先更新。
代码
#include<bits/stdc++.h>
using namespace std;
inline int read()
{
int ret=0,f=1;
char ch=getchar();
while(ch>'9'||ch<'0')
{
if(ch=='-')
f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
ret=(ret<<1)+(ret<<3)+ch-'0';
ch=getchar();
}
return ret*f;
}
int n,q,a[500005],op,x,val,l,r,ans;
bool flag;
struct SegmentTree
{
int l;
int r;
int val;
int sum;
int lmax;
int rmax;
}t[2000005],tmp;
void build(int pos,int l,int r)
{
t[pos].l=l;
t[pos].r=r;
if(l==r)
{
t[pos].sum=a[l];
t[pos].lmax=a[l];
t[pos].rmax=a[l];
t[pos].val=a[l];
return;
}
int mid=(l+r)>>1;
build(pos<<1,l,mid);
build(pos<<1|1,mid+1,r);
t[pos].sum=t[pos<<1].sum+t[pos<<1|1].sum;
t[pos].lmax=max(t[pos<<1].lmax,t[pos<<1].sum+t[pos<<1|1].lmax);
t[pos].rmax=max(t[pos<<1|1].rmax,t[pos<<1].rmax+t[pos<<1|1].sum);
t[pos].val=max(max(t[pos<<1].val,t[pos<<1|1].val),t[pos<<1].rmax+t[pos<<1|1].lmax);
}
void modify(int pos,int x,int val)
{
if(t[pos].l==t[pos].r)
{
t[pos].sum=val;
t[pos].lmax=val;
t[pos].rmax=val;
t[pos].val=val;
return;
}
int mid=(t[pos].l+t[pos].r)>>1;
if(x<=mid)
modify(pos<<1,x,val);
else
modify(pos<<1|1,x,val);
t[pos].sum=t[pos<<1].sum+t[pos<<1|1].sum;
t[pos].lmax=max(t[pos<<1].lmax,t[pos<<1].sum+t[pos<<1|1].lmax);
t[pos].rmax=max(t[pos<<1|1].rmax,t[pos<<1].rmax+t[pos<<1|1].sum);
t[pos].val=max(max(t[pos<<1].val,t[pos<<1|1].val),t[pos<<1].rmax+t[pos<<1|1].lmax);
}
void query(int pos,int l,int r)
{
if(t[pos].l>=l&&t[pos].r<=r)
{
if(!flag)
{
flag=1;
tmp=t[pos];
}
else
{
tmp.lmax=max(tmp.lmax,tmp.sum+t[pos].lmax);
tmp.sum=tmp.sum+t[pos].sum;
tmp.val=max(max(tmp.val,t[pos].val),tmp.rmax+t[pos].lmax);
tmp.rmax=max(t[pos].rmax,tmp.rmax+t[pos].sum);
}
return;
}
int mid=(t[pos].l+t[pos].r)>>1;
if(l<=mid)
query(pos<<1,l,r);
if(r>mid)
query(pos<<1|1,l,r);
}
int main()
{
n=read();
q=read();
for(register int i=1;i<=n;i++)
a[i]=read();
build(1,1,n);
while(q--)
{
op=read();
if(op==2)
{
x=read();
val=read();
modify(1,x,val);
}
else
{
l=read();
r=read();
if(l>r)
swap(l,r);
flag=0;
query(1,l,r);
printf("%d
",tmp.val);
}
}
return 0;
}