不错的思维题,犀利的线段树。解题思路百度很多。。我那蹩脚的表达能力,就不误导大家了。
#include<stdio.h> #include<string.h> #include<algorithm> #include<map> #define ll long long #define lson l , m , rt << 1 #define rson m + 1 , r , rt << 1 | 1 using namespace std ; const int maxn = 222222 ; map<int,int> mp ; int col[maxn<<2] , b[maxn] , a[maxn] , mx[maxn<<2] ; ll ans = 0 , sum[maxn<<2] ; int nxt[maxn] ; inline void push_up ( int rt ) { sum[rt] = sum[rt<<1] + sum[rt<<1|1] ; mx[rt] = mx[rt<<1|1] ; } void push_down ( int rt , int m ) { if ( col[rt] != -1 ) { int ls = rt << 1 , rs = rt << 1 | 1 ; mx[ls] = mx[rs] = col[ls] = col[rs] = col[rt] ; sum[ls] = (ll) col[ls] * ( m - ( m >> 1 ) ) ; sum[rs] = (ll) col[rs] * ( m >> 1 ) ; col[rt] = -1 ; } } void build ( int l , int r , int rt ) { col[rt] = -1 ; if ( l == r ) { mx[rt] = sum[rt] = b[l] ; return ; } int m = ( l + r ) >> 1 ; build ( lson ) ; build ( rson ) ; push_up ( rt ) ; } int find ( int l , int r , int rt , int v ) { if ( l == r ) return l ; push_down ( rt , r - l + 1 ) ; int m = ( l + r ) >> 1 ; if ( mx[rt<<1] >= v ) return find ( lson , v ) ; else return find ( rson , v ) ; } void update ( int a , int b , int c , int l , int r , int rt ) { if ( a <= l && r <= b ) { col[rt] = c ; sum[rt] = (ll) c * ( r - l + 1 ) ; mx[rt] = c ; return ; } push_down ( rt , r - l + 1 ) ; int m = ( l + r ) >> 1 ; if ( a <= m ) update ( a , b , c , lson ) ; if ( m < b ) update ( a , b , c , rson ) ; push_up ( rt ) ; } void update ( int a , int l , int r , int rt ) { if ( l == r ) { mx[rt] = -1 ; sum[rt] = 0 ; return ; } push_down ( rt , r - l + 1 ) ; int m = ( l + r ) >> 1 ; if ( a <= m ) update ( a , lson ) ; else update ( a , rson ) ; push_up ( rt ) ; } void solve ( int n ) { int i , j , k , l , r ; for ( i = 1 ; i <= n ; i ++ ) { ans += sum[1] ; update ( i , 1 , n , 1 ) ; r = nxt[i] ; if ( mx[1] < a[i] ) continue ; l = find ( 1 , n , 1 , a[i] ) ; if ( l <= r ) update ( l , r , a[i] , 1 , n , 1 ) ; } } int main () { int n , i , j , k ; while ( scanf ( "%d" , &n ) != EOF ) { if ( n == 0 ) break ; ans = 0 ; mp.clear () ; for ( i = 1 ; i <= n ; i ++ ) scanf ( "%d" , &a[i] ) ; mp[a[1]] = 1 ; if ( a[1] == 0 ) k = b[1] = 1 ; else k = b[1] = 0 ; for ( i = 2 ; i <= n ; i ++ ) { mp[a[i]] = 1 ; while ( mp[k] ) k ++ ; b[i] = k ; } mp.clear () ; for ( i = n ; i>= 1 ; i -- ) { if ( !mp[a[i]] ) nxt[i] = n ; else nxt[i] = mp[a[i]] - 1 ; mp[a[i]] = i ; } build ( 1 , n , 1 ) ; solve ( n ) ; printf ( "%I64d " , ans ) ; } return 0 ; }