@description@
定义一个序列是好的,当且仅当这个序列中,相等的两个数之间的所有数全部相等。
每次操作可以将某个元素值对应的所有元素修改成另一元素值。
一个序列的困难度定义为,将这个序列修改成好的序列的最少需要修改的位置数。
现在给定初始序列 a1, a2, ..., an 以及 q 次操作,每次操作为 i x,表示将第 i 个元素修改为 x。
计算初始时以及每次操作后序列的困难度。
Input
第一行包含两个整数 n 与 q (1≤n≤200000, 0≤q≤200000),表示序列长度与操作数。
第二行包含 n 个整数 a1, a2, ..., an (1≤ai≤200000),表示初始序列。
接下来 q 行每行两个整数 it, xt (1≤it≤n, 1≤xt≤200000),描述了一次操作。
Output
输出 q + 1 个整数,表示初始序列以及每次操作后的序列的困难度。
Example
input
5 6
1 2 1 2 1
2 1
4 1
5 3
2 3
4 2
2 1
output
2
1
0
0
2
3
0
@solution@
假如如果没有修改(即 easy 版本),应该怎么做?
首先假如某个元素值 k,它最左边出现的位置为 l[k],最右边出现的位置为 r[k],则 l[k]~r[k] 必须要修改为同一元素。
假如最后 a[p]~a[q] 必须为同一元素,则我们保留 a[p]~a[q] 中出现次数最多的元素,修改掉其他的元素是最优的。
我们用数值来描述限制:定义 b[i] 表示 a[i] 与 a[i+1] 被多少次要求为同一元素。则对于每一个 k,我们将 b[ l[k] ... r[k]-1 ] 区间 + 1,即可维护 b 的值。
每次 b 中连续的非零值就形成了最后要求是同一元素的区间。
我们只需要维护出这些非零值形成的连续区间中,区间的长度 - 区间中元素出现次数最大值,这个值之和,即可得到答案。
我们将某一元素 k 出现的次数 cnt[k] 维护在 l[k] 这一位置,记作 c[l[k]]。再定义 d[i] 表示 max{c[i], c[i+1]}。
则只需要维护非零值形成的连续区间中,区间的长度 - 区间中 d 的最大值,该值之和即可。
给 b 序列区间 + 1 可以使用线段树实现。考虑使用线段树怎么实现查询功能。
我们再加入两个哨兵结点 b[0] 与 b[n+1],方便下面的实现。
因为非零值的相关数值不好维护,我们在线段树中维护的是 “删掉当前区间内所有的 b[i] 最小值后,剩下的连续区间的信息”。
为了方便合并,我们统计时并不把包含区间端点的区间统计入当前线段树结点。
因为哨兵节点的存在,我们最终线段树的根一定为我们需要的答案。
在线段树中每个结点先维护 tg, cmn, mx 这几个值,分别表示加法标记,当前区间的 b[i] 最小值,当前区间的 d[i] 最大值。
为了方便合并,再维护 lmx, rmx,表示删掉所有最小值后,左/右端点所在连续区间的最大值(注意可能不存在这样的区间,这时候记作 -1)。
最后,还要维护 ans, cnt,表示删掉所有最小值后, 区间 d 最大值之和与区间长度之和。
这里的“区间长度”,对于完整的区间(即不含区间端点)定义为 b 序列中的区间长度 + 1(即对应的原序列 a 中的区间长度);否则如果包含区间端点,定义为 b 序列中的区间长度本身。
维护过程,只需要讨论左右儿子的 cmn 是否等于当前结点的 cmn,进行分类讨论。
具体细节可以参考代码。
总时间复杂度 O(nlogn)。
@accepted code@
#include<set>
#include<cstdio>
#include<algorithm>
using namespace std;
const int MAXN = 200000;
int arr[MAXN + 5];
struct segtree{
#define lch (x<<1)
#define rch (x<<1|1)
struct node{
int l, r;
int tg, cmn;
int lmx, rmx, mx, ans, cnt;
}t[4*MAXN + 5];
void pushup(int x) {
t[x].mx = max(t[lch].mx, t[rch].mx);
t[x].cmn = min(t[lch].cmn, t[rch].cmn);
if( t[x].cmn == t[lch].cmn )
t[x].lmx = t[lch].lmx;
else t[x].lmx = max(t[lch].mx, t[rch].lmx);
if( t[x].cmn == t[rch].cmn )
t[x].rmx = t[rch].rmx;
else t[x].rmx = max(t[rch].mx, t[lch].rmx);
if( t[x].cmn != t[rch].cmn ) {
t[x].ans = t[lch].ans;
t[x].cnt = t[lch].cnt + t[rch].r - t[rch].l + 1;
}
else if( t[x].cmn != t[lch].cmn ) {
t[x].ans = t[rch].ans;
t[x].cnt = t[rch].cnt + t[lch].r - t[lch].l + 1;
}
else {
t[x].ans = t[lch].ans + t[rch].ans;
t[x].cnt = t[lch].cnt + t[rch].cnt;
if( t[lch].rmx != -1 || t[rch].lmx != -1 )
t[x].ans += max(t[lch].rmx, t[rch].lmx), t[x].cnt++;
}
}
void pushdown(int x) {
if( t[x].tg ) {
t[lch].tg += t[x].tg, t[rch].tg += t[x].tg;
t[lch].cmn += t[x].tg, t[rch].cmn += t[x].tg;
t[x].tg = 0;
}
}
void build(int x, int l, int r) {
t[x].l = l, t[x].r = r, t[x].tg = 0, t[x].ans = 0;
if( l == r ) {
t[x].cmn = 0, t[x].lmx = t[x].rmx = -1, t[x].mx = max(arr[l], arr[l + 1]);
return ;
}
int mid = (l + r) >> 1;
build(lch, l, mid), build(rch, mid + 1, r);
pushup(x);
}
void update(int x, int p) {
if( t[x].l > p || t[x].r < p )
return ;
if( t[x].l == t[x].r ) {
t[x].mx = max(arr[p], arr[p + 1]);
return ;
}
pushdown(x);
update(lch, p);
update(rch, p);
pushup(x);
}
void modify(int x, int l, int r, int d) {
if( l <= t[x].l && t[x].r <= r ) {
t[x].tg += d, t[x].cmn += d;
return ;
}
if( l > t[x].r || r < t[x].l )
return ;
pushdown(x);
modify(lch, l, r, d);
modify(rch, l, r, d);
pushup(x);
}
}T;
set<int>st[MAXN + 5];
int a[MAXN + 5], n, q;
void update(int p) {T.update(1, p - 1), T.update(1, p);}
void remove(int x) {
int l = *st[a[x]].begin(), r = *st[a[x]].rbegin();
T.modify(1, l, r - 1, -1), arr[l] = 0, update(l);
st[a[x]].erase(x);
if( !st[a[x]].empty() ) {
int l = *st[a[x]].begin(), r = *st[a[x]].rbegin();
T.modify(1, l, r - 1, 1), arr[l] = st[a[x]].size(), update(l);
}
}
void add(int x) {
if( !st[a[x]].empty() ) {
int l = *st[a[x]].begin(), r = *st[a[x]].rbegin();
T.modify(1, l, r - 1, -1), arr[l] = 0, update(l);
}
st[a[x]].insert(x);
int l = *st[a[x]].begin(), r = *st[a[x]].rbegin();
T.modify(1, l, r - 1, 1), arr[l] = st[a[x]].size(), update(l);
}
int main() {
scanf("%d%d", &n, &q);
for(int i=1;i<=n;i++)
scanf("%d", &a[i]);
T.build(1, 0, n);
for(int i=1;i<=n;i++) add(i);
printf("%d
", T.t[1].cnt - T.t[1].ans);
for(int i=1;i<=q;i++) {
int it, xt; scanf("%d%d", &it, &xt);
remove(it), a[it] = xt, add(it);
printf("%d
", T.t[1].cnt - T.t[1].ans);
}
}
@details@
好神奇的线段树。。。