$$`
left[
egin{matrix}
A ,B
C,D
end{matrix}
ight]
imes
left[
egin{matrix}
E,F
G,H
end{matrix}
ight]
left[
egin{matrix}
AE+BG, AF+BH
CE+DG,CF+DH
end{matrix}
ight]
`$$
Strassen算法于1969年由德国数学家Strassen提出,该方法引入七个中间变量,每个中间变量都只需要进行一次乘法运算。而朴素算法却需要进行8次乘法运算。
原理
Strassen算法的原理如下所示,使用sympy验证Strassen算法的正确性
import sympy as s
A = s.Symbol("A")
B = s.Symbol("B")
C = s.Symbol("C")
D = s.Symbol("D")
E = s.Symbol("E")
F = s.Symbol("F")
G = s.Symbol("G")
H = s.Symbol("H")
p1 = A * (F - H)
p2 = (A + B) * H
p3 = (C + D) * E
p4 = D * (G - E)
p5 = (A + D) * (E + H)
p6 = (B - D) * (G + H)
p7 = (A - C) * (E + F)
print(A * E + B * G, (p5 + p4 - p2 + p6).simplify())
print(A * F + B * H, (p1 + p2).simplify())
print(C * E + D * G, (p3 + p4).simplify())
print(C * F + D * H, (p1 + p5 - p3 - p7).simplify())
复杂度分析
$$f(N)=7 imes f(frac{N}{2})=7^2 imes f(frac{N}{4})=...=7^k imes f(frac{N}{2^k})
$$
最终复杂度为$7^{log_2 N}=N^{log_2 7}
$
验证有效性
使用numpy验证Strassen算法的有效性:
import timeit
import numpy as np
N = 5000
M = 5000
a = np.random.random((N, M))
test_count = 10
def use_numpy():
ans = np.matmul(a, a.T)
return ans
def use_numpy_strassen():
# numpy使用strassen方法
NN = N // 2
MM = M // 2
A, B, C, D = a[:NN, :MM], a[:NN, MM:], a[NN:, :MM], a[NN:, MM:]
b = a.T
E, F, G, H = b[:NN, :MM], b[:NN, MM:], b[NN:, :MM], b[NN:, MM:]
p1 = np.matmul(A, F - H)
p2 = np.matmul(A + B, H)
p3 = np.matmul(C + D, E)
p4 = np.matmul(D, (G - E))
p5 = np.matmul(A + D, E + H)
p6 = np.matmul(B - D, G + H)
p7 = np.matmul(A - C, E + F)
ans = np.hstack((np.vstack((p5 + p4 - p2 + p6, p3 + p4)), np.vstack((p1 + p2, p1 + p5 - p3 - p7))))
return ans
print(timeit.timeit(use_numpy, number=10))
print(timeit.timeit(use_numpy_strassen, number=10))
one = use_numpy()
three = use_numpy_strassen()
print(one.reshape(-1)[:5])
print(three.reshape(-1)[:5])
实验说明:这里只使用了一层Strassen,正常的Strassen应该是递归的并且需要在空间上进行优化,从而避免太多的空间复制。
实验结果出人意料:不采用strassen算法仅需要17秒,采用strassen算法需要40秒。
在进行实验验证时,最好不要使用python,因为python隐藏的细节太多了。