链接
题意
给定一个长度为 (N) 的序列,求前 (k) 大的 (a_i operatorname{xor} a_j) 之和,答案对 (10^9+7) 取模。
其中 (a_i operatorname{xor} a_j) 与 (a_j operatorname{xor} a_i) 被看成是同一对。
( exttt{Data Range:}1leq nleq 10^6,1leq kleq frac{n(n-1)}{2})。
约定
以下称 (a_i operatorname{xor} a_j) 的值为异或值。
(b_v[i..j]) 为值 (v) 在 二进制意义 下第 (i) 位至第 (j) 位的取值,如 (b_3[0..1]=11_2),(b_5[1..2]=10_2),(b_5[0,2]=101_2)(下标 (2) 即表示二进制意义下)。
二进制意义下的最高位为 (mx)。
题解
此题是 [十二省联考2019]异或粽子 的加强版,解法吊打 ( exttt{std})。
提供了一种时间复杂度与 (k) 无关的解法,时间复杂度为 (O(nomega^2+nlog n)),其中 (omega) 为二进制意义下 (a_i) 的最高位,是一个(leq 31) 的常数
首先可以想到,要求出前 (k) 大的异或和,必须先求出第 (k) 大的异或和。
有一种很暴力的做法,枚举每个 (a_i,a_j),求出第 (k) 大的异或值。无疑可以优化,枚举每一个 (a_i),将上述过程放至 ( exttt{trie}) 上,从高位贪心,用类似权值线段树求第 (k) 大的方式在 ( exttt{trie}) 上二分即可。
更加具体地说,我们将一棵 ( exttt{trie}) 分为 (n) 层,第 (i) 层的指针即 (a_i) 所对应的 (a_j) 的前缀。从最高位开始枚举第 (k) 大的 (a_i operatorname{xor} a_j) 的每一位上取 (0/1),在枚举到第 (x) 位时,在每一层上计算 (a_i operatorname{xor} a_j) 的第 (x) 位为 (1) 的方案数,记这个数为 (cnt)。若 (cntgeq k),则说明这一位必须取 (1),否则所求会小于第 (k) 大。若 (cnt < k),则说明这一位必须取 (0) ,否则所求会大于第 (k) 大。在 (n) 层 (trie) 上对应更新指针即可,并统计第 (x) 位的取值,更新第 (k) 大的异或值在 (x) 上的取值即可。
求出第 (k) 大的 (a_i operatorname{xor} a_j) 后,记其为 (val)。若 (val) 第 (x) 位上为 (0),则若存在一个异或值 (t),令(b_{val}[x+1..mx]=b_t[x+1..mx]),仅在 (t) 的第 (x) 位上为 (1),(t) 一定大于 (val) ,属于前 (k) 大的取值。
考虑计算这部分值的贡献。仍然是从高位贪心,分成 (n) 层。每层按照 (val) 的值在 ( exttt{trie}) 上递归,枚举值 (val) 在二进制意义上为 (0) 的那些位,统计仅在这一位上取 (1) ,第 (x+1-mx) 位上与 (val) 在二进制意义下取值相同的那些异或值对答案造成的贡献。
由于即使确定了 (a_i),(a_j) 也是在一棵子树中,对一棵子树统计答案,不太好做。可以在初始时直接对 (a_i) 排序,这样 ( exttt{trie}) 的一棵子树即对应了 (a) 上一段连续的区间。统计贡献时按位统计,对每一位上的 (0/1) 分别维护一个前缀和,这样就可以了。
Show the Code
#include<cstdio>
#include<algorithm>
typedef long long ll;
const ll mod=1e9+7;
int n,tot=0,rt=0;
int a[50005],tmp[50005];
int sum[35][50005][2];
int ch[1000005][2],size[1000005],l[1000005],r[1000005];
inline ll read() {
register ll x=0,f=1;register char s=getchar();
while(s>'9'||s<'0') {if(s=='-') f=-1;s=getchar();}
while(s>='0'&&s<='9') {x=x*10+s-'0';s=getchar();}
return x*f;
}
inline ll Add(ll x,ll y) {return ((x+y)%mod+mod)%mod;}
inline ll pow(ll x,ll p) {
ll res=1;
for(;p;p>>=1) {if(p&1) res=res*x%mod;x=x*x%mod;}
return res;
}
inline void insert(int val,int id) {
if(!rt) rt=++tot,l[rt]=1,r[rt]=n;
int p=rt;++size[rt];
for(register int d=31;d>=0;--d) {
int &nxt=ch[p][val>>d&1];
if(!nxt) {nxt=++tot;l[nxt]=id;}
p=nxt;r[p]=id;++size[p];
}
}
inline int getVal(ll &k) {
int res=0;
for(register int i=1;i<=n;++i) tmp[i]=rt;
for(register int d=31;d>=0;--d) {//若当前取1的个数为cnt,cnt>k
int x=0;
ll cnt=0;
for(register int i=1;i<=n;++i) cnt+=size[ch[tmp[i]][(a[i]>>d&1)^1]];
if(cnt>=k) res|=(1<<d),x=1;//cnt<k
else k-=cnt,x=0;
for(register int i=1;i<=n;++i) tmp[i]=ch[tmp[i]][(a[i]>>d&1)^x];
}
return res;
}
inline ll ask(int l,int r,int val) {
if(!l||!r) return 0;
ll res=0;
for(register int d=31;d>=0;--d) res=Add(res,(ll)(sum[d][r][(val>>d&1)^1]-sum[d][l-1][(val>>d&1)^1])*((1ll<<d)%mod)%mod);
return res;
}
inline ll getSum(ll k) {
int val=getVal(k);
ll res=val*k%mod;
for(register int i=1;i<=n;++i) tmp[i]=rt;
for(register int d=31;d>=0;--d) {
int x=val>>d&1;
if(!x) {for(register int i=1;i<=n;++i) res=Add(res,ask(l[ch[tmp[i]][(a[i]>>d&1)^1]],r[ch[tmp[i]][(a[i]>>d&1)^1]],a[i]));}
for(register int i=1;i<=n;++i) tmp[i]=ch[tmp[i]][(a[i]>>d&1)^x];
}
return res;
}
signed main() {
n=read();
ll k=read()*2ll;//25*1e8=2.5*1e9
for(register int i=1;i<=n;++i) a[i]=read();
std::sort(a+1,a+1+n);
for(register int i=1;i<=n;++i) insert(a[i],i);
for(register int i=1;i<=n;++i) {
for(register int d=31;d>=0;--d) {
sum[d][i][0]=sum[d][i-1][0];
sum[d][i][1]=sum[d][i-1][1];
++sum[d][i][a[i]>>d&1];
}
}
printf("%lld
",getSum(k)*pow(2,mod-2)%mod);
return 0;
}