树状数组
树状数组:一个数组,支持单点修改和区间查询。复杂度(O(nlogn))
lowbit
(lowbit)函数用于求某个数的二进制表示中的最低的一位(1)
例:(6_{10}=110_{2}~~~~~~~lowbit(6)=10_{2}=2)
求(lowbit)的两种方法
((1))
int lowbit(x){
return x-(x&(x-1));
}
例:(1001100_{2}-1_{10})
(~~~=1001011_{2})
显然减了1之后,最后一位1及以后相当于按位取反,与预算之后,最后一位1及后边的0都会变成零,在与原数相减就会得到(lowbit)值
((2))
int lowbit(x){
return x&-x;//-x为x按位取反再加一
}
例:(1001100->0110011->+1->0110100)
(~~~~~~~~0110100)两数按位与之后为(100_{2})
树状数组思想
(A[i])为原数组,(C[i])为前缀和。什么区间呢?
假设(n=6),
(sum_{i=1}^{6}=(A[1]+A[2]+A[3]+A[4])+(A[5]+A[6])),而(6)的二进制是(110),(110=100+10),转化成十进制就是(6=4+2),和上面的求和公式很像。所以我们把(n)拆成几个区间来求和,按照(n)的二进制来拆分。
(C[i])表示的就是从第(i)的元素往前(lowbit(i))个元素的和。
功能及实现
单点修改
修改每个点,它的后面的数的前缀和都会改变,所以每次修改都要维护前缀和
void add(int x,int k){//在x的位置增加k
while(x<=n){
tree[x]+=k;
x+=lowbit(x);
}
}
区间查询
查询一段区间([x,y])的和,可以转化为求([1,y]-[1,x-1])的值
int getsum(int k){
int ans=0;
while(k){
ans+=tree[k];
k-=lowbit(k);
}
}
区间修改
利用查分思想,每次加入当前数与前一个数的差,这样区间修改就变成单点修改
int main(){
...
int x,y=0;
for(int i=1;i<=n;++i){
cin>>x;
add(i,x-y);//add与上面一样
y=x;
}
...
}
单点查询
差分,单点查询变区间查询
求逆序对
在求逆序对的问题中,树状数组中维护的是在第(i)个数插入之前,有多少比他大的数插进来了。
做法有很多。
因为一般给的数会很大,还会有负数,树状数组无法维护,所以需要离散化。
sort(temp+1,temp+1+n);
int cnt=unique(temp+1,temp+1+n)-temp-1;
for(int i=1;i<=n;++i) ~a[i]=lower_bound(temp+1,temp+1+cnt,a[i])-temp;
显然离散化之后,新的数,也就是在原来数组中按大小排序后的位置。手推可知,新数组的逆序对即为旧数组的逆序对。
也可以建结构体,两个变量分别表示原值和序号。按变量大小排序后,对序号建树状数组,其逆序对即为原序列逆序对。但有一点需要注意,若两数大小相同,需将原序列中排位靠前的数后插入。
bool cmp(const Dier &x,const Dier &y){
if(x.x==y.x) return x.k<y.k;//因为我是用逆序插入,若顺序的话,需要反过来
return x.x<y.x;
}
因为树状数组维护的是在(i)之前有多少比(i)大的数,所以每次插入的值为1,表示有一个数。插入可以逆序也可以顺序。顺序的话,就是用总共插入的数减去在它前面的数;逆序的话,直接求在它之前的数即可
//顺序
for(int i=1;i<=n;++i){
ans+=getsum(n)-getsum(a[i].k);
add(a[i].k,1);
}
//逆序
for(int i=n;i;--i){
add(a[i].k,1);
ans+=getsum(a[i].k-1);
}
完整代码
//单点修改,区间查询
#include<cstdio>
using namespace std;
int n,m,tree[500005];
inline int lowbit(int x){
return x&-x;
}
inline void add(int x,int k){
while(x<=n){
tree[x]+=k;
x+=lowbit(x);
}
}
inline int sum(int k){
int ans=0;
while(k>0){
ans+=tree[k];
k-=lowbit(k);
}
return ans;
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1,x;i<=n;++i) scanf("%d",&x),add(i,x);
for(int i=1,b,x,k;i<=m;i++){
scanf("%d%d%d",&b,&x,&k);
if(b==1) add(x,k);
else printf("%d
",sum(k)-sum(x-1));
}
return 0;
}
//区间修改,单点查询
#include<cstdio>
using namespace std;
int n,m;
long long tree[500005];
int lowbit(int x){
return x&-x;
}
void add(int x,int k){
while(x<=n){
tree[x]+=k;
x+=lowbit(x);
}
}
long long sum(int x){
long long ans=0;
while(x>0){
ans+=tree[x];
x-=lowbit(x);
}
return ans;
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1,x,y=0;i<=n;++i){
scanf("%d",&x);
add(i,x-y);
y=x;
}
while(m--){
int b,x,y,k;
scanf("%d",&b);
if(b==1){
scanf("%d%d%d",&x,&y,&k);
add(x,k),add(y+1,-k);
}
else{
scanf("%d",&x);printf("%d
",sum(x));
}
}
return 0;
}
//求逆序对,结构体+逆序插入
#include<algorithm>
#include<iostream>
#include<cstdio>
using namespace std;
int n;
long long ans,tree[500005];
struct Dier{
long long x;
int k;
}a[500005];
long long read() {
long long x=0;int f=0;char c=getchar();
while(c<'0'||c>'9'){f|=c=='-';c=getchar();}
while(c>='0'&&c<='9'){x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return f?-x:x;
}
bool cmp(const Dier &x,const Dier &y){
if(x.x == y.x) return x.k < y.k;
return x.x<y.x;
}
int lowbit(int x){
return x&-x;
}
void insert(int x,int k){
while(k<=n){
tree[k]+=x,k+=lowbit(k);
}
}
long long getsum(int k){
long long sum=0;
while(k){
sum+=tree[k],k-=lowbit(k);
}
return sum;
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;++i) a[i].x=read(),a[i].k=i;
sort(a+1,a+1+n,cmp);
for(int i=n;i;--i){
insert(1,a[i].k);
ans+=getsum(a[i].k-1);
}
printf("%lld
",ans);
return 0;
}
//求逆序对,离散化+顺序插入
#include<iostream>
#include<cstdio>
using namespace std;
long long read(){
long long x=0;int f=0;char c=getchar();
while(c>'9'||c<'0') f|=c=='-',c=getchar();
while(c>='0'&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar();
return f?-x:x;
}
int n,a[500010],temp[500010],tree[500010];
long long ans;
inline int lowbit(int x){
return x&-x;
}
int getsum(int x){
int ans=0;
while(x){
ans+=tree[x];
x-=lowbit(x);
}
return ans;
}
void add(int x){
while(x<=n){
++tree[x];
x+=lowbit(x);
}
}
int main() {
n=read();
for(int i=1;i<=n;++i) a[i]=read(),temp[i]=a[i];
sort(temp+1,temp+1+n);
int cnt=unique(temp+1,temp+1+n)-temp-1;
for(int i=1;i<=n;++i) a[i]=lower_bound(temp+1,temp+cnt+1,a[i])-temp;
for(int i=1;i<=n;++i){
ans+=getsum(n)-getsum(a[i]);
add(a[i]);
}
printf("%d",ans);
return 0;
}
欢迎指正评论O(∩_∩)O~~