题目大意
给定一个长度为(n)的序列(a_1,a_2dots,a_n)。保证(forall i:1leq a_ileq n)。请你求出,序列里有多少三元组((i,j,k)),满足(a[i,j])里的所有数,都在(a[j+1,k])里出现过;且(a[j+1,k])里所有数,都在(a[i,j])里出现过。
(nleq 2 imes 10^5)。
本题题解
枚举(k)。对每个(j),维护使三元组((i,j,k))合法的最小的和最大的(i),分别记为( ext{mini}[j], ext{maxi}[j])。那么,当前(k)的三元组数量就是:(sum_{j=1}^{k-1}( ext{maxi}[j]- ext{mini}[j]+1))。考虑分别计算( ext{maxi})的和和( ext{mini})的和。
记每个位置(t)上的数上一次和下一次出现的位置分别为( ext{pre}[t])和( ext{nxt}[t]),特别地,如果前面/后面没有相同的数,则( ext{pre}[t]=0)或( ext{nxt}[t]=n+1)。那么,我们发现,三元组((i,j,k))合法的充分必要条件是:(max_{t=i}^{j}( ext{nxt}[t])leq k),且(min_{t=j+1}^{k}( ext{pre}[t])geq i)。
由此可知,( ext{maxi}[j])就是满足(min_{t=j+1}^{k}( ext{pre}[t])geq i)的最大的(i),( ext{mini}[j])就是满足(max_{t=i}^{j}( ext{nxt[}t])leq k)的最小的(i)。
( ext{maxi})比较好维护,他就等于(min_{t=j+1}^{k}( ext{pre}[j]))。当从(k-1)变到(k)时,我们让所有(jin[1,k-1])的( ext{maxi}[j])对( ext{pre}[k])取(min)即可。
考虑( ext{mini})。我们称( ext{nxt}[t]>k)的位置为不合法的,其他位置为合法的。那么对于每个(j),( ext{mini}[j])就相当于(j)前面、最靠近(j)的那个不合法的位置(+1)。特别地,如果(j)本身就不合法,我们认为( ext{mini}[j]=j+1)。从(k-1)变到(k),会使得所有( ext{nxt}[t]=k)的位置,从不合法变成合法。相当于把两段( ext{mini})的区间“合并”起来(令后一段区间的值等于前一段区间的值)。而( ext{nxt}[t]=k)的位置最多只有一个:就是( ext{pre}[k])。所以每次对一段区间执行区间覆盖(或者区间取(min))即可(事实上因为( ext{maxi})要支持的是区间取(min),所以都用区间取(min)反而更好写)。
还有一个要注意的点是,我们要始终保证,( ext{mini}[j]leq ext{maxi}[j]+1),所以对( ext{maxi})取(min)的时候,要对( ext{mini})做一样的操作。
总结来说,需要支持区间对一个数取(min),区间求和,可以用吉老师线段树实现。另外,我们还要对一个位置求它前面、最靠近它的不合法的位置,同时要支持单点修改(把某个位置从不合法变为合法),这个可以用线段上二分实现。
时间复杂度(O(nlog n))。
参考代码:
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
namespace Fread{
const int MAXN=1<<20;
char buf[MAXN],*S,*T;
inline char getchar(){
if(S==T){
T=(S=buf)+fread(buf,1,MAXN,stdin);
if(S==T)return EOF;
}
return *S++;
}
}
#ifdef ONLINE_JUDGE
#define getchar Fread::getchar
#endif
inline int read(){
int f=1,x=0;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
inline ll readll(){
ll f=1,x=0;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
/* ------ by:duyi ------ */ // dysyn1314
const int MAXN=2e5;
int n;
/*
struct Baoli{
int a[MAXN+5],val[MAXN+5],val2[MAXN+5];
int get_nxt0(int p){
for(int i=p;i<=n+1;++i)if(a[i]==0)return i;
throw;
}
int get_pre0(int p){
for(int i=p;i>=0;--i)if(a[i]==0)return i;
throw;
}
void set1(int p){
a[p]=1;
}
void init(){
for(int i=1;i<=n;++i)val[i]=val2[i]=i;
}
void modify_min_mxi(int l,int r,int x){
for(int i=l;i<=r;++i)val[i]=min(val[i],x);
}
void modify_min_mni(int l,int r,int x){
for(int i=l;i<=r;++i)val2[i]=min(val2[i],x);
}
int get_sum_mxi(){
int res=0;
for(int i=1;i<=n;++i)res+=val[i]*a[i];
return res;
}
int get_sum_mni(){
int res=0;
for(int i=1;i<=n;++i)res+=val2[i]*a[i];
return res;
}
}T;
*/
class SegmentTree{
private:
int sz[MAXN*4+5],mx[2][MAXN*4+5],se[2][MAXN*4+5],ct[2][MAXN*4+5];
ll sum[2][MAXN*4+5];
void _pu(int p,int *mx,int *se,int *ct,ll *sum){
sum[p]=sum[p<<1]+sum[p<<1|1];
if(mx[p<<1]>mx[p<<1|1]){
mx[p]=mx[p<<1];
se[p]=max(se[p<<1],mx[p<<1|1]);
ct[p]=ct[p<<1];
}
else if(mx[p<<1]<mx[p<<1|1]){
mx[p]=mx[p<<1|1];
se[p]=max(mx[p<<1],se[p<<1|1]);
ct[p]=ct[p<<1|1];
}
else{
mx[p]=mx[p<<1];
se[p]=max(se[p<<1],se[p<<1|1]);
ct[p]=ct[p<<1]+ct[p<<1|1];
}
}
void push_up(int p){
sz[p]=sz[p<<1]+sz[p<<1|1];
_pu(p,mx[0],se[0],ct[0],sum[0]);
_pu(p,mx[1],se[1],ct[1],sum[1]);
}
void _pd(int p,int *mx,int *ct,ll *sum){
if(mx[p]<mx[p<<1]){
sum[p<<1]-=(ll)ct[p<<1]*(mx[p<<1]-mx[p]);
mx[p<<1]=mx[p];
}
if(mx[p]<mx[p<<1|1]){
sum[p<<1|1]-=(ll)ct[p<<1|1]*(mx[p<<1|1]-mx[p]);
mx[p<<1|1]=mx[p];
}
}
void push_down(int p){
_pd(p,mx[0],ct[0],sum[0]);
_pd(p,mx[1],ct[1],sum[1]);
}
void build(int p,int l,int r){
if(l==r){
mx[0][p]=mx[1][p]=l;
se[0][p]=se[1][p]=-1;
return;
}
int mid=(l+r)>>1;
build(p<<1,l,mid);
build(p<<1|1,mid+1,r);
push_up(p);
}
void modify1(int p,int l,int r,int pos){
if(l==r){
sz[p]=1;
ct[0][p]=ct[1][p]=1;
sum[0][p]=mx[0][p];
sum[1][p]=mx[1][p];
return;
}
push_down(p);
int mid=(l+r)>>1;
if(pos<=mid)modify1(p<<1,l,mid,pos);
else modify1(p<<1|1,mid+1,r,pos);
push_up(p);
}
int __first0(int p,int l,int r){
if(l==r){assert(sz[p]==0);return l;}
push_down(p);
int mid=(l+r)>>1;
if(sz[p<<1]<mid-l+1)return __first0(p<<1,l,mid);
else return __first0(p<<1|1,mid+1,r);
}
int _nxt0(int p,int l,int r,int ql,int qr){
if(ql<=l && qr>=r){
if(sz[p]==r-l+1)return n+1;
else return __first0(p,l,r);
}
push_down(p);
int mid=(l+r)>>1,res=n+1;
if(ql<=mid&&sz[p<<1]<mid-l+1)res=_nxt0(p<<1,l,mid,ql,qr);
if(res!=n+1)return res;
if(qr>mid&&sz[p<<1|1]<r-mid)return _nxt0(p<<1|1,mid+1,r,ql,qr);
else return n+1;
}
int __last0(int p,int l,int r){
if(l==r){assert(sz[p]==0);return l;}
push_down(p);
int mid=(l+r)>>1;
if(sz[p<<1|1]<r-mid)return __last0(p<<1|1,mid+1,r);
else return __last0(p<<1,l,mid);
}
int _pre0(int p,int l,int r,int ql,int qr){
if(ql<=l && qr>=r){
if(sz[p]==r-l+1)return 0;
else return __last0(p,l,r);
}
push_down(p);
int mid=(l+r)>>1,res=0;
if(qr>mid&&sz[p<<1|1]<r-mid)res=_pre0(p<<1|1,mid+1,r,ql,qr);
if(res)return res;
if(ql<=mid&&sz[p<<1]<mid-l+1)return _pre0(p<<1,l,mid,ql,qr);
else return 0;
}
void modify2(int p,int l,int r,int ql,int qr,int x,int t){
//区间对x取min
if(x>=mx[t][p])return;
if(ql<=l && qr>=r && se[t][p]<x){
sum[t][p]-=(ll)ct[t][p]*(mx[t][p]-x);
mx[t][p]=x;
return;
}
push_down(p);
int mid=(l+r)>>1;
if(ql<=mid)modify2(p<<1,l,mid,ql,qr,x,t);
if(qr>mid)modify2(p<<1|1,mid+1,r,ql,qr,x,t);
push_up(p);
}
public:
//mxi tree0
//mni tree1
void set1(int p){modify1(1,1,n,p);}
int get_nxt0(int p){
if(p>n)return n+1;
if(p<1)return 0;
return _nxt0(1,1,n,p,n);
}
int get_pre0(int p){
if(p>n)return n+1;
if(p<1)return 0;
return _pre0(1,1,n,1,p);
}
void modify_min_mxi(int l,int r,int x){
if(l>r)return;
modify2(1,1,n,l,r,x,0);
}
void modify_min_mni(int l,int r,int x){
if(l>r)return;
modify2(1,1,n,l,r,x,1);
}
ll get_sum_mxi(){return sum[0][1];}
ll get_sum_mni(){return sum[1][1];}
void init(){build(1,1,n);}
}T;
int a[MAXN+5],nxt[MAXN+5],pre[MAXN+5],pos[MAXN+5];
int main(){
n=read();
for(int i=1;i<=n;++i){a[i]=read();pre[i]=pos[a[i]];pos[a[i]]=i;}
for(int i=1;i<=n;++i)pos[i]=n+1;
for(int i=n;i>=1;--i){nxt[i]=pos[a[i]];pos[a[i]]=i;}
T.init();
ll ans=0;
for(int k=1;k<=n;++k){
if(pre[k]){
int x=T.get_nxt0(pre[k]+1)-1;
//cout<<"* "<<x<<" "<<T.get_pre0(pre[k]-1)<<endl;
T.modify_min_mni(pre[k],x,T.get_pre0(pre[k]-1));
T.set1(pre[k]);
}
T.modify_min_mxi(1,k-1,pre[k]);
T.modify_min_mni(1,k-1,pre[k]);
ans+=T.get_sum_mxi()-T.get_sum_mni();
}
cout<<ans<<endl;
return 0;
}