文章目录
- @[toc]
- 问题描述
- 基础算法
- 时间复杂性
- `Strassen`算法
- 时间复杂性
- 问题时间复杂性
- `Python`实现
文章目录
- @[toc]
- 问题描述
- 基础算法
- 时间复杂性
- `Strassen`算法
- 时间复杂性
- 问题时间复杂性
- `Python`实现
个人主页:丷从心.
系列专栏:Python基础
学习指南:Python学习指南
问题描述
- 设 A A A和 B B B是两个 n × n n \times n n×n矩阵, A A A和 B B B的乘积矩阵 C C C中元素 c i j = ∑ k = 1 n a i k b k j c_{ij} = \displaystyle\sum\limits_{k = 1}^{n}{a_{ik} b_{kj}} cij=k=1∑naikbkj
- 每计算 C C C的一个元素 c i j c_{ij} cij,需要做 n n n次乘法和 n − 1 n - 1 n−1次加法,求出矩阵 C C C的 n 2 n^{2} n2个元素所需的时间为 O ( n 3 ) O(n^{3}) O(n3)
基础算法
- 假设 n n n是 2 2 2的幂,将矩阵 A A A、 B B B和 C C C中每个矩阵都分块成 4 4 4个大小相等的子矩阵,每个子矩阵都是 n / 2 × n / 2 n / 2 \times n / 2 n/2×n/2的方阵
∣ C 11 C 12 C 21 C 22 ∣ = ∣ A 11 A 12 A 21 A 22 ∣ ∣ B 11 B 12 B 21 B 22 ∣ \begin{vmatrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{vmatrix} = \begin{vmatrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{vmatrix} \begin{vmatrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{vmatrix} C11C21C12C22 = A11A21A12A22 B11B21B12B22
C 11 = A 11 B 11 + A 12 B 21 C 12 = A 11 B 12 + A 12 B 22 C 21 = A 21 B 11 + A 22 B 21 C 22 = A 21 B 12 + A 22 B 22 C_{11} = A_{11} B_{11} + A_{12} B_{21} \\ C_{12} = A_{11} B_{12} + A_{12} B_{22} \\ C_{21} = A_{21} B_{11} + A_{22} B_{21} \\ C_{22} = A_{21} B_{12} + A_{22} B_{22} C11=A11B11+A12B21C12=A11B12+A12B22C21=A21B11+A22B21C22=A21B12+A22B22
时间复杂性
- 计算 2 2 2个 n n n阶方阵的乘积转化为计算 8 8 8个 n / 2 n / 2 n/2阶方阵的乘积和 4 4 4个 n / 2 n / 2 n/2阶方阵的加法, 2 2 2个 n / 2 × n / 2 n / 2 \times n / 2 n/2×n/2矩阵的加法显然可以在 O ( n 2 ) O(n^{2}) O(n2)时间内完成
T ( n ) = { O ( 1 ) n = 2 8 T ( n / 2 ) + O ( n 2 ) n > 2 T(n) = \begin{cases} O(1) & n = 2 \\ 8 T(n / 2) + O(n^{2}) & n > 2 \end{cases} T(n)={O(1)8T(n/2)+O(n2)n=2n>2
T ( n ) = O ( n 3 ) T(n) = O(n^{3}) T(n)=O(n3)
Strassen
算法
Strassen
算法只用了 7 7 7次乘法运算,但增加了加减法的运算次数
M 1 = A 11 ( B 12 − B 22 ) M 2 = ( A 11 + A 12 ) B 22 M 3 = ( A 21 + A 22 ) B 11 M 4 = A 22 ( B 21 − B 11 ) M 5 = ( A 11 + A 22 ) ( B 11 + B 22 ) M 6 = ( A 12 − A 22 ) ( B 21 + B 22 ) M 7 = ( A 11 − A 21 ) ( B 11 + B 12 ) M_{1} = A_{11} (B_{12} - B_{22}) \\ M_{2} = (A_{11} + A_{12}) B_{22} \\ M_{3} = (A_{21} + A_{22}) B_{11} \\ M_{4} = A_{22} (B_{21} - B_{11}) \\ M_{5} = (A_{11} + A_{22})(B_{11} + B_{22}) \\ M_{6} = (A_{12} - A_{22})(B_{21} + B_{22}) \\ M_{7} = (A_{11} - A_{21})(B_{11} + B_{12}) M1=A11(B12−B22)M2=(A11+A12)B22M3=(A21+A22)B11M4=A22(B21−B11)M5=(A11+A22)(B11+B22)M6=(A12−A22)(B21+B22)M7=(A11−A21)(B11+B12)
C 11 = M 5 + M 4 − M 2 + M 6 C 12 = M 1 + M 2 C 21 = M 3 + M 4 C 22 = M 5 + M 1 − M 3 − M 7 C_{11} = M_{5} + M_{4} - M_{2} + M_{6} \\ C_{12} = M_{1} + M_{2} \\ C_{21} = M_{3} + M_{4} \\ C_{22} = M_{5} + M_{1} - M_{3} - M_{7} C11=M5+M4−M2+M6C12=M1+M2C21=M3+M4C22=M5+M1−M3−M7
时间复杂性
Strassen
算法用了 7 7 7次对于 n / 2 n / 2 n/2阶矩阵乘积的递归调用和 18 18 18次 n / 2 n / 2 n/2阶矩阵的加减运算
T ( n ) = { O ( 1 ) n = 2 7 T ( n / 2 ) + O ( n 2 ) n > 2 T(n) = \begin{cases} O(1) & n = 2 \\ 7 T(n / 2) + O(n^{2}) & n > 2 \end{cases} T(n)={O(1)7T(n/2)+O(n2)n=2n>2
T ( n ) = O ( n log 7 ) ≈ O ( n 2.81 ) T(n) = O(n^{\log{7}}) \approx O(n^{2.81}) T(n)=O(nlog7)≈O(n2.81)
问题时间复杂性
- H o p c r o f t Hopcroft Hopcroft和 K e r r Kerr Kerr已经证明计算 2 2 2个 2 × 2 2 \times 2 2×2矩阵的乘积, 7 7 7次乘法是必要的
- 目前最好的计算时间上界是 O ( n 2.376 ) O(n^{2.376}) O(n2.376),所知的矩阵乘法的最好下界仍是它的平凡下界 Ω ( n 2 ) \Omega(n^{2}) Ω(n2)
Python
实现
import numpy as np
def strassen_matrix_multiply(a, b):
n = a.shape[0]
# 如果输入矩阵的维度小于等于阈值, 使用传统的矩阵乘法
if n <= 128:
return np.dot(a, b)
# 将输入矩阵划分为四个子矩阵
mid = n // 2
a11 = a[:mid, :mid]
a12 = a[:mid, mid:]
a21 = a[mid:, :mid]
a22 = a[mid:, mid:]
b11 = b[:mid, :mid]
b12 = b[:mid, mid:]
b21 = b[mid:, :mid]
b22 = b[mid:, mid:]
# 递归地计算七个矩阵乘法
m1 = strassen_matrix_multiply(a11, b12 - b22)
m2 = strassen_matrix_multiply(a11 + a12, b22)
m3 = strassen_matrix_multiply(a21 + a22, b11)
m4 = strassen_matrix_multiply(a22, b21 - b11)
m5 = strassen_matrix_multiply(a11 + a22, b11 + b22)
m6 = strassen_matrix_multiply(a12 - a22, b21 + b22)
m7 = strassen_matrix_multiply(a11 - a21, b11 + b12)
# 计算结果矩阵的四个子矩阵
c11 = m5 + m4 - m2 + m6
c12 = m1 + m2
c21 = m3 + m4
c22 = m5 + m1 - m3 - m7
# 组合四个子矩阵形成结果矩阵
c = np.zeros((n, n))
c[:mid, :mid] = c11
c[:mid, mid:] = c12
c[mid:, :mid] = c21
c[mid:, mid:] = c22
return c
a = np.random.randint(0, 10, (256, 256))
b = np.random.randint(0, 10, (256, 256))
res = strassen_matrix_multiply(a, b)
print('矩阵 a:')
print(a)
print('\n矩阵 b:')
print(b)
print('\n乘积矩阵 c:')
print(res)
矩阵 a:
[[2 6 1 ... 9 7 7]
[7 1 8 ... 1 0 9]
[5 0 8 ... 6 2 5]
...
[6 7 2 ... 6 0 1]
[1 4 7 ... 0 5 2]
[3 3 3 ... 9 6 8]]
矩阵 b:
[[3 5 1 ... 9 8 9]
[9 6 0 ... 8 1 5]
[3 9 6 ... 9 0 5]
...
[5 7 7 ... 3 8 8]
[9 5 3 ... 1 8 1]
[3 3 9 ... 4 7 0]]
乘积矩阵 c:
[[5090. 5517. 4863. ... 4977. 4769. 4939.]
[5148. 5909. 5747. ... 5603. 5070. 5260.]
[4376. 5175. 4717. ... 4968. 4668. 4526.]
...
[4708. 5294. 4991. ... 4945. 4681. 5202.]
[4641. 5307. 4955. ... 5087. 5157. 4795.]
[4722. 5173. 5144. ... 5050. 4845. 4829.]]