题目链接:http://codeforces.com/contest/981/problem/G
题目大意:
有n个初始为空的‘魔法’可重集,向一个‘可重集’加入元素时,若该元素未出现过,则将其加入;否则该可重集中所有元素的个数都会翻倍。
例如将$2$加入${1,3}$会得到${1,2,3}$,将$2$加入${1,2,3,3}$会得到${1,1,2,2,3,3,3,3}$.
$q$次操作,每次操作要么向一个区间内的所有可重集加入某个元素,要么询问一个区间内可重集的大小之和。
$n,q ≤ 2×10^5$
题解:
发现对于出现过该元素的区间就是区间乘,没有出现过的就是区间加
操作1我们先把$[l,r]$全部乘上2,再把之前已经出现过当前元素的区间乘上2的逆元再+1,最后合并一下左右区间就好了。合并大概就是保证这个操作时间复杂度的关键了,只是我也不知道怎么算
发现这样我们要维护每个元素出现的区间,同时支持方便的合并,我们开n个set就好了
正是做完这题我发现我竟然不会写区间乘法线段树
#include<algorithm> #include<cstring> #include<cstdio> #include<iostream> #include<set> #define pa pair<int,int> #define mid ((l+r)>>1) using namespace std; typedef long long ll; const int N=2e5+15; const int mod=998244353; ll n,q,inv; ll sum[N<<2],add[N<<2],mul[N<<2]; set <pa> st[N]; inline ll read() { char ch=getchar(); ll s=0,f=1; while (ch<'0'||ch>'9') {if (ch=='-') f=-1;ch=getchar();} while (ch>='0'&&ch<='9') {s=(s<<3)+(s<<1)+ch-'0';ch=getchar();} return s*f; } ll qpow(ll a,ll b) { ll re=1; for (;b;b>>=1,a=a*a%mod) if (b&1) re=re*a%mod; return re; } void build(int o,int l,int r) { mul[o]=1;add[o]=sum[o]=0; if (l==r) return; build(o<<1,l,mid); build(o<<1|1,mid+1,r); } void pushup(int o,int l,int r) { //sum[o]=sum[o<<1]+sum[o<<1|1]; sum[o]=(sum[o<<1]*mul[o<<1]+add[o<<1]*(mid-l+1))%mod; sum[o]=(sum[o]+sum[o<<1|1]*mul[o<<1|1]+add[o<<1|1]*(r-mid))%mod; } void pushdown(int o,int l,int r) { if (mul[o]!=1) { ll p=mul[o]; mul[o]=1; //(sum[o<<1]*=p)%=mod; //(sum[o<<1|1]*=p)%=mod; (add[o<<1]*=p)%=mod; (add[o<<1|1]*=p)%=mod; (mul[o<<1]*=p)%=mod; (mul[o<<1|1]*=p)%=mod; } if (add[o]!=0) { ll p=add[o]; add[o]=0; // (sum[o<<1]+=p*(mid-l+1))%=mod; // (sum[o<<1|1]+=p*(r-mid))%=mod; (add[o<<1]+=p)%=mod; (add[o<<1|1]+=p)%=mod; } } void update(int o,int l,int r,int x,int y,ll z,int flag) { if (l>=x&&r<=y) { if (flag==1) { (mul[o]*=z)%=mod; (add[o]*=z)%=mod; // (sum[o]*=z)%=mod; } if (flag==2) { (add[o]+=z)%=mod; // (sum[o]+=(r-l+1)*z)%=mod; } return; } pushdown(o,l,r); if (x<=mid) update(o<<1,l,mid,x,y,z,flag); if (y>mid) update(o<<1|1,mid+1,r,x,y,z,flag); pushup(o,l,r); } void merge(int x,int L,int R) { set<pa>::iterator it; it=st[x].lower_bound(pa(L,L)); for (;it!=st[x].end();it++) { set<pa>::iterator lst=it;lst--; int l=(*lst).second+1; int r=(*it).first-1; int upl=max(L,l); int upr=min(R,r); if (upr>=upl) { update(1,1,n,upl,upr,inv,1); update(1,1,n,upl,upr,1,2); } if ((*it).first>=R) break; } int mergeL=L,mergeR=R; it=st[x].upper_bound(pa(L,L));it--; if ((*it).second>=mergeL) mergeL=(*it).first; it=st[x].upper_bound(pa(R,R));it--; if ((*it).second>=mergeR) mergeR=(*it).second; vector <pa> er; it=st[x].lower_bound(pa(mergeL,mergeL)); for (;it!=st[x].end();it++) { pa e=*it; if (e.first>=mergeL&&e.second<=mergeR) er.push_back(e); else break; } for (int i=0;i<er.size();i++) st[x].erase(er[i]); st[x].insert(pa(mergeL,mergeR)); } ll query(int o,int l,int r,int x,int y) { if (l>=x&&r<=y) return (sum[o]*mul[o]+add[o]*(r-l+1))%mod; // if (l>=x&&r<=y) return sum[o]%mod; pushdown(o,l,r); ll re=0; if (x<=mid) (re+=query(o<<1,l,mid,x,y))%=mod; if (y>mid) (re+=query(o<<1|1,mid+1,r,x,y))%=mod; pushup(o,l,r); return re; } int main() { inv=qpow(2,mod-2); //inv=(mod+1)/2; n=read();q=read(); for (int i=0;i<=n;i++) { st[i].insert(pa(0,0)); st[i].insert(pa(n+1,n+1)); } build(1,1,n); while (q--) { int opt=read(); if (opt==1) { int l=read(),r=read(),z=read(); update(1,1,n,l,r,2,1); merge(z,l,r); } if (opt==2) { int l=read(),r=read(); printf("%lld ",query(1,1,n,l,r)); } } return 0; }