Chapter 5 树状数组和线段树
+++
树状数组
1.单点修改 2.区间查询
原数组下标从1开始,假如树状数组是c[],那么c[x], x的二进制表示最后有几个0就是第几层
假设是第k层。那么c[x] = (x - lowbit(x),x]。lowbit(x) = 2 ^ k
1.动态求连续区间和 1264
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 1e5 + 10;
int a[N], tr[N];
int n, m;
int lowbit(int x)
{
return x & -x;
}
void add(int x, int y)//加上一个元素
{
for(int i = x; i <= n; i += lowbit(i) ) tr[i] += y;
}
int query(int x)//前缀和
{
int res = 0;
for(int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
int main()
{
cin >> n >> m;
for(int i = 1; i <= n; i++ ) scanf("%d", a + i);
//当成将每个元素添加到原本全是0的数组中
for(int i = 1; i <= n; i++ ) add(i, a[i]);
while(m--)
{
int k, x, y;
cin >> k >> x >> y;
if(k == 0) cout << query(y) - query(x - 1) << endl;
else add(x , y);
}
return 0;
}
2.数星星 1265
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 32010;
int tr[N], level[N];
int lowbit(int x)
{
return x & -x;
}
void add(int x)
{
for(int i = x; i <= N; i+= lowbit(i) ) tr[i]++;
}
int query(int x)
{
int res = 0;
for(int i = x; i ; i -= lowbit(i) ) res += tr[i];
return res;
}
int main()
{
int n;
cin >> n;
for(int i = 0; i < n; i++ )
{
int x, y;
scanf("%d%d", &x, &y);
x++;//树状数组从下标1开始,整体右移
level[query(x)]++;
add(x);
}
for(int i = 0; i < n; i++ )
printf("%d
", level[i]);
return 0;
}
-
线段树
x-------- 父节点 x / 2 (x >> 1) ---------左儿子2 * x (x << 1) ------ 右儿子2 * x + 1 (x << 1 | 1)
主要用到的函数:
1.pushup 用子节点信息更新当前节点信息
2.build 在一段区间上初始化线段树
3.modify 修改
4.query 询问
1.动态求连续区间和 1264
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 1e5 + 10;
int n, m, w[N];
struct Node
{
int l, r;
int sum;
}tr[N * 4];
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void build(int u, int l, int r)
{
if(l == r) tr[u] = {l, r, w[r]};//c++ 11
else
{
tr[u] = {l, r};//
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
int query(int u, int l, int r)
{
if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
int mid = tr[u].l + tr[u].r >> 1;
int sum = 0;
if(l <= mid) sum += query(u << 1, l, r);
if(r > mid) sum += query(u << 1 | 1, l, r);
return sum;
}
void modify(int u, int x, int v)
{
if(tr[u].l == tr[u].r) tr[u].sum += v;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}
int main()
{
cin >> n >> m;
for(int i = 1; i <= n; i++ ) scanf("%d", w + i);
build(1, 1, n);
while(m--)
{
int k, a, b;
scanf("%d%d%d", &k, &a, &b);
if(k == 0) printf("%d
", query(1, a, b));
else modify(1, a, b);
}
return 0;
}
2.数列区间最大值 1270
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <climits>
using namespace std;
const int N = 1e5 + 10;
int n, m, w[N];
struct Node
{
int l, r;
int maxv;
}tr[N * 4];
void pushup(int u)
{
tr[u].maxv = max(tr[u << 1].maxv, tr[u << 1 | 1].maxv);
}
void build(int u, int l, int r)
{
if(l == r) tr[u] = {l, r, w[r]};
else
{
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
int query(int u, int l, int r)
{
if(tr[u].l >= l && tr[u].r <= r) return tr[u].maxv;
int mid = tr[u].l + tr[u].r >> 1;
int max1 = INT_MIN;
if(l <= mid) max1 = max(max1, query(u << 1, l, r));
if(r > mid) max1 = max(max1, query(u << 1 | 1, l, r));
return max1;
}
int main()
{
cin >> n >> m;
for(int i = 1; i <= n; i++ ) scanf("%d", w + i);
build(1, 1, n);
int L, R;
while(m--)
{
scanf("%d%d", &L, &R);
printf("%d
", query(1, L, R));
}
return 0;
}
3.小朋友排队 1215
因为最终是要从小到大排列,所以核心就是冒泡排序,计算逆序对的数量与每个小朋友需要换位置的次数。
树状数组tr[]盛放的是每个身高出现的次数
树状数组tr[]是以身高h[] + 1来当作下标的,因为树状数组下标不可以是0,所以就把身高全部自增1
树状数组维护的是每个身高出现的次数,刚开始树状数组是0,所以要一次add一个h[],因为之后add进去的元素是不会被之前的元素统计的,所以在讨论h[i]之前比h[i]大的格数是要从前往后遍历,同理计算h[i]之后比h[i]小的数要从后往前遍历,要注意的是统计完一个之后要把树状数组归零。
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 1e6 + 10;
int h[N], tr[N];
int sum[N], n;
int lowbit(int x)
{
return x & -x;
}
void add(int x, int y)
{
for(int i = x; i < N; i+= lowbit(i) ) tr[i] += y;
}
int query(int x)
{
int res = 0;
for(int i = x; i ; i -= lowbit(i)) res += tr[i];
return res;
}
int main()
{
cin >> n;
for(int i = 1; i <= n; i++ ) scanf("%d", h + i), h[i]++;
//之前比h[i]大的数的个数
for(int i = 1; i <= n; i++ )
{
sum[i] = query(N) - query(h[i]);
add(h[i], 1);
}
memset(tr, 0, sizeof tr);
//之后比h[i]小的数的个数
for(int i = n; i >= 1; i-- )
{
sum[i] += query(h[i] - 1);
add(h[i], 1);
}
LL res = 0;
for(int i = 1; i <= n; i++ ) res += (LL)sum[i] * (1 + sum[i]) / 2;
cout << res << endl;
return 0;
}
-
差分
1.差分 797
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 1e5 + 10;
int a[N], b[N];
int n, m;
void insert(int l, int r, int c)//核心
{
b[l] += c;
b[r + 1] -= c;
}
int main()
{
cin >> n >> m;
for(int i = 1; i <= n; i++ ) scanf("%d", a + i);
for(int i = 1; i <= n; i++ ) insert(i, i, a[i]);
int l, r, c;
while(m--)
{
scanf("%d%d%d", &l, &r, &c);
insert(l, r, c);
}
for(int i = 1; i <= n; i++ ) a[i] = a[i - 1] + b[i];
for(int i = 1; i <= n; i++ ) printf("%d ", a[i]);
return 0;
}
2.差分矩阵 798
//比较简单,了解差分的思想
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 1010;
int a[N][N], b[N][N];
int n, m, q;
void insert(int x1, int y1, int x2, int y2, int c)
{
b[x1][y1] += c;
b[x2 + 1][y1] -= c;
b[x1][y2 + 1] -= c;
b[x2 + 1][y2 + 1] += c;
}
int main()
{
cin >> n >> m >> q;
for(int i = 1; i <= n; i++ )
for(int j = 1; j <= m; j++ )
{
scanf("%d", &a[i][j]);
insert(i, j, i, j, a[i][j]);
}
int x1, y1, x2, y2, c;
while(q--)
{
scanf("%d%d%d%d%d", &x1, &y1, &x2, &y2, &c);
insert(x1, y1, x2, y2, c);
}
for(int i = 1; i <= n; i++ )
for(int j = 1; j <= m; j++ )
a[i][j] = a[i - 1][j] + a[i][j - 1] - a[i - 1][j - 1] + b[i][j];
for(int i = 1; i <= n; i++ )
{
for(int j = 1; j <= m; j++ )
{
printf("%d ", a[i][j]);
}
cout << endl;
}
return 0;
}
- else
螺旋折线 1237
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
int main()
{
int x, y;
cin >> x >> y;
if (abs(x) <= y) // 在上方
{
int n = y;
cout << (LL)(2 * n - 1) * (2 * n) + x - (-n) << endl;
}
else if (abs(y) <= x) // 在右方
{
int n = x;
cout << (LL)(2 * n) * (2 * n) + n - y << endl;
}
else if (abs(x) <= abs(y) + 1 && y < 0) // 在下方
{
int n = abs(y);
cout << (LL)(2 * n) * (2 * n + 1) + n - x << endl;
}
else // 在左方
{
int n = abs(x);
cout << (LL)(2 * n - 1) * (2 * n - 1) + y - (-n + 1) << endl;
}
return 0;
}