P6242 【模板】线段树 3
线段树维护历史最值+区间取min。
区间取min:
线段树维护一个区间最大值((MaxA))和严格次大值((se)),还要维护最大值个数(cnt),区间和(sum),然后分情况:(设当前与(k)取min)
当(k >= t[o].MaxA)时,直接返回;
当(t[o].se < k < t[o].MaxA)时,(t[o].sum += t[o].cnt * (k - t[o].MaxA)),(t[o].MaxA = k);
当(k <= t[o].se)时,继续往下递归。
具体维护看代码:
void up(int o) {
if(t[ls(o)].MaxA > t[rs(o)].MaxA) {
t[o].MaxA = t[ls(o)].MaxA; t[o].cnt = t[ls(o)].cnt;
t[o].se = max(t[ls(o)].se, t[rs(o)].MaxA);
}
else if(t[ls(o)].MaxA < t[rs(o)].MaxA) {
t[o].MaxA = t[rs(o)].MaxA; t[o].cnt = t[rs(o)].cnt;
t[o].se = max(t[rs(o)].se, t[ls(o)].MaxA);
}
else {
t[o].MaxA = t[rs(o)].MaxA; t[o].cnt = t[ls(o)].cnt + t[rs(o)].cnt;
t[o].se = max(t[ls(o)].se, t[rs(o)].se);
}
t[o].MaxB = max(t[ls(o)].MaxB, t[rs(o)].MaxB);
t[o].sum = t[ls(o)].sum + t[rs(o)].sum;
}
维护历史最大值:
要维护4个标记:最大值加减标记((add1)),最大值历史最大加减标记((add1)_ ()),非最大值加减标记((add2)),非最大值历史最大加减标记((add2)_())。
void modify(int o, int l, int r, int add1, int add1_, int add2, int add2_) {
t[o].sum += 1ll * add1 * t[o].cnt + 1ll * add2 * (r - l + 1 - t[o].cnt);
t[o].MaxB = max(t[o].MaxB, t[o].MaxA + add1_); // MaxB代表历史最大值,用add1_更新
t[o].add1_ = max(t[o].add1_, t[o].add1 + add1_); //标记也记得更新
t[o].MaxA += add1; t[o].add1 += add1;
t[o].add2_ = max(t[o].add2_, t[o].add2 + add2_);
if(t[o].se != -inf) t[o].se += add2; t[o].add2 += add2;
}
完整代码:
#include <bits/stdc++.h>
#define ls(o) (o << 1)
#define rs(o) (o << 1 | 1)
#define mid ((l + r) >> 1)
using namespace std;
inline long long read() {
long long s = 0, f = 1; char ch;
while(!isdigit(ch = getchar())) (ch == '-') && (f = -f);
for(s = ch ^ 48;isdigit(ch = getchar()); s = (s << 1) + (s << 3) + (ch ^ 48));
return s * f;
}
const int N = 1e6 + 5, inf = 2e9;
int n, m;
long long x;
struct tree {
long long sum;
int MaxA, MaxB, cnt, se;
int add1, add1_, add2, add2_;
} t[N << 2];
void up(int o) {
if(t[ls(o)].MaxA > t[rs(o)].MaxA) {
t[o].MaxA = t[ls(o)].MaxA; t[o].cnt = t[ls(o)].cnt;
t[o].se = max(t[ls(o)].se, t[rs(o)].MaxA);
}
else if(t[ls(o)].MaxA < t[rs(o)].MaxA) {
t[o].MaxA = t[rs(o)].MaxA; t[o].cnt = t[rs(o)].cnt;
t[o].se = max(t[rs(o)].se, t[ls(o)].MaxA);
}
else {
t[o].MaxA = t[rs(o)].MaxA; t[o].cnt = t[ls(o)].cnt + t[rs(o)].cnt;
t[o].se = max(t[ls(o)].se, t[rs(o)].se);
}
t[o].MaxB = max(t[ls(o)].MaxB, t[rs(o)].MaxB);
t[o].sum = t[ls(o)].sum + t[rs(o)].sum;
}
void build(int o, int l, int r) {
if(l == r) {
t[o].MaxA = t[o].MaxB = t[o].sum = read();
t[o].se = -inf; t[o].cnt = 1;
return ;
}
build(ls(o), l, mid); build(rs(o), mid + 1, r);
up(o);
}
void modify(int o, int l, int r, int add1, int add1_, int add2, int add2_) {
t[o].sum += 1ll * add1 * t[o].cnt + 1ll * add2 * (r - l + 1 - t[o].cnt);
t[o].MaxB = max(t[o].MaxB, t[o].MaxA + add1_);
t[o].add1_ = max(t[o].add1_, t[o].add1 + add1_);
t[o].MaxA += add1; t[o].add1 += add1;
t[o].add2_ = max(t[o].add2_, t[o].add2 + add2_);
if(t[o].se != -inf) t[o].se += add2; t[o].add2 += add2;
}
void down(int o, int l, int r) {
int tmp = max(t[ls(o)].MaxA, t[rs(o)].MaxA);
if(t[ls(o)].MaxA == tmp)
modify(ls(o), l, mid, t[o].add1, t[o].add1_, t[o].add2, t[o].add2_);
else
modify(ls(o), l, mid, t[o].add2, t[o].add2_, t[o].add2, t[o].add2_);
//add1,add1_是维护区间最大值的标记,如果这个区间没有父节点的最大值,那么最大值标记不下传
if(t[rs(o)].MaxA == tmp)
modify(rs(o), mid + 1, r, t[o].add1, t[o].add1_, t[o].add2, t[o].add2_);
else
modify(rs(o), mid + 1, r, t[o].add2, t[o].add2_, t[o].add2, t[o].add2_);
t[o].add1 = t[o].add1_ = t[o].add2 = t[o].add2_ = 0;
}
void change_add(int o, int l, int r, int x, int y, int k) {
if(x <= l && y >= r) { modify(o, l, r, k, k, k, k); return ; }
down(o, l, r);
if(x <= mid) change_add(ls(o), l, mid, x, y, k);
if(y > mid) change_add(rs(o), mid + 1, r, x, y, k);
up(o);
}
void change_min(int o, int l, int r, int x, int y, int k) {
if(t[o].MaxA <= k) return ;
if(x <= l && y >= r && t[o].MaxA > k && t[o].se < k) {
modify(o, l, r, k - t[o].MaxA, k - t[o].MaxA, 0, 0);
return ;
}
down(o, l, r);
if(x <= mid) change_min(ls(o), l, mid, x, y, k);
if(y > mid) change_min(rs(o), mid + 1, r, x, y, k);
up(o);
}
long long query_sum(int o, int l, int r, int x, int y) {
if(x <= l && y >= r) return t[o].sum;
down(o, l, r);
long long res = 0;
if(x <= mid) res += query_sum(ls(o), l, mid, x, y);
if(y > mid) res += query_sum(rs(o), mid + 1, r, x, y);
return res;
}
int query_A(int o, int l, int r, int x, int y) {
if(x <= l && y >= r) { return t[o].MaxA; }
down(o, l, r);
int res = -inf;
if(x <= mid) res = max(res, query_A(ls(o), l, mid, x, y));
if(y > mid) res = max(res, query_A(rs(o), mid + 1, r, x, y));
return res;
}
int query_B(int o, int l, int r, int x, int y) {
if(x <= l && y >= r) { return t[o].MaxB; }
down(o, l, r);
int res = -inf;
if(x <= mid) res = max(res, query_B(ls(o), l, mid, x, y));
if(y > mid) res = max(res, query_B(rs(o), mid + 1, r, x, y));
return res;
}
int main() {
n = read(); m = read();
build(1, 1, n);
for(int i = 1, opt, l, r;i <= m; i++) {
opt = read(); l = read(); r = read();
if(opt == 1) x = read(), change_add(1, 1, n, l, r, x);
if(opt == 2) x = read(), change_min(1, 1, n, l, r, x);
if(opt == 3) printf("%lld
", query_sum(1, 1, n, l, r));
if(opt == 4) printf("%d
", query_A(1, 1, n, l, r));
if(opt == 5) printf("%d
", query_B(1, 1, n, l, r));
}
return 0;
}