LOJ6433 「PKUSC2018」最大前缀和
题目大意
给定一个长度为 (n) 的序列 (a),求它在所有 (n!) 种排列方式下的【最大前缀和】之和。即,对所有 (1dots n) 的排列 (p),求:
[sum_{p} max_{i = 1}^{n}left{sum_{j = 1}^{i} a_{p_{j}}
ight}
]
数据范围:(1leq nleq 20),(sum_{i = 1}^{n} |a_i|leq 10^9)。
本题题解
约定:对于一个序列,如果前缀和在多个位置都能取到最大值,我们以最靠后的位置为准。这样每个序列的最大前缀和所在位置都是唯一的,考虑枚举这个位置 (i),观察序列需要满足什么条件:
- (forall j in[i + 1, n]),(sum_{k = i + 1}^{j} a_k < 0),即 (a) 的 ([i + 1, n]) 这段后缀的所有前缀和都小于 (0)。
- (forall j in[1, i - 1]),(sum_{k = j + 1}^{i} a_k geq 0),即 (a) 的 ([2, i]) 这段子段的所有后缀和都大于等于 (0)。
这两个条件,就是【(i) 是最大前缀和所在位置】的充分必要条件。因为如果存在不符合上述要求的 (j),则 (j) 会取代 (i) 成为最大前缀和所在位置;反之,则没有 (j) 能成为最大前缀和所在位置。
另外,请注意,在满足上述两个条件的前提下,我们对 (a_1) 是没有限制的,它可正,可负,可零。
状压 DP。
- 设 (f(s)) 表示:一个序列,用了集合 (s) 里的这些数,且所有后缀和都 (geq 0),这样的序列有多少个。转移时,考虑在序列前面加入一个数即可。
- 设 (g(s)) 表示:一个序列,用了集合 (s) 里的这些数,且所有前缀和都 (< 0),这样的序列有多少个。转移时,考虑在序列后面加入一个数即可。
请注意,在 (f(s)) 和 (g(s)) 里,我们认为总共有 (|s|!) 种序列,也就是两个相同的数值交换位置后被认为是不同的序列。
这两个 DP 的时间复杂度都是 (mathcal{O}(2^n n))。
完成 DP 后,考虑统计答案。枚举最终序列里 (a_1) 的值(前文说过,我们对 (a_1) 没有限制)。然后枚举一个不包含 (a_1) 的集合 (s),表示 ([2, i]) 这个子段所用的数。设剩下的 (n - 1 - |s|) 个数为集合 (t)。则此时的方案数是:(f(s) imes g(t))。对答案的贡献是:((a_1 + sum_{xin s}x) imes f(s) imes g(t))。
时间复杂度 (mathcal{O}(2^n n))。
参考代码
// problem: LOJ6433
#include <bits/stdc++.h>
using namespace std;
#define mk make_pair
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }
const int MAXN = 20;
const int MOD = 998244353;
inline int mod1(int x) { return x < MOD ? x : x - MOD; }
inline int mod2(int x) { return x < 0 ? x + MOD : x; }
inline void add(int &x, int y) { x = mod1(x + y); }
inline void sub(int &x, int y) { x = mod2(x - y); }
int n, a[MAXN + 5];
ll sum[1 << MAXN];
int f[1 << MAXN], g[1 << MAXN];
int main() {
cin >> n;
for (int i = 0; i < n; ++i) {
cin >> a[i];
}
for (int i = 0; i < (1 << n); ++i) {
for (int j = 0; j < n; ++j) {
if ((i >> j) & 1) {
sum[i] += a[j];
}
}
}
f[0] = 1; // f: 任意一个前缀都 >= 0 的方案数
for (int i = 0; i <= (1 << n) - 2; ++i) {
for (int j = 0; j < n; ++j) {
if (!((i >> j) & 1) && sum[i] + a[j] >= 0) {
add(f[i | (1 << j)], f[i]);
}
}
}
g[0] = 1; // g: 任意一个前缀都 < 0 的方案数
for (int i = 0; i <= (1 << n) - 2; ++i) {
for (int j = 0; j < n; ++j) {
if (!((i >> j) & 1) && sum[i] + a[j] < 0) {
add(g[i | (1 << j)], g[i]);
}
}
}
int ans = 0;
for (int i = 0; i < n; ++i) { // a[1]
int all = (((1 << n) - 1) ^ (1 << i));
for (int j = 0; j < (1 << n); ++j) {
if (!((j >> i) & 1) && sum[j] >= 0) {
int s = (sum[j] + a[i] + MOD + MOD) % MOD;
add(ans, (ll)f[j] * g[all ^ j] % MOD * s % MOD);
}
}
}
cout << ans << endl;
return 0;
}