• 快速沃尔什变换(FWT)笔记


    开头Orz hy,Orz yrx

    部分转载自hy的博客

    快速沃尔什变换,可以快速计算两个多项式的位运算卷积(即and,or和xor)

    问题模型如下:

    给出两个多项式$A(x)$,$B(x)$,求$C(x)$满足$C[i]=sumlimits_{j⊗k=i}A[j]*B[k]$.

    约定记号

    $⊗$表示某种位运算(and,or和xor中的一种),若$a$,$b$是两个整数,则$a⊗b$表示对这两个数按位进行位运算;若$A$,$B$是两个多项式,则$A⊗B$表示对这两个多项式做如上卷积;两个多项式的点积用$·$表示。

    FWT

    感觉这个算法就是瞎凑出来的(大佬轻喷)

    考虑对$A$和$B$做某种变换(类似FFT),使得变换之后对应位相乘之后逆运算就可以得到卷积$C(x)$。

    那么这种变换$F(A)$(其中$A$是一个多项式)需要满足:

    $F(A)·F(B)=F(A⊗B)$

    $F(kast A)=kast F(A)$

    $F(A+B)=F(A)+F(B)$

    那就瞎凑呗,考虑用类似FFT的分治思想来解决,把多项式$A$的下标按照二进制最高位分类,最高位为0的记为$A_0$,为1的记为$A_1$,则$A=(A_0,A_1)$。

    继续凑,设$F(A)=(k_{1}A_{0}+k_{2}A_{1},k_{3}A_{0}+k_{4}A_{1})$,那么要做的就是求出这四个常数。

    不难发现$(k_{1},k_{2})$与$(k_{3},k_{4})$并没有本质上的区别,即求出了前半部分的多个解取其中两个代入即可。

    将结果写作$(C_{0},C_{1})$,看回之前变换的定义,这里分类讨论:

    对于$and$:

    因为

    $0⊗0=0$,$1⊗0=0$,$0⊗1=0$,$1⊗1=1$

    所以

    $(A_{0},A_{1})⊗(B_{0},B_{1})=(A_{0}⊗B_{0}+A_{0}⊗B_{1}+A_{1}⊗B_{0},A_{1}⊗B_{1})$

    用两种方法表示$C$,可得

    $(k_{1}A_{0}+k_{2}A_{1})·(k_{1}B_{0}+k_{2}B_{1})$

    $=k_{1}(A_{0}⊗B_{0}+A_{0}⊗B_{1}+A_{1}⊗B_{0})+k_{2}(A_{1}⊗B_{1})$

    拆括号得:

    $k_{1}^{2}(A_{0}⊗B_{0})+k_{1}k_{2}(A_{0}⊗B_{1})+k_{1}k_{2}(A_{1}⊗B_{0})+k_{2}^{2}(A_{1}⊗B_{1})$

    $=k_{1}(A_{0}⊗B_{0})+k_{1}(A_{0}⊗B_{1})+k_{1}(A_{1}⊗B_{0})+k_{2}(A_{1}⊗B_{1})$

    则有:

    $egin{cases}
    k_{1}=k_{1}^{2} \
    k_{1}=k_{1}k_{2} \
    k_{2}=k_{2}^{2} 
    end{cases}$

    解得$egin{cases} k_{1}=0 \k_{2}=0end{cases}$或$egin{cases} k_{1}=1 \k_{2}=0end{cases}$或$egin{cases} k_{1}=1 \k_{2}=1end{cases}$

    考虑到要可以逆变换,所以解不能选两个相同的或者两个零(类似于求逆矩阵),因此这里只能选$(0,1)$和$(1,1)$两组解

    令$(k_{1},k_{2})=(1,1)$,$(k_{3},k_{4})=(0,1)$

    把系数写成矩阵,那么

    $egin{bmatrix} k_1 & k_2 \ k_3 & k_4 end{bmatrix} = egin{bmatrix} 1 & 1 \ 0 & 1 end{bmatrix}$

    把矩阵求逆,就可以得到逆变换的系数:

    $egin{bmatrix} 1 & -1 \ 0 & 1 end{bmatrix}$

    对于$or$:$(A_{0},A_{1})⊗(B_{0},B_{1})=(A_{0}⊗B_{0},A_{0}⊗B_{1}+A_{1}⊗B_{0}+A_{1}⊗B_{1})$

    正变换:$egin{bmatrix} 1 & 1 \ 1 & 0 end{bmatrix}$

    逆变换:$egin{bmatrix} 0 & 1 \ 1 & -1 end{bmatrix}$

    对于$xor$:$(A_{0},A_{1})⊗(B_{0},B_{1})=(A_{0}⊗B_{0}+A_{1}⊗B_{1},A_{0}⊗B_{1}+A_{1}⊗B_{0})$

    正变换:$egin{bmatrix} 1 & 1 \ 1 & -1 end{bmatrix}$

    逆变换:$egin{bmatrix} frac{1}{2} & frac{1}{2} \ frac{1}{2} & -frac{1}{2} end{bmatrix}$

    代码(洛谷P4717):

     1 #include<iostream>
     2 #include<cstring>
     3 #include<cstdio>
     4 #include<cmath>
     5 #define OR 0
     6 #define AND 1
     7 #define XOR 2
     8 using namespace std;
     9 typedef long long ll;
    10 const int mod=998244353;
    11 int inv2,bit,bitnum,n,m,a[1000001],b[1000001],a1[1000001],a2[1000001],a3[1000001],b1[1000001],b2[1000001],b3[1000001];
    12 int fastpow(int x,int y){
    13     int ret=1;
    14     for(;y;y>>=1,x=(ll)x*x%mod){
    15         if(y&1)ret=(ll)ret*x%mod;
    16     }
    17     return ret;
    18 }
    19 void fwt(int s[],int n,int ty,int op){
    20     for(int i=2;i<=n;i<<=1){
    21         for(int j=0;j<n;j+=i){
    22             for(int k=0;k<i/2;k++){
    23                 int x=s[j+k],y=s[j+k+(i>>1)];//A0,A1
    24                 if(op==1){
    25                     if(ty==OR){
    26                         s[j+k+(i>>1)]=(x+y)%mod;
    27                     }else if(ty==AND){
    28                         s[j+k]=(x+y)%mod;
    29                     }else{
    30                         s[j+k+(i>>1)]=(x+mod-y)%mod;
    31                         s[j+k]=(x+y)%mod;
    32                     }
    33                 }else{
    34                     if(ty==OR){
    35                         s[j+k+(i>>1)]=((ll)y-x+mod)%mod;
    36                     }else if(ty==AND){
    37                         s[j+k]=(ll)((ll)x+mod-y)%mod;
    38                     }else{
    39                         s[j+k+(i>>1)]=(ll)((ll)x+mod-y)*inv2%mod;
    40                         s[j+k]=(ll)(x+y)*inv2%mod;
    41                     }
    42                 }
    43             }
    44         }
    45     }
    46 }
    47 int main(){
    48     scanf("%d",&bitnum);
    49     bit=(1<<bitnum);
    50     inv2=fastpow(2,mod-2);
    51     for(int i=0;i<bit;i++)scanf("%d",&a[i]),a1[i]=a[i],a2[i]=a[i],a3[i]=a[i];
    52     for(int i=0;i<bit;i++)scanf("%d",&b[i]),b1[i]=b[i],b2[i]=b[i],b3[i]=b[i];
    53     fwt(a1,bit,0,1);
    54     fwt(a2,bit,1,1);
    55     fwt(a3,bit,2,1);
    56     fwt(b1,bit,0,1);
    57     fwt(b2,bit,1,1);
    58     fwt(b3,bit,2,1);
    59     for(int i=0;i<bit;i++){
    60         a1[i]=(ll)a1[i]*b1[i]%mod;
    61         a2[i]=(ll)a2[i]*b2[i]%mod;
    62         a3[i]=(ll)a3[i]*b3[i]%mod;
    63     }
    64     fwt(a1,bit,0,-1);
    65     fwt(a2,bit,1,-1);
    66     fwt(a3,bit,2,-1);
    67     for(int i=0;i<bit;i++)printf("%d ",a1[i]);
    68     printf("
    ");
    69     for(int i=0;i<bit;i++)printf("%d ",a2[i]);
    70     printf("
    ");
    71     for(int i=0;i<bit;i++)printf("%d ",a3[i]);
    72     return 0;
    73 }
  • 相关阅读:
    TPS限流
    JDK并发基础与部分源码解读
    tomcat6-servlet规范对接 与 ClassLoader隔离
    tomcat6-输入输出buffer设计
    tomcat6-endpoint设计
    springMVC请求路径 与实际资源路径关系
    mysql 常用的数据类型
    认识IPv4分组
    CSMA/CD协议(载波侦听多路访问/碰撞检测) 最小帧长理解
    简单的vector--- 2
  • 原文地址:https://www.cnblogs.com/dcdcbigbig/p/9359330.html
Copyright © 2020-2023  润新知