题意:You are given an array A of N non-negative integers and an integer M.
Find the number of pair(i,j) such that 1≤i≤j≤N and min(Ai,Ai+1,...,Aj)⋅(Ai⊕Ai+1⊕...⊕Aj)≤M.
先用把数字从小到大依次插入,用个set预处理出每个最小值控制的子区间的范围,对于大小相同的数,靠左的数优先。
然后对于每个子区间,最小值将它划分成左右两段,枚举较短的那段,在较长的那段中用可持久化Trie查询满足条件的对数。前缀异或和别忘了处理,很经典了。
就是设M/min(Ai,Ai+1,...,Aj)为W,每次向左右儿子查询,如果W的这位是零,那就不累计答案,直接朝零(这里指异或上左端点的前缀和以后的零)走;如果W的这位是一,那就统计上零(这里指异或上左端点的前缀和以后的零)的子树的答案,然后朝一走。别忘了到叶子以后还要在累计一次。(叶子必然满足条件,可以累计。)
#include<cstdio> #include<set> #include<algorithm> using namespace std; typedef set<int>::iterator ITER; typedef long long ll; int n; ll m,a[100005]; int b[100005],lb[100005],ub[100005]; int w,xors[100005]; bool cmp(const int &x,const int &y){ return (a[x]!=a[y] ? a[x]<a[y] : x<y); } ll ans; #define N 100005 #define MAXBIT 31 int root[N],ch[N*(MAXBIT+1)][2],sz[N*(MAXBIT+1)],tot; void add(int now,int W) { int old=root[now-1]; root[now]=++tot; now=root[now]; for(int i=MAXBIT;i>=1;--i) { int Bit=((W>>(i-1))&1); sz[now]=sz[old]+1; ch[now][Bit^1]=ch[old][Bit^1]; ch[now][Bit]=++tot; now=ch[now][Bit]; old=ch[old][Bit]; } sz[now]=sz[old]+1; } void query(int bef,int L,int R,int W,int fl) { L=root[L];R=root[R+1]; for(int i=MAXBIT;i>=1;--i){ int Bit=(W>>(i-1)&1); if(Bit==1){ if((xors[bef-fl]>>(i-1)&1)==1){ ans+=(ll)(sz[ch[R][1]]-sz[ch[L][1]]); R=ch[R][0]; L=ch[L][0]; } else{ ans+=(ll)(sz[ch[R][0]]-sz[ch[L][0]]); R=ch[R][1]; L=ch[L][1]; } } else{ if((xors[bef-fl]>>(i-1)&1)==1){ R=ch[R][1]; L=ch[L][1]; } else{ R=ch[R][0]; L=ch[L][0]; } } } ans+=(ll)(sz[R]-sz[L]); } int main(){ //freopen("f.in","r",stdin); scanf("%d%lld",&n,&m); for(int i=1;i<=n;++i){ scanf("%lld",&a[i]); xors[i]=(xors[i-1]^a[i]); } for(int i=1;i<=n;++i){ b[i]=i; } sort(b+1,b+n+1,cmp); set<int> S; for(int i=1;i<=n;++i){ ITER it=S.lower_bound(b[i]); if(it!=S.end()){ ub[b[i]]=((*it)-1); } else{ ub[b[i]]=n; } if(it==S.begin()){ lb[b[i]]=1; } else{ --it; lb[b[i]]=((*it)+1); } S.insert(b[i]); } add(1,0); for(int i=1;i<=n;++i){ add(i+1,xors[i]); } for(int i=1;i<=n;++i){ if(m/a[i]>=(1ll<<31)){ ans+=(ll)(i-lb[i]+1)*(ll)(ub[i]-i+1); continue; } w=(int)(m/a[i]); if(i-lb[i]<=ub[i]-i){ for(int j=lb[i];j<=i;++j){ query(j,i,ub[i],w,1); } } else{ for(int j=i;j<=ub[i];++j){ query(j,lb[i]-1,i-1,w,0); } } } printf("%lld ",ans); return 0; }