思路:
线段树
先求出mex(1,1), mex(1, 2) , mex(1,3),...,mex(1,n)(单调上升),先将这些mex放进线段树里求和
然后再求出next[i]表示下一次出现a[i] 的位置
然后从前往后不停的删数,对于一个数a[i],我们删掉他的影响是:l为mex大于a[i]的位置,r 为next[i],l 到 r-1 之间的 mex都变为 a[i]
然后这个线段树只需要维护区间最大值(方便查找第一个大于a[i]的位置)和区间和就可以了
代码:
#include<bits/stdc++.h> using namespace std; #define fi first #define se second #define pi acos(-1.0) #define LL long long //#define mp make_pair #define pb push_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pii pair<int, int> #define mem(a, b) memset(a, b, sizeof(a)) #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); #define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout); //head const int N = 2e5 + 5; int a[N], nxt[N], mx[N<<2], lazy[N<<2], mex[N]; LL sum[N<<2]; map<int, int>mp; void push_up(int rt) { sum[rt] = sum[rt<<1] + sum[rt<<1|1]; mx[rt] = max(mx[rt<<1], mx[rt<<1|1]); } void push_down(int rt, int len) { sum[rt<<1] = 1LL * lazy[rt] * (len - (len >> 1)); mx[rt<<1] = lazy[rt]; lazy[rt<<1] = lazy[rt]; sum[rt<<1|1] = 1LL * lazy[rt] * (len >> 1); mx[rt<<1|1] = lazy[rt]; lazy[rt<<1|1] = lazy[rt]; lazy[rt] = 0; } void build(int rt, int l, int r) { if(l == r) { mx[rt] = sum[rt] = mex[l]; return ; } int m = (l+r) >> 1; build(ls); build(rs); push_up(rt); } void update(int x, int L, int R, int rt, int l, int r) { if(L <= l && r <= R) { mx[rt] = x; sum[rt] = 1LL * (r-l+1) * x; lazy[rt] = x; return ; } if(lazy[rt]) push_down(rt, r-l+1); int m = (l+r) >> 1; if(L <= m) update(x, L, R, ls); if(R > m) update(x, L, R, rs); push_up(rt); } LL query(int L, int R, int rt, int l, int r) { if(L <= l && r <= R) return sum[rt]; if(lazy[rt]) push_down(rt, r-l+1); int m = (l+r) >> 1; LL ans = 0; if(L <= m) ans += query(L, R, ls); if(R > m) ans += query(L, R, rs); push_up(rt); return ans; } int Find(int x, int rt, int l, int r) { if(l == r) return l; int m = (l+r) >> 1; if(lazy[rt]) push_down(rt, r-l+1); if(mx[rt<<1] > x) return Find(x, ls); else return Find(x, rs); } int main() { int n; while(~scanf("%d", &n) && n) { mem(lazy, 0); build(1, 1, n); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); mp.clear(); int tmp = 0; for (int i = 1; i <= n; i++) { mp[a[i]]++; while(mp.find(tmp) != mp.end()) tmp++; mex[i] = tmp; } build(1, 1, n); mp.clear(); for (int i = n; i >= 1; i--) { if(mp.find(a[i]) == mp.end()) nxt[i] = n+1; else nxt[i] = mp[a[i]]; mp[a[i]] = i; } LL ans = 0; for (int i = 1; i <= n; i++) { ans += query(1, n, 1, 1, n); if(mx[1] <= a[i]) continue; int l = Find(a[i], 1, 1, n); int r = nxt[i]; if(l < r) update(a[i], l, r-1, 1, 1, n); } printf("%lld ", ans); } return 0; }