把区间拆分成若干个子区间,满足每个区间长度是 (2) 的整数次幂并且最后几位是从全 (0) 到全 (1)。
容易发现,这样两个区间的合并最后是上面的位异或起来,较低位可以随便选。
因为需要排序这样复杂度会变成 (O(n^2log^2 V log(n^2log^2 V))) 。
分析两个原区间的合并。我们只考虑(拆分后)左边区间的长度大于(拆分后)右边区间的长度的情况(即左边较低位只选较低位),因为右边区间的不同前缀只会有 2 个,于是合并出来的区间只有 2 种。
我们只需要将重复的去掉,复杂度就成了 (O(n^2log V log(n^2log V))) 。
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int Bit = 60;
typedef pair<ll,ll> Pir;
vector<Pir> A,B;
inline void maketrie(ll l,ll r,vector<Pir>& rngs){// the range is : (l, r) // [l+1, r-1]
for(int i=60;~i;i--){
ll last = (1ll<<(i+1))-1;
if( l - (l&last) == r - (r&last))continue;
if(r & (1ll << i)){
rngs.push_back(Pir(i,r - (1ll<<i)));
}
}
for(int i=60;~i;i--){
ll last = (1ll<<(i+1))-1;
if( l - (l&last) == r - (r&last))continue;
if(!(l & (1ll << i))){
rngs.push_back(Pir(i,l + (1ll<<i)));
}
}
}
vector< pair<ll,int> > op;
typedef pair<ll,int> pi;
void merge(Pir a,Pir b){
ll Xor = a.second ^ b.second;
Xor -= Xor & ((1ll<<max(a.first,b.first))-1);
ll rnglen = 1ll << max(a.first, b.first);
op.push_back(pi(Xor, 1)), op.push_back(pi(Xor + rnglen, -1));
}
const int mod = 998244353;
const int inv2 = (mod+1)>>1;
ll S(ll n){
return n%mod*((n+1)%mod)%mod*inv2%mod;
}
int main()
{
int sa, sb;
cin >> sa;
for(int i=1;i<=sa;i++){
ll l,r;scanf("%lld%lld",&l,&r);
maketrie(l-1,r+1,A);
}
cin >> sb;
for(int i=1;i<=sb;i++){
ll l,r;scanf("%lld%lld",&l,&r);
maketrie(l-1,r+1,B);
}
for(size_t i=0;i<A.size();i++){
ll LastXor = -1;
for(size_t j=0;j<B.size();j++)if(A[i].first > B[j].first){
if(LastXor != (A[i].second ^ B[j].second) - ((A[i].second ^ B[j].second) & ((1ll << A[i].first)-1))){
LastXor = (A[i].second ^ B[j].second) - ((A[i].second ^ B[j].second) & ((1ll << A[i].first)-1));
merge(A[i], B[j]);
}
}
}
for(size_t j=0;j<B.size();j++){
ll LastXor = -1;
for(size_t i=0;i<A.size();i++)if(A[i].first <= B[j].first){
if(LastXor != (A[i].second ^ B[j].second) - ((A[i].second ^ B[j].second) & ((1ll << B[j].first)-1))){
LastXor = (A[i].second ^ B[j].second) - ((A[i].second ^ B[j].second) & ((1ll << B[j].first)-1));
merge(A[i], B[j]);
}
}
}
sort(op.begin(), op.end());
int ans = 0, t = 0;
for(size_t i=0,j;i<op.size();i=j+1){
j=i;while(j+1<op.size() && op[j+1].first == op[j].first)++j;
for(size_t k=i;k<=j;k++)t+=op[k].second;
if(t>0){
ans += (S(op[j+1].first-1)-S(op[j].first-1))%mod;
ans %= mod;ans += mod;ans %= mod;
}
}
cout << ans << endl;
return 0;
}