1. 思路小结
要优化你提供的稀疏矩阵乘法代码,我们可以引入CSR(压缩稀疏行)格式来避免遍历零元素,从而提高效率。CSR格式通过仅存储非零元素以及它们的行和列索引,可以有效减少稀疏矩阵计算时的时间复杂度。下面是对代码的优化版本,采用CSR格式进行稀疏矩阵的乘法:
优化步骤:
将稀疏矩阵转换为CSR格式,存储非零元素的位置和对应的值。
在矩阵乘法过程中,仅对非零元素进行计算,从而跳过零值。
对每一行的非零元素,在相应的列上执行乘法操作。
1.1 优化思路
进行的是两个稀疏矩阵的乘法。稀疏矩阵通常具有大量的零元素,因此直接使用常规矩阵乘法会导致大量的无效计算。为了提高效率,常用的优化方法是只对非零元素进行计算,而跳过零值。为此,我们采用**CSR(压缩稀疏行,Compressed Sparse Row)**格式进行稀疏矩阵存储和乘法计算。
1.1.1 核心步骤如下:
-
矩阵的稀疏表示:
- 原矩阵A和B可能有大量的零元素,因此我们采用CSR格式来存储这些矩阵。
- CSR格式由以下三个部分组成:
values[]
: 存储所有非零元素的值。colIndex[]
: 存储每个非零元素所在的列索引。rowPtr[]
: 记录每行的非零元素在values[]
中的起始位置。
-
矩阵的稀疏乘法:
- 对于矩阵A的每一行,我们找到其所有非零元素的位置及其值。
- 对于每一个非零元素,我们在矩阵B的相应列中查找与之匹配的非零元素。
- 最后将这些匹配的非零元素相乘,并累加到结果矩阵的对应位置。
-
优化:
- 通过CSR格式,避免了遍历和处理零元素,从而减少了不必要的计算。
- 我们直接对非零元素进行乘法运算,结果累积到结果矩阵C的对应位置。
1.2 算法复杂度分析
1.2.1 常规矩阵乘法的复杂度:
对于两个大小分别为 m x n
和 n x p
的矩阵,常规的矩阵乘法复杂度为O(m * n * p)。因为对于每一个 m x p
的结果元素,我们需要计算 n 次乘法操作。
1.2.2 稀疏矩阵乘法的复杂度:
由于稀疏矩阵大部分元素为零,我们只需要处理非零元素。假设矩阵A和矩阵B的非零元素分别为 nnzA
和 nnzB
,稀疏矩阵乘法的复杂度可以近似表示为:
- 对于每个非零元素
A[i][k]
,我们只需遍历矩阵B的第k列的非零元素进行乘法。因此稀疏矩阵乘法的复杂度大约为 O(nnzA * nnzB),其中nnzA
和nnzB
是矩阵A和矩阵B的非零元素数量。
这相比于常规矩阵乘法的复杂度有了显著的提升,尤其是当矩阵非常稀疏时(即大部分元素为0),非零元素的数量远小于矩阵的总大小。
1.2.3 空间复杂度:
使用CSR格式的空间复杂度为:
O(nnz)
:用于存储所有非零元素及其列索引。O(m)
:用于存储每一行的起始位置。
总体空间复杂度为 O(nnz + m),其中nnz
是矩阵的非零元素数量,m
是矩阵的行数。
1.2.4 总结
通过使用CSR格式存储稀疏矩阵,我们能够有效避免对零元素的计算,显著提升了稀疏矩阵乘法的计算效率。时间复杂度从常规的O(m * n * p)降低到接近于非零元素的数量 O(nnzA * nnzB)
,特别适合处理大规模稀疏矩阵的场景。
2. 优化后代码及其复杂度为
代码解析:
toCSR 函数:将普通的二维稀疏矩阵转换为CSR格式。values数组存储非零元素,colIndex存储每个非零元素的列索引,rowPtr则记录每行的非零元素在 values 数组中的起始位置。
multiplySparseMatricesCSR 函数:使用CSR格式进行矩阵乘法。通过 rowPtr 和 colIndex 来快速定位非零元素,避免了对零值的无效计算。
优化效果:
通过CSR格式存储非零元素,并跳过零元素的乘法操作,能够显著减少计算时间。
避免遍历零值,提高了计算效率,尤其在大规模稀疏矩阵的场景下。
#include <iostream>
#include <vector>
using namespace std;
// CSR格式的稀疏矩阵
struct CSRMatrix {
vector<int> values; // 存储非零元素的值
vector<int> colIndex; // 存储非零元素的列索引
vector<int> rowPtr; // 每一行的开始位置
};
// 将稀疏矩阵转换为CSR格式
CSRMatrix toCSR(const vector<vector<int>>& matrix) {
CSRMatrix csr;
int row = matrix.size();
int col = matrix[0].size();
csr.rowPtr.push_back(0); // 第一行的开始位置是0
// 遍历矩阵,收集非零元素的信息
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
if (matrix[i][j] != 0) {
csr.values.push_back(matrix[i][j]);
csr.colIndex.push_back(j);
}
}
csr.rowPtr.push_back(csr.values.size()); // 记录下一行的开始位置
}
return csr;
}
// 使用CSR格式进行稀疏矩阵乘法
vector<vector<int>> multiplySparseMatricesCSR(const CSRMatrix& A, const CSRMatrix& B, int colB) {
int rowA = A.rowPtr.size() - 1;
vector<vector<int>> C(rowA, vector<int>(colB, 0)); // 初始化结果矩阵
// 遍历A的每一行
for (int i = 0; i < rowA; i++) {
// A的第i行的非零元素从A.rowPtr[i]到A.rowPtr[i+1]-1
for (int aPos = A.rowPtr[i]; aPos < A.rowPtr[i+1]; aPos++) {
int colA = A.colIndex[aPos]; // 该非零元素所在的列
int aValue = A.values[aPos]; // 非零元素的值
// 对应B的第colA行
for (int j = B.rowPtr[colA]; j < B.rowPtr[colA+1]; j++) {
int colBIndex = B.colIndex[j];
int bValue = B.values[j];
C[i][colBIndex] += aValue * bValue;
}
}
}
return C;
}
int main() {
// 定义稀疏矩阵A
vector<vector<int>> A = {
{1, 0, 0},
{-1, 0, 3}
};
// 定义稀疏矩阵B
vector<vector<int>> B = {
{7, 0, 0},
{0, 0, 0},
{0, 0, 1}
};
// 将矩阵A和B转换为CSR格式
CSRMatrix csrA = toCSR(A);
CSRMatrix csrB = toCSR(B);
// 计算A和B的乘积
vector<vector<int>> C = multiplySparseMatricesCSR(csrA, csrB, B[0].size());
// 输出结果矩阵
cout << "Result of A * B:" << endl;
for (const auto& row : C) {
for (int elem : row) {
cout << elem << " ";
}
cout << endl;
}
return 0;
}
3. 优化前原始代码及其复杂度为O(m * n * p),这里是最朴素的思路,没有利用稀疏特性做任何优化
#include <iostream>
#include <vector>
using namespace std;
// 定义稀疏矩阵乘法函数
vector<vector<int>> multiplySparseMatrices(vector<vector<int>>& A, vector<vector<int>>& B) {
int rowA = A.size();
int colA = A[0].size();
int rowB = B.size();
int colB = B[0].size();
// 初始化结果矩阵,大小为rowA * colB
vector<vector<int>> C(rowA, vector<int>(colB, 0));
// 遍历矩阵A的每一行
for (int i = 0; i < rowA; i++) {
// 遍历矩阵A的每个列,寻找非零元素
for (int k = 0; k < colA; k++) {
if (A[i][k] != 0) {
// 当A的某个位置非零时,计算该元素和矩阵B的第k行
for (int j = 0; j < colB; j++) {
if (B[k][j] != 0) {
C[i][j] += A[i][k] * B[k][j];
}
}
}
}
}
return C;
}
int main() {
// 定义稀疏矩阵A
vector<vector<int>> A = {
{1, 0, 0},
{-1, 0, 3}
};
// 定义稀疏矩阵B
vector<vector<int>> B = {
{7, 0, 0},
{0, 0, 0},
{0, 0, 1}
};
// 计算A和B的乘积
vector<vector<int>> C = multiplySparseMatrices(A, B);
// 输出结果矩阵
cout << "Result of A * B:" << endl;
for (const auto& row : C) {
for (int elem : row) {
cout << elem << " ";
}
cout << endl;
}
return 0;
}