题目大意
一开始有一棵线段树,然后有一个操作序列,问执行这个操作序列的所有子集时线段树上有标记的节点个数和。
题解
其实我们把它除以(2^m)后发现就是有标记节点的期望个数。
然后套路的根据期望的线性性,我们要统计所有点有标记的概率和。
然后我们来讨论一些情况:
1、当前节点和修改区间没有交且当前节点的父亲节点也没有交,那么这个点的标记就不会动。
2、当前节点被修改区间包含且父亲节点也被包含,那根本碰不到这个节点,也不会动。
3、当前节点被修改区间包含且父亲节点没有被包含,那么这个节点一定会有标记。
4、当前节点和修改区间有交但不包含,那么这个点一定没有标记。
5、这个点和修改区间没有交但是父亲有,那么这个点有没有标记取决于这个点的祖先节点(包括自己)有没有标记。
我们观察到第5种情况需要考虑到祖先节点是否有标记,所以我们设一个(f)表示这个点有标记的概率,(g)表示这个点的祖先节点有(包括自己)标记的概率。
对于第一种情况不做讨论。
对于第二种情况,(f)是没有变化的,(g=g*0.5+0.5)。
对于第三种情况,(f=f*0.5+0.5 g=g*0.5+0.5);
对于第四种情况(f=f*0.5 g=g*0.5)。
对于第五种情况(f=f*0.5+g*0.5 g=g*0.5+g*0.5)。
这些都可以按照线段树的操作去维护。
代码
#include<iostream>
#include<cstdio>
#define ls tr[cnt].l
#define rs tr[cnt].r
#define N 100009
using namespace std;
typedef long long ll;
const int mod=998244353;
ll now,inv2;
int tot,n,m,rot;
inline int rd(){
int x=0;char c=getchar();bool f=0;
while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
return f?-x:x;
}
inline ll power(ll x,ll y){
ll ans=1;
while(y){if(y&1)ans=ans*x%mod;x=x*x%mod;y>>=1;}
return ans;
}
inline void MOD(ll &x){x=x>=mod?x-mod:x;}
struct node{
int l,r;
ll multag,addtag,f,g,sum;
}tr[N<<1];
inline void pushup(int cnt){(tr[cnt].sum=tr[ls].sum+tr[rs].sum+tr[cnt].f)%=mod;}
inline void gx(int cnt){
tr[cnt].f=(tr[cnt].g+tr[cnt].f)*inv2%mod;
pushup(cnt);
}
inline void gan(int cnt,ll x,ll y){
MOD(tr[cnt].g=tr[cnt].g*x%mod+y);
tr[cnt].multag=tr[cnt].multag*x%mod;
MOD(tr[cnt].addtag=tr[cnt].addtag*x%mod+y);
}
inline void pushdown(int cnt){
if(tr[cnt].multag==1&&!tr[cnt].addtag)return;
gan(ls,tr[cnt].multag,tr[cnt].addtag);
gan(rs,tr[cnt].multag,tr[cnt].addtag);
tr[cnt].multag=1;tr[cnt].addtag=0;
}
void upd(int cnt,int l,int r,int L,int R){
if(l>=L&&r<=R){
MOD(tr[cnt].f=tr[cnt].f*inv2%mod+inv2);
gan(cnt,inv2,inv2);pushup(cnt);return;
}
tr[cnt].f=tr[cnt].f*inv2%mod;
tr[cnt].g=tr[cnt].g*inv2%mod;
int mid=(l+r)>>1;
pushdown(cnt);
if(mid>=L)upd(ls,l,mid,L,R);
if(mid<R)upd(rs,mid+1,r,L,R);
if(mid<L)gx(ls);if(mid>=R)gx(rs);
pushup(cnt);
}
void build(int &cnt,int l,int r){
cnt=++tot;
tr[cnt].multag=1;
if(l==r)return;
int mid=(l+r)>>1;
build(ls,l,mid);build(rs,mid+1,r);
}
int main(){
n=rd();m=rd();inv2=power(2,mod-2);
build(rot,1,n);
int l,r,opt;now=1;
for(int i=1;i<=m;++i){
opt=rd();
if(opt==1){
l=rd();r=rd();
upd(rot,1,n,l,r);
MOD(now=now+now);
}
else printf("%lld
",tr[rot].sum*now%mod);
}
return 0;
}