比赛的时候感觉是线段树,直接求,超时了。
看了题解,才知道是有点规律的,要先固定端点,求出所有以该端点为左端点的区间,然后不断移动左端点,直到求出所有区间。
用nxt[i]存储下一个出现a[i]的下标。
先求出所有的mex(1,i) 再删除a[1],则在区间(2,nxt[1]-1)上 mex(2,i)>a[1]的值全部改为a[1],
因为mex(2,i)是递增的,所以可以求出第一个大于a[1]的mex(2,i)的下标,然后进行区间更新。
以此类推,从左到右删除所有元素,就可以得出答案。
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 #include<map> 5 using namespace std; 6 const int maxn=200000+5; 7 int a[maxn]; 8 int vis[maxn]; 9 int mex[maxn]; 10 int nxt[maxn]; 11 long long sum[maxn<<2]; 12 int setv[maxn<<2]; 13 int mx[maxn<<2]; 14 int cur; 15 void push_up(int rt) 16 { 17 sum[rt]=sum[rt<<1]+sum[rt<<1|1]; 18 mx[rt]=max(mx[rt<<1],mx[rt<<1|1]); 19 } 20 void push_down(int rt,int len) 21 { 22 if(setv[rt]!=-1) 23 { 24 sum[rt<<1]=setv[rt]*(len-len/2); 25 mx[rt<<1]=setv[rt]; 26 sum[rt<<1|1]=setv[rt]*(len/2); 27 mx[rt<<1|1]=setv[rt]; 28 setv[rt<<1]=setv[rt<<1|1]=setv[rt]; 29 setv[rt]=-1; 30 } 31 } 32 void build(int l,int r,int rt) 33 { 34 if(l==r) 35 { 36 sum[rt]=mex[cur]; 37 mx[rt]=mex[cur++]; 38 return ; 39 } 40 int m=(l+r)>>1; 41 build(l,m,rt<<1); 42 build(m+1,r,rt<<1|1); 43 push_up(rt); 44 } 45 void update(int L,int R,int v,int l,int r,int rt) 46 { 47 if(L<=l&&r<=R) 48 { 49 sum[rt]=v*(r-l+1); 50 mx[rt]=v; 51 setv[rt]=v; 52 return ; 53 } 54 push_down(rt,r-l+1); 55 int m=(l+r)>>1; 56 if(L<=m) update(L,R,v,l,m,rt<<1); 57 if(R>m) update(L,R,v,m+1,r,rt<<1|1); 58 push_up(rt); 59 } 60 int get_index(int L,int R,int v,int l,int r,int rt) 61 { 62 if(l==r) 63 { 64 return l; 65 } 66 push_down(rt,r-l+1); 67 int m=(l+r)>>1; 68 if(mx[rt<<1]>v&&L<=m) return get_index(L,R,v,l,m,rt<<1); 69 else if(mx[rt<<1|1]>v&&R>m) return get_index(L,R,v,m+1,r,rt<<1|1); 70 return R+1; 71 } 72 long long query(int L,int R,int l,int r,int rt) 73 { 74 if(L<=l&&r<=R) 75 { 76 return sum[rt]; 77 } 78 push_down(rt,r-l+1); 79 int m=(l+r)>>1; 80 long long ret=0; 81 if(L<=m) ret+=query(L,R,l,m,rt<<1); 82 if(R>m) ret+=query(L,R,m+1,r,rt<<1|1); 83 return ret; 84 } 85 int main() 86 { 87 int n; 88 while(scanf("%d",&n)&&n) 89 { 90 for(int i=1;i<=n;i++) 91 scanf("%d",&a[i]); 92 memset(vis,0,sizeof(vis)); 93 mex[1]=0; 94 if(a[1]<n) vis[a[1]]=1; 95 while(vis[mex[1]]) mex[1]++; 96 for(int i=2;i<=n;i++) 97 { 98 if(a[i]<n) vis[a[i]]=1; 99 mex[i]=mex[i-1]; 100 if(a[i]==mex[i-1]) 101 { 102 while(vis[mex[i]]) mex[i]++; 103 } 104 } 105 map<int,int> mp; 106 for(int i=n;i>=1;i--) 107 { 108 if(mp.find(a[i])==mp.end()) nxt[i]=n+1; 109 else nxt[i]=mp[a[i]]; 110 mp[a[i]]=i; 111 } 112 cur=1; 113 memset(setv,-1,sizeof(setv)); 114 memset(sum,0,sizeof(sum)); 115 memset(mx,0,sizeof(mx)); 116 build(1,n,1); 117 long long ans=0; 118 for(int i=1;i<=n;i++) 119 { 120 ans+=query(i,n,1,n,1); 121 int s_index=get_index(i+1,nxt[i]-1,a[i],1,n,1); 122 int e_index=nxt[i]-1; 123 if(s_index<=e_index) update(s_index,e_index,a[i],1,n,1); 124 } 125 printf("%I64d ",ans); 126 } 127 return 0; 128 }