题目
题目链接:https://www.luogu.com.cn/problem/P5607
给你一个长度为 n 的整数序列 (a_1), (a_2), (ldots), (a_n) ,你需要实现以下两种操作,每个操作都可以用四个整数 (opt;l;r;v) 来表示:
(opt=1) 时,代表把一个区间 ([l,r]) 内的所有数都 xor 上 (v)。
(opt=2) 时, 查询一个区间 ([l,r]) 内选任意个数(包括 (0) 个)数 xor 起来,这个值与 (v) 的最大 xor 和是多少。
(n,mleq 5 imes 10^4,a_iin[0,10^9])。
思路
区间异或最大值。考虑线性基。
因为线性基合并是 (O(log^2 a)) 的,所以考虑再套上一个 (log) 数据结构维护区间线性基。显然采用线段树。
但是这个区间异或操作很烦,我们无法支持线性基内的数全部异或上 (v) 这种操作。我们必须把区间修改转化为单点修改。
所以我们令 (a'_i=a_i ext{ xor }a_{i-1}),这样我们每次区间异或上 (v) 就只需要让 (a'_l) 和 (a'_{r+1}) 异或上 (v) 了。那么 (a_i= ext{xor}^{i}_{j=1}a'_j),可以在线段树上 (O(log n)) 得到。
那么再回来考虑区间查询:不难发现,(a_lsim a_r) 所表示的线性基,与 (a'_{l+1}sim a'_{r}∪a_l) 所表示的线性基是一样的。那么直接线段树上合并线性基,然后插入 (a_l),最后查询与 (v) 异或起来的最大值即可。
时间复杂度 (O(mlog nlog a^2))。
可能有些卡常,我记 (cnt) 表示一个线性基插入的元素个数,如果 (cnt=log a) 那么直接返回不用插入了。这样快了整整 (1s)。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=50010,LG=30;
int n,m,a[N];
struct Xor
{
int cnt,d[LG+1];
void ins(int x)
{
if (cnt==LG) return;
for (int i=LG;i>=0;i--)
if (x&(1<<i))
{
if (!d[i]) { d[i]=x; cnt++; return; }
x^=d[i];
}
}
int query(int res)
{
for (int i=LG;i>=0;i--)
if (!(res&(1<<i)) && d[i]) res^=d[i];
return res;
}
void merge(Xor lc,Xor rc)
{
Xor x;
memset(x.d,0,sizeof(x.d)); x.cnt=0;
for (int i=LG;i>=0;i--)
x.ins(lc.d[i]),x.ins(rc.d[i]);
memcpy(d,x.d,sizeof(d)); cnt=x.cnt;
}
}b[N*4],c;
struct SegTree
{
int val[N*4];
void pushup(int x)
{
b[x].merge(b[x*2],b[x*2+1]);
val[x]=val[x*2]^val[x*2+1];
}
void build(int x,int l,int r)
{
if (l==r)
{
val[x]=a[l]; b[x].ins(val[x]);
return;
}
int mid=(l+r)>>1;
build(x*2,l,mid); build(x*2+1,mid+1,r);
pushup(x);
}
void update(int x,int l,int r,int k,int v)
{
if (l==r)
{
memset(b[x].d,0,sizeof(b[x].d)); b[x].cnt=0;
val[x]^=v; b[x].ins(val[x]);
return;
}
int mid=(l+r)>>1;
if (k<=mid) update(x*2,l,mid,k,v);
else update(x*2+1,mid+1,r,k,v);
pushup(x);
}
int query1(int x,int l,int r,int k)
{
if (l==r) return val[x];
int mid=(l+r)>>1;
if (k<=mid) return query1(x*2,l,mid,k);
else return query1(x*2+1,mid+1,r,k)^val[x*2];
}
void query2(int x,int l,int r,int ql,int qr)
{
if (ql<=l && qr>=r)
{
c.merge(c,b[x]);
return;
}
int mid=(l+r)>>1;
if (ql<=mid) query2(x*2,l,mid,ql,qr);
if (qr>mid) query2(x*2+1,mid+1,r,ql,qr);
}
}seg;
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++) scanf("%d",&a[i]);
for (int i=n;i>=1;i--) a[i]^=a[i-1];
seg.build(1,1,n);
while (m--)
{
int opt,l,r,v;
scanf("%d%d%d%d",&opt,&l,&r,&v);
if (opt==1)
{
seg.update(1,1,n,l,v);
if (r<n) seg.update(1,1,n,r+1,v);
}
else
{
memset(c.d,0,sizeof(c.d)); c.cnt=0;
c.ins(seg.query1(1,1,n,l));
if (l<r) seg.query2(1,1,n,l+1,r);
printf("%d
",c.query(v));
}
}
return 0;
}