• 卷积理论 & 高维FWT学习笔记


    之前做了那么多生成函数和多项式卷积的题目,结果今天才理解了优化卷积算法的实质。


    首先我们以二进制FWT or作为最简单的例子入手。

    我们发现正的FWT or变换就是求$hat{a}_j=sum_{iin j}a_i$,即子集和,那这个是怎么来的呢?

    我们假设$a$到$hat{a}$的转移矩阵为$X$,则

    $$(sum_{j}X_{i,j}a_j)*(sum_{j}X_{i,j}b_j)=sum_jX_{i,j}(sum_{s|t=j}a_sb_t)$$

    所以考虑$a_sb_t$的贡献。

    $$X_{i,s}*X_{i,t}=X_{i,s|t}$$

    所以对于$X$的每一行都有$X_s*X_t=X_{s|t}$

    而且由于最后还要进行逆变换,也就是乘上$X^{-1}$,我们知道矩阵可以求逆当且仅当$X$的行列式不为0,所以$X$的任意两行都不相同。

    根据这个,我们先假设$X$中只有0和1(因为这样是最简单的),然后$X_{s|t}=1$与$X_s=X_t=1$等价,所以就可以推出来了。

    先看$n=8$的情形。

    $$X=egin{pmatrix}1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \1 & 1 & 0 & 0 & 0 & 0 & 0 & 0 \1 & 0 & 1 & 0 & 0 & 0 & 0 & 0 \1 & 1 & 1 & 1 & 0 & 0 & 0 & 0 \1 & 0 & 0 & 0 & 1 & 0 & 0 & 0 \1 & 1 & 0 & 0 & 1 & 1 & 0 & 0 \1 & 0 & 1 & 0 & 1 & 0 & 1 & 0 \1 & 1 & 1 & 1 & 1 & 1 & 1 & 1end{pmatrix}$$

    打表找规律可得

    $$X_{i,j}=prod_{k=0}^{n-1}C_{i[2^k],j[2^k]}$$

    其中$i[2^k]$表示$i$在二进制下的第$k$位。

    $$C=egin{pmatrix}1 & 0 \ 1 & 1end{pmatrix}$$

    然后我们就知道如何进行分治计算这个向量乘矩阵了。(对,就是那个三重循环)


    我们也可以把FFT的矩阵也这样写出来。

    $$A=egin{pmatrix}omega_n^0 & omega_n^0 & ldots & omega_n^0 & omega_n^0 \omega_n^0 & omega_n^1 & ldots & omega_n^{n-2} & omega_n^{n-1} \vdots & vdots & ddots & ddots & ddots \omega_n^0 & omega_n^{n-1} & ldots & omega_n^{(n-2)(n-1)} & omega_n^{(n-1)(n-1)}end{pmatrix}$$

    即$A_{i,j}=omega_n^{ij}$,所以

    $$A^{-1}=frac{1}{n}egin{pmatrix}omega_n^{-0} & omega_n^{-0} & ldots & omega_n^{-0} & omega_n^{-0} \omega_n^{-0} & omega_n^{-1} & ldots & omega_n^{-(n-2)} & omega_n^{-(n-1)} \vdots & vdots & ddots & ddots & ddots \omega_n^{-0} & omega_n^{-(n-1)} & ldots & omega_n^{-(n-2)(n-1)} & omega_n^{-(n-1)(n-1)}end{pmatrix}$$

    即$A^{-1}_{i,j}=frac{omega_n^{-ij}}{m}$


    UOJ272 【清华集训2016】石家庄的工人阶级队伍比较坚强

    我们设$B_{i,j}$表示$f_{i-1}$到$f_i$的转移矩阵。

    定义$aoplus b$表示三进制不进位加法,$aominus b$表示三进制不退位减法。易得这两个运算互为逆运算。

    则$forall k,B_{ioplus k,joplus k}=B_{i,j}$,由数学归纳法得$forall k,B_{ioplus k,joplus k}^n=B_{i,j}^n$即$B_{i,j}^n=B_{0,jominus i}^n$

    $$f_{n,i}=sum_{j}f_{0,j}*B_{j,i}^n=sum_{j}f_{0,j}*B_{0,iominus j}^n=sum_{xoplus y=i}f_x*B_{0,y}^n$$

    所以我们只需要求出$B$矩阵的第一行并与$f_0$做三进制下的异或卷积就可以了。

    我们先考虑二进制下的。

    $$C=egin{pmatrix}1 & 1 \1 & -1end{pmatrix}$$

    ($C$矩阵的意义见上)

    所以感性理解一下(或者可以自己推一推),三进制的异或卷积的矩阵就是:

    $$C=egin{pmatrix}1 & 1 & 1 \1 & omega & omega^2 \1 & omega^2 & omegaend{pmatrix}$$

    $$C^{-1}=frac{1}{3}egin{pmatrix}1 & 1 & 1 \1 & omega^2 & omega \1 & omega & omega^2end{pmatrix}$$

    其中$omega=frac{-1+sqrt{3}i}{2}$

    但是$sqrt{3}$运算非常麻烦,还会有精度问题,所以我们取$1,omega$作为基底而不是$1,i$,即把复数表示成$a+bomega$的形式。

    乘法与$a+bi$的乘法不一样,需要推一推。

    $$(a+bomega)(c+domega)=ac+(bc+ad)omega+bd(-omega-1)=(ac-bd)+(bc+ad-bd)omega$$

    然后就应该是做完了。

     1 #include<cstdio>
     2 #define Rint register int
     3 using namespace std;
     4 typedef long long LL;
     5 const int N = 531441;
     6 int n, m, t, p, po[13], cntx[N], cnty[N];
     7 inline void exgcd(int a, int b, int &x, int &y){
     8     if(!b){x = 1; y = 0; return;}
     9     exgcd(b, a % b, y, x); y -= (LL) a / b * x;
    10 }
    11 struct complex {
    12     int x, y;
    13     inline complex(int x = 0, int y = 0): x(x), y(y){}
    14     inline complex operator + (const complex &o) const {return complex((x + o.x) % p, (y + o.y) % p);}
    15     inline complex operator - (const complex &o) const {return complex((x - o.x + p) % p, (y - o.y + p) % p);}
    16     inline complex operator * (const complex &o) const {
    17         return complex(((LL) x * o.x % p - (LL) y * o.y % p + p) % p, ((LL) y * o.x % p + (LL) x * o.y % p - (LL) y * o.y % p + p) % p);
    18     }
    19 } A[N], B[N];
    20 inline complex kasumi(complex a, int b){
    21     complex res = complex(1, 0);
    22     while(b){
    23         if(b & 1) res = res * a;
    24         a = a * a;
    25         b >>= 1;
    26     }
    27     return res;
    28 }
    29 inline complex calc1(const complex &a){return complex((p - a.y) % p, (a.x - a.y + p) % p);}
    30 inline complex calc2(const complex &a){return complex((a.y - a.x + p) % p, (p - a.x) % p);}
    31 inline void dft(complex *A){
    32     for(Rint mid = 1;mid < n;mid *= 3)
    33         for(Rint j = 0;j < n;j += mid * 3)
    34             for(Rint k = 0;k < mid;k ++){
    35                 complex x = A[j + k], y = A[j + k + mid], z = A[j + k + mid * 2];
    36                 A[j + k] = x + y + z;
    37                 A[j + k + mid] = x + calc1(y) + calc2(z);
    38                 A[j + k + mid * 2] = x + calc1(z) + calc2(y);
    39             }
    40 }
    41 inline void idft(complex *A){
    42     for(Rint mid = 1;mid < n;mid *= 3)
    43         for(Rint j = 0;j < n;j += mid * 3)
    44             for(Rint k = 0;k < mid;k ++){
    45                 complex x = A[j + k], y = A[j + k + mid], z = A[j + k + mid * 2];
    46                 A[j + k] = x + y + z;
    47                 A[j + k + mid] = x + calc1(z) + calc2(y);
    48                 A[j + k + mid * 2] = x + calc1(y) + calc2(z);
    49             }
    50 }
    51 int trans[13][13];
    52 int main(){
    53     scanf("%d%d%d", &m, &t, &p);
    54     po[0] = 1;
    55     for(Rint i = 1;i <= m;i ++) po[i] = (LL) po[i - 1] * 3;
    56     n = po[m];
    57     for(Rint i = 1;i <= m;i ++) po[i] %= p;
    58     if(p == 1){
    59         for(Rint i = 0;i < n;i ++) puts("0");
    60         return 0;
    61     }
    62     for(Rint i = 0;i < n;i ++) scanf("%d", &A[i].x);
    63     for(Rint i = 0;i <= m;i ++)
    64         for(Rint j = 0;i + j <= m;j ++) scanf("%d", trans[i] + j);
    65     for(Rint i = 0;i < n;i ++){
    66         cntx[i] = cntx[i / 3] + (i % 3 == 1);
    67         cnty[i] = cnty[i / 3] + (i % 3 == 2);
    68         B[i].x = trans[cntx[i]][cnty[i]];
    69         //printf("%d ", B[i].x);
    70     }
    71     //putchar('
    ');
    72     dft(A); dft(B);
    73     //for(Rint i = 0;i < n;i ++) printf("(%d, %d)
    ", A[i].x, A[i].y);
    74     //for(Rint i = 0;i < n;i ++) printf("(%d, %d)
    ", B[i].x, B[i].y);
    75     for(Rint i = 0;i < n;i ++) A[i] = A[i] * kasumi(B[i], t);
    76     idft(A);
    77     int inv, tmp;
    78     exgcd(n, p, inv, tmp);
    79     inv = (inv + p) % p;
    80     for(Rint i = 0;i < n;i ++)
    81         printf("%d
    ", (LL) A[i].x * inv % p);
    82 }
    UOJ272
  • 相关阅读:
    SQLSERVER 2008 R2导入XLSX文件错误,提示报错:The 'Microsoft.ACE.OLEDB.12.0' provider is not registered on the local machine. (System.Data)
    转-马斯洛需求层次理论
    邮件发送提示错误“0x800CCC0F”,并存留在发件箱
    WIN10系统JAVA环境配置
    转-简单了解Python装饰器实现原理
    set集合
    阿里云服务器地址及端口
    转:苹果手机同步阿里云邮箱日历
    官方:金蝶实际成本在制品分配详解
    官方:金蝶实际成本在制品材料分配详解
  • 原文地址:https://www.cnblogs.com/AThousandMoons/p/10926924.html
Copyright © 2020-2023  润新知