P4462 [CQOI2018]异或序列
题目描述
已知一个长度为n的整数数列(a_1,a_2,...,a_n),给定查询参数(l、r),问在(a_l,a_{l+1},...,a_r)区间内,有多少子序列满足异或和等于(k)。也就是说,对于所有的(x,y (I ≤ x ≤ y ≤ r)),能够满足(a_x igoplus a_{x+1} igoplus ... igoplus a_y = k)的(x,y)有多少组。
输入格式
输入文件第一行,为3个整数(n,m,k。)
第二行为空格分开的n个整数,即(a_1,a_2,..a_n)。
接下来m行,每行两个整数(l_j,r_j),表示一次查询。
输出格式
输出文件共m行,对应每个查询的计算结果。
输入输出样例
输入 #1
4 5 1
1 2 3 1
1 4
1 3
2 3
2 4
4 4
输出 #1
4
2
1
2
1
说明/提示
对于30%的数据,(1 ≤ n, m ≤ 1000)
对于100%的数据,(1 ≤ n, m ≤ 10^5, 0 ≤ k, a_i ≤ 10^5,1 ≤ l_j ≤ r_j ≤ n)
Solution
异或操作有一个性质
已知a^b=c
则a^c=b,b^c=a
那么我们将原序列(a[])变成前缀和序列(s[])之后
(a_x igoplus a_{x+1} igoplus ... igoplus a_y = k)(Rightarrow)(s_{x-1}igoplus s_r=k)(Rightarrow)(s_{r}igoplus k=s_{x-1})
那么问题就由多少个区间异或和为(k)转化成了多少个数对的异或值为(k)
可以使用莫队求解
重点还是这个(add())函数和(remove())函数
当我们加入一个值(a[x])时,我们想要知道当前所在区间有多少个(a[i]igoplus a[x]=k),其中(i<x),也就是(a[x]igoplus k)的个数,因为这些数都可以和(a[x])异或起来得到(k)
void add(int x) {ans+=cnt[a[x]^k]; cnt[a[x]]++;}
当我们删除一个值时,我们依然需要知道前面有多少个(a[i]igoplus a[x]=k),其中(i<x)
void remove(int x) {ans-=cnt[a[x]^k]; cnt[a[x]]--;}
代码是这样子的吗,不对!
当k=0时,(a[x]igoplus k=a[x]),如果我们先用(ans)减去贡献,再使(cnt[]--),答案会少一
原因就是我们要统计的其实是(i<x)的(cnt[a[x]])的个数,不能把第(x)个位置上的值算进去,所以需要先(cnt[a[x]]--),再用(ans)减去贡献
void remove(int x) {cnt[a[x]]--; ans-=cnt[a[x]^k];}
当然,在k!=0的情况下,顺序是没有影响的,因为如果k!=0,那么(a[x]igoplus k!=a[x]),那么当前的(cnt[a[x]])不会影响到答案
如果不相信的话,下面是hack数据
3 2 0
0 0 0
1 2
3 3
正确答案应该是
3
1
Code
#include<bits/stdc++.h>
#define lol long long
#define in(i) (i=read())
using namespace std;
const lol N=1e5+10,mod=1e9+7;
lol read() {
lol ans=0,f=1; char i=getchar();
while(i<'0' || i>'9') {if(i=='-') f=-1; i=getchar();}
while(i>='0' && i<='9') ans=(ans<<1)+(ans<<3)+(i^48),i=getchar();
return ans*f;
}
int n,m,k,block,ans;
int a[N],cnt[N],sum[N];
struct query{
int l,r,id,pos;
bool operator < (const query &a) const {
return pos==a.pos?r<a.r:pos<a.pos;
}
}t[N];
void add(int x) {ans+=cnt[a[x]^k]; cnt[a[x]]++;}
void remove(int x) {cnt[a[x]]--; ans-=cnt[a[x]^k];}
int main() {
in(n), in(m), in(k); block=sqrt(n);
for (int i=1;i<=n;i++) in(a[i]), a[i]^=a[i-1];
for (int i=1,l,r;i<=m;i++) {
in(l), in(r);
t[i].l=l-1, t[i].r=r;
t[i].id=i;
t[i].pos=(l-1)/block+1;
}
sort(t+1,t+1+m);
for (int i=1,curl=1,curr=0;i<=m;i++) {
int l=t[i].l,r=t[i].r;
while(curl<l) remove(curl++);
while(curl>l) add(--curl);
while(curr<r) add(++curr);
while(curr>r) remove(curr--);
sum[t[i].id]=ans;
}
for (int i=1;i<=m;i++) cout<<sum[i]<<endl;
}