题目
分析
(O(n))求出mex(1,i)(1<=i<=n):
虽然0<=ai<=10^9,但只有n个数,所以mex一定小于等于n
for(long long j=1;j<=n;j++)
{
if(a[j]<=n)
bz[a[j]]=false;
for(long long k=top;k<=n;k++)
{
if(bz[k])
{
top=k;
ans+=top;
break;
}
}
}
显然mex是单调不下降的,
接着用线段树维护mex。
如果删掉a[i],从下一个mex比a[i]大的位置到下一个a[i]的位置之前的mex都会改变,都会变成a[i]。
所以用线段树维护区间最大mex以及区间mex和。
#include <cmath>
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <queue>
const long long maxlongint=2147483647;
const long long mo=1000000007;
const long long N=2010000;
using namespace std;
long long a[N],n,ans=0,next[N];
long long bz[N]={0};
struct ddx
{
long long sum,mxe,lazy;
}tree[N*4];
long long put(long long v,long long l,long long r,long long x,long long y)
{
if(l==r)
{
tree[v].mxe=tree[v].sum=y;
return 0;
}
long long mid=(l+r)/2;
if(x<=mid)
put(v*2,l,mid,x,y);
else
put(v*2+1,mid+1,r,x,y);
tree[v].sum=tree[v*2].sum+tree[v*2+1].sum;
tree[v].mxe=max(tree[v*2].mxe,tree[v*2+1].mxe);
}
long long findpos(long long v,long long l,long long r,long long x)
{
if(l==r)
{
if(tree[v].mxe>x)
return l;
else return maxlongint;
}
long long mid=(l+r)/2;
if(tree[v].lazy>=0)
{
tree[v*2].mxe=tree[v*2+1].mxe=tree[v*2].lazy=tree[v*2+1].lazy=tree[v].lazy;
tree[v*2].sum=tree[v].lazy*(mid-l+1);
tree[v*2+1].sum=tree[v].lazy*(r-mid);
tree[v].lazy=-1;
}
long long o=0;
if(tree[v*2].mxe>x)
o=findpos(v*2,l,mid,x);
else
o=findpos(v*2+1,mid+1,r,x);
tree[v].sum=tree[v*2].sum+tree[v*2+1].sum;
tree[v].mxe=max(tree[v*2].mxe,tree[v*2+1].mxe);
return o;
}
long long change(long long v,long long l,long long r,long long x,long long y,long long z)
{
if(l==x && y==r)
{
tree[v].sum=z*(r-l+1);
tree[v].mxe=z;
tree[v].lazy=z;
return 0;
}
long long mid=(l+r)/2;
if(tree[v].lazy>=0)
{
tree[v*2].mxe=tree[v*2+1].mxe=tree[v*2].lazy=tree[v*2+1].lazy=tree[v].lazy;
tree[v*2].sum=tree[v].lazy*(mid-l+1);
tree[v*2+1].sum=tree[v].lazy*(r-mid);
tree[v].lazy=-1;
}
if(y<=mid)
change(v*2,l,mid,x,y,z);
else
if(x>mid)
change(v*2+1,mid+1,r,x,y,z);
else
{
change(v*2,l,mid,x,mid,z);
change(v*2+1,mid+1,r,mid+1,y,z);
}
tree[v].sum=tree[v*2].sum+tree[v*2+1].sum;
tree[v].mxe=max(tree[v*2].mxe,tree[v*2+1].mxe);
}
long long find(long long v,long long l,long long r,long long x,long long y)
{
if(l==x && y==r)
{
ans+=tree[v].sum;
return 0;
}
long long mid=(l+r)/2;
if(tree[v].lazy>=0)
{
tree[v*2].mxe=tree[v*2+1].mxe=tree[v*2].lazy=tree[v*2+1].lazy=tree[v].lazy;
tree[v*2].sum=tree[v].lazy*(mid-l+1);
tree[v*2+1].sum=tree[v].lazy*(r-mid);
tree[v].lazy=-1;
}
if(y<=mid)
find(v*2,l,mid,x,y);
else
if(x>mid)
find(v*2+1,mid+1,r,x,y);
else
{
find(v*2,l,mid,x,mid);
find(v*2+1,mid+1,r,mid+1,y);
}
tree[v].sum=tree[v*2].sum+tree[v*2+1].sum;
tree[v].mxe=max(tree[v*2].mxe,tree[v*2+1].mxe);
}
int main()
{
scanf("%lld",&n);
for(long long i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
if(a[i]<=n)
{
next[bz[a[i]]]=i;
bz[a[i]]=i;
}
}
for(long long i=1;i<=n;i++)
if(!next[i])
next[i]=n+1;
memset(bz,true,sizeof(bz));
for(long long i=0;i<=N*4-1;i++)
tree[i].lazy=-1;
long long top=0;
ans=0;
for(long long j=1;j<=n;j++)
{
if(a[j]<=n)
bz[a[j]]=false;
for(long long k=top;k<=n;k++)
{
if(bz[k])
{
top=k;
ans+=top;
break;
}
}
put(1,1,n,j,top);
}
for(long long i=2;i<=n;i++)
{
if(a[i-1]<=n)
{
long long pos=findpos(1,1,n,a[i-1]);
if(next[i-1]-1>=pos)
change(1,1,n,pos,next[i-1]-1,a[i-1]);
}
find(1,1,n,i,n);
}
printf("%lld",ans);
}