Strassen矩阵乘法问题(Java)
文章目录
- Strassen矩阵乘法问题(Java)
- 1、前置介绍
- 3、代码实现
- 4、复杂度分析
- 5、参考资料
1、前置介绍
矩阵乘法是线性代数中最常见的问题之一 ,它在数值计算中有广泛的应用。 设A和B是2个nXn矩阵,
它们的乘积AB同样是一个nXn矩阵。 A和B的乘积矩阵C中元素C[i][j]定义为:
C
[
i
]
[
j
]
=
∑
k
=
1
n
A
[
i
]
[
k
]
B
[
k
]
[
j
]
C[i][j] = \sum_{k=1}^{n}A[i][k]B[k][j]
C[i][j]=k=1∑nA[i][k]B[k][j]
采用传统方法,时间复杂度为:O(n3)
因为按照上述的定义来计算A和 B的乘积矩阵c,则每计算C的一个元素C[i][j],需要做n次乘法运算和n-1次加法运算
。 因此,得到矩阵C的n2 个元素所需的计算时间为 O(n3) 。
为解决计算计算效率问题,Strassen算法由此出现,该算法基本思想是分治
,将计算2个n阶矩阵乘积所需的计算时间改进到0(nlog7) = 0(n2.81)
我们知道,C11=A11*B11+A12*B21
矩阵A和B的示意图如下:
传统方法:
2个n阶方阵的乘积转换为8个n/2 阶方阵的乘积和4个n/2阶方阵的加法。
由此可得:
C11 = A11B11 + A12B21
C12 = A11B12 + A12B22
C21 = A21B11 + A22B21
C22 = A21B12 + A22B22
分治法:
为了降低时间复杂度,必须减少乘法的次数。
使用与上例类似的技术,将矩阵A,B和C中每一矩阵都分块成4个大小相等的子矩阵。由此可将方程C=AB重写为:
2个n阶方阵的乘积转换为7个n/2 阶方阵的乘积和18个n/2阶方阵的加减法。
伪代码如下:
// 递归维度分半算法:
public void STRASSEN(n,A,B,C);
{
if n=2 then MATRIX-MULTIPLY(A,B,C)
/ /结束循环,计算 两个2阶方阵的乘法
else{
将矩阵A和B分块;
STRASSEN(n/2,A11,B12-B22,M1);
STRASSEN(n/2,A11+A12,B22,M2);
STRASSEN(n/2,A21+A22,B11,M3);
STRASSEN(n/2,A22,B21-B11,M4);
STRASSEN(n/2,A11+A22,B11+B22,M5);
STRASSEN(n/2,A12-A22,B21+B22,M6);
STRASSEN(n/2,A11-A21,B11+B12,M7);}
}
算法导论伪代码:
3、代码实现
public class StrassenMatrixMultiply
{
public static void main(String[] args)
{
int[] a = new int[]
{
1, 1, 1, 1,
2, 2, 2, 2,
3, 3, 3, 3,
4, 4, 4, 4
};
int[] b = new int[]
{
1, 2, 3, 4,
1, 2, 3, 4,
1, 2, 3, 4,
1, 2, 3, 4
};
int length = 4;
int[] c = sMM(a, b, length);
for(int i = 0; i < c.length; i++)
{
System.out.print(c[i] + " ");
if((i + 1) % length == 0) //换行
System.out.println();
}
}
public static int[] sMM(int[] a, int[] b, int length) {
if(length == 2) {
return getResult(a, b);
}
else {
int tlength = length / 2;
// 把a数组分为四部分,进行分治递归
int[] aa = new int[tlength * tlength];
int[] ab = new int[tlength * tlength];
int[] ac = new int[tlength * tlength];
int[] ad = new int[tlength * tlength];
// 把b数组分为四部分,进行分治递归
int[] ba = new int[tlength * tlength];
int[] bb = new int[tlength * tlength];
int[] bc = new int[tlength * tlength];
int[] bd = new int[tlength * tlength];
// TODO 划分子矩阵
for(int i = 0; i < length; i++) {
for(int j = 0; j < length; j++) {
/*
* 划分矩阵:
* 例子:将 4 * 4 的矩阵,变为 2 * 2 的矩阵,
* 那么原矩阵左上、右上、左下、右下的四个元素分别归为新矩阵
*/
if(i < tlength) {
if(j < tlength) {
aa[i * tlength + j] = a[i * length + j];
ba[i * tlength + j] = b[i * length + j];
} else {
ab[i * tlength + (j - tlength)] = a[i * length + j];
bb[i * tlength + (j - tlength)] = b[i * length + j];
}
} else {
if(j < tlength) {
//i 大于 tlength 时,需要减去 tlength,j同理
//因为 b,c,d三个子矩阵有对应了父矩阵的后半部分
ac[(i - tlength) * tlength + j] = a[i * length + j];
bc[(i - tlength) * tlength + j] = b[i * length + j];
} else {
ad[(i - tlength) * tlength + (j - tlength)] = a[i * length + j];
bd[(i - tlength) * tlength + (j - tlength)] = b[i * length + j];
}
}
}
}
// TODO 分治递归
int[] result = new int[length * length];
// temp:4个临时矩阵
int[] t1 = add(sMM(aa, ba, tlength), sMM(ab, bc, tlength));
int[] t2 = add(sMM(aa, bb, tlength), sMM(ab, bd, tlength));
int[] t3 = add(sMM(ac, ba, tlength), sMM(ad, bc, tlength));
int[] t4 = add(sMM(ac, bb, tlength), sMM(ad, bd, tlength));
// TODO 归并结果
for(int i = 0; i < length; i++) {
for(int j = 0; j < length; j++) {
if (i < tlength){
if(j < tlength) {
result[i * length + j] = t1[i * tlength + j];
} else {
result[i * length + j] = t2[i * tlength + (j - tlength)];
}
} else {
if(j < tlength) {
result[i * length + j] = t3[(i - tlength) * tlength + j];
} else {
result[i * length + j] = t4[(i - tlength) * tlength + (j - tlength)];
}
}
}
}
return result;
}
}
public static int[] getResult(int[] a, int[] b) {
int p1 = a[0] * (b[1] - b[3]);
int p2 = (a[0] + a[1]) * b[3];
int p3 = (a[2] + a[3]) * b[0];
int p4 = a[3] * (b[2] - b[0]);
int p5 = (a[0] + a[3]) * (b[0] + b[3]);
int p6 = (a[1] - a[3]) * (b[2] + b[3]);
int p7 = (a[0] - a[2]) * (b[0] + b[1]);
int c00 = p5 + p4 - p2 + p6;
int c01 = p1 + p2;
int c10 = p3 + p4;
int c11 = p5 + p1 -p3 - p7;
return new int[] {c00, c01, c10, c11};
}
public static int[] add(int[] a, int[] b) {
int[] c = new int[a.length];
for(int i = 0; i < a.length; i++) {
c[i] = a[i] + b[i];
}
return c;
}
// TODO 返回一个数是不是2的幂次方
public static boolean adjust(int x) {
return (x & (x - 1)) == 0;
}
}
4、复杂度分析
传统方法和分治法的复杂度比较,如下图所示;
T ( n ) = { O ( 1 ) , n = 2 7 T ( n / 2 ) + O ( n 2 ) , n > 2 T(n) = \left\{ \begin{matrix} O(1), n = 2 \\ 7T(n/2) + O(n^2), n > 2\\ \end{matrix} \right. T(n)={O(1),n=27T(n/2)+O(n2),n>2
T(n) = 0(nlog7 ) = 0(n2.81)
5、参考资料
- 算法分析与设计(第四版)
- 算法导论第三版
- 博客园