题目链接: BZOJ - 3110
题目分析
这道题是一道树套树的典型题目,我们使用线段树套线段树,一层是区间线段树,一层是权值线段树。一般的思路是外层用区间线段树,内层用权值线段树,但是这样貌似会很难写。多数题解都使用了外层权值线段树,内层区间线段树,于是我就这样写了。每次插入会在 logn 棵线段树中一共建 log^2(n) 个结点,所以空间应该开到 O(nlog^2(n)) 。由于这道题查询的是区间第 k 大,所以我们存在线段树中的数值是输入数值的相反数(再加上 n 使其为正数),这样查第 k 小就可以了。在查询区间第 k 大值的时候,我们用类似二分的方法,一层一层地逼近答案。
写代码的时候出现的错误:在每一棵区间线段树中修改数值的时候,应该调用的是像 Insert(Lc[x], 1, n, l, r) 这样子,但我经常写成 Insert(x << 1, s, t, l, r) 之类的。注意!
代码
#include <iostream> #include <cstdio> #include <cstring> #include <cstdlib> #include <cmath> #include <algorithm> using namespace std; const int MaxN = 100000 + 5, MaxM = 100000 * 16 * 16 + 5; int n, m, f, a, b, c, Index, Ans; int Root[MaxN * 4], Lc[MaxM], Rc[MaxM], Sum[MaxM], Lazy[MaxM]; inline int gmin(int a, int b) { return a < b ? a : b; } inline int gmax(int a, int b) { return a > b ? a : b; } int Get(int x, int s, int t, int l, int r) { if (l <= s && r >= t) return Sum[x]; int p = 0, q = 0, m = (s + t) >> 1; if (l <= m) p = Get(Lc[x], s, m, l, r); if (r >= m + 1) q = Get(Rc[x], m + 1, t, l, r); return (p + q + Lazy[x] * (gmin(t, r) - gmax(s, l) + 1)); } int GetKth(int l, int r, int k) { int s = 1, t = n * 2, m, x = 1, Temp; while (s != t) { m = (s + t) >> 1; if ((Temp = Get(Root[x << 1], 1, n, l, r)) >= k) { t = m; x = x << 1; } else { s = m + 1; x = x << 1 | 1; k -= Temp; } } return s; } void Insert(int &x, int s, int t, int l, int r) { if (x == 0) x = ++Index; if (l <= s && r >= t) { Sum[x] += t - s + 1; ++Lazy[x]; return; } int m = (s + t) >> 1; if (l <= m) Insert(Lc[x], s, m, l, r); if (r >= m + 1) Insert(Rc[x], m + 1, t, l, r); Sum[x] = Sum[Lc[x]] + Sum[Rc[x]] + Lazy[x] * (t - s + 1); } void Add(int l, int r, int Num) { int s = 1, t = n * 2, m, x = 1; while (s != t) { Insert(Root[x], 1, n, l, r); m = (s + t) >> 1; if (Num <= m) { t = m; x = x << 1; } else { s = m + 1; x = x << 1 | 1; } } Insert(Root[x], 1, n, l, r); } int main() { scanf("%d%d", &n, &m); Index = 0; for (int i = 1; i <= m; ++i) { scanf("%d%d%d%d", &f, &a, &b, &c); if (f == 1) { c = -c + n + 1; Add(a, b, c); } else { Ans = GetKth(a, b, c); Ans = -Ans + n + 1; printf("%d ", Ans); } } return 0; }