• Strassen算法


    $$`
    left[
    egin{matrix}
    A ,B
    C,D
    end{matrix}
    ight]
    imes
    left[
    egin{matrix}
    E,F
    G,H
    end{matrix}
    ight]

    left[
    egin{matrix}
    AE+BG, AF+BH
    CE+DG,CF+DH
    end{matrix}
    ight]
    `$$

    Strassen算法于1969年由德国数学家Strassen提出,该方法引入七个中间变量,每个中间变量都只需要进行一次乘法运算。而朴素算法却需要进行8次乘法运算。

    原理

    Strassen算法的原理如下所示,使用sympy验证Strassen算法的正确性

    import sympy as s
    
    A = s.Symbol("A")
    B = s.Symbol("B")
    C = s.Symbol("C")
    D = s.Symbol("D")
    E = s.Symbol("E")
    F = s.Symbol("F")
    G = s.Symbol("G")
    H = s.Symbol("H")
    p1 = A * (F - H)
    p2 = (A + B) * H
    p3 = (C + D) * E
    p4 = D * (G - E)
    p5 = (A + D) * (E + H)
    p6 = (B - D) * (G + H)
    p7 = (A - C) * (E + F)
    
    print(A * E + B * G, (p5 + p4 - p2 + p6).simplify())
    print(A * F + B * H, (p1 + p2).simplify())
    print(C * E + D * G, (p3 + p4).simplify())
    print(C * F + D * H, (p1 + p5 - p3 - p7).simplify())
    
    

    复杂度分析

    $$f(N)=7 imes f(frac{N}{2})=7^2 imes f(frac{N}{4})=...=7^k imes f(frac{N}{2^k})$$
    最终复杂度为$7^{log_2 N}=N^{log_2 7}$

    验证有效性

    使用numpy验证Strassen算法的有效性:

    import timeit
    
    import numpy as np
    
    N = 5000
    M = 5000
    a = np.random.random((N, M))
    test_count = 10
    
    
    def use_numpy():
        ans = np.matmul(a, a.T)
        return ans
    
    
    def use_numpy_strassen():
        # numpy使用strassen方法
        NN = N // 2
        MM = M // 2
        A, B, C, D = a[:NN, :MM], a[:NN, MM:], a[NN:, :MM], a[NN:, MM:]
        b = a.T
        E, F, G, H = b[:NN, :MM], b[:NN, MM:], b[NN:, :MM], b[NN:, MM:]
        p1 = np.matmul(A, F - H)
        p2 = np.matmul(A + B, H)
        p3 = np.matmul(C + D, E)
        p4 = np.matmul(D, (G - E))
        p5 = np.matmul(A + D, E + H)
        p6 = np.matmul(B - D, G + H)
        p7 = np.matmul(A - C, E + F)
        ans = np.hstack((np.vstack((p5 + p4 - p2 + p6, p3 + p4)), np.vstack((p1 + p2, p1 + p5 - p3 - p7))))
        return ans
    
    
    print(timeit.timeit(use_numpy, number=10))
    print(timeit.timeit(use_numpy_strassen, number=10))
    one = use_numpy()
    three = use_numpy_strassen()
    print(one.reshape(-1)[:5])
    print(three.reshape(-1)[:5])
    
    

    实验说明:这里只使用了一层Strassen,正常的Strassen应该是递归的并且需要在空间上进行优化,从而避免太多的空间复制。
    实验结果出人意料:不采用strassen算法仅需要17秒,采用strassen算法需要40秒。
    在进行实验验证时,最好不要使用python,因为python隐藏的细节太多了。

  • 相关阅读:
    php 创建文件
    php xml格式对象 返回->对应格式数组
    php 将16进制数串转换为二进制数据的函数
    php 生成随机字符串
    高质量PHP代码的50个实用技巧:非常值得收藏
    php __FILE__,__CLASS__等魔术变量,及实例
    纯js上传文件 很好用
    XMLHttpRequest上传文件实现进度条
    Java BufferedReader、InputStream简介
    Java socket通信
  • 原文地址:https://www.cnblogs.com/weiyinfu/p/10452265.html
Copyright © 2020-2023  润新知