(mathcal{Description})
(mathcal{Solution})
有位运算先按位考虑贡献
考虑若区间长度没有特殊贡献,即所有长度的贡献是一样的
那么答案就是这一位异或起来为(1)的子集个数
关于这个,只需知道在这个区间中在这一位为(1)的数量(n_1)和为(0)的数量(n_0)即可
异或起来为(1)就要求有奇数个(1)和任意个(0)
选这一位有(1)的方案数总数为(2^{n_1}),其中一半有奇数个(1),一半有偶数个(1),所以有(2^{n_1-1})中方案使这一位有奇数个(1)
所以总方案就是(2^{n_1} imes 2^{n_0})
可惜区间长度是有贡献的,所以上面方法是不行的
容易想到的是按位考虑后再按长度考虑,这样的复杂度是(O(n^3logn)),结合上面可以得到(50)分的部分分
(egin{aligned}sum_{bit}2^{bit}sum_{len}x^{len}sum_{i,i&1=1}egin{pmatrix}n_1\ iend{pmatrix}egin{pmatrix}n_0\ len-iend{pmatrix}end{aligned})
另一个想法是分治,先解决出左边长度为(i)的方案数和解决右边长度为(n-i)的方案数
则长度为(n)的可以由左右两边的方案结合得到,容易发现这是个卷积的形式,可以用(FFT)优化
到此,上面的方法都不是正解
上面的方法都是单独考虑区间长度进行计算的,所以复杂度一定带了个(n),考虑一次性算出所有的答案
于是考虑将(x^i)和组合数套起来,得到(egin{aligned}ans=left(sum_{i,i&1=1}egin{pmatrix}n_1\iend{pmatrix}x^i
ight)left(sum_iegin{pmatrix}n_0\iend{pmatrix}x^i
ight)end{aligned})
那个(i&1=1)亦可写作(i\%2==1),这个可以用单位根反演做,但其实没有必要
仔细看一下这个式子,(x^i)旁乘一个(1^{n_{0/1}-i}),会发现这就是(left(x+1
ight)^{n_{0/1}})二项式展开后的结果
(egin{aligned}ans=left(sum_{i,i&1=1}egin{pmatrix}n_1\iend{pmatrix}x^i1^{n_1-i}
ight)left(sum_iegin{pmatrix}n_0\iend{pmatrix}x^i1^{n_0-i}
ight)end{aligned})
考虑算出左边可以为偶数的值,再减去为偶数的项,实际上(frac{left(x+1
ight)^n-left(1-x
ight)^n}{2})就是左边(i)为奇数的答案
最后得到
(egin{aligned}ans=left(frac{left(x+1
ight)^{n_1}-left(1-x
ight)^{n_1}}{2}
ight)left(x+1
ight)^{n_0}end{aligned})
这样一次询问每一位就是(logn),总复杂度就是(nlog^2n)
(mathcal{Code})
/*******************************
Author:Morning_Glory
LANG:C++
Created Time:2019年10月16日 星期三 08时59分51秒
*******************************/
#include <cstdio>
#include <fstream>
#include <algorithm>
using namespace std;
const int limit = 35;
const int maxn = 100005;
const int mod = 998244353;
const int inv = 499122177;
//{{{cin
struct IO{
template<typename T>
IO & operator>>(T&res){
res=0;
bool flag=false;
char ch;
while((ch=getchar())>'9'||ch<'0') flag|=ch=='-';
while(ch>='0'&&ch<='9') res=(res<<1)+(res<<3)+(ch^'0'),ch=getchar();
if (flag) res=~res+1;
return *this;
}
}cin;
//}}}
int n,m,lim;
int a[maxn];
int s[limit][maxn];
//{{{ksm
int ksm (int a,int b)
{
a%=mod;
int s=1;
for (;b;b>>=1,a=1ll*a*a%mod)
if (b&1) s=1ll*s*a%mod;
return s;
}
//}}}
int main()
{
cin>>n>>m;
for (int i=1;i<=n;++i) cin>>a[i],lim=max(lim,a[i]);
for (int i=30;~i;--i)
if (lim>(1<<i)){ lim=i+1;break;}
for (int i=0;i<=lim;++i)
for (int j=1;j<=n;++j) s[i][j]=s[i][j-1]+((a[j]>>i)&1);
while (m--){
int l,r,x,ans=0;
cin>>l>>r>>x;
int le=r-l+1;
for (int b=0;b<=lim;++b){
int n1=s[b][r]-s[b][l-1],n0=le-n1,mi=(1<<b)%mod;
ans=(ans+1ll*(ksm(x+1,n1)-ksm(mod+1-x,n1)+mod)%mod*inv%mod*ksm(x+1,n0)%mod*mi%mod)%mod;
}
printf("%d
",ans);
}
return 0;
}
如有哪里讲得不是很明白或是有错误,欢迎指正
如您喜欢的话不妨点个赞收藏一下吧