咕咕了。。。于是借鉴了小粉兔的做法ORZ。。。
其实就是维护最大子段和的线段树,但上面又多了一些操作。。。。QWQ
维护8个信息:1/0的个数(sum),左/右边起1/0的最长长度(ls,rs),整段区间中1/0的连续最长长度(mx)。
于是对于各个操作,我们有了一些tag。。。
tg1[]是区间赋值标记,没有标记时为-1,有标记时为0或1;tg2[]是区间取反标记,没有标记时为 0,有标记时为1。
注意标记下传时要先传tg1[],再传tg2[],否则取反标记会被赋值标记覆盖
#include<cstdio> #include<iostream> #define R register int #define ls (tr<<1) #define rs (tr<<1|1) const int N=262144; using namespace std; inline int g() { R ret=0,fix=1; register char ch; while(!isdigit(ch=getchar())) fix=ch=='-'?-1:fix; do ret=ret*10+(ch^48); while(isdigit(ch=getchar())); return ret*fix; } int n,q,a[N]; struct node{ int sum0,sum1,ls0,ls1,rs0,rs1,mx0,mx1; node(int s0=0,int s1=0,int ls0=0,int ls1=0,int rs0=0,int rs1=0,int mx0=0,int mx1=0): sum0(s0),sum1(s1),ls0(ls0),ls1(ls1),rs0(rs0),rs1(rs1),mx0(mx0),mx1(mx1) {} }; inline node upd(node l,node r) { return node(l.sum0+r.sum0,l.sum1+r.sum1, l.sum1?l.ls0:l.sum0+r.ls0,l.sum0?l.ls1:l.sum1+r.ls1, r.sum1?r.rs0:r.sum0+l.rs0,r.sum0?r.rs1:r.sum1+l.rs1, max(max(l.mx0,r.mx0),l.rs0+r.ls0), max(max(l.mx1,r.mx1),l.rs1+r.ls1)); } node t[N]; int len[N],tg1[N],tg2[N]; inline void push(int tr,int typ) { node& tmp=t[tr]; if(typ==0) tg2[tr]=0,tg1[tr]=0,tmp=node(0,len[tr],0,len[tr],0,len[tr],0,len[tr]); else if(typ==1) tg2[tr]=0,tg1[tr]=1,tmp=node(len[tr],0,len[tr],0,len[tr],0,len[tr],0); else if(typ==2) tg2[tr]^=1,swap(tmp.sum0,tmp.sum1),swap(tmp.ls0,tmp.ls1),swap(tmp.rs0,tmp.rs1),swap(tmp.mx0,tmp.mx1); } inline void spread(int tr) { if(~tg1[tr]) push(ls,tg1[tr]),push(rs,tg1[tr]); if(tg2[tr]) push(ls,2),push(rs,2); tg1[tr]=-1,tg2[tr]=0; } inline void build(int tr,int l,int r) { len[tr]=r-l+1,tg1[tr]=-1; if(l==r) {R tmp=g(); t[tr]=node(tmp,tmp^1,tmp,tmp^1,tmp,tmp^1,tmp,tmp^1); return ;} R md=l+r>>1; build(ls,l,md),build(rs,md+1,r); t[tr]=upd(t[ls],t[rs]); } inline void change(int tr,int l,int r,int LL,int RR,int d) { if(LL<=l&&r<=RR) {push(tr,d); return ;} spread(tr); R md=l+r>>1; if(LL<=md) change(ls,l,md,LL,RR,d); if(RR>md) change(rs,md+1,r,LL,RR,d); t[tr]=upd(t[ls],t[rs]); } inline node query(int tr,int l,int r,int LL,int RR) { if(LL<=l&&r<=RR) return t[tr]; spread(tr); R md=l+r>>1; register node ret=node(); if(LL<=md) ret=query(ls,l,md,LL,RR); if(RR>md) ret=upd(ret,query(rs,md+1,r,LL,RR)); return ret; } signed main() { n=g(),q=g(); build(1,1,n); for(R i=1;i<=q;++i) { R op=g(),l=g()+1,r=g()+1; if(op<3) change(1,1,n,l,r,op); else {register node tmp=query(1,1,n,l,r); printf("%d ",op==3?tmp.sum0:tmp.mx0);} } }
2019.04.27