考场上切了不考虑没有逆元的情况(出题人真良心).
把概率都乘到一起后发现求的就是线段树上每个节点保存的权值和的平方的和.
这个的修改和查询都可以通过打标记来实现.
考场代码:
#include <cstdio> #include <algorithm> #define lson (now<<1) #define rson (now<<1|1) #define ll long long #define setIO(s) freopen(s".in","r",stdin) , freopen(s".out","w",stdout) using namespace std; char *p1, *p2, buf[100000]; namespace IO { #define nc() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 100000, stdin), p1 == p2) ? EOF : *p1 ++ ) int rd() { int x = 0, f = 1; char c = nc(); while (c < 48) { if (c == '-') f = -1; c = nc(); } while (c > 47) { x = (((x << 2) + x) << 1) + (c ^ 48), c = nc(); } return x * f; } }; const int mod=998244353,N=120005; int arr[N],n,Q; inline ll qpow(ll base,ll k) { ll tmp=1; for(;k;base=base*base%mod,k>>=1)if(k&1)tmp=tmp*base%mod; return tmp; } inline ll inv(ll k) { return qpow(k,mod-2); } struct Node { int len; ll sum,sqr,sumlen,sqrlen,lazy; }t[N<<2]; inline void pushup(int l,int r,int now) { int mid=(l+r)>>1; t[now].sum=t[lson].sum; t[now].sqr=t[lson].sqr; t[now].sumlen=t[lson].sumlen; t[now].sqrlen=t[lson].sqrlen; if(r>mid) { t[now].sum=(t[now].sum+t[rson].sum)%mod; t[now].sqr=(t[now].sqr+t[rson].sqr)%mod; t[now].sumlen=(t[now].sumlen+t[rson].sumlen)%mod; t[now].sqrlen=(t[now].sqrlen+t[rson].sqrlen)%mod; } t[now].sqr=(t[now].sqr+(ll)t[now].sum*t[now].sum)%mod; t[now].sumlen=(t[now].sumlen+t[now].len*t[now].sum%mod)%mod; t[now].sqrlen=(t[now].sqrlen+(ll)t[now].len*t[now].len%mod)%mod; } inline void mark(int l,int r,int now,ll v) { t[now].lazy+=v, t[now].lazy%=mod; t[now].sqr=(t[now].sqr+((v*v)%mod)*t[now].sqrlen%mod+2ll*v*t[now].sumlen%mod)%mod; t[now].sumlen=(t[now].sumlen+(v*t[now].sqrlen)%mod)%mod; t[now].sum=(t[now].sum+(t[now].len*v)%mod)%mod; } inline void pushdown(int l,int r,int now) { int mid=(l+r)>>1; if(t[now].lazy) { mark(l,mid,lson,t[now].lazy); if(r>mid) mark(mid+1,r,rson,t[now].lazy); t[now].lazy=0; } } void build(int l,int r,int now) { t[now].len=r-l+1; if(l==r) { t[now].sum=arr[l]; t[now].sqr=(ll)arr[l]*arr[l]%mod; t[now].sumlen=t[now].len*t[now].sum%mod; t[now].sqrlen=(ll)t[now].len*t[now].len%mod; return; } int mid=(l+r)>>1; if(l<=mid) build(l,mid,lson); if(r>mid) build(mid+1,r,rson); pushup(l,r,now); } void update(int l,int r,int now,int L,int R,ll v) { if(l>=L&&r<=R) { mark(l,r,now,v); return; } pushdown(l,r,now); int mid=(l+r)>>1; if(L<=mid) update(l,mid,lson,L,R,v); if(R>mid) update(mid+1,r,rson,L,R,v); pushup(l,r,now); } int main() { using namespace IO; int i,j,cas; // setIO("b"); n=rd(),Q=rd(); for(i=1;i<=n;++i) arr[i]=rd(); build(1,n,1); for(cas=1;cas<=Q;++cas) { int opt,l,r,v; opt=rd(); if(opt==1) { l=rd(),r=rd(),v=rd(), update(1,n,1,l,r,v); } if(opt==2) { ll a=t[1].sqr,b=t[1].sum; printf("%lld ",a*inv(b)%mod); } } return 0; }