CF1034D Intervals of Intervals
题目大意
给定数轴上的 (n) 条线段 ([a_i, b_i])((1leq ileq n),(1leq a_i < b_ileq 10^9))。
定义区间 ([l, r])((1leq lleq rleq n))的价值为 (l,dots ,r) 这些线段的并的长度。
请你选择 (k) 个不同的区间,最大化所选区间的价值之和,并输出这个最大值。
数据范围:(1leq nleq 3 imes 10^5),(1leq kleq min{frac{n(n + 1)}{2},10^9})。
本题题解
考虑一个子问题:有 (n) 条线段,(q) 次询问,每次问一个区间 ([l, r])((1leq lleq rleq n))里的线段的并的长度。
可以离线,按编号从 (1) 到 (n) 考虑每条线段,假设当前线段 (i) 是询问的右端点,我们维护每个左端点的答案。
与此同时,对数轴上所有位置 (p),维护最后一个覆盖到它的线段编号 (mathrm{lst}(p)),特别地,如果位置 (p) 没有被覆盖过,认为 (mathrm{lst}(p) = 0)。那么一次询问 ([l,i]) 的答案就是 (mathrm{lst}) 值大于等于 (l) 的位置有多少个。
加入线段 (i),相当于把 ([a_i, b_i]) 这段区间里所有位置的 (mathrm{lst}) 值改为 (i)。假设一个位置原本的 (mathrm{lst}) 值为 (i'),那么 ([i' + 1, i]) 里的左端点,答案都要增加 (1)。
把相邻、且 (mathrm{lst}) 值相同的位置,缩成一条线段。那么只需要用 ( exttt{std::set}) 维护若干条线段。每个 (i) 只会带来新增 (mathcal{O}(1)) 条线段(虽然可能它会一次删去很多线段,但这个复杂度是均摊的,所以暴力删就行)。删除一个原有线段,对答案数组(每个左端点的答案)带来的修改是区间加,用线段树维护答案数组即可。
总时间复杂度 (mathcal{O}(nlog n))。另外提一句,把线段树的修改可持久化一下,该问题也可以在线做。
回到本题,我们显然是选价值最大的 (k) 个区间。考虑二分价值第 (k) 大的区间的价值。设当前二分值为 (mathrm{mid}),问题转化为求:价值大于等于 (mathrm{mid}) 的区间数量,根据数量 (<) 或 (geq k),调整二分的上下界。
设区间 ([l, r]) 的价值为 (v(l, r))。那么显然有:(v(l, r)geq v(l + 1, r)),(v(l, r)geq v(l, r - 1)) ((l < r))。
我们仍然从 (1) 到 (n) 枚举右端点。对于同一个右端点 (i),设每个左端点 (l) 的答案为 (f_i(l)),即 (f_i(l) = v(l, i))。那么有:(f_i(1)geq f_i(2)geq dotsgeq f_i(i))。
记最后一个满足 (f_i(l)geq mathrm{mid}) 的 (l) 为 (L(i))。特别地,如果 (f_i(1) < mathrm{mid}),(L_i(i) = 0)。我们要求的就是 (sum_{i = 1}^{n}L(i))。如果用线段树维护 (f) 数组(支持区间加),那么 (L(i)) 可以在线段树上二分求出来,不过这样总时间复杂度是 (mathcal{O}(nlog^2 n)) 的。还可以继续优化。
考虑到 (L(1)leq L(2)leq dotsleq L(n)),所以可以在枚举 (i) 的同时,用 ( ext{two pointers}) 的方法维护 (L(i))。在二分之前,先预处理出要做的区间加法(也就是把 ( exttt{set}) 的 (log) 放在二分外面),然后用差分来实现区间加。
时间复杂度 (mathcal{O}(nlog n))。
有一个要注意的细节是,二分出第 (k) 大的价值 (v) 后,答案并不一定是所有 (geq v) 的价值之和。因为价值等于 (v) 的区间可能有多个,超过 (k) 个的部分我们不能取。
参考代码
实际提交时建议使用读入、输出优化,详见本博客公告。
// problem: CF1034D
#include <bits/stdc++.h>
using namespace std;
#define mk make_pair
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }
const int MAXN = 3e5;
const int MAXR = 1e9;
int n, K;
pii seg[MAXN + 5];
vector<pii> ev[MAXN + 5];
#define mk3(x, y, z) mk(mk(x, y), z)
#define IT set<pair<pii, int> > :: iterator
set<pair<pii, int> > s;
void insert(int sl, int sr, int sid) {
IT it = s.lower_bound(mk3(sl, 0, 0));
if (it != s.begin()) {
--it;
if (it -> fi.se > sl) {
pair<pii, int> tmp = *it;
s.erase(it);
s.insert(mk3(tmp.fi.fi, sl, tmp.se));
s.insert(mk3(sl, tmp.fi.se, tmp.se));
}
}
it = s.lower_bound(mk3(sr, 0, 0));
if (it != s.begin()) {
--it;
if (it -> fi.se > sr) {
pair<pii, int> tmp = *it;
s.erase(it);
s.insert(mk3(tmp.fi.fi, sr, tmp.se));
s.insert(mk3(sr, tmp.fi.se, tmp.se));
}
}
IT bg = s.lower_bound(mk3(sl, 0, 0));
IT ed = bg;
if (bg != s.end() && bg -> fi.se <= sr) {
int lst = sl;
do {
int cl = ed -> fi.fi;
int cr = ed -> fi.se;
if (cl != lst)
ev[sid].push_back(mk(1, cl - lst));
ev[sid].push_back(mk(ed -> se + 1, cr - cl));
lst = cr;
++ed;
} while(ed != s.end() && ed -> fi.se <= sr);
if (lst != sr)
ev[sid].push_back(mk(1, sr - lst));
s.erase(bg, ed);
} else {
ev[sid].push_back(mk(1, sr - sl));
}
s.insert(mk3(sl, sr, sid));
// cerr << "**********" << endl;
// for (it = s.begin(); it != s.end(); ++it) {
// cerr << (it -> fi.fi) << " " << (it -> fi.se) << endl;
// }
}
#undef IT
#undef mk3
int f[MAXN + 5];
ll calc(int min_size) {
// 并的大小 >= min_size 的区间有几个
// 时间复杂度 O(n)
for (int i = 1; i <= n; ++i) {
f[i] = 0; // 清空
}
int l = 0;
ll res = 0;
for (int i = 1; i <= n; ++i) {
// 右端点: i
for (int j = 0; j < SZ(ev[i]); ++j) {
int p = ev[i][j].fi;
int v = ev[i][j].se;
ckmax(p, l);
// 对 l = p...i 的答案, 加上 v
f[p] += v;
f[i + 1] -= v;
}
// l: 使得 [l, i] 的答案 >= min_size 的最大的 l
while (l + 1 <= i && f[l + 1] + f[l] >= min_size) {
f[l + 1] += f[l];
if (l + 1 > 1)
assert(f[l + 1] <= f[l]);
++l;
}
// cerr << "** " << l << endl;
res += l;
}
return res;
}
ll calc_ans(int min_size) {
// 并的大小 >= min_size 的区间的答案之和
// 时间复杂度 O(n)
for (int i = 1; i <= n; ++i) {
f[i] = 0; // 清空
}
int l = 0;
ll sum = 0, res = 0;
for (int i = 1; i <= n; ++i) {
// 右端点: i
for (int j = 0; j < SZ(ev[i]); ++j) {
int p = ev[i][j].fi;
int v = ev[i][j].se; // 对 l = p...i 的答案, 加上 v
if (p <= l) {
sum += (ll)v * (l - p + 1);
}
ckmax(p, l);
f[p] += v;
f[i + 1] -= v;
}
// l: 使得 [l, i] 的答案 >= min_size 的最大的 l
while (l + 1 <= i && f[l + 1] + f[l] >= min_size) {
f[l + 1] += f[l];
sum += f[l + 1];
++l;
}
res += sum;
}
return res;
}
int main() {
cin >> n >> K;
for (int i = 1; i <= n; ++i) {
cin >> seg[i].fi >> seg[i].se;
}
for (int i = 1; i <= n; ++i) {
insert(seg[i].fi, seg[i].se, i);
}
int l = 0, r = MAXR;
while (l < r) {
int mid = (l + r + 1) >> 1;
if (calc(mid) >= K) {
l = mid;
} else {
r = mid - 1;
}
}
// cerr << l << endl;
cout << calc_ans(l + 1) + (ll)l * (K - calc(l + 1)) << endl;
return 0;
}