题意
给定一个(N) 个数的序列 (A),如果序列 (A) 不是非降序的,你需要在其中选择一个数删掉,
不断重复这个操作直到序列 (A) 非降。求有多少种不同的删数方案。注意:删掉的数的集
合相同,但是删数的顺序不同,视作不同的删数方案。
答案对(10^9+7)取模
(Nleq 2000,A_i leq 2000)
解法
这题十分巧妙,结合了容斥和DP
首先用DP求出这个序列中所有的一定长度的非降序列以及它们的个数
设(f[i][j])为长度为i,末尾为(a[j])的非降序列的数目,那么转移就十分显然了
可以看出,上面的式子要求和,还与位置与权值大小有关,可以考虑用树状数组进行优化。
控制每次转移时树状数组中的元素都在当前枚举位置之前,保证位置的合法
枚举长度,长度每次改变时清空树状数组,把上一长度求出的(dp)值赋在树状数组中,达到转移的目的
这样就能求出(c[i])了 ((c[i])为长度为i的非降序列的个数)
如果不考虑非法情况,最后的答案
用容斥原理去除非法情况
如果用上面的方式算出答案,在什么情况下会出现重复呢?
我们在计算剩下((n-i))个数的删除顺序时,有可能在某一个时刻序列已经是非降的了,按照题意应该停止;但是我们没有停止。就是这一部分构成了非法的情况
如何去掉这种非法情况呢?
在形成长度为i的非降序列之前,我们还要删掉一个数:如果在删掉这个数之前,整个序列就已经是非降的了,那么这一种情况代表的方案就是所有非法的情况
由于非降序列删去任意一个数仍是非降序列,所以这个删去的数有((i+1))种取值,也就是会贡献((i+1))个非法序列
又因为所有长度为(i)的非降序列一定包含在长度为(i+1)的非降序列中
也就是说,只要存在长度为(i+1)的非降序列,就一定有对于长度为(i)的不合法情况
我们可以先构成一个长度为(i+1)的非降序列,即(c[i+1]*(n-i-1)!)
当然,在构成长度为(i+1)非降序列时,也会有不合法的情况,但是反正是要求不合法的情况,这种不合法的情况也应该包括进去(我在说什么@#$@#%#^)
所以答案为
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 2500;
const int mod = 1e9 + 7;
int n, rng;
int a[N], o[N];
int fac[N], sum[N], f[N][N];
struct BIT {
int c[N];
void clear() {
memset(c, 0, sizeof c);
}
void insert(int p, int v) {
if (!p || !v) return;
for (; p <= rng; p += p & -p) (c[p] += v) %= mod;
}
int query(int p) {
int res = 0;
for (; p; p -= p & -p) (res += c[p]) %= mod;
return res;
}
} bit;
void config() {
fac[0] = 1;
for (int i = 1; i <= n; ++i) fac[i] = 1LL * fac[i - 1] * i % mod;
for (int i = 1; i <= n; ++i) f[1][i] = 1;
for (int i = 2; i <= n; ++i) {
bit.clear();
for (int j = i - 1; j <= n; ++j) {
if (j >= i) f[i][j] = bit.query(a[j]);
bit.insert(a[j], f[i - 1][j]);
}
}
for (int i = 1; i <= n; ++i)
for (int j = i; j <= n; ++j) sum[i] = (sum[i] + f[i][j]) % mod;
}
int main() {
freopen("strong.in", "r", stdin);
freopen("strong.out", "w", stdout);
scanf("%d", &n);
for (int i = 1; i <= n; ++i) scanf("%d", a + i);
rng = *max_element(a + 1, a + n + 1);
config();
long long ans = 0;
for (int i = 1; i <= n; ++i)
ans = (ans + 1LL * sum[i] * fac[n - i] % mod - 1LL * sum[i + 1] * fac[n - i - 1] % mod * (i + 1) % mod + mod) % mod;
printf("%lld
", ans);
return 0;
}