题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4747
题意:有一组序列a[i](1<=i<=N), 让你求所有的mex(l,r), mex(l,r)表示区间[l,r]中最小的未在序列中出现的非负整数。
思路:冥思苦想半天无想法,白做了那么多线段树。 很明显的维护区间问题,容易想到线段树,比较难想到操作。 枚举一个序列的所mex(1,i),mex(2,i)……可以发现序列mex(x,i)是一个单调递增序列,我们需要求得就是所有以x开头的序列和,mex(x,i)(x<=i<=n)。这点确定了就好办了,记录每个位置的数后面最早重复出现的位置next[x],如果无则为设n+1。那么我们就可以发现,当第x个数所对应的序列 mex(x,i)(x<=i<=n)所对应的序列求完之后,删去此位置的数,位置x+1~next[x]-1序列中mex值大于a[x]的都改为a[x],因为a[x]没有了,下一个a[x]还未出现,所以可以证明这样做是正确的。从1到n扫一遍亦求出了所有的mex()。
基本上所有的操作都可以用到线段树。开始没有想到一点的是如何找序列中刚好大于a[x]的位置,并且此位置到next[x]-1赋值为a[x],怎么都没想到log(n)的操作,其实这里依然可以用到线段树,因为序列是单调递增的,另开一个区间维护序列mavv[u]表示区间中最大的mex值,随着询问以及其他操作成段更新即可。
1 #include <iostream> 2 #include <cstdio> 3 #include <cmath> 4 #include <map> 5 #include <algorithm> 6 #include <cstring> 7 #include <sstream> 8 using namespace std; 9 10 #define lz 2*u,l,mid 11 #define rz 2*u+1,mid+1,r 12 typedef long long lld; 13 const int maxn=222222; 14 int a[maxn], b[maxn], next[maxn]; 15 lld sum[4*maxn], mavv[4*maxn], flag[4*maxn]; 16 map<int,int>mp; 17 18 void push_up(int u, int l, int r) 19 { 20 sum[u]=sum[2*u]+sum[2*u+1]; 21 mavv[u]=mavv[2*u+1]; 22 } 23 24 void push_down(int u, int l, int r) 25 { 26 int mid=(l+r)>>1; 27 if(flag[u]!=-1) 28 { 29 flag[2*u]=flag[2*u+1]=flag[u]; 30 mavv[2*u]=mavv[2*u+1]=flag[u]; 31 sum[2*u]=(lld)(mid-l+1)*flag[u]; 32 sum[2*u+1]=(lld)(r-mid)*flag[u]; 33 flag[u]=-1; 34 } 35 } 36 37 void build(int u, int l, int r) 38 { 39 flag[u]=-1; 40 int mid=(l+r)>>1; 41 if(l==r) 42 { 43 sum[u]=mavv[u]=b[l]; 44 return ; 45 } 46 build(lz); 47 build(rz); 48 push_up(u,l,r); 49 } 50 51 void Update(int u, int l, int r, int tl, int tr, int val) 52 { 53 if(tl>tr) return ; 54 if(tl<=l&&r<=tr) 55 { 56 mavv[u]=val; 57 sum[u]=(lld)val*(r-l+1); 58 flag[u]=val; 59 return ; 60 } 61 push_down(u,l,r); 62 int mid=(l+r)>>1; 63 if(tr<=mid) Update(lz,tl,tr,val); 64 else if(tl>mid) Update(rz,tl,tr,val); 65 else 66 { 67 Update(lz,tl,mid,val); 68 Update(rz,mid+1,tr,val); 69 } 70 push_up(u,l,r); 71 } 72 73 int find(int u, int l, int r, int tmp) 74 { 75 if(l==r) return l; 76 push_down(u,l,r); 77 int mid=(l+r)>>1; 78 if(mavv[2*u]>tmp) return find(lz,tmp); 79 else return find(rz,tmp); 80 } 81 82 int main() 83 { 84 int n; 85 while(cin >> n,n) 86 { 87 for(int i=1; i<=n; i++) scanf("%d",a+i); 88 mp.clear(); 89 for(int i=n; i>=1; i--) 90 { 91 if(mp[ a[i] ]) next[i]=mp[ a[i] ]; 92 else next[i]=n+1; 93 mp[ a[i] ]=i; 94 } 95 mp.clear(); 96 int x=0; 97 for(int i=1; i<=n; i++) 98 { 99 mp[ a[i] ]=1; 100 while(mp[x]) ++x; 101 b[i]=x; 102 } 103 build(1,1,n); 104 lld ans=0; 105 for(int i=1; i<=n; i++) 106 { 107 ans+=sum[1]; 108 if(mavv[1]>a[i]) 109 { 110 int id=find(1,1,n,a[i]); 111 Update(1,1,n,max(id,i+1),next[i]-1,a[i]); 112 } 113 Update(1,1,n,i,i,0); 114 } 115 cout << ans <<endl; 116 } 117 }