树状数组的本职工作是修改点,查询区间和
我们可以先回顾一下姊妹篇:(一维)树状数组的实现
然后我们再回顾一下差分数组,差分数组可以实现修改区间,查询点
如果不用树状数组进行优化的话,修改是O(1),查询是O(n)的
我们要做的就是用树状数组把查询操作优化成对数级别的
这里直接给出树状数组的代码以及差分数组的代码:
先是树状数组的:
int a[maxn]; int c[maxn]; int lowbit(int x) { return x&(-x); } void update(int x,int y) { while(x<=n) { c[x]+=y; x+=lowbit(x); } } int sum(int x) { int ans=0; while(x>0) { ans+=c[x]; x-=lowbit(x); } return ans; }
然后是差分数组的:
int a[maxn]; int b[maxn]; int n,q; void update(int x,int y,int z) { b[x]+=z; b[y+1]-=z; } int sum(int x) { int ans=0; for(int i=1;i<=x;i++) ans+=b[i]; return ans; }
那么我们只要把他们结合在一起,就是传说中的差分树状数组了,其实很简单,就是把差分数组存到树状数组里面来维护,但是这样就会有一个问题,虽然查询的效率变好了
但是修改的效率由原来的O(1)变成了O(logn)的,所以有利有弊需要自己衡量
下面直接给出差分树状数组的源码,其实就是把上述的两份代码拼接在了一起
1 //aininot260 2 //修改区间,查询点 3 #include<iostream> 4 #include<cstring> 5 using namespace std; 6 const int maxn=100005; 7 const int maxq=100005; 8 int a[maxn]; 9 int b[maxn]; 10 int c[maxn]; 11 int lowbit(int x) 12 { 13 return x&(-x); 14 } 15 int n,q; 16 void c_update(int x,int y) 17 { 18 while(x<=n) 19 { 20 c[x]+=y; 21 x+=lowbit(x); 22 } 23 } 24 void update(int x,int y,int z) 25 { 26 b[x]+=z; 27 b[y+1]-=z; 28 c_update(x,z); 29 c_update(y+1,-z); 30 } 31 int sum(int x) 32 { 33 int ans=0; 34 while(x>0) 35 { 36 ans+=c[x]; 37 x-=lowbit(x); 38 } 39 return ans; 40 } 41 int main() 42 { 43 cin>>n; 44 for(int i=1;i<=n;i++) 45 cin>>a[i]; 46 b[1]=a[1]; 47 c_update(1,b[1]); 48 for(int i=2;i<=n;i++) 49 { 50 b[i]=a[i]-a[i-1]; 51 c_update(i,b[i]); 52 } 53 cin>>q; 54 while(q--) 55 { 56 int x; 57 cin>>x; 58 if(x==1) 59 { 60 int y,z,w; 61 cin>>y>>z>>w; 62 update(y,z,w); 63 } 64 if(x==2) 65 { 66 int y; 67 cin>>y; 68 cout<<sum(y)<<endl; 69 } 70 } 71 return 0; 72 }
接着要介绍的就是利用树状数组来实现区间修改和区间查询,这样就可以规避很多必须要写线段树的情况
写完之后发现,帅呆了。时间空间代码量全部吊打线段树
写这种BIT的时候我们需要维护三个数组
long long delta[maxn]; //差分数组 long long deltai[maxn]; //delta*i long long sum[maxn];//原始前缀和
初始化的时候,直接计算前缀和就可以了,两个树状数组都是用来存delta的
然后是修改区间操作,4次Update即可:
int y,z,w; cin>>y>>z>>w; update(delta,y,w); update(delta,z+1,-w); update(deltai,y,w*y); update(deltai,z+1,-w*(z+1));
查询区间和的操作,直接在前缀和基础上加减delta就可以了
long long suml=sum[y-1]+y*query(delta,y-1)-query(deltai,y-1); long long sumr=sum[z]+(z+1)*query(delta,z)-query(deltai,z); cout<<sumr-suml<<endl;
下面给出完整的代码:
1 #include<iostream> 2 #include<cstdio> 3 #include<algorithm> 4 using namespace std; 5 const int maxn=200005,maxq=200005; 6 int n,q; 7 int lowbit(int x) 8 { 9 return x&(-x); 10 } 11 long long delta[maxn]; //差分数组 12 long long deltai[maxn]; //delta*i 13 long long sum[maxn];//原始前缀和 14 void update(long long *c,int x,int y) 15 { 16 while(x<=n) 17 { 18 c[x]+=y; 19 x+=lowbit(x); 20 } 21 } 22 long long query(long long *c,int x) 23 { 24 long long ans=0; 25 while(x>0) 26 { 27 ans+=c[x]; 28 x-=lowbit(x); 29 } 30 return ans; 31 } 32 int main() 33 { 34 cin>>n; 35 for(int i=1;i<=n;i++) 36 { 37 int x; 38 cin>>x; 39 sum[i]=sum[i-1]+x; 40 } 41 cin>>q; 42 while(q--) 43 { 44 int x; 45 cin>>x; 46 if(x==1) 47 { 48 int y,z,w; 49 cin>>y>>z>>w; 50 update(delta,y,w); 51 update(delta,z+1,-w); 52 update(deltai,y,w*y); 53 update(deltai,z+1,-w*(z+1)); 54 } 55 if(x==2) 56 { 57 int y,z; 58 cin>>y>>z; 59 long long suml=sum[y-1]+y*query(delta,y-1)-query(deltai,y-1); 60 long long sumr=sum[z]+(z+1)*query(delta,z)-query(deltai,z); 61 cout<<sumr-suml<<endl; 62 } 63 } 64 return 0; 65 }
其实介绍到这里已经差不多了,但是精彩的还在后面,二维树状数组!