为什么最劣解选手还有脸来写题解啊……
这题首先考虑每个数 (x(1leq xleq n)) 分开算贡献。设 (B_1,B_2,cdots,B_{m_x}(A_{B_i}=x))。我们发现对于每个数分开计算是,(A_1,A_2,cdots,A_n) 只有两种数,一种是等于 (x),另一种是不等于 (x)。考虑一个答案区间 (l' sim r'),一定是由一个我们称为“根区间” (lsim r(l'leq l,rleq r')) 的东西扩展出来的。所谓的“根区间”就是指 (A_l=A_r=x),即区间左右两个端点都是这个数,于是就有了一个 (O(n^2)) 的暴力做法。我们发现“根区间”能够衍生、拓展出来的区间数,与“根区间”内等于 (x) 的数的个数 (-) 不等于 (x) 的数的个数以及 (l) 左边连续不为 (x) 的个数和 (r) 右边连续不为 (r) 的个数有关(不然根区间就不是 (lsim r) 了)。
既然无法直接枚举根区间,我们就要另外想办法。其实提到计算区间贡献这个东西,显然会想到分治。于是就是愉快的讨论:对于一个区间 (lsim r),中点为 (mid),“根区间” (l'sim r'(lleq l'leq r'leq r))一共分为 (3) 种:
- (l'leq r'leq mid)
- (l'leq mid < r')
- (mid< l' leq r')
对于第一种和第三种分别属于 (lsim mid) 和 (mid+1sim r) 区间内的,如实我们就只要想办法计算跨越 (mid) 的根区间即第二种 (l'leq midleq r'leq r) 的贡献了。我们考虑用一棵权值线段树维护对于每个 (y) 有多个个 (l') 满足 (l'sim mid) 中等于 (x) 的个数 (-) 不等于 (x) 的个数 (=) (y)。考虑 (B_{i}+1sim B_i) 这一段等于 (x) 的个数 (-) 不等于 (x) 的个数形成了一个公差为 (1) 的等差数列。我们先看一下我们会有什么形式的询问:对于 (B_{i}sim B_{i+1}-1) 这一段中等于 (x) 的个数 (-) 不等于 (x) 的个数形成了一个公差为 (-1) 的等差数列,对于等于 (x) 的个数 (-) 不等于 (x) 的个数 (=) (y) 的数,对于答案的贡献为左边段中等于 (x) 的个数 (-) 不等于 (x) 的个数可以为 (-y+1sim n),于是对答案的贡献就是左区间中等于 (x) 的个数 (-) 不等于 (x) 的个数为 (-y+1sim n) 的方案数和。比如说 (n=6,A={1,0,0,0,0,1}) 对于 (x=1) ,(B={1,6}),(l=1,r=2,mid=1),(2sim 3) 等于 (x) 的个数 (-) 不等于 (x) 的个数形成了一个公差为 (1) 的等差数列 (-2,-1),当然还有 (1) 等于 (x) 的个数 (-) 不等于 (x) 的个数 (= -1)。右区间 (4sim 5) 中等于 (x) 的个数 (-) 不等于 (x) 的个数形成了一个公差为 (-1) 的等差数列 (-1,-2),需要查询的是左边等于 (x) 的个数 (-) 不等于 (x) 的个数为 (-(-1)+1=2 sim n) 和 (-(-2)+1=3 sim n),当然还有最后的一项 (-1) 要查询。我们感兴趣的是 (4sim5) 这一段:观察项的系数值为 (2) 的数为 (1),值为 (3sim n) 的数为 (2)。发现一定是一个等差数列 (+) 一个所有数相同的数列,分别维护即可,时间复杂度 (O(nlog_ n^2))/kk。
实现:
等于 (x) 的个数 (-) 不等于 (x) 的个数的范围是 (-n sim n) 的,但是真正有用的区间没有那么多,所以考虑动态开点。
每一次都要注意将树清空或将清除之前的所有修改操作。
注:细节很多,比较难写难调/kk。
代码:
#include<bits/stdc++.h>
// #define ls (num<<1)
// #define rs (num<<1|1)
#define log(a) cerr<<" 33[32m[DEBUG] "<<#a<<'='<<(a)<<" @ line "<<__LINE__<<" 33[0m"<<endl
#define LL long long
#define SZ(x) ((int)x.size()-1)
#define ms(a,b) memset(a,b,sizeof a)
#define F(i,a,b) for(int i=(a);i<=(b);++i)
#define DF(i,a,b) for(int i=(a);i>=(b);--i)
using namespace std;
inline int read(){char ch=getchar(); int w=1,c=0;
for(;!isdigit(ch);ch=getchar()) if (ch=='-') w=-1;
for(;isdigit(ch);ch=getchar()) c=(c<<1)+(c<<3)+(ch^48);
return w*c;
}
const int N=5e5+10;
vector<int>v[N];
int cnt,n,ls[N<<2],rs[N<<2],rt;
bool cl[N<<2];
LL ans,tag[N<<2],sum[N<<2],sum2[N<<2],len[N<<2],len2[N<<2];
inline void down(int&num,int x,int l,int r){
if(!num){
num=++cnt;
len2[num]=r-l+1;
len[num]=len2[num]*(l+r)/2;
}
tag[num]+=x;
sum[num]+=len[num]*x;
sum2[num]+=len2[num]*x;
}
inline void down2(int&num,int l,int r){
if(!num){
num=++cnt;
len2[num]=r-l+1;
len[num]=len2[num]*(l+r)/2;
}
cl[num]=true;
sum[num]=sum2[num]=tag[num]=0;
}
inline void pushdown(int num,int l,int r){
if(cl[num]){
int mid=(l+r)>>1;
down2(ls[num],l,mid);
down2(rs[num],mid+1,r);
cl[num]=false;
}
if(tag[num]){
int mid=(l+r)>>1;
down(ls[num],tag[num],l,mid);
down(rs[num],tag[num],mid+1,r);
tag[num]=0;
}
}
inline void pushup(int num){
sum[num]=sum[ls[num]]+sum[rs[num]];
sum2[num]=sum2[ls[num]]+sum2[rs[num]];
}
inline void change(int&num,int l,int r,int L,int R,int x){
if(!num){
num=++cnt;
len2[num]=r-l+1;
len[num]=len2[num]*(l+r)/2;
}
if(L<=l&&r<=R){
down(num,x,l,r);
return;
}
pushdown(num,l,r);
int mid=(l+r)>>1;
if(mid>=L)change(ls[num],l,mid,L,R,x);
if(mid<R)change(rs[num],mid+1,r,L,R,x);
pushup(num);
}
inline int query(int&num,int l,int r,int L,int R){
if(!num){
num=++cnt;
len2[num]=r-l+1;
len[num]=len2[num]*(l+r)/2;
return 0;
}
if(L<=l&&r<=R)return sum[num];
pushdown(num,l,r);
int mid=(l+r)>>1;
if(mid<L)return query(rs[num],mid+1,r,L,R);
if(mid>=R)return query(ls[num],l,mid,L,R);
return query(ls[num],l,mid,L,R)+query(rs[num],mid+1,r,L,R);
}
inline int query2(int&num,int l,int r,int L,int R){
if(!num){
num=++cnt;
len2[num]=r-l+1;
len[num]=len2[num]*(l+r)/2;
return 0;
}
if(L<=l&&r<=R)return sum2[num];
pushdown(num,l,r);
int mid=(l+r)>>1;
if(mid<L)return query2(rs[num],mid+1,r,L,R);
if(mid>=R)return query2(ls[num],l,mid,L,R);
return query2(ls[num],l,mid,L,R)+query2(rs[num],mid+1,r,L,R);
}
inline void solve(int num,int l,int r){
if(l==r){
ans++;
return;
}int mid=(l+r)>>1;
solve(num,l,mid);solve(num,mid+1,r);
int s=0;
down2(rt,-n,n);
DF(i,mid,l){
int r=s-(i==mid?1:0);
s-=v[num][i+1]-v[num][i]-1;
int l=s;
s++;
if(l<=r)change(rt,-n,n,l,r,1);
}
if(!l)change(rt,-n,n,s-(v[num][l]-(l?v[num][l-1]:0)-1),s,1);
else change(rt,-n,n,s,s,1);
s=0;
F(i,mid+1,r){
s++;
int l=s;
s-=(i==SZ(v[num])?n+1:v[num][i+1])-v[num][i]-1;
int r=s;
l=(-l)+1;
r=(-r)+1;
ans+=query(rt,-n,n,l,r);
if(l!=1)ans-=(l-1)*query2(rt,-n,n,l,r);
if(r!=n)ans+=(r-l+1)*query2(rt,-n,n,r+1,n);
}
}
signed main(){
n=read();int type;type=read();
F(i,1,n){
int x=read();
v[x].push_back(i);
}
F(i,0,n-1)
if(v[i].size())solve(i,0,SZ(v[i]));
cout<<ans;
return 0;
}