CF1470E Strange Permutation
题目大意
给出一个 (1) 到 (n) 的排列 (p_{1dots n})。你可以选择若干个互不重叠的区间,并将它们翻转,称为一组翻转操作。翻转一个区间 ([l,r]) 的代价是 (r - l)。一组翻转的代价是所选区间的代价之和。你希望花费的代价不超过 (c)。
每组可能的翻转操作,都会得到一个排列。考虑其中本质不同的排列,将这些排列按字典序从小到大排序。
你需要回答 (q) 次询问。每次询问有两个参数 (i,j),表示问字典序第 (j) 小的排列里,第 (i) 个位置上的数是几。如果不存在 (j) 个排列,则输出 (-1)。
数据范围:(1leq nleq 3 imes10^4),(1leq cleq 4),(1leq qleq 3 imes 10^5)。
本题题解
首先,因为操作互不重叠,且原序列是个排列,所以得到的结果序列一定互不相同。那么一组翻转操作,就唯一对应一个结果序列。问题从找字典序第 (j) 小结果序列,转化为找排序后第 (j) 个翻转操作组。
考虑一个长度为 (n) 的序列,进行总代价不超过 (c) 的一组翻转操作,能得到多少种结果?枚举总代价 (i) ((0leq ileq c))。在序列每两个元素之间放一个小球(共 (n - 1) 个小球)。则一次操作的代价,就是区间里的小球数。所以恰好进行 (i) 次操作的方案数就是 ({n - 1choose i})。总代价不超过 (c) 的方案数,就是 ( ext{ways}(n, c) = sum_{i = 0}^{c}{n - 1choose i})。若询问的 (j) 大于这个数,则可以直接输出 (-1)。
设 (F(i, c, k)),表示仅考虑 ([i,n]) 这段后缀,花费的总代价不超过 (c) 的,(结果序列的字典序)第 (k) 小的翻转操作组中的,第一个翻转操作(换句话说它返回一组 ((l,r)) 表示这个操作)。
为了快速实现 (F) 函数,我们先预处理一个列表 (L(i, c))。表示 ([i, n]) 这段后缀,花费的总代价不超过 (c),按得到的序列的字典序小到大排序的,“第一次操作”的序列。即:序列里每个元素,是一个区间,表示对这个区间进行翻转操作,以这个翻转操作为第一次操作的翻转操作组。假设翻转操作为 ((l,r)),则这样的翻转操作组数量就是 (w = ext{ways}(n - r, c - (r - l)))。把这样的 ((w, l, r)) 作为一个元素存在 (L) 中。即:(L(i, c) = {(l_1, r_1, w_1), (l_2,r_2, w_2),dots,(l_k, r_k, w_k)})。那么求 (F(i, c, k)) 时,我们只需要在 (L(i, c)) 序列里二分出最小的 (j),满足 (w_1 + w_2 + dots + w_j geq k),则 (F(i, c, k) = (l_j,r_j))。
考虑求 (L)。按 (i) 从大到小递推。
从 (i + 1) 推到 (i) 时,只需要插入以 (i) 为左端点的翻转操作。即 ((i, i + j)) ((1leq jleq min(c, n - i))),这样的操作最多不超过 (c) 个。暴力枚举这些操作。考虑原来在 (L(i + 1, c)) 里的操作,它们对应的序列的第一个元素,现在都是 (p_i)。而 ((i, i + j)) 对应的序列的第一个元素,是 (p_{i + j})。因为 (p_{i + j} eq p_i),所以 ((i, i + j)) 这个操作,要么插入到原来所有操作的左边,要么插入到原来所有操作的右边。
于是我们用一个 ( exttt{deque}) 来维护,即可推出 (L(1, c))。在 (L(1, c)) 上找到一个连续的区间,即可得到 (L(i, c))。从构造过程不难看出 (L(1, c)) 的长度是 (mathcal{O}(nc)) 的。
预处理出 (L) 后,可以通过二分,在 (mathcal{O}(log(nc))) 的时间内实现 (F) 函数。通过调用最多不超过 (c) 次 (F) 函数,可以回答一个询问。
时间复杂度 (mathcal{O}(nc^2 +qclog (nc)))。
参考代码
实际提交时建议加上读入、输出优化。详见本博客公告。
// problem: CF1470E
#include <bits/stdc++.h>
using namespace std;
#define mk make_pair
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }
const int MAXC = 4;
const int MAXN = 3e4;
int n, c, q, a[MAXN + 5];
ll ways_eq(int len, int c) {
// 长度为 len, 操作代价和恰好为 c
// comb(len - 1, c)
if (len <= 1) {
return c == 0;
}
if (len - 1 < c) {
return 0;
}
ll res = 1;
for (int i = len - 1; i >= len - c; --i) {
res *= i;
}
for (int i = c; i > 1; --i) {
res /= i;
}
return res;
}
ll ways_leq(int len, int c) {
// 长度为 len, 操作代价和小于或等于 c
ll res = 0;
for (int i = 0; i <= c; ++i) {
res += ways_eq(len, i);
}
return res;
}
struct Node {
int l, r;
ll w; // 后面的操作方案数
Node() {}
Node(int _l, int _r, ll _w) {
l = _l;
r = _r;
w = _w;
}
};
Node dq[MAXC + 1][MAXN * MAXC * 2 + 10];
int ql[MAXC + 1], qr[MAXC + 1];
int sum[MAXC + 1][MAXN + 5];
ll sumw[MAXC + 1][MAXN * MAXC + 5];
bool cmp(Node lhs, Node rhs) {
return a[lhs.r] < a[rhs.r];
}
Node F(int st, int s, ll rank) {
int i = sum[s][st]; // 在 st 之前的, 一共有这么多区间
/*
// 暴力查找
ll cur = 0;
for (i = ql[s] + i; i <= qr[s]; ++i) {
cur += dq[s][i].w;
assert(cur == sumw[s][i - ql[s] + 1] - sumw[s][sum[s][st]]);
if (cur >= rank) {
return Node(dq[s][i].l, dq[s][i].r, cur - dq[s][i].w);
}
}
*/
int l = i + 1, r = qr[s] - ql[s] + 2;
while (l < r) {
int mid = (l + r) >> 1;
if (sumw[s][mid] - sumw[s][i] >= rank) {
r = mid;
} else {
l = mid + 1;
}
}
assert(l <= qr[s] - ql[s] + 1);
return Node(dq[s][ql[s] + l - 1].l, dq[s][ql[s] + l - 1].r, sumw[s][l - 1] - sumw[s][i]);
}
void solve_case() {
cin >> n >> c >> q;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
for (int s = 1; s <= c; ++s) { // 总代价小于或等于 s
// cerr << "-------- maxcost " << s << " --------" << endl;
ql[s] = MAXN * MAXC + 5, qr[s] = MAXN * MAXC + 4; // 队列清空
for (int i = 1; i <= n; ++i)
sum[s][i] = 0;
dq[s][++qr[s]] = Node(n, n, 1); // 什么都不翻转
for (int i = n - 1; i >= 1; --i) {
int dl = 0, dr = 0;
for (int j = 1; j <= min(s, n - i); ++j) {
// 翻转区间 [i, i + j]
ll w = ways_leq(n - (i + j), s - j);
if (a[i + j] < a[i]) {
// 翻转后是 a[i + j], 不翻转是 a[i], 两者比一比
// 翻转后更小, push_front
dq[s][ql[s] - (++dl)] = Node(i, i + j, w);
sum[s][i + 1]++;
} else {
dq[s][qr[s] + (++dr)] = Node(i, i + j, w);
}
}
if (dl) {
sort(dq[s] + ql[s] - dl, dq[s] + ql[s], cmp);
ql[s] -= dl;
}
if (dr) {
sort(dq[s] + qr[s] + 1, dq[s] + qr[s] + dr + 1, cmp);
qr[s] += dr;
}
}
// cerr << "print queue: " << endl;
// for (int i = ql[s]; i <= qr[s]; ++i) {
// cerr << dq[s][i].l << " " << dq[s][i].r << " " << dq[s][i].w << endl;
// }
// cerr << "queue end" << endl;
for (int i = 1; i <= n; ++i) {
sum[s][i] += sum[s][i - 1];
}
for (int i = ql[s]; i <= qr[s]; ++i) {
sumw[s][i - ql[s] + 1] = sumw[s][i - ql[s]] + dq[s][i].w;
}
}
ll lim = ways_leq(n, c);
for (int tq = 1; tq <= q; ++tq) {
int pos;
ll rank;
// cerr << "-------- query: " << endl;
cin >> pos >> rank;
if (rank > lim) {
cout << -1 << endl;
continue;
}
vector<pii> revs;
int p = 1;
int s = c;
while (1) {
Node t = F(p, s, rank);
// cerr << "** " << t.l << " " << t.r << " " << t.w << endl;
revs.push_back(make_pair(t.l, t.r));
assert(t.w < rank);
rank -= t.w;
s -= (t.r - t.l);
p = t.r + 1;
if (!s)
break;
if (p > n)
break;
}
/*
// 暴力翻转
static int aa[MAXN + 5];
for (int i = 1; i <= n; ++i) {
aa[i] = a[i];
}
for (int i = 0; i < SZ(revs); ++i) {
reverse(aa + revs[i].fi, aa + revs[i].se + 1);
}
cout << aa[pos] << endl;
*/
bool flag = 0;
for (int i = 0; i < SZ(revs); ++i) {
if (revs[i].fi <= pos && revs[i].se >= pos) {
cout << a[revs[i].se - (pos - revs[i].fi)] << endl;
flag = 1;
break;
}
}
if (!flag) {
cout << a[pos] << endl;
}
}
}
int main() {
int T; cin >> T; while (T--) {
solve_case();
}
return 0;
}