线段树
九条可怜是一个喜欢数据结构的女孩子,在常见的数据结构中,可怜最喜欢的就是线段树。
线段树的核心是懒标记,下面是一个带懒标记的线段树的伪代码,其中 `tag` 数组为懒标记:
![](https://s2.ax1x.com/2019/04/02/AyHyRJ.md.png)
其中函数 $ exttt{Lson}( ext{Node})$ 表示 $ ext{Node}$ 的左儿子,$ exttt{Rson}( ext{Node})$ 表示 $ ext{Node}$ 的右儿子。
现在可怜手上有一棵 $[1, n]$ 上的线段树,编号为 $1$。这棵线段树上的所有节点的 `tag` 均为 $0$。接下来可怜进行了 $m$ 次操作,操作有两种:
- $1 l r$,假设可怜当前手上有 $t$ 棵线段树,可怜会把每棵线段树复制两份(`tag` 数组也一起复制),原先编号为 $i$ 的线段树复制得到的两棵编号为 $2i − 1$ 与 $2i$,在复制结束后,可怜手上一共有 $2t$ 棵线段树。接着,可怜会对所有编号为奇数的线段树进行一次 $ exttt{Modify}( ext{root}, 1, n, l, r)$。
- $2$,可怜定义一棵线段树的权值为它上面有多少个节点 `tag` 为 $1$。可怜想要知道她手上所有线段树的权值和是多少。
Sol
这题很妙。
考虑用概率来表示出现次数。
注意到如果一个点的祖先有标记,那么他也有可能被影响。
可以记a为i这一个节点的1的出现概率,b为i和i的祖先中有一个出现出现1的概率。
这两个互相转移一下。
大致思路是把点分成四类:经过的,覆盖的,pushdown的和覆盖的点的儿子。
最后一类需要区间加。
#include<cstdio> #include<iostream> #include<cstdlib> #include<cstring> #include<algorithm> #include<cmath> #define maxn 100005 #define ll long long #define mod 998244353 using namespace std; int n,m; ll ny=499122177; struct node{ ll a,b,sum,bc,ba; }tr[maxn*8]; void wh(int k){ tr[k].sum=(tr[k*2].sum+tr[k*2+1].sum+tr[k].a)%mod; } void ch1(int k){ tr[k].a=tr[k].a*ny%mod;tr[k].b=tr[k].b*ny%mod;wh(k); } void ch2(int k){ tr[k].a=ny*(tr[k].a+1)%mod;tr[k].b=ny*(tr[k].b+1)%mod;wh(k); } void ch3(int k){ tr[k].a=(tr[k].b+tr[k].a)*ny%mod;tr[k].a%=mod;wh(k); } void ch4(int k){ tr[k].b=ny*(tr[k].b+1)%mod; tr[k].bc=tr[k].bc*ny%mod;tr[k].ba=(tr[k].ba+1)*ny%mod; } void down(int k){ int ls=k*2,rs=k*2+1; if(tr[k].bc!=1){ ll &t=tr[k].bc; tr[ls].b=tr[ls].b*t%mod;tr[ls].bc=tr[ls].bc*t%mod;tr[ls].ba=tr[ls].ba*t%mod; tr[rs].b=tr[rs].b*t%mod;tr[rs].bc=tr[rs].bc*t%mod;tr[rs].ba=tr[rs].ba*t%mod; t=1; } if(tr[k].ba!=0){ ll &t=tr[k].ba; tr[ls].b=(tr[ls].b+t)%mod;tr[ls].ba=(tr[ls].ba+t)%mod; tr[rs].b=(tr[rs].b+t)%mod;tr[rs].ba=(tr[rs].ba+t)%mod; t=0; } } void add(int k,int l,int r,int li,int ri){ if(l>=li&r<=ri){ ch2(k); //if(l<r)ch4(k*2),ch4(k*2+1); tr[k].ba=(tr[k].ba+1)*ny%mod; tr[k].bc=tr[k].bc*ny%mod; return; } down(k); int mid=l+r>>1; if(li<=mid)add(k*2,l,mid,li,ri); else ch3(k*2); if(ri>mid)add(k*2+1,mid+1,r,li,ri); else ch3(k*2+1); ch1(k); } int main(){ cin>>n>>m; ll num=1; for(int i=1;i<maxn;i++)tr[i].bc=1; for(int i=1,op,l,r;i<=m;i++){ scanf("%d",&op); if(op==1){ scanf("%d%d",&l,&r); add(1,1,n,l,r); num=num*2%mod; } else { ll ans=tr[1].sum*num%mod; ans=(ans+mod)%mod; printf("%lld ",ans); } } return 0; }