树状数组
一、长什么样?
假设有一数组a,数组b为a的前缀和数组,即b[i] = a[i] + a[i-1] + ... + a[1],树状数组c为a的(部分)前缀和数组,即c[i] = a[i] + a[i-1] + ... + a[i+1-lowbit(i)],也即c[i]为lowbit(i)个数组a的元素的和。
lowbit(i): 把i以二进制展开,只留下最低位的1及其后的0,其余位清零后所得的数值。注意i & -i
可以直接算出这个值。
上述方式所得的数组c中每个元素对应的数组a元素和的个数,由下图可以看得很清晰,注意第一个元素从下标1开始。
进一步,求b[i]可以用x个数组c中的元素相加,其中x = 数字i的二进制展开中bit1的个数,具体由哪几个元素相加看图容易得到,下标的计算可以参考具体的代码。动态更新a[i]时,需要更新b[j],其中b[j] = a[j] + ... + a[i] + ... ,即含a[i]的数组b中的元素都需要被更新。
若用普通数组动态维护前缀和,复杂度为O(n),若用树状数组,由二进制展开的特性易知复杂度则为O(logn),看图也易得。
二、怎么维护?
1.单点修改 区间查询(查询区间和)
-
初始化树状数组c为0(第一个元素下标为1, 下标为0的元素也要初始化为0)
-
原数组a的第x个元素加v
void add(int x, int v) { while (x <= n) { c[x] += v; //c为a对应的树状数组 x += lowbit(x); } }
-
求原数组a中第一个至第x个元素的和
int sum(int x) { int t = 0; while (x) { t += c[x]; x -= lowbit(x); } return t; }
-
区间[x1, x2]和查询
int query(int x1, int x2) {return sum(x2) - sum(x1 - 1);}
2.区间修改,单点查询(对区间内所有值加上某一值)
思路:差分,设差分数组p,则p[i] = a[i] - a[i-1]=>a[i] = p[i] + p[i-1] + ... + p[1]。维护差分数组p的树状数组即可。
add()
sum()
的实现同单点修改,区间查询。
-
初始化树状数组c为0
-
初始化赋值时要注意维护的是差分数组p对应的树状数组c。例如:在原数组a下标为x处初始化为v,对应于在数组p中给p[x]初始化为v-p[x-1]。(p的初值置为0)
add(x, v - sum(x-1)); //按下标连续初始化时不用调用sum()
-
原数组a中下标在[x1, x2]范围的值均加上某值v,对应于在数组p中给p[x1]加v,给p[x2+1]减v
add(x1, v), add(x2 + 1, -v);
-
查询原数组a中下标为x的值,对应于求p[x] + p[x-1] + ... + p[1],即求p的某前缀和
sum(x);
3.区间修改,区间查询
由区间修改,单点查询的基础上,优化一下区间查询的效率即可。首先推一下公式,随便截了个图,该图的描述方法和之前的有些不同。
di为差分数组的第i个元素,an为原数组第n个元素。公式表示原数组的前缀和,等价变换可得,原数组的前缀和可以转化为差分数组的前缀和以及di*i的前缀和。故得一思路:通过树状数组维护差分数组的前缀和,再通过公式计算出原数组的前缀和。
-
add()
sum()
稍作修改:void add(int x, int v, int *a) { while (x <= n) { a[x] += v; x += lowbit(x); } } int sum(int x, int *a) { int ret = 0; while (x) { ret += a[x]; x -= lowbit(x); } return ret; }
-
初始化树状数组为0
-
初始化赋值:维护差分数组的同时维护数组array,其中array[i] = di * i,具体初始化的方法与区间修改,单点查询的情况类似。设在原数组a下标为x处初始化为v。
int t = v - sum(x-1); //按下标连续初始化时不用调用sum() add(x, t, 差分数组的指针); add(x, t * x, array的指针);
-
原数组中下标在[x1, x2]范围的值均加上某值v
add(x1, v, 差分数组的指针), add(x2 + 1, -v, array的指针); add(x1, v * x1, 差分数组的指针), add(x2 + 1, -v * (x2 + 1), array的指针);
-
查询原数组[x1, x2]区间和为
( (x2 + 1) * sum(x2, 差分数组的指针) - sum(x2, array的指针) ) - ( x1 * sum(x1 - 1, 差分数组的指针) - sum(x1 - 1, array的指针) );
三、模版题各一道。。
#include <cstdio>
const int N = 500000 + 20;
int n, m, c[N];
inline int read() {
int x = 0, f = 1; char ch = getchar();
while ('0' > ch || ch > '9') {if (ch == '-') f = -1; ch = getchar();}
while ('0' <= ch && ch <= '9') {x = 10 * x + ch - '0'; ch = getchar();}
return x * f;
}
inline int lowbit(int x) {return x & -x;}
void add(int x, int v) {
while (x <= n) {
c[x] += v;
x += lowbit(x);
}
}
int sum(int x) {
int t = 0;
while (x) {
t += c[x];
x -= lowbit(x);
}
return t;
}
int main() {
n = read(), m = read();
for (int i = 1; i <= n; i++) add(i, read());
while (m--) {
int cmd = read(), k1 = read(), k2 = read();
if (cmd == 1) add(k1, k2);
else printf("%d
", sum(k2) - sum(k1 - 1));
}
return 0;
}
#include <cstdio>
const int N = 500000 + 20;
int n, m, c[N];
inline int read() {
int x = 0, f = 1; char ch = getchar();
while ('0' > ch || ch > '9') {if (ch == '-') f = -1; ch = getchar();}
while ('0' <= ch && ch <= '9') {x = 10 * x + ch - '0'; ch = getchar();}
return x * f;
}
inline int lowbit(int x) {return x & -x;}
void add(int x, int v) {
while (x <= n) {
c[x] += v;
x += lowbit(x);
}
}
int sum(int x) {
int t = 0;
while (x) {
t += c[x];
x -= lowbit(x);
}
return t;
}
int main() {
n = read(), m = read();
int last = 0, t;
for (int i = 1; i <= n; i++) {
t = read();
add(i, t - last);
last = t;
}
while (m--) {
int cmd = read();
if (cmd == 1) {
int x = read(), y = read(), k = read();
add(x, k), add(y + 1, -k);
} else {
int x = read();
printf("%d
", sum(x));
}
}
return 0;
}
#include <cstdio>
#include <cstring>
typedef long long ll;
const int N = 100000 + 10;
ll n, q;
ll num1[N], num2[N];
inline ll lowbit(ll x) {return x & (-x);}
void add(ll x, ll v, ll *a) {
while (x <= n) {
a[x] += v;
x += lowbit(x);
}
}
ll sum(ll x, ll *a) {
ll ret = 0;
while (x) {
ret += a[x];
x -= lowbit(x);
}
return ret;
}
inline ll query(ll a, ll b) {
return (b + 1) * sum(b, num1) - sum(b, num2)
- (a * sum(a - 1, num1) - sum(a - 1, num2));
}
int main() {
scanf("%lld%lld", &n, &q);
ll last = 0, t;
for (int i = 1; i <= n; i++) {
scanf("%lld", &t);
add(i, t - last, num1);
add(i, (t - last) * i, num2); //一开始写成num1。。
last = t;
}
char line[30];
ll a, b, c;
getchar(); //别忘了。。
while (q--) {
gets(line);
if (line[0] == 'Q') {
sscanf(line + 1, "%lld%lld", &a, &b);
printf("%lld
", query(a, b));
} else {
sscanf(line + 1, "%lld%lld%lld", &a, &b, &c);
add(a, c, num1);
add(b + 1, -c, num1);
add(a, c * a, num2);
add(b + 1, -c * (b + 1), num2);
}
}
return 0;
}
`