CF1188D Make Equal
题目大意
有 (n) 个非负整数 (a_1, a_2, dots, a_n)。每次操作,你可以从中任选一个数,并把它加上 (2) 的任意非负整数次幂。求使得 (n) 个数相等所需的最小操作次数。可以证明答案不超过 (10^{18})。
数据范围:(1leq nleq 10^5),(1leq a_ileq 10^{17})。
本题题解
先将 (a) 序列排序。
假设最终所有数相等时等于 (y)。那么所需的最小操作次数一定是:
其中 (mathrm{bitcnt}(x)) 表示 (x) 在二进制下 (1) 的个数。
显然 (ygeq a_n)。因为减法比加法复杂,我们稍微转化一下。设 (x = y - a_n),(b_i = a_n - a_i),则上式等于:
问题转化为,求一个 (x),使得该式的值最小。
按二进制位,从低到高 DP,假设当前考虑到第 (k) 位。考虑 ((x + b_i)) 第 (k) 位的数值((0) 还是 (1)),会受到哪些因素影响:
- (b_i) 第 (k) 位的数值。
- (x) 第 (k) 位的数值。
- 第 (k - 1) 位有没有进位。
1 是已知的。2 在转移时分两种情况讨论即可。关键是 3,因为 (n) 个数第 (k - 1) 位的进位情况是不同的,也就是说有 (2^n) 种可能的状态,全记下来的话状态数爆炸了。
下面是本题的核心:因为每个数加上的 (x) 相同,所以 (b_imod 2^k) 越大的数,越有可能进位。也就是说,如果把 (b) 序列按照 (b_imod 2^k) 从小到大排序,那么在上一位发生进位的数是一段后缀。于是,我们在状态里,只需要记录上一位有几个数发生了进位,就相当于知道了哪些数发生了进位。
具体来说,转移时枚举上一位发生进位的数的数量,记为 (j)。将 (b) 序列按照 (b_imod 2^k) 的大小排序后,分为前 (n - j) 个数(没发生进位的)和后 (j) 个数(发生了进位的),再根据 (b_i) 第 (k) 位是 (0) 还是 (1),分为四小类。转移考虑 (x) 的第 (k) 位是 (0) 还是 (1)。那么:
这张表格描述的是 (n) 个 ((x + b_i)),共有 (4 imes 2 = 8) 种情况:第一行里的四类表示,(b_imod 2^k) 是前 (n - j) 个数还是后 (j) 个数(这决定了它在上一位有没有进位),以及 (b_i) 的第 (k) 位是 (0) 还是 (1)。第一列里的两类,表示 (x) 的第 (k) 位是 (0) 还是 (1)。表格中,“进位/不进位”决定了转移到的 DP 状态;(0, 1) 也就是 ((x + b_i)) 第 (k) 位的值,这决定了新的 DP 值。具体可以见代码。
接下来还有最后一个问题,(x) 要枚举到多大(有多少位)?可以证明,一定存在最优解,满足 (x leq max{b_i})。
证明:
设 (B = max{b_i})。
如果,(x > B)。设 (s) 表示 ((x + B)) 的最高位。那么 (2^{s + 1} > x + B geq 2^{s})。所以 (2x > 2^{s}),(x > 2^{s - 1})。设 (x' = x - 2^{s - 1})。
对任意位置 (i),((x + b_i)) 的最高位要么为 (s),要么为 (s - 1)。
- 若 ((x + b_i)) 的第 (s - 1) 位为 (1),那么 (mathrm{bitcnt}(b_i + x') = mathrm{bitcnt}(x + b_i) - 1)。
- 若 ((x + b_i)) 的第 (s - 1) 位为 (0),那么第 (s) 位一定为 (1),所以 (mathrm{bitcnt}(b_i + x') = mathrm{bitcnt}(x + b_i))。
可以发现,(x') 的结果一定不差于 (x)。所以只要最优的 (x > B),我们就不断将它减去 (2^{s - 1}),最终一定能得到 (xleq B) 的最优解。
因为 (xleq max{b_i}leqmax{a_i}),所以 DP 只需要进行 (mathcal{O}(log a)) 位,每一位要排序,所以总时间复杂度 (mathcal{O}(ncdot log acdot log n))。
有一个优化是,在上一位排好序的序列的基础上,根据当前位是 (0) 还是 (1),拆成两个子序列,再拼起来。这样就不用每一位都重新排序了,时间复杂度 (mathcal{O}(nlog a))。
参考代码
// problem: CF1188D
#include <bits/stdc++.h>
using namespace std;
#define mk make_pair
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }
const int MAXN = 1e5;
const int MAXBIT = 57; // [0, 57]
const int INF = 1e9;
int n;
ull a[MAXN + 5];
int id[MAXN + 5], id0[MAXN + 5], id1[MAXN + 5];
int dp[MAXBIT + 1][MAXN + 5];
int pre[MAXN + 5][2], suf[MAXN + 5][2];
int main() {
cin >> n;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
sort(a + 1, a + n + 1);
for (int i = 1; i <= n; ++i) {
a[i] = a[n] - a[i];
}
for (int i = 0; i <= MAXBIT; ++i)
for (int j = 0; j <= n; ++j)
dp[i][j] = INF;
int cnt = 0;
for (int i = 1; i <= n; ++i)
if ((a[i] & 1) == 0)
id[++cnt] = i;
dp[0][0] = n - cnt; // x 当前位填 0, 0 人进位, 结果中有 n - cnt 个 1
ckmin(dp[0][n - cnt], cnt); // x 当前位填 1, n - cnt 人进位, 结果中有 cnt 个 1
for (int i = 1; i <= n; ++i)
if ((a[i] & 1) == 1)
id[++cnt] = i;
assert(cnt == n);
for (int i = 1; i <= MAXBIT; ++i) {
// sort(id + 1, id + n + 1, cmp);
for (int j = 1; j <= n; ++j) {
pre[j][0] = pre[j - 1][0] + (((a[id[j]] >> i) & 1) == 0);
pre[j][1] = pre[j - 1][1] + (((a[id[j]] >> i) & 1) == 1);
}
for (int j = n; j >= 1; --j) {
suf[j][0] = suf[j + 1][0] + (((a[id[j]] >> i) & 1) == 0);
suf[j][1] = suf[j + 1][1] + (((a[id[j]] >> i) & 1) == 1);
}
for (int j = 0; j <= n; ++j) if (dp[i - 1][j] != INF) {
ckmin(dp[i][suf[n - j + 1][1]], dp[i - 1][j] + pre[n - j][1] + suf[n - j + 1][0]); // x 当前位填 0
if (i < MAXBIT)
ckmin(dp[i][pre[n - j][1] + j], dp[i - 1][j] + pre[n - j][0] + suf[n - j + 1][1]); // x 当前位填 1
}
// 相当于排序:
int cnt0 = 0, cnt1 = 0;
for (int j = 1; j <= n; ++j)
if (((a[id[j]] >> i) & 1) == 0)
id0[++cnt0] = id[j];
for (int j = 1; j <= n; ++j)
if (((a[id[j]] >> i) & 1) == 1)
id1[++cnt1] = id[j];
for (int j = 1; j <= cnt0; ++j)
id[j] = id0[j];
for (int j = 1; j <= cnt1; ++j)
id[cnt0 + j] = id1[j];
}
int ans = INF;
for (int j = 0; j <= n; ++j)
ckmin(ans, dp[MAXBIT][j]);
cout << ans << endl;
return 0;
}