题目
给出 (n) 个三元组({ a_i,b_i,c_i })和(x,y,z);
将每个三元组扩展成((x)个(a_i),(y)个(b_i),(z)个(c_i));
问从(n)组里面每组选一个数,这(n)个数异或值为x 的方案数(mod 998244353)是多少;
(1 le n le 10^5 , 1 le k le 17 , 0 le x,y,z le 10^9 , 0 le a_i,b_i,c_i lt 2^k) ;
题解
-
最后的答案异或一个 (oplus_{i=1}^{n} a_i) ,令({a_i,b_i,c_i})变成$ { 0 , a_i wedge b_i , a_i wedge c_i } $ ;
-
令(F_{i,0}+=x , F_{i,b_i}+=y , F_{i,c_i}+=z) ,把所有(fwt(F_i))点乘起来再(ifwt)回去即可;
-
考虑如何求最后的乘积(Pi F_i);
-
对于(fwt(F_i)),每一项一定都是(x+y+z , x+y-z , x-y+z , x - y - z) 之一;
设纵向的个数为(i,j,k,l),解出每一位(i,j,k,l)即可快速算出最后的乘积,首先:
[egin{align} i +j +k + l = n end{align} ]令只考虑(F_i,b_i=1),设所有(F)加起来(fwt)到得到对应位值上的值为(p)(x=0,y=1,z=0):
[i + j - k - l = p ]同理只令(F_i,c_i = 1),有(x=0,y=0,z=1):
[i - j + k - l = p ]令(F_{i,b_i wedge c_i}=1),相当于上面两个的点值乘法,有
[i - j - k + l = p ]解方程即可;
-
最后(ifwt)回来;
#include<bits/stdc++.h> #define mod 998244353 #define ll long long using namespace std; const int N=1<<17; int n,X,Y,Z,l,s; int A[N],B[N],C[N],ans[N]; char gc(){ static char*p1,*p2,S[1000000]; if(p1==p2)p2=(p1=S)+fread(S,1,1000000,stdin); return(p1==p2)?EOF:*p1++; } int rd(){ int x=0;char c=gc(); while(c<'0'||c>'9')c=gc(); while(c>='0'&&c<='9')x=(x<<1)+(x<<3)+c-'0',c=gc(); return x; } int pw(int x,int y){ int re=1; while(y){ if(y&1)re=(ll)re*x%mod; y>>=1;x=(ll)x*x%mod; } return re; } void fwt(int*a){ for(int i=1;i<l;i<<=1) for(int j=0;j<l;j+=i<<1) for(int k=0;k<i;++k){ int t1=a[j+k],t2=a[j+k+i]; a[j+k]=t1+t2; a[j+k+i]=t1-t2; } } void dec(int&x,int y){x-=y;if(x<0)x+=mod;} void ifwt(int*a){ for(int i=1;i<l;i<<=1) for(int j=0;j<l;j+=i<<1) for(int k=0;k<i;++k){ int iv2=(mod+1)/2; int t1=a[j+k],t2=a[j+k+i]; a[j+k]=(ll)(t1+t2)*iv2%mod; a[j+k+i]=(ll)(t1-t2+mod)*iv2%mod; } } int main(){ //freopen("H.in","r",stdin); //freopen("H.out","w",stdout); n=rd();l=1<<rd(); X=rd();Y=rd();Z=rd(); for(int i=1;i<=n;++i){ int a=rd(),b=rd(),c=rd(); s^=a;b^=a;c^=a;a=b^c; A[b]++,B[c]++,C[a]++; } fwt(A);fwt(B);fwt(C); int t1=((ll)X+Y+Z)%mod; int t2=((ll)X+Y-Z+mod)%mod; int t3=((ll)X-Y+Z+mod)%mod; int t4=((ll)X-Y-Z+mod+mod)%mod; for(int i=0;i<l;++i){ ans[i] = (ll)pw(t1,(n+A[i]+B[i]+C[i])>>2) *pw(t2,(n+A[i]-B[i]-C[i])>>2)%mod *pw(t3,(n-A[i]+B[i]-C[i])>>2)%mod *pw(t4,(n-A[i]-B[i]+C[i])>>2)%mod; } ifwt(ans); for(int i=0;i<l;++i)printf("%d ",ans[i^s]); return 0; }