Description
给定一个长度为 (n (n le 2 imes 10^5)) 的数列 (a_i),求满足区间 (OR) 大于区间最大值的区间的个数。
Solution
注意到区间或永远不会小于区间最大值
于是转化为求区间或等于区间最大值的区间个数
枚举最大值元素 (a_p),结合预处理得到其控制的(下标)区间 ([l,r])
此时我们需要得到它向左/向右能扩展的最大长度
考虑到单调性,利用 ST 表即可
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1000005;
const int lgN = 20;
const int dbg = 0;
int n,a[N],lb[N],rb[N],lg2[N];
struct St_table
{
int st[N][lgN];
void build(int n,int *a)
{
for(int i=1;i<=n;i++)
{
st[i][0]=a[i];
}
for(int j=1;j<lgN;j++)
{
for(int i=1;i<=n;i++)
{
st[i][j]=st[i][j-1]|st[i+(1<<(j-1))][j-1];
}
}
}
int query(int l,int r)
{
int p=lg2[r-l+1];
int i=l, j=r;
return st[i][p] | st[j-(1<<p)+1][p];
}
} st;
void build_right_bound()
{
stack <int> s;
memset(rb,0,sizeof rb);
for(int i=1;i<=n;i++)
{
while(s.size() && a[s.top()]<=a[i])
{
rb[s.top()]=i-1;
s.pop();
}
s.push(i);
}
for(int i=1;i<=n;i++)
{
if(rb[i]==0)
{
rb[i]=n;
}
}
}
void build_left_bound()
{
stack <int> s;
memset(lb,0,sizeof lb);
for(int i=n;i>=1;--i)
{
while(s.size() && a[s.top()]<a[i])
{
lb[s.top()]=i+1;
s.pop();
}
s.push(i);
}
for(int i=1;i<=n;i++)
{
if(lb[i]==0)
{
lb[i]=1;
}
}
}
int get_left(int p,int len)
{
int l=p-len+1, r=p;
l=max(l,1ll);
return st.query(l,r);
}
int get_right(int p,int len)
{
int l=p, r=p+len-1;
r=min(r,n);
return st.query(l,r);
}
int bisect_left(int p,int mask)
{
int l=1, r=p+1;
while(l<r)
{
int mid=(l+r)/2;
if(get_left(p,mid) & ~mask) r=mid;
else l=mid+1;
}
return l-1;
}
int bisect_right(int p,int mask)
{
int l=1, r=n-p+2;
while(l<r)
{
int mid=(l+r)/2;
if(get_right(p,mid) & ~mask) r=mid;
else l=mid+1;
}
return l-1;
}
void self_check1()
{
cout<<"monostack dataout:"<<endl;
for(int i=1;i<=n;i++)
{
cout<<i<<" "<<lb[i]<<","<<rb[i]<<endl;
}
}
void self_check2()
{
cout<<"st_table selfcheck:"<<endl;
for(int i=1;i<=n;i++)
{
for(int j=i;j<=n;j++)
{
int ans=st.query(i,j);
int chk=0;
for(int k=i;k<=j;k++)
{
chk|=a[k];
}
cout<<(ans==chk);
}
}
cout<<endl;
}
signed main()
{
ios::sync_with_stdio(false);
cin>>n;
for(int i=1;i<=n;i++) cin>>a[i];
for(int i=1;i<=n;i++) lg2[i]=log2(i);
build_right_bound();
build_left_bound();
if(dbg) self_check1();
st.build(n,a);
if(dbg) self_check2();
int ans=0;
for(int i=1;i<=n;i++)
{
int ql=lb[i], qr=rb[i];
int al=bisect_left(i,a[i]);
int ar=bisect_right(i,a[i]);
al=min(al,i-lb[i]+1);
ar=min(ar,rb[i]-i+1);
if(dbg) cout<<"i="<<i<<" "<<al<<" "<<ar<<" "<<endl;
ans+=al*ar;
}
if(dbg) cout<<"src_ans = "<<ans<<endl;
ans=n*(n+1)/2-ans;
cout<<ans<<endl;
}