不太清楚是不是动态dp……?
这个维护其实和最大连续子段差不多,维护l[x][y],r[x][y],m[x][y]分别表示包含左儿子的01个数为(x,y)的区间个数,包含右儿子的01个数为(x,y)的区间个数,和01个数为(x,y)的所有区间个数
x表示1的个数情况,0表示0个,1表示1个,2表示>=2的偶数个,3表示>=3的奇数个
y表示0的个数情况,0表示0个,1表示1个,2表示>=2个
转移的话合并ls.r,rs.l即可,注意是乘法,注意细节,转移很难写……
#include<iostream>
#include<cstdio>
using namespace std;
const int N=100005;
int n,m,a[N];
struct xds
{
long long l[4][3],r[4][3],m[4][3],s[2];
}t[N<<2];
int read()
{
int r=0,f=1;
char p=getchar();
while(p>'9'||p<'0')
{
if(p=='-')
f=-1;
p=getchar();
}
while(p>='0'&&p<='9')
{
r=r*10+p-48;
p=getchar();
}
return r*f;
}
int wk1(int x)
{
return (x<=1)?x:(x%2+2);
}
int wk0(int x)
{
return (x<=1)?x:2;
}
xds operator + (const xds &a,const xds &b)
{
xds c;
c.s[0]=a.s[0]+b.s[0];
c.s[1]=a.s[1]+b.s[1];
for(int i=0;i<=3;i++)
for(int j=0;j<=2;j++)
{
c.l[i][j]=a.l[i][j];
c.r[i][j]=b.r[i][j];
c.m[i][j]=a.m[i][j]+b.m[i][j];
}
for(int i=0;i<=3;i++)
for(int j=0;j<=2;j++)
if(a.r[i][j])
for(int k=0;k<=3;k++)
for(int l=0;l<=2;l++)
if(b.l[k][l])
c.m[wk1(i+k)][wk0(j+l)]+=a.r[i][j]*b.l[k][l];
for(int i=0;i<=3;i++)
for(int j=0;j<=2;j++)
{
c.l[wk1(a.s[1]+i)][wk0(a.s[0]+j)]+=b.l[i][j];
c.r[wk1(b.s[1]+i)][wk0(b.s[0]+j)]+=a.r[i][j];
}
return c;
}
void build(int ro,int l,int r)
{
if(l==r)
{
int x=(a[l]==1),y=(a[l]==0);
t[ro].s[0]=y,t[ro].s[1]=x;
t[ro].m[x][y]=t[ro].l[x][y]=t[ro].r[x][y]=1;
return;
}
int mid=(l+r)>>1;
build(ro<<1,l,mid);
build(ro<<1|1,mid+1,r);
t[ro]=t[ro<<1]+t[ro<<1|1];
}
void update(int ro,int l,int r,int p)
{
if(l==r)
{
int x=t[ro].s[1],y=t[ro].s[0];
t[ro].m[x][y]=t[ro].l[x][y]=t[ro].r[x][y]=0;
swap(t[ro].s[0],t[ro].s[1]);
t[ro].m[y][x]=t[ro].l[y][x]=t[ro].r[y][x]=1;
return;
}
int mid=(l+r)>>1;
if(p<=mid)
update(ro<<1,l,mid,p);
else
update(ro<<1|1,mid+1,r,p);
t[ro]=t[ro<<1]+t[ro<<1|1];
}
xds ques(int ro,int l,int r,int x,int y)
{
if(l==x&&r==y)
return t[ro];
int mid=(l+r)>>1;
if(y<=mid)
return ques(ro<<1,l,mid,x,y);
else if(x>mid)
return ques(ro<<1|1,mid+1,r,x,y);
else
return ques(ro<<1,l,mid,x,mid)+ques(ro<<1|1,mid+1,r,mid+1,y);
}
int main()
{
n=read();
for(int i=1;i<=n;i++)
a[i]=read();
build(1,1,n);
m=read();
while(m--)
{
int o=read();
if(o==1)
{
int x=read();
update(1,1,n,x);
}
else
{
int l=read(),r=read();
xds x=ques(1,1,n,l,r);
printf("%lld
",x.m[0][0]+x.m[0][1]+x.m[0][2]+x.m[2][0]+x.m[2][1]+x.m[2][2]+x.m[3][2]);
}
}
return 0;
}