目录
- 手写Matrix
- 前言
- 1. 明确需求
- 2. 基本实现
- 2.1 创建矩阵
- 2.2 外部访问
- 2.3 <<操作符重载
- 3. 矩阵运算
- 3.1 矩阵标量运算
- 3.2 通用矩阵乘法
- 3.3 矩阵求逆
- 4. 完整示例代码
- 总结
手写Matrix
前言
手写AI推出的全新面向AI算法的C++课程 Algo C++,链接。记录下个人学习笔记,仅供自己参考。
本次课程主要是手写Matrix代码
课程大纲可看下面的思维导图
1. 明确需求
我们先来明确下 Matrix 类中应该实现那些功能
1.只能够表示 2 维的矩阵形式,即使是向量也会用 matrix 表示
-我们指表达 float 格式的矩阵,不表达其它形式
2.矩阵的乘法,通用矩阵的乘法形式
3.要能够求解逆矩阵
4.可以通过指定行和列进行矩阵的创建
5.可以允许使用 {1, 2, 3} 这种形式进行数据填充的方式创建
6.能够与标量进行常规的加减法
2. 基本实现
2.1 创建矩阵
我们先来实现 Matrix 类中矩阵创建的功能
#ifndef MATRIX_HPP
#define MATRIX_HPP
#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>
class Matrix{
public:
Matrix() = default;
private:
int rows_ = 0;
int cols_ = 0;
std::vector<float> data_;
};
#endif // MATRIX_HPP
在上面的示例代码中,我们定义了矩阵 Matrix 类的默认构造函数,以及定义了三个私有成员变量 rows_、cols、data_
分别代表矩阵的行号、列号以及数据。#ifndef #define #endif
语句是为了防止头文件重复包含。值得注意的是,我们习惯在成员变量尾部加上_
我们再来实现使用 {1, 2, 3}进行数据填充的方式来创建矩阵
#ifndef MATRIX_HPP
#define MATRIX_HPP
#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>
class Matrix{
public:
Matrix() = default;
Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
rows_ = rows;
cols_ = cols;
// 隐式转换,其实是执行了vector的赋值操作
data_ = data;
// 1. data的元素为空,说明是不指定数据情况下进行创建
// 2. data的元素不空,说明是指定数据情况下创建
// 1. 元素数量等于rows * cols
// 2. 元素数量小于rows * cols
if(data_.empty()){
// resize表示分配rows * cols个元素,在vector中
// 此时没有对vector做初始化,但是其内部的值全部为0,这是vector保证的
data_.resize(rows * cols);
}else{
if(data_.size() != rows * cols)
std::cout << "Invalid construct.\n";
}
}
private:
int rows_ = 0;
int cols_ = 0;
std::vector<float> data_;
};
#endif // MATRIX_HPP
在上面的示例代码中,我们实现了可以通过 Matrix m(3, 1, {1, 3, 2})
这种方式来创建矩阵。值得注意的是
- {} 这种类型的数据在 C++ 中叫做 initializer_list,是一种容器。
- 在 C++ 的容器中,比如 STL 对象中的 vector、list 等,它们分配的空间,如果不进行初始化,则其内部的值是 0
- 对于空,在 C 语言里面大部分是指 malloc 分配出来的内存没有初始化的情况
- 此时 malloc 分配内存的值其实是随机的
- 此时 new 分配的内存的值也是随机的
- data 参数是以 & 引用的方式传递,防止拷贝;而 const 参数表示传入常引用,在函数内部对其不进行修改
- 我们习惯在传值的时候,对于非基础类型,一般会传递常引用,使得效率更高,避免拷贝
2.2 外部访问
我们还希望能够外部访问到矩阵的行、列以及矩阵的元素变量
#ifndef MATRIX_HPP
#define MATRIX_HPP
#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>
class Matrix{
public:
Matrix() = default;
Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
rows_ = rows;
cols_ = cols;
// 隐式转换,其实是执行了vector的赋值操作
data_ = data;
// 1. data的元素为空,说明是不指定数据情况下进行创建
// 2. data的元素不空,说明是指定数据情况下创建
// 1. 元素数量等于rows * cols
// 2. 元素数量小于rows * cols
if(data_.empty()){
// resize表示分配rows * cols个元素,在vector中
// 此时没有对vector做初始化,但是其内部的值全部为0,这是vector保证的
data_.resize(rows * cols);
}else{
if(data_.size() != rows * cols)
std::cout << "Invalid construct.\n";
}
}
int rows() {return rows_;}
int cols() {return cols_;}
std::vector<float>& data(){return data_;}
private:
int rows_ = 0;
int cols_ = 0;
std::vector<float> data_;
};
#endif // MATRIX_HPP
在上面的示例代码中,我们分别实现了 rows()、cols()、data()
函数用来访问矩阵,如下所示:
#include <iostream>
#include "matrix.hpp"
using namespace std;
int main(){
Matrix m(3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
std::cout << "矩阵行 = " << m.rows() << std::endl;
std::cout << "矩阵列 = " << m.cols() << std::endl;
std::cout << "矩阵数据如下: " << std::endl;
for(int i = 0; i < m.data().size(); ++i){
std::cout << m.data()[i] << std::endl;
}
return 0;
}
我们希望能够通过 m[i][j] 这种形式去访问到矩阵的元素,需要用到操作符重载
#ifndef MATRIX_HPP
#define MATRIX_HPP
#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>
class Matrix{
public:
Matrix() = default;
Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
rows_ = rows;
cols_ = cols;
// 隐式转换,其实是执行了vector的赋值操作
data_ = data;
// 1. data的元素为空,说明是不指定数据情况下进行创建
// 2. data的元素不空,说明是指定数据情况下创建
// 1. 元素数量等于rows * cols
// 2. 元素数量小于rows * cols
if(data_.empty()){
// resize表示分配rows * cols个元素,在vector中
// 此时没有对vector做初始化,但是其内部的值全部为0,这是vector保证的
data_.resize(rows * cols);
}else{
if(data_.size() != rows * cols)
std::cout << "Invalid construct.\n";
}
}
int rows() {return rows_;}
int cols() {return cols_;}
std::vector<float>& data(){return data_;}
float& operator()(int ir, int ic){
// data_在内存中是连续的
// 比如说我们有3x3的矩阵,那么
// data_就等于 = {1, 2, 3, 4, 5, 6, 7, 8, 9}
// 它代表的就是:
/*
1 2 3
4 5 6
7 8 9
*/
// 如果要访问2行,0列。此时应该是对应的7
// 把2d的索引,映射到连续1d空间的索引上
int index = ir * cols_ + ic;
return data_[index];
}
private:
int rows_ = 0;
int cols_ = 0;
std::vector<float> data_;
};
#endif // MATRIX_HPP
在上述示例代码中,我们对操作符 () 进行了重载,使得矩阵可以通过 m(i,j) 这种方式访问矩阵的元素,值得注意的是:
-
我们需要将
operator()
看成一个整体 -
在 C++ 中不允许 [] 提供更多参数,这种操作只能提供一个参数,因此,可以换成 m(i,j),此时是可以允许的
-
我们返回的是 float & 引用而不是常引用,意味着我们可以直接修改其内部元素
#include <iostream>
#include "matrix.hpp"
using namespace std;
int main(){
Matrix m(3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
std::cout << "矩阵行 = " << m.rows() << std::endl;
std::cout << "矩阵列 = " << m.cols() << std::endl;
auto data = m.data();
m(1, 1) = 123.5;
std::cout << "矩阵数据如下: " << std::endl;
for(int ir = 0; ir < m.rows(); ++ir){
for(int ic = 0; ic < m.cols(); ++ic){
std::cout << m(ir, ic) << "\t";
}
std::cout << "\n";
}
return 0;
}
2.3 <<操作符重载
我们还需要重载操作符 <<,使其能直接打印矩阵即 std::cout << m
#ifndef MATRIX_HPP
#define MATRIX_HPP
#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>
class Matrix{
public:
Matrix() = default;
Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
rows_ = rows;
cols_ = cols;
// 隐式转换,其实是执行了vector的赋值操作
data_ = data;
// 1. data的元素为空,说明是不指定数据情况下进行创建
// 2. data的元素不空,说明是指定数据情况下创建
// 1. 元素数量等于rows * cols
// 2. 元素数量小于rows * cols
if(data_.empty()){
// resize表示分配rows * cols个元素,在vector中
// 此时没有对vector做初始化,但是其内部的值全部为0,这是vector保证的
data_.resize(rows * cols);
}else{
if(data_.size() != rows * cols)
std::cout << "Invalid construct.\n";
}
}
// 在 rows() 后面加const,表示这个函数是常量函数
// 潜台词是:它不会对内部成员做修改,仅仅只做访问查询
int rows() const{return rows_;}
int cols() const{return cols_;}
std::vector<float>& data(){return data_;}
// 这个表示,重载一个可以修改元素的函数
float& operator(int ir, int ic){
int index = ir * cols_ + ic;
return data_[index];
}
// 这个表示,重载一个只能读取的函数
const float& operator()(int ir, int ic)const{
// data_在内存中是连续的
// 比如说我们有3x3的矩阵,那么
// data_就等于 = {1, 2, 3, 4, 5, 6, 7, 8, 9}
// 它代表的就是:
/*
1 2 3
4 5 6
7 8 9
*/
// 如果要访问2行,0列。此时应该是对应的7
// 把2d的索引,映射到连续1d空间的索引上
int index = ir * cols_ + ic;
return data_[index];
}
private:
int rows_ = 0;
int cols_ = 0;
std::vector<float> data_;
};
std::ostream& operator<<(std::ostream& out, const Matrix& m){
printf("Matrix( %d x %d)\n", m.rows(), m.cols());
for(int ir = 0; ir < m.rows(); ++ir){
for(int ic = 0; ic < m.cols(); ++ic){
printf("%g\t", m(ir, ic));
}
printf("\n");
}
return out;
}
#endif // MATRIX_HPP
在上述示例代码中我们重载了操作符 <<
,值得注意的是:
-
<< 等价于 operator<<,因此
std::cout << 123
等价于std::cout.operator<<(123)
-
<< 操作符重载有两种方式
- 第一种是存在于类内的,例如
std::cout.operator<<(m)
;由于这个是系统文件,最好不要修改 - 第二种是存在于全局作用域的,例如
std::ostream& operator<<(std::ostream& out, const Matrix& m)
- 首先,全局操作符重载,是特定操作符为函数名称
- 其次,第一个参数,称之为左操作数;第二个参数,称之为右操作数
- 左操作数 out 对象为引用而非常引用是因为 out 对象存在写操作,势必是修改,因此不能是常量了,必须是非常量引用
- 右操作数 Matrix 对象为常引用是因为避免拷贝值得发生,我们在这里只需要读取就行了
- 左操作数 << 右操作数 相当于 operator<<(左操作数,右操作数)
- 第一种是存在于类内的,例如
-
在 C++ 中类的函数,分为常规函数(具有修改和访问权限)和常量函数(只有访问权限,没有修改权限)。右操作数是一个常量对象只能访问常量函数,因此在调用的函数应该是常量函数,所以在
rows()、cols()、operator()
函数后面都要加上 const 关键字,且operator ()
返回的应该是一个常量引用 -
我们需要对
operator()
重载既能实现常量函数,又能完成修改操作,可以将operator()
写两遍,一个表示重载一个可以修改元素的函数,一个表示重载一个只能读取的函数
#include <iostream>
#include "matrix.hpp"
using namespace std;
int main(){
Matrix m(3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
cout << m ;
return 0;
}
最后我们来区分下:
int rows() const{return rows_};
const int rows(){return rows_};
第一个函数
int rows() const{return rows_}
是一个 const 成员函数,其中的 const 关键字表示该成员函数不会修改对象的成员变量,即保证了函数内部不会修改rows_
的值,同时该成员函数可以被 const 对象和非 const 对象调用而第二个函数
const int rows(){return rows_}
是一个非 const 成员函数,其中的 const 关键字表示函数返回的是 const 类型的值,但并没有对函数本身做出限制。该函数可以被非 const 对象调用,但不能被 const 对象调用因此,第一个函数可以用于保证对象的成员变量不被修改,并且可以被 const 和非 const 对象调用,而第二个函数只能被非 const 对象调用,如果被 const 对象调用则会产生编译错误
3. 矩阵运算
3.1 矩阵标量运算
实现矩阵与标量的 + - * / 四则运算
#ifndef MATRIX_HPP
#define MATRIX_HPP
#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>
#include <functional>
class Matrix{
public:
Matrix() = default;
Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
rows_ = rows;
cols_ = cols;
data_ = data;
if(data_.empty()){
data_.resize(rows * cols);
}else{
if(data_.size() != rows * cols)
std::cout << "Invalid construct.\n";
}
}
// ========== + - * / ==========
Matrix element_wise(const std::function<float(float)>& func) const{
Matrix output = *this; // 复制一份
for(int i = 0; i < output.data_.size(); ++i){
output.data_[i] = func(output.data_[i]);
}
return output;
}
Matrix operator*(float value) const{
// lambda函数 C++11特性
return element_wise([&](float x){return x * value;});
}
Matrix operator+(float value) const{
return element_wise([&](float x){return x + value});
}
Matrix operator-(float value) const{
return element_wise([&](float x){return x - value;});
}
Matrix operator/(float value) const{
return element_wise([&](float x){return x / value;});
}
int rows() const{return rows_;}
int cols() const{return cols_;}
std::vector<float>& data(){return data_;}
float& operator(int ir, int ic){
int index = ir * cols_ + ic;
return data_[index];
}
const float& operator()(int ir, int ic)const{
int index = ir * cols_ + ic;
return data_[index];
}
private:
int rows_ = 0;
int cols_ = 0;
std::vector<float> data_;
};
std::ostream& operator<<(std::ostream& out, const Matrix& m){
printf("Matrix( %d x %d)\n", m.rows(), m.cols());
for(int ir = 0; ir < m.rows(); ++ir){
for(int ic = 0; ic < m.cols(); ++ic){
printf("%g\t", m(ir, ic));
}
}
return out;
}
Matrix operator*(float value, const Matrix& m){
return m * value;
}
#endif // MATRIX_HPP
上述示例代码中实现了矩阵 Matrix 类中加、减、乘、除四个运算符的重载实现,值得注意的是:
-
element_wise()
函数是对矩阵中的每个元素都应用一个函数,得到一个新的矩阵并返回。这样可以将四则运算统一起来,四则运算的函数采用匿名函数实现可以简化代码 -
四个运算符重载函数
operator*()、operator+()、operator-()、operator/()
,都是调用element_wise()
函数,传入一个对应的 lambda 表达式,对矩阵中的每个元素都进行相应的四则运算,得到一个新的矩阵并返回。 -
lambda 表达式定义了一个匿名函数,并且可以捕获一定范围内的变量。其语法形式如下:
[capture](params) opt -> ret {body;}
- 其中
capture
是捕获列表,&
表示捕获全局引用,=
表示捕获全局值,&value
表示捕获特定的值/引用 params
是参数列表,和普通函数的参数一样opt
是函数选项,不需要可以省略ret
是返回值类型,可以省略,编译器会自动推导body
是函数体- 关于 lambda 表达式的更多细节可参考 https://subingwen.cn/cpp/lambda/
- 其中
-
在矩阵与标量乘操作中可能存在 m*2 和 2*m 两种情形,我们分别采用了类内重载和全局重载两种方式实现
#include <iostream>
#include "matrix.hpp"
using namespace std;
int main(){
Matrix m1(3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
Matrix m2(3, 3, {2, 0, 0, 0, 1, 0, 0, 0, 1});
std::cout << m1 * 2 + 1 - 10;
std::cout << 2 * m2;
return 0;
}
3.2 通用矩阵乘法
通用矩阵乘法 gemm() 的实现
我们需要先编译 OpenBLAS 库,具体可参考 Ubuntu20.04软件安装大全
#ifndef MATRIX_HPP
#define MATRIX_HPP
#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>
#include <functional>
class Matrix{
public:
Matrix() = default;
Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
rows_ = rows;
cols_ = cols;
data_ = data;
if(data_.empty()){
data_.resize(rows * cols);
}else{
if(data_.size() != rows * cols)
std::cout << "Invalid construct.\n";
}
}
// ========== + - * / ==========
Matrix element_wise(const std::function<float(float)>& func) const{
Matrix output = *this; // 复制一份
for(int i = 0; i < output.data_.size(); ++i){
output.data_[i] = func(output.data_[i]);
}
return output;
}
Matrix operator*(float value) const{
// lambda函数 C++11特性
return element_wise([&](float x){return x * value;});
}
Matrix operator+(float value) const{
return element_wise([&](float x){return x + value});
}
Matrix operator-(float value) const{
return element_wise([&](float x){return x - value;});
}
Matrix operator/(float value) const{
return element_wise([&](float x){return x / value;});
}
int rows() const{return rows_;}
int cols() const{return cols_;}
const std::vector<float>& data(){return data_;}
const float* ptr()const{return data_.data();}
float* ptr(return data_.data();)
float& operator(int ir, int ic){
int index = ir * cols_ + ic;
return data_[index];
}
const float& operator()(int ir, int ic)const{
int index = ir * cols_ + ic;
return data_[index];
}
private:
int rows_ = 0;
int cols_ = 0;
std::vector<float> data_;
};
std::ostream& operator<<(std::ostream& out, const Matrix& m){
printf("Matrix( %d x %d)\n", m.rows(), m.cols());
for(int ir = 0; ir < m.rows(); ++ir){
for(int ic = 0; ic < m.cols(); ++ic){
printf("%g\t", m(ir, ic));
}
}
return out;
}
Matrix operator*(float value, const Matrix& m){
return m * value;
}
// ========== gemm ==========
Matrix gemm(const Matrix& a, bool ta, const Matrix& b, bool tb, float alpha, float beta){
// AB = C
// A^T B = C
// A B^T = C
// AB * scale + bias
// C = ta(A) * tb(B) * alpha + beta
// Cmxn = ta(A)mxk ta(B)kxn
int ta_rows = ta ? a.cols() : a.rows();
int ta_cols = ta ? a.rows() : a.cols();
int tb_rows = tb ? b.cols() : b.rows();
int tb_cols = tb ? b.rows() : b.cols();
Matrix c(ta_rows, tb_cols);
int m = ta_rows;
int n = tb_cols;
int k = ta_cols;
// 为了解决比如步长不等于列数的情况
int lda = a.cols(); // A矩阵的每一行所需要的步长
int ldb = b.cols();
int ldc = c.cols();
cblas_sgemm(
CblasRowMajor,
ta ? CblasTrans : CblasNoTrans,
tb ? CblasTrans : CblasNoTrans,
m, n, k, alpha, a.ptr(), lda, b.ptr(), ldb, beta, c.ptr(), ldc
);
return c;
}
#endif // MATRIX_HPP
cblas_sgemm
函数是 BLAS 库中的矩阵乘法函数,其参数如下:(from chatGPT)
void cblas_sgemm(
const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB,
const int M,
const int N,
const int K,
const float alpha,
const float *A,
const int lda,
const float *B,
const int ldb,
const float beta,
float *C,
const int ldc
);
其中各参数的含义如下:
- Order:矩阵的存储顺序。CBLAS_ORDER 枚举类型,取值可以是 CblasRowMajor 或 CblasColMajor,分别表示按行存储和按列存储。
- TransA:A 矩阵的转置情况。CBLAS_TRANSPOSE 枚举类型,取值可以是 CblasNoTrans(不转置)、CblasTrans(转置)或 CblasConjTrans(共轭转置)。
- TransB:B 矩阵的转置情况。CBLAS_TRANSPOSE 枚举类型,取值可以是 CblasNoTrans(不转置)、CblasTrans(转置)或 CblasConjTrans(共轭转置)。
- M:C 矩阵的行数。
- N:C 矩阵的列数。
- K:A 和 B 矩阵中共享的维度,即 A 矩阵的列数或 B 矩阵的行数。
- alpha:乘法操作的系数,通常取值为1。
- A:存储 A 矩阵的数组。
- lda:A 矩阵每行的元素个数,通常为 A 矩阵的列数。
- B:存储 B 矩阵的数组。
- ldb:B 矩阵每行的元素个数,通常为 B 矩阵的列数。
- beta:加法操作的系数,通常取值为 0。
- C:存储结果 C 矩阵的数组。
- ldc:C 矩阵每行的元素个数,通常为 C 矩阵的列数。
cblas_sgemm
函数会对 A、B、C 矩阵进行矩阵乘法运算,并将结果存储在 C 矩阵中。其中 A 矩阵的大小为 MxK,B 矩阵的大小为 KxN,C 矩阵的大小为 MxN
注意:实现和声明需要分离 class 不存在这种情况,只有函数存在。当函数的声明和实现都放在头文件时,可能会出现重复定义的问题
#include <iostream>
#include "matrix.hpp"
using namespace std;
Matrix gemm(const Matrix& a, bool ta, const Matrix& b, bool tb, float alpha, float beta);
int main(){
Matrix m1(3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
Matrix m2(3, 3, {2, 0, 0, 0, 1, 0, 0, 0, 1});
std::cout << gemm(m1, false, m2, false, 1.0f, 0.0f);
std::cout << gemm(m1, true, m2, false, 1.0f, 0.0f);
return 0;
}
3.3 矩阵求逆
求逆矩阵
#ifndef MATRIX_HPP
#define MATRIX_HPP
#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>
#include <functional>
class Matrix{
public:
Matrix() = default;
Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
rows_ = rows;
cols_ = cols;
data_ = data;
if(data_.empty()){
data_.resize(rows * cols);
}else{
if(data_.size() != rows * cols)
std::cout << "Invalid construct.\n";
}
}
// ========== + - * / ==========
Matrix element_wise(const std::function<float(float)>& func) const{
Matrix output = *this; // 复制一份
for(int i = 0; i < output.data_.size(); ++i){
output.data_[i] = func(output.data_[i]);
}
return output;
}
Matrix operator*(float value) const{
// lambda函数 C++11特性
return element_wise([&](float x){return x * value;});
}
Matrix operator+(float value) const{
return element_wise([&](float x){return x + value});
}
Matrix operator-(float value) const{
return element_wise([&](float x){return x - value;});
}
Matrix operator/(float value) const{
return element_wise([&](float x){return x / value;});
}
int rows() const{return rows_;}
int cols() const{return cols_;}
const std::vector<float>& data(){return data_;}
const float* ptr()const{return data_.data();}
float* ptr(return data_.data();)
float& operator(int ir, int ic){
int index = ir * cols_ + ic;
return data_[index];
}
const float& operator()(int ir, int ic)const{
int index = ir * cols_ + ic;
return data_[index];
}
private:
int rows_ = 0;
int cols_ = 0;
std::vector<float> data_;
};
std::ostream& operator<<(std::ostream& out, const Matrix& m){
printf("Matrix( %d x %d)\n", m.rows(), m.cols());
for(int ir = 0; ir < m.rows(); ++ir){
for(int ic = 0; ic < m.cols(); ++ic){
printf("%g\t", m(ir, ic));
}
}
return out;
}
Matrix operator*(float value, const Matrix& m){
return m * value;
}
// ========== gemm ==========
Matrix gemm(const Matrix& a, bool ta, const Matrix& b, bool tb, float alpha, float beta) const{
// AB = C
// A^T B = C
// A B^T = C
// AB * scale + bias
// C = ta(A) + tb(B) * alpha + beta
// Cmxn = ta(A)mxk ta(B)kxn
int ta_rows = ta ? a.cols() : a.rows();
int ta_cols = ta ? a.rows() : a.cols();
int tb_rows = tb ? b.cols() : b.rows();
int tb_cols = tb ? b.rows() : b.cols();
Matrix c(ta_rows, tb_cols);
int m = ta_rows;
int n = tb_cols;
int k = ta_cols;
// 为了解决比如步长不等于列数的情况
int lda = a.cols(); // A矩阵的每一行所需要的步长
int ldb = b.cols();
int ldc = c.cols();
cblas_sgemm(
CblasRowMajor,
ta ? CblasTrans : CblasNoTrans,
tb ? CblasTrans : CblasNoTrans,
m, n, k, alpha, a.ptr(), lda, b.ptr(), ldb, beta, c.ptr(), ldc
);
return c;
}
// ========== inv ==========
Matrix inverse(const Matrix& a){
if(a.rows() != a.cols()){
printf("Invalid to compute inverse matrix by %d x %d\n", a.rows(), a.cols());
return Matrix();
}
Matrix output = a;
int n = a.rows();
int *ipiv = new int[n];
/* LU分解 */
int code = LAPACKE_sgetrf(LAPACK_COL_MAJOR, n, n, output.ptr(), n, ipiv);
if(code == 0){
/* 使用LU分解求解通用逆矩阵 */
code = LAPACKE_sgetri(LAPACK_COL_MAJOR, n, output.ptr(), n, ipiv);
}
if(code != 0){
printf("LAPACKE inverse matrix failed, code = %d\n", code);
return Matrix();
}
delete[] ipiv;
return output;
}
#endif // MATRIX_HPP
LAPACK_sgetrf函数和
LAPACK_sgetri 函数的使用可以实现对逆矩阵的求解。具体来说,可以先使用
LAPACK_sgetrf函数进行 LU 分解,然后再使用
LAPACK_sgetri` 函数对 LU 分解后的矩阵进行求逆。(from chatGPT)
LAPACK_sgetrf
函数的参数:
- order:表示矩阵数据的存储顺序,可以是LAPACK_ROW_MAJOR或者LAPACK_COL_MAJOR。
- m:表示矩阵 A 的行数。
- n:表示矩阵 A 的列数。
- A:指向矩阵 A 的指针。
- lda:表示矩阵 A 的行宽。
- ipiv:指向一个长度为 min(m,n) 的整数数组,存储 LU 分解的行置换信息。
LAPACK_sgetrf
函数的返回值:
- 如果返回值等于零,则表示操作成功完成。
- 如果返回值小于零,则表示参数错误或某个 U(i,i) 为零,无法进行 LU 分解。
- 如果返回值大于零,则表示 A 的前返回值列的 LU 分解出现奇异矩阵,无法求解。
LAPACK_sgetri
函数的参数:
- order:表示矩阵数据的存储顺序,可以是LAPACK_ROW_MAJOR或者LAPACK_COL_MAJOR。
- n:表示矩阵A的行数和列数。
- A:指向矩阵A的指针。
- lda:表示矩阵A的行宽。
- ipiv:指向一个长度为n的整数数组,存储LU分解的行置换信息。
LAPACK_sgetri
函数的返回值:
- 如果返回值等于零,则表示操作成功完成。
- 如果返回值小于零,则表示参数错误。
- 如果返回值大于零,则表示某个 A(i,i) 为零,无法进行求解。
#include <iostream>
#include "matrix.hpp"
using namespace std;
Matrix inverse(const Matrix& a);
int main(){
Matrix m1(3, 3, {3, 2, 3, 4, 5, 6, 7, 8, 9});
Matrix m2(3, 3, {2, 0, 0, 0, 1, 0, 0, 0, 1});
std::cout << inverse(m1);
std::cout << inverse(m2);
return 0;
}
4. 完整示例代码
我们将声明和实现分离,共三个文件为 matrix.hpp、matrix.cpp、main.cpp
matrix.hpp
#ifndef MATRIX_HPP
#define MATRIX_HPP
#include <vector>
#include <iostream>
#include <ostream>
#include <initializer_list>
#include <functional>
class Matrix{
public:
Matrix();
Matrix(int rows, int cols, const std::initializer_list<float>& data={});
Matrix element_wise(const std::function<float(float)>& func) const;
Matrix operator*(float value) const;
Matrix operator-(float value) const;
Matrix operator+(float value) const;
Matrix operator/(float value) const;
const int rows() const{return rows_;};
const int cols() const{return cols_;};
const std::vector<float>& data() const{return data_;};
const float* ptr()const{return data_.data();};
float* ptr(){return data_.data();};
float& operator()(int ir, int ic);
const float& operator()(int ir, int ic) const;
// Matrix gemm(const Matrix& other, bool ta, bool tb, float alpha=1.0f, float beta=0.0f){
// return ::gemm(*this, ta, other, tb, alpha, beta);
// }
private:
int rows_;
int cols_;
std::vector<float> data_;
};
std::ostream& operator<<(std::ostream& out, const Matrix& m);
Matrix operator*(float value, const Matrix& m);
#endif // MATRIX_HPP
matrix.cpp
#include "cblas.h"
#include "lapacke.h"
#include "matrix.hpp"
Matrix::Matrix(){}
Matrix::Matrix(int rows, int cols, const std::initializer_list<float>& data){
this->rows_ = rows;
this->cols_ = cols;
this->data_ = data;
if(this->data_.size() < rows * cols)
this->data_.resize(rows * cols);
};
Matrix Matrix::element_wise(const std::function<float(float)>& func) const{
Matrix output = *this;
for(int i = 0; i < output.data_.size(); ++i){
output.data_[i] = func(output.data_[i]);
}
return output;
}
Matrix Matrix::operator*(float value) const{
return element_wise([&](float x){return x * value;});
}
Matrix Matrix::operator-(float value) const{
return element_wise([&](float x){return x - value;});
}
Matrix Matrix::operator+(float value) const{
return element_wise([&](float x){return x + value;});
}
Matrix Matrix::operator/(float value) const{
return element_wise([&](float x){return x / value;});
}
float& Matrix::operator()(int ir, int ic){
int index = ir * cols_ + ic;
return data_[index];
}
const float& Matrix::operator()(int ir, int ic) const{
int index = ir * cols_ + ic;
return data_[index];
}
std::ostream& operator<<(std::ostream& out, const Matrix& m){
printf("Matrix (%d x %d)\n", m.rows(), m.cols());
for(int ir = 0; ir < m.rows(); ++ir){
for(int ic = 0; ic < m.cols(); ++ic){
printf("%g\t", m(ir, ic));
}
printf("\n");
}
return out;
}
Matrix operator*(float value, const Matrix& m){
return m * value;
}
// ========== gemm ==========
Matrix gemm(const Matrix& a, bool ta, const Matrix& b, bool tb, float alpha, float beta){
// C = ta(A) * ta(B) * alpha + beta
int ta_rows = ta ? a.cols() : a.rows();
int ta_cols = ta ? a.rows() : a.cols();
int tb_rows = tb ? b.cols() : b.rows();
int tb_cols = tb ? b.rows() : b.cols();
Matrix c(ta_rows, tb_cols);
int m = ta_rows;
int n = tb_cols;
int k = ta_cols;
// 为了解决比如步长不等于列数的情况
int lda = a.cols(); // A矩阵的每一行所需要的步长
int ldb = b.cols();
int ldc = c.cols();
cblas_sgemm(
CblasRowMajor,
ta ? CblasTrans : CblasNoTrans,
tb ? CblasTrans : CblasNoTrans,
m, n, k, alpha, a.ptr(), lda, b.ptr(), ldb, beta, c.ptr(), ldc
);
return c;
}
// ========== inv ==========
Matrix inverse(const Matrix& a){
if(a.rows() != a.cols()){
printf("Invalid to compute inverse matrix by %d x %d\n", a.rows(), a.cols());
return Matrix();
}
Matrix output = a;
int n = a.rows();
int *ipiv = new int[n];
/* LU分解 */
int code = LAPACKE_sgetrf(LAPACK_COL_MAJOR, n, n, output.ptr(), n, ipiv);
if(code == 0){
/* 使用LU分解求解通用逆矩阵 */
code = LAPACKE_sgetri(LAPACK_COL_MAJOR, n, output.ptr(), n, ipiv);
}
if(code != 0){
printf("LAPACKE inverse matrix failed, code = %d\n", code);
return Matrix();
}
delete[] ipiv;
return output;
}
main.cpp
#include <iostream>
#include "matrix.hpp"
using namespace std;
Matrix inverse(const Matrix& a);
Matrix gemm(const Matrix& a, bool ta, const Matrix& b, bool tb, float alpha, float beta);
int main(){
Matrix m1(3, 3, {3, 2, 3, 4, 5, 6, 7, 8, 9});
Matrix m2(3, 3, {2, 0, 0, 0, 1, 0, 0, 0, 1});
std::cout << gemm(m1, false, m2, false, 1.0f, 0.0f);
std::cout << inverse(m2);
std::cout << m1 * 2 + 5 - 1;
return 0;
}
总结
本次课程跟随杜老师手写了 Matrix 类的具体实现,学习到了很多关于 C++ 语法、习惯的知识,同时也碰到了一些问题,学习的过程就是不断解决问题的过程😄