这道题要是直接想正解实在太难了,还得从一些特殊的情况一点点入手。
1.如果ai本身就是递增的,那么令bi = ai即最优解。
2.如果ai严格递减,则b1 = b2 = b3 = ……= bn = 中位数为最优解。这个可以用初中的几何证明:把 |bi - ai| 想象成数轴上两点间距离,那么 |x - a1| + |x - a2|的最优解x一定a1和a2之间;扩展到三项:|x - a1| + |x - a2| + |x - a3|,那么x就是中位数;以此类推,x是中位数得证。
现在再想想一般情况:对于数列{an},可能有一段递增,有一段递减,那么就把这些看成一段一段的:初始每一段的长度都为1,令中位数为ci,则ci = ai。若ci < ci+1,那么就保持不变;否则将ci和ci+1所在的区间合并,取一个新的中位数,作为新区间的答案。就这样不断的合并,直到ci不下降。
因此我们需要一个数据结构,支持合并、查询最大值和删除。为什么要查询最大值和删除呢?因为维护中位数可以只维护⌈1/2区间长度⌉小的数,用一个大根堆,则堆顶就是中位数。合并完两个区间后,就一直删除堆顶,直到元素个数 = ⌈1/2区间长度⌉。
那么就能想到用左偏树实现啦。所以左偏树并不是这道题的核心,实际上只起到了一个优化的作用。
最后一点,就是题中让求的是严格递增,然而我们一直是在不下降的前提下思考的。有一个巧妙的做法,就是开始ai -= i,就转换成了不下降。
代码稍微参考了别人的题解(左偏树还是不太熟啊)。
1 #include<cstdio> 2 #include<iostream> 3 #include<cmath> 4 #include<algorithm> 5 #include<cstring> 6 #include<cstdlib> 7 #include<cctype> 8 #include<vector> 9 #include<stack> 10 #include<queue> 11 using namespace std; 12 #define enter puts("") 13 #define space putchar(' ') 14 #define Mem(a, x) memset(a, x, sizeof(a)) 15 #define rg register 16 typedef long long ll; 17 typedef double db; 18 const int INF = 0x3f3f3f3f; 19 const db eps = 1e-8; 20 const int maxn = 1e6 + 5; 21 inline ll read() 22 { 23 ll ans = 0; 24 char ch = getchar(), last = ' '; 25 while(!isdigit(ch)) {last = ch; ch = getchar();} 26 while(isdigit(ch)) {ans = ans * 10 + ch - '0'; ch = getchar();} 27 if(last == '-') ans = -ans; 28 return ans; 29 } 30 inline void write(ll x) 31 { 32 if(x < 0) x = -x, putchar('-'); 33 if(x >= 10) write(x / 10); 34 putchar(x % 10 + '0'); 35 } 36 37 int n; 38 39 struct Node 40 { 41 int root, l, r, siz; 42 ll val; 43 }st[maxn]; 44 int top = 0; 45 46 ll val[maxn]; 47 int dis[maxn], ls[maxn], rs[maxn]; 48 int merge(int x, int y) 49 { 50 if(!x || !y) return x | y; 51 if(val[x] < val[y]) swap(x, y); 52 rs[x] = merge(rs[x], y); 53 if(dis[rs[x]] > dis[ls[x]]) swap(rs[x], ls[x]); 54 dis[x] = dis[rs[x]] + 1; 55 return x; 56 } 57 int Del(int x) 58 { 59 return merge(ls[x], rs[x]); 60 } 61 62 ll ans = 0; 63 64 int main() 65 { 66 n = read(); 67 for(int i = 1; i <= n; ++i) val[i] = read() - i; 68 st[++top] = (Node){1, 1, 1, 1, val[1]}; 69 for(int i = 2; i <= n; ++i) 70 { 71 st[++top] = (Node){i, i, i, 1, val[i]}; 72 while(top && st[top].val < st[top - 1].val) 73 { 74 top--; 75 st[top].root = merge(st[top].root, st[top + 1].root); 76 st[top].siz += st[top + 1].siz; 77 st[top].r = st[top + 1].r; 78 while(st[top].siz > (st[top].r - st[top].l + 1 + 1) >> 1) //向上取整 79 { 80 st[top].siz--; 81 st[top].root = Del(st[top].root); 82 } 83 st[top].val = val[st[top].root]; 84 } 85 } 86 for(int i = 1, j = 1; i <= n; ++i) 87 { 88 if(i > st[j].r) j++; 89 ans += abs(st[j].val - val[i]); 90 } 91 write(ans); enter; 92 for(int i = 1, j = 1; i <= n; ++i) 93 { 94 if(i > st[j].r) j++; 95 write(st[j].val + i), space; 96 } 97 enter; 98 return 0; 99 }