对于一个长度为 (n) 的序列 (a) 做插入排序,即依次考虑 (a_2, cdots, a_n) 每个元素:
-
如果 (a_i ge a_{i - 1}),即前缀依然保持不下降,不做操作。
-
否则,找到最靠前的位置 (p),使得 (a_i < a_p),然后将 (a_i) 插入到 (a_p) 前面,并重新标号序列。记这次插入为 ((i, p))。
接下来给定 (m) 个二元组 ((x_i, y_i)),求有多少个长度为 (n) 的序列 (a),满足 (forall i, a_i in [1, n] cap mathbb N) 且做插入排序恰好进行给定的 (m) 次插入。
输出序列的数量对 (998244353) 取模的结果。
(2 le n le 2 cdot 10^5),(0 le m < n),(2 le x_1 < x_2 < cdots le n),(1 le y_i < x_i),(sum n le 2 cdot 10^5),(sum m) 没有限制。
3s, 512MB
当 (m) 次插入确定了以后,本质上是一个从原序列到新序列的置换,例如 (n = 3),只进行了 ((2, 1)) 这次插入,那么相当于是从原先的 ([a_1, a_2, a_3]) 变成了 ([a_2, a_1, a_3])。
那么,根据最终序列是不下降的,可以得到一个 (a_2 le a_1 le a_3) 的不等式。
这个条件显然不是充要的,原因在于,当发生插入 ((x, y)) 时,确定了一组明确的严格小于的关系(被插入的数 (<) 原来该位置上的数,注意不是 (a_x < a_y)),而非简单的不下降。
在上面的例子中,发生 ((2, 1)) 这次插入表示 (a_2 < a_1)。所以真正严格的限制应当是 (a_2 < a_1 le a_3)。
先不妨不关心最终的限制是什么,假设我们知道了最终限制中有 (c) 个「小于号」和 (n - 1 - c) 个「小于等于号」,根据隔板法,可以知道方案数应当是 (inom{2n - 1 - c}{n})。
问题转化为了求出 (c) 的值,即小于号的个数。
根据上面的过程,可以发现,一个数前面是小于号而不是小于等于号的要求是「有另一个数直接插入到了它的前面」。
考虑逆序这些插入,维护一个集合 (S)。初始时 (S) 包含了 (1,2, cdots, n) 的所有的数。
每次遇到一个 ((x_i, y_i)) 时,查询 (S) 中第 (y_i) 小的数 (p) 和第 (y_i + 1) 小的数 (q),即将 (p) 插入到了 (q) 之前。那么从 (S) 中删除 (p),令 (tag_q gets 1),表示有个数插入到了 (q) 之前。
最终 (c) 就等于 $tag= 1 $ 的位置的个数。
注意实现应当是 (O(m log n)),不要搞成了 (O(n log n))。
代码用线段树二分实现的。
#include <bits/stdc++.h>
typedef long long ll;
const int N = 2e5 + 5, MOD = 998244353;
int n, m, x[N], y[N], roll[N], val[N << 2], fac[N << 1], ifac[N << 1];
std::set<int> pos;
int qpow(int a, int p) {
int res = 1;
while(p) {
if(p & 1) res = (ll)res * a % MOD;
a = (ll)a * a % MOD; p >>= 1;
}
return res;
}
void init(int l, int r, int rt) {
val[rt] = r - l + 1;
if(l == r) return;
int mid = (l + r) >> 1;
init(l, mid, rt << 1); init(mid + 1, r, rt << 1 | 1);
}
int query(int l, int r, int rt, int k) {
if(l == r) return l;
int mid = (l + r) >> 1;
if(val[rt << 1] >= k) return query(l, mid, rt << 1, k);
else return query(mid + 1, r, rt << 1 | 1, k - val[rt << 1]);
}
void modify(int l, int r, int rt, int p, int v) {
val[rt] += v;
if(l == r) return;
int mid = (l + r) >> 1;
if(p <= mid) modify(l, mid, rt << 1, p, v);
else modify(mid + 1, r, rt << 1 | 1, p, v);
}
int C(int a, int b) {
if(a < 0 || b < 0 || a < b) return 0;
return (ll)fac[a] * ifac[b] % MOD * ifac[a - b] % MOD;
}
int main() {
fac[0] = 1;
for(int i = 1; i < N * 2; i++) fac[i] = (ll)fac[i - 1] * i % MOD;
ifac[N * 2 - 1] = qpow(fac[N * 2 - 1], MOD - 2);
for(int i = N * 2 - 1; i; i--) ifac[i - 1] = (ll)ifac[i] * i % MOD;
int test; scanf("%d", &test); init(1, N - 1, 1);
while(test--) {
scanf("%d %d", &n, &m); pos.clear();
for(int i = 1; i <= m; i++) scanf("%d %d", &x[i], &y[i]);
for(int i = m; i; i--) {
int p = query(1, N - 1, 1, y[i]), q = query(1, N - 1, 1, y[i] + 1);
modify(1, N - 1, 1, p, -1);
roll[i] = p; pos.insert(q);
}
for(int i = 1; i <= m; i++) modify(1, N - 1, 1, roll[i], 1);
int c = (int)pos.size();
printf("%d
", C(n * 2 - 1 - c, n));
}
return 0;
}