AlgoC++第七课:手写Matrix

news2025/1/1 10:32:41

目录

  • 手写Matrix
    • 前言
    • 1. 明确需求
    • 2. 基本实现
      • 2.1 创建矩阵
      • 2.2 外部访问
      • 2.3 <<操作符重载
    • 3. 矩阵运算
      • 3.1 矩阵标量运算
      • 3.2 通用矩阵乘法
      • 3.3 矩阵求逆
    • 4. 完整示例代码
    • 总结

手写Matrix

前言

手写AI推出的全新面向AI算法的C++课程 Algo C++,链接。记录下个人学习笔记,仅供自己参考。

本次课程主要是手写Matrix代码

课程大纲可看下面的思维导图

在这里插入图片描述

1. 明确需求

我们先来明确下 Matrix 类中应该实现那些功能

1.只能够表示 2 维的矩阵形式,即使是向量也会用 matrix 表示

​ -我们指表达 float 格式的矩阵,不表达其它形式

2.矩阵的乘法,通用矩阵的乘法形式

3.要能够求解逆矩阵

4.可以通过指定行和列进行矩阵的创建

5.可以允许使用 {1, 2, 3} 这种形式进行数据填充的方式创建

6.能够与标量进行常规的加减法

2. 基本实现

2.1 创建矩阵

我们先来实现 Matrix 类中矩阵创建的功能

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>

class Matrix{
public:
    Matrix() = default;

private:
    int rows_ = 0;
    int cols_ = 0;
    std::vector<float> data_;
    
};

#endif // MATRIX_HPP

在上面的示例代码中,我们定义了矩阵 Matrix 类的默认构造函数,以及定义了三个私有成员变量 rows_、cols、data_ 分别代表矩阵的行号、列号以及数据。#ifndef #define #endif 语句是为了防止头文件重复包含。值得注意的是,我们习惯在成员变量尾部加上_

我们再来实现使用 {1, 2, 3}进行数据填充的方式来创建矩阵

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>

class Matrix{
public:
    Matrix() = default;
    Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
        
        rows_ = rows;
        cols_ = cols;
        
        // 隐式转换,其实是执行了vector的赋值操作
        data_ = data;
        
        // 1. data的元素为空,说明是不指定数据情况下进行创建
        // 2. data的元素不空,说明是指定数据情况下创建
        	// 1. 元素数量等于rows * cols
        	// 2. 元素数量小于rows * cols
       	
        if(data_.empty()){
            
            // resize表示分配rows * cols个元素,在vector中
            // 此时没有对vector做初始化,但是其内部的值全部为0,这是vector保证的
            data_.resize(rows * cols);
        }else{
            if(data_.size() != rows * cols)
                std::cout << "Invalid construct.\n";
        }
    }

private:
    int rows_ = 0;
    int cols_ = 0;
    std::vector<float> data_;
    
};

#endif // MATRIX_HPP

在上面的示例代码中,我们实现了可以通过 Matrix m(3, 1, {1, 3, 2}) 这种方式来创建矩阵。值得注意的是

  • {} 这种类型的数据在 C++ 中叫做 initializer_list,是一种容器。
  • 在 C++ 的容器中,比如 STL 对象中的 vector、list 等,它们分配的空间,如果不进行初始化,则其内部的值是 0
  • 对于空,在 C 语言里面大部分是指 malloc 分配出来的内存没有初始化的情况
    • 此时 malloc 分配内存的值其实是随机的
    • 此时 new 分配的内存的值也是随机的
  • data 参数是以 & 引用的方式传递,防止拷贝;而 const 参数表示传入常引用,在函数内部对其不进行修改
  • 我们习惯在传值的时候,对于非基础类型,一般会传递常引用,使得效率更高,避免拷贝

2.2 外部访问

我们还希望能够外部访问到矩阵的行、列以及矩阵的元素变量

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>

class Matrix{
public:
    Matrix() = default;
    Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
        
        rows_ = rows;
        cols_ = cols;
        
        // 隐式转换,其实是执行了vector的赋值操作
        data_ = data;
        
        // 1. data的元素为空,说明是不指定数据情况下进行创建
        // 2. data的元素不空,说明是指定数据情况下创建
        	// 1. 元素数量等于rows * cols
        	// 2. 元素数量小于rows * cols
       	
        if(data_.empty()){
            
            // resize表示分配rows * cols个元素,在vector中
            // 此时没有对vector做初始化,但是其内部的值全部为0,这是vector保证的
            data_.resize(rows * cols);
        }else{
            if(data_.size() != rows * cols)
                std::cout << "Invalid construct.\n";
        }
    }

    int rows() {return rows_;}
    int cols() {return cols_;}
    std::vector<float>& data(){return data_;}
    
private:
    int rows_ = 0;
    int cols_ = 0;
    std::vector<float> data_;
    
};

#endif // MATRIX_HPP

在上面的示例代码中,我们分别实现了 rows()、cols()、data() 函数用来访问矩阵,如下所示:

#include <iostream>
#include "matrix.hpp"
using namespace std;


int main(){

    Matrix m(3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});

    std::cout << "矩阵行 = " << m.rows() << std::endl;
    std::cout << "矩阵列 = " << m.cols() << std::endl;
    
    std::cout << "矩阵数据如下: " << std::endl;
    for(int i = 0; i < m.data().size(); ++i){
        std::cout << m.data()[i] << std::endl;
    }    

    return 0;
}

在这里插入图片描述

我们希望能够通过 m[i][j] 这种形式去访问到矩阵的元素,需要用到操作符重载

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>

class Matrix{
public:
    Matrix() = default;
    Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
        
        rows_ = rows;
        cols_ = cols;
        
        // 隐式转换,其实是执行了vector的赋值操作
        data_ = data;
        
        // 1. data的元素为空,说明是不指定数据情况下进行创建
        // 2. data的元素不空,说明是指定数据情况下创建
        	// 1. 元素数量等于rows * cols
        	// 2. 元素数量小于rows * cols
       	
        if(data_.empty()){
            
            // resize表示分配rows * cols个元素,在vector中
            // 此时没有对vector做初始化,但是其内部的值全部为0,这是vector保证的
            data_.resize(rows * cols);
        }else{
            if(data_.size() != rows * cols)
                std::cout << "Invalid construct.\n";
        }
    }

    int rows() {return rows_;}
    int cols() {return cols_;}
    std::vector<float>& data(){return data_;}
    
    float& operator()(int ir, int ic){
        
        // data_在内存中是连续的
        // 比如说我们有3x3的矩阵,那么
        // data_就等于 = {1, 2, 3, 4, 5, 6, 7, 8, 9}
        // 它代表的就是:
        /*
        	1 2 3
        	4 5 6
        	7 8 9
        */
        // 如果要访问2行,0列。此时应该是对应的7
        // 把2d的索引,映射到连续1d空间的索引上
        int index = ir * cols_ + ic;
        return data_[index];
    }
    
private:
    int rows_ = 0;
    int cols_ = 0;
    std::vector<float> data_;
    
};

#endif // MATRIX_HPP

在上述示例代码中,我们对操作符 () 进行了重载,使得矩阵可以通过 m(i,j) 这种方式访问矩阵的元素,值得注意的是:

  • 我们需要将 operator() 看成一个整体

  • 在 C++ 中不允许 [] 提供更多参数,这种操作只能提供一个参数,因此,可以换成 m(i,j),此时是可以允许的

  • 我们返回的是 float & 引用而不是常引用,意味着我们可以直接修改其内部元素

#include <iostream>
#include "matrix.hpp"
using namespace std;


int main(){

    Matrix m(3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});

    std::cout << "矩阵行 = " << m.rows() << std::endl;
    std::cout << "矩阵列 = " << m.cols() << std::endl;
    
    auto data = m.data();
    
    m(1, 1) = 123.5;

    std::cout << "矩阵数据如下: " << std::endl;
    for(int ir = 0; ir < m.rows(); ++ir){
        for(int ic = 0; ic < m.cols(); ++ic){
            std::cout << m(ir, ic) << "\t";
        }
        std::cout << "\n";
    }

    return 0;
}

在这里插入图片描述

2.3 <<操作符重载

我们还需要重载操作符 <<,使其能直接打印矩阵即 std::cout << m

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>

class Matrix{
public:
    Matrix() = default;
    Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
        
        rows_ = rows;
        cols_ = cols;
        
        // 隐式转换,其实是执行了vector的赋值操作
        data_ = data;
        
        // 1. data的元素为空,说明是不指定数据情况下进行创建
        // 2. data的元素不空,说明是指定数据情况下创建
        	// 1. 元素数量等于rows * cols
        	// 2. 元素数量小于rows * cols
       	
        if(data_.empty()){
            
            // resize表示分配rows * cols个元素,在vector中
            // 此时没有对vector做初始化,但是其内部的值全部为0,这是vector保证的
            data_.resize(rows * cols);
        }else{
            if(data_.size() != rows * cols)
                std::cout << "Invalid construct.\n";
        }
    }

    // 在 rows() 后面加const,表示这个函数是常量函数
    // 潜台词是:它不会对内部成员做修改,仅仅只做访问查询
    int rows() const{return rows_;}
    int cols() const{return cols_;}
    std::vector<float>& data(){return data_;}
    
    // 这个表示,重载一个可以修改元素的函数
    float& operator(int ir, int ic){
        int index = ir * cols_ + ic;
        return data_[index];
    }
    
    // 这个表示,重载一个只能读取的函数
    const float& operator()(int ir, int ic)const{
        
        // data_在内存中是连续的
        // 比如说我们有3x3的矩阵,那么
        // data_就等于 = {1, 2, 3, 4, 5, 6, 7, 8, 9}
        // 它代表的就是:
        /*
        	1 2 3
        	4 5 6
        	7 8 9
        */
        // 如果要访问2行,0列。此时应该是对应的7
        // 把2d的索引,映射到连续1d空间的索引上
        int index = ir * cols_ + ic;
        return data_[index];
    }
    
private:
    int rows_ = 0;
    int cols_ = 0;
    std::vector<float> data_;
    
};

std::ostream& operator<<(std::ostream& out, const Matrix& m){
    
    printf("Matrix( %d x %d)\n", m.rows(), m.cols());
    for(int ir = 0; ir < m.rows(); ++ir){
        for(int ic = 0; ic < m.cols(); ++ic){
            printf("%g\t", m(ir, ic));
        }
        printf("\n");
    }
    return out;
}

#endif // MATRIX_HPP

在上述示例代码中我们重载了操作符 <<,值得注意的是:

  • << 等价于 operator<<,因此 std::cout << 123 等价于 std::cout.operator<<(123)

  • << 操作符重载有两种方式

    • 第一种是存在于类内的,例如 std::cout.operator<<(m);由于这个是系统文件,最好不要修改
    • 第二种是存在于全局作用域的,例如 std::ostream& operator<<(std::ostream& out, const Matrix& m)
      • 首先,全局操作符重载,是特定操作符为函数名称
      • 其次,第一个参数,称之为左操作数;第二个参数,称之为右操作数
      • 左操作数 out 对象为引用而非常引用是因为 out 对象存在写操作,势必是修改,因此不能是常量了,必须是非常量引用
      • 右操作数 Matrix 对象为常引用是因为避免拷贝值得发生,我们在这里只需要读取就行了
      • 左操作数 << 右操作数 相当于 operator<<(左操作数,右操作数)
  • 在 C++ 中类的函数,分为常规函数(具有修改和访问权限)和常量函数(只有访问权限,没有修改权限)。右操作数是一个常量对象只能访问常量函数,因此在调用的函数应该是常量函数,所以在 rows()、cols()、operator() 函数后面都要加上 const 关键字,且 operator () 返回的应该是一个常量引用

  • 我们需要对 operator() 重载既能实现常量函数,又能完成修改操作,可以将 operator() 写两遍,一个表示重载一个可以修改元素的函数,一个表示重载一个只能读取的函数

#include <iostream>
#include "matrix.hpp"
using namespace std;


int main(){

    Matrix m(3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});

    cout << m ;

    return 0;
}

在这里插入图片描述

最后我们来区分下

int rows() const{return rows_};
const int rows(){return rows_};

第一个函数 int rows() const{return rows_} 是一个 const 成员函数,其中的 const 关键字表示该成员函数不会修改对象的成员变量,即保证了函数内部不会修改 rows_ 的值,同时该成员函数可以被 const 对象和非 const 对象调用

而第二个函数 const int rows(){return rows_} 是一个非 const 成员函数,其中的 const 关键字表示函数返回的是 const 类型的值,但并没有对函数本身做出限制。该函数可以被非 const 对象调用,但不能被 const 对象调用

因此,第一个函数可以用于保证对象的成员变量不被修改,并且可以被 const 和非 const 对象调用,而第二个函数只能被非 const 对象调用,如果被 const 对象调用则会产生编译错误

3. 矩阵运算

3.1 矩阵标量运算

实现矩阵与标量的 + - * / 四则运算

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>
#include <functional>

class Matrix{
public:
    Matrix() = default;
    Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
        
        rows_ = rows;
        cols_ = cols;
        data_ = data;

        if(data_.empty()){
            data_.resize(rows * cols);
        }else{
            if(data_.size() != rows * cols)
                std::cout << "Invalid construct.\n";
        }
    }

    // ========== + - * / ==========
    Matrix element_wise(const std::function<float(float)>& func) const{
        Matrix output = *this;	// 复制一份
        for(int i = 0; i < output.data_.size(); ++i){
            output.data_[i] = func(output.data_[i]);
        }
        return output;
    }
    
    Matrix operator*(float value) const{
        // lambda函数 C++11特性
        return element_wise([&](float x){return x * value;});
    }
    
    Matrix operator+(float value) const{
        return element_wise([&](float x){return x + value});
    }
    
    Matrix operator-(float value) const{
        return element_wise([&](float x){return x - value;});
    }
    
    Matrix operator/(float value) const{
        return element_wise([&](float x){return x / value;});
    }
    
    int rows() const{return rows_;}
    int cols() const{return cols_;}
    std::vector<float>& data(){return data_;}
    
    float& operator(int ir, int ic){
        int index = ir * cols_ + ic;
        return data_[index];
    }
    
    const float& operator()(int ir, int ic)const{
        int index = ir * cols_ + ic;
        return data_[index];
    }
    
private:
    int rows_ = 0;
    int cols_ = 0;
    std::vector<float> data_;
    
};

std::ostream& operator<<(std::ostream& out, const Matrix& m){
    
    printf("Matrix( %d x %d)\n", m.rows(), m.cols());
    for(int ir = 0; ir < m.rows(); ++ir){
        for(int ic = 0; ic < m.cols(); ++ic){
            printf("%g\t", m(ir, ic));
        }
    }
    return out;
}

Matrix operator*(float value, const Matrix& m){
    return m * value;
}

#endif // MATRIX_HPP

上述示例代码中实现了矩阵 Matrix 类中加、减、乘、除四个运算符的重载实现,值得注意的是:

  • element_wise() 函数是对矩阵中的每个元素都应用一个函数,得到一个新的矩阵并返回。这样可以将四则运算统一起来,四则运算的函数采用匿名函数实现可以简化代码

  • 四个运算符重载函数 operator*()、operator+()、operator-()、operator/(),都是调用 element_wise() 函数,传入一个对应的 lambda 表达式,对矩阵中的每个元素都进行相应的四则运算,得到一个新的矩阵并返回。

  • lambda 表达式定义了一个匿名函数,并且可以捕获一定范围内的变量。其语法形式如下:

    [capture](params) opt -> ret {body;}
    
    • 其中 capture 是捕获列表,& 表示捕获全局引用,= 表示捕获全局值,&value 表示捕获特定的值/引用
    • params 是参数列表,和普通函数的参数一样
    • opt 是函数选项,不需要可以省略
    • ret 是返回值类型,可以省略,编译器会自动推导
    • body 是函数体
    • 关于 lambda 表达式的更多细节可参考 https://subingwen.cn/cpp/lambda/
  • 在矩阵与标量乘操作中可能存在 m*2 和 2*m 两种情形,我们分别采用了类内重载和全局重载两种方式实现

#include <iostream>
#include "matrix.hpp"
using namespace std;


int main(){

    Matrix m1(3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
    Matrix m2(3, 3, {2, 0, 0, 0, 1, 0, 0, 0, 1});

    std::cout << m1 * 2 + 1 - 10;
    std::cout << 2 * m2;

    return 0;
}

在这里插入图片描述

3.2 通用矩阵乘法

通用矩阵乘法 gemm() 的实现

我们需要先编译 OpenBLAS 库,具体可参考 Ubuntu20.04软件安装大全

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>
#include <functional>

class Matrix{
public:
    Matrix() = default;
    Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
        
        rows_ = rows;
        cols_ = cols;
        data_ = data;

        if(data_.empty()){
            data_.resize(rows * cols);
        }else{
            if(data_.size() != rows * cols)
                std::cout << "Invalid construct.\n";
        }
    }

    // ========== + - * / ==========
    Matrix element_wise(const std::function<float(float)>& func) const{
        Matrix output = *this;	// 复制一份
        for(int i = 0; i < output.data_.size(); ++i){
            output.data_[i] = func(output.data_[i]);
        }
        return output;
    }
    
    Matrix operator*(float value) const{
        // lambda函数 C++11特性
        return element_wise([&](float x){return x * value;});
    }
    
    Matrix operator+(float value) const{
        return element_wise([&](float x){return x + value});
    }
    
    Matrix operator-(float value) const{
        return element_wise([&](float x){return x - value;});
    }
    
    Matrix operator/(float value) const{
        return element_wise([&](float x){return x / value;});
    }
    
    int rows() const{return rows_;}
    int cols() const{return cols_;}
    const std::vector<float>& data(){return data_;}
    const float* ptr()const{return data_.data();}
    float* ptr(return data_.data();)
    
    float& operator(int ir, int ic){
        int index = ir * cols_ + ic;
        return data_[index];
    }
    
    const float& operator()(int ir, int ic)const{
        int index = ir * cols_ + ic;
        return data_[index];
    }
     
private:
    int rows_ = 0;
    int cols_ = 0;
    std::vector<float> data_;
    
};

std::ostream& operator<<(std::ostream& out, const Matrix& m){
    
    printf("Matrix( %d x %d)\n", m.rows(), m.cols());
    for(int ir = 0; ir < m.rows(); ++ir){
        for(int ic = 0; ic < m.cols(); ++ic){
            printf("%g\t", m(ir, ic));
        }
    }
    return out;
}

Matrix operator*(float value, const Matrix& m){
    return m * value;
}

// ========== gemm ==========
Matrix gemm(const Matrix& a, bool ta, const Matrix& b, bool tb, float alpha, float beta){
    // AB = C
    // A^T B = C
    // A B^T = C
    // AB * scale + bias
    // C = ta(A) * tb(B) * alpha + beta
    // Cmxn = ta(A)mxk ta(B)kxn
    
    int ta_rows = ta ? a.cols() : a.rows();
    int ta_cols = ta ? a.rows() : a.cols();
    int tb_rows = tb ? b.cols() : b.rows();
    int tb_cols = tb ? b.rows() : b.cols();
    
    Matrix c(ta_rows, tb_cols);
    int m = ta_rows;
    int n = tb_cols;
    int k = ta_cols;
    
    // 为了解决比如步长不等于列数的情况
    int lda = a.cols();		// A矩阵的每一行所需要的步长
    int ldb = b.cols();
    int ldc = c.cols();
    
    cblas_sgemm(
    	CblasRowMajor,
        ta ? CblasTrans : CblasNoTrans,
        tb ? CblasTrans : CblasNoTrans,
        m, n, k, alpha, a.ptr(), lda, b.ptr(), ldb, beta, c.ptr(), ldc
    );
    return c;
}

#endif // MATRIX_HPP

cblas_sgemm 函数是 BLAS 库中的矩阵乘法函数,其参数如下:(from chatGPT)

void cblas_sgemm(
    const enum CBLAS_ORDER Order,
    const enum CBLAS_TRANSPOSE TransA,
    const enum CBLAS_TRANSPOSE TransB,
    const int M,
    const int N,
    const int K,
    const float alpha,
    const float *A,
    const int lda,
    const float *B,
    const int ldb,
    const float beta,
    float *C,
    const int ldc
);

其中各参数的含义如下:

  • Order:矩阵的存储顺序。CBLAS_ORDER 枚举类型,取值可以是 CblasRowMajor 或 CblasColMajor,分别表示按行存储和按列存储。
  • TransA:A 矩阵的转置情况。CBLAS_TRANSPOSE 枚举类型,取值可以是 CblasNoTrans(不转置)、CblasTrans(转置)或 CblasConjTrans(共轭转置)。
  • TransB:B 矩阵的转置情况。CBLAS_TRANSPOSE 枚举类型,取值可以是 CblasNoTrans(不转置)、CblasTrans(转置)或 CblasConjTrans(共轭转置)。
  • M:C 矩阵的行数。
  • N:C 矩阵的列数。
  • K:A 和 B 矩阵中共享的维度,即 A 矩阵的列数或 B 矩阵的行数。
  • alpha:乘法操作的系数,通常取值为1。
  • A:存储 A 矩阵的数组。
  • lda:A 矩阵每行的元素个数,通常为 A 矩阵的列数。
  • B:存储 B 矩阵的数组。
  • ldb:B 矩阵每行的元素个数,通常为 B 矩阵的列数。
  • beta:加法操作的系数,通常取值为 0。
  • C:存储结果 C 矩阵的数组。
  • ldc:C 矩阵每行的元素个数,通常为 C 矩阵的列数。

cblas_sgemm 函数会对 A、B、C 矩阵进行矩阵乘法运算,并将结果存储在 C 矩阵中。其中 A 矩阵的大小为 MxK,B 矩阵的大小为 KxN,C 矩阵的大小为 MxN

注意实现和声明需要分离 class 不存在这种情况,只有函数存在。当函数的声明和实现都放在头文件时,可能会出现重复定义的问题

#include <iostream>
#include "matrix.hpp"
using namespace std;

Matrix gemm(const Matrix& a, bool ta, const Matrix& b, bool tb, float alpha, float beta);


int main(){

    Matrix m1(3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
    Matrix m2(3, 3, {2, 0, 0, 0, 1, 0, 0, 0, 1});

    std::cout << gemm(m1, false, m2, false, 1.0f, 0.0f);
    std::cout << gemm(m1, true, m2, false, 1.0f, 0.0f);

    return 0;
}

在这里插入图片描述

3.3 矩阵求逆

求逆矩阵

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <initializer_list>
#include <iostream>
#include <ostream>
#include <functional>

class Matrix{
public:
    Matrix() = default;
    Matrix(int rows, int cols, const std::initializer_list<float>& data={}){
        
        rows_ = rows;
        cols_ = cols;
        data_ = data;

        if(data_.empty()){
            data_.resize(rows * cols);
        }else{
            if(data_.size() != rows * cols)
                std::cout << "Invalid construct.\n";
        }
    }

    // ========== + - * / ==========
    Matrix element_wise(const std::function<float(float)>& func) const{
        Matrix output = *this;	// 复制一份
        for(int i = 0; i < output.data_.size(); ++i){
            output.data_[i] = func(output.data_[i]);
        }
        return output;
    }
    
    Matrix operator*(float value) const{
        // lambda函数 C++11特性
        return element_wise([&](float x){return x * value;});
    }
    
    Matrix operator+(float value) const{
        return element_wise([&](float x){return x + value});
    }
    
    Matrix operator-(float value) const{
        return element_wise([&](float x){return x - value;});
    }
    
    Matrix operator/(float value) const{
        return element_wise([&](float x){return x / value;});
    }
    
    int rows() const{return rows_;}
    int cols() const{return cols_;}
    const std::vector<float>& data(){return data_;}
    const float* ptr()const{return data_.data();}
    float* ptr(return data_.data();)
    
    float& operator(int ir, int ic){
        int index = ir * cols_ + ic;
        return data_[index];
    }
    
    const float& operator()(int ir, int ic)const{
        int index = ir * cols_ + ic;
        return data_[index];
    }
     
private:
    int rows_ = 0;
    int cols_ = 0;
    std::vector<float> data_;
    
};

std::ostream& operator<<(std::ostream& out, const Matrix& m){
    
    printf("Matrix( %d x %d)\n", m.rows(), m.cols());
    for(int ir = 0; ir < m.rows(); ++ir){
        for(int ic = 0; ic < m.cols(); ++ic){
            printf("%g\t", m(ir, ic));
        }
    }
    return out;
}

Matrix operator*(float value, const Matrix& m){
    return m * value;
}

// ========== gemm ==========
Matrix gemm(const Matrix& a, bool ta, const Matrix& b, bool tb, float alpha, float beta) const{
    // AB = C
    // A^T B = C
    // A B^T = C
    // AB * scale + bias
    // C = ta(A) + tb(B) * alpha + beta
    // Cmxn = ta(A)mxk ta(B)kxn
    
    int ta_rows = ta ? a.cols() : a.rows();
    int ta_cols = ta ? a.rows() : a.cols();
    int tb_rows = tb ? b.cols() : b.rows();
    int tb_cols = tb ? b.rows() : b.cols();
    
    Matrix c(ta_rows, tb_cols);
    int m = ta_rows;
    int n = tb_cols;
    int k = ta_cols;
    
    // 为了解决比如步长不等于列数的情况
    int lda = a.cols();		// A矩阵的每一行所需要的步长
    int ldb = b.cols();
    int ldc = c.cols();
    
    cblas_sgemm(
    	CblasRowMajor,
        ta ? CblasTrans : CblasNoTrans,
        tb ? CblasTrans : CblasNoTrans,
        m, n, k, alpha, a.ptr(), lda, b.ptr(), ldb, beta, c.ptr(), ldc
    );
    return c;
}

// ========== inv ========== 
Matrix inverse(const Matrix& a){
    
    if(a.rows() != a.cols()){
        printf("Invalid to compute inverse matrix by %d x %d\n", a.rows(), a.cols());
        return Matrix();
    }
    
    Matrix output = a;
    int n = a.rows();
    int *ipiv = new int[n];
    
    /* LU分解 */
    int code = LAPACKE_sgetrf(LAPACK_COL_MAJOR, n, n, output.ptr(), n, ipiv);
    if(code == 0){
        /* 使用LU分解求解通用逆矩阵 */
        code = LAPACKE_sgetri(LAPACK_COL_MAJOR, n, output.ptr(), n, ipiv);
    }

    if(code != 0){
        printf("LAPACKE inverse matrix failed, code = %d\n", code);
        return Matrix();
    }

    delete[] ipiv;
    return output;
}
#endif // MATRIX_HPP

LAPACK_sgetrf函数和LAPACK_sgetri 函数的使用可以实现对逆矩阵的求解。具体来说,可以先使用LAPACK_sgetrf函数进行 LU 分解,然后再使用LAPACK_sgetri` 函数对 LU 分解后的矩阵进行求逆。(from chatGPT)

LAPACK_sgetrf 函数的参数:

  • order:表示矩阵数据的存储顺序,可以是LAPACK_ROW_MAJOR或者LAPACK_COL_MAJOR。
  • m:表示矩阵 A 的行数。
  • n:表示矩阵 A 的列数。
  • A:指向矩阵 A 的指针。
  • lda:表示矩阵 A 的行宽。
  • ipiv:指向一个长度为 min(m,n) 的整数数组,存储 LU 分解的行置换信息。

LAPACK_sgetrf 函数的返回值:

  • 如果返回值等于零,则表示操作成功完成。
  • 如果返回值小于零,则表示参数错误或某个 U(i,i) 为零,无法进行 LU 分解。
  • 如果返回值大于零,则表示 A 的前返回值列的 LU 分解出现奇异矩阵,无法求解。

LAPACK_sgetri 函数的参数:

  • order:表示矩阵数据的存储顺序,可以是LAPACK_ROW_MAJOR或者LAPACK_COL_MAJOR。
  • n:表示矩阵A的行数和列数。
  • A:指向矩阵A的指针。
  • lda:表示矩阵A的行宽。
  • ipiv:指向一个长度为n的整数数组,存储LU分解的行置换信息。

LAPACK_sgetri 函数的返回值:

  • 如果返回值等于零,则表示操作成功完成。
  • 如果返回值小于零,则表示参数错误。
  • 如果返回值大于零,则表示某个 A(i,i) 为零,无法进行求解。
#include <iostream>
#include "matrix.hpp"
using namespace std;

Matrix inverse(const Matrix& a);


int main(){

    Matrix m1(3, 3, {3, 2, 3, 4, 5, 6, 7, 8, 9});
    Matrix m2(3, 3, {2, 0, 0, 0, 1, 0, 0, 0, 1});

    std::cout << inverse(m1);
    std::cout << inverse(m2);

    return 0;
}

在这里插入图片描述

4. 完整示例代码

我们将声明和实现分离,共三个文件为 matrix.hpp、matrix.cpp、main.cpp

matrix.hpp

#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <vector>
#include <iostream>
#include <ostream>
#include <initializer_list>
#include <functional>

class Matrix{
public:
    Matrix();
    Matrix(int rows, int cols, const std::initializer_list<float>& data={});

    Matrix element_wise(const std::function<float(float)>& func) const;
    Matrix operator*(float value) const;
    Matrix operator-(float value) const;
    Matrix operator+(float value) const;
    Matrix operator/(float value) const;

    const int rows() const{return rows_;};
    const int cols() const{return cols_;};
    const std::vector<float>& data() const{return data_;};
    const float* ptr()const{return data_.data();};
    float* ptr(){return data_.data();};

    float& operator()(int ir, int ic);

    const float& operator()(int ir, int ic) const;

    // Matrix gemm(const Matrix& other, bool ta, bool tb, float alpha=1.0f, float beta=0.0f){
    //     return ::gemm(*this, ta, other, tb, alpha, beta);
    // }

private:
    int rows_;
    int cols_;
    std::vector<float> data_;
};

std::ostream& operator<<(std::ostream& out, const Matrix& m);

Matrix operator*(float value, const Matrix& m);

#endif // MATRIX_HPP

matrix.cpp

#include "cblas.h"
#include "lapacke.h"
#include "matrix.hpp"

Matrix::Matrix(){}
Matrix::Matrix(int rows, int cols, const std::initializer_list<float>& data){
    this->rows_ = rows;
    this->cols_ = cols;
    this->data_ = data;

    if(this->data_.size() < rows * cols)
        this->data_.resize(rows * cols);
};

Matrix Matrix::element_wise(const std::function<float(float)>& func) const{
    Matrix output = *this;
    for(int i = 0; i < output.data_.size(); ++i){
        output.data_[i] = func(output.data_[i]);
    }
    return output;
}

Matrix Matrix::operator*(float value) const{
    return element_wise([&](float x){return x * value;});
}
Matrix Matrix::operator-(float value) const{
    return element_wise([&](float x){return x - value;});
}
Matrix Matrix::operator+(float value) const{
    return element_wise([&](float x){return x + value;});
}
Matrix Matrix::operator/(float value) const{
    return element_wise([&](float x){return x / value;});
}

float& Matrix::operator()(int ir, int ic){
    int index = ir * cols_ + ic;
    return data_[index];
}

const float& Matrix::operator()(int ir, int ic) const{
    int index = ir * cols_ + ic;
    return data_[index];
}


std::ostream& operator<<(std::ostream& out, const Matrix& m){
    
    printf("Matrix (%d x %d)\n", m.rows(), m.cols());
    for(int ir = 0; ir < m.rows(); ++ir){
        for(int ic = 0; ic < m.cols(); ++ic){
            printf("%g\t", m(ir, ic));
        }
        printf("\n");
    }
    return out;
}

Matrix operator*(float value, const Matrix& m){
    return m * value;
}


// ========== gemm ==========
Matrix gemm(const Matrix& a, bool ta, const Matrix& b, bool tb, float alpha, float beta){
    // C = ta(A) * ta(B) * alpha + beta

    int ta_rows = ta ? a.cols() : a.rows();
    int ta_cols = ta ? a.rows() : a.cols();
    int tb_rows = tb ? b.cols() : b.rows();
    int tb_cols = tb ? b.rows() : b.cols();

    Matrix c(ta_rows, tb_cols);
    int m = ta_rows;
    int n = tb_cols;
    int k = ta_cols;
    
    // 为了解决比如步长不等于列数的情况
    int lda = a.cols();     // A矩阵的每一行所需要的步长
    int ldb = b.cols();
    int ldc = c.cols();

    cblas_sgemm(
        CblasRowMajor,
        ta ? CblasTrans : CblasNoTrans,
        tb ? CblasTrans : CblasNoTrans,
        m, n, k, alpha, a.ptr(), lda, b.ptr(), ldb, beta, c.ptr(), ldc 
    );
    return c;
}

// ========== inv ========== 
Matrix inverse(const Matrix& a){
    
    if(a.rows() != a.cols()){
        printf("Invalid to compute inverse matrix by %d x %d\n", a.rows(), a.cols());
        return Matrix();
    }
    
    Matrix output = a;
    int n = a.rows();
    int *ipiv = new int[n];
    
    /* LU分解 */
    int code = LAPACKE_sgetrf(LAPACK_COL_MAJOR, n, n, output.ptr(), n, ipiv);
    if(code == 0){
        /* 使用LU分解求解通用逆矩阵 */
        code = LAPACKE_sgetri(LAPACK_COL_MAJOR, n, output.ptr(), n, ipiv);
    }

    if(code != 0){
        printf("LAPACKE inverse matrix failed, code = %d\n", code);
        return Matrix();
    }

    delete[] ipiv;
    return output;
}

main.cpp

#include <iostream>
#include "matrix.hpp"
using namespace std;

Matrix inverse(const Matrix& a);
Matrix gemm(const Matrix& a, bool ta, const Matrix& b, bool tb, float alpha, float beta);


int main(){

    Matrix m1(3, 3, {3, 2, 3, 4, 5, 6, 7, 8, 9});
    Matrix m2(3, 3, {2, 0, 0, 0, 1, 0, 0, 0, 1});

    std::cout << gemm(m1, false, m2, false, 1.0f, 0.0f);
    std::cout << inverse(m2);

    std::cout << m1 * 2 + 5 - 1;

    return 0;
}

总结

本次课程跟随杜老师手写了 Matrix 类的具体实现,学习到了很多关于 C++ 语法、习惯的知识,同时也碰到了一些问题,学习的过程就是不断解决问题的过程😄

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/464907.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

01 背包 (二维 )

首先是我对背包问题的理解&#xff1a; 有一个背包可以放下 n kg&#xff0c;有一些物品&#xff0c;价值和重量一一对应&#xff0c;问题是&#xff0c;需要怎样才能使背包中的价值最大&#xff1f; 不同的规则对应不同的背包问题 01背包&#xff1a;每一个物品只能被放入一次…

Docker consul的容器集群的部署|consul-template部署

Docker consul的容器集群的部署|consul-template部署 一、Consul 概述基于nginx和consul构建高可用及自动发现的Docker服务架构 二 consul实验步骤2.1 部署Consul集群 (server)2.2 Consul部署&#xff08;Client端&#xff09;2.3 consul-template部署(server)2.4 编译安装ngin…

【翻译一下官方文档】邂逅uniCloud云函数(基础篇)

我将用图文的形式&#xff0c;把市面上优质的课程加以自己的理解&#xff0c;详细的把&#xff1a;创建一个uniCloud的应用&#xff0c;其中的每一步记录出来&#xff0c;方便大家写项目中&#xff0c;做到哪一步不会了&#xff0c;可以轻松翻看文章进行查阅。&#xff08;此文…

量表题如何分析?

量表是一种测量工具&#xff0c;量表设计标准有很多&#xff0c;并且每种量表的设计都有各自的特性&#xff0c;不同量表的特性也决定了测量尺度&#xff0c;在数据分析中常用的量表为李克特量表。李克特量表1932年由美国社会心理学家李克特在当时原有总加量表的基础上进行改进…

Java8使用Stream流实现List列表简单使用

目录 1.forEach() 2.filter&#xff08;T -> boolean&#xff09; 3.findAny()和findFirst() 4.map(T -> R) 和flatMap(T -> stream) 5.distinct() 去重 6.limit(long n)和skip(long n) 7.anyMatch(T -> boolean) 8.allMatch(T -> boolean) 9.noneMat…

ASP.NET Core MVC 从入门到精通之数据库

随着技术的发展&#xff0c;ASP.NET Core MVC也推出了好长时间&#xff0c;经过不断的版本更新迭代&#xff0c;已经越来越完善&#xff0c;本系列文章主要讲解ASP.NET Core MVC开发B/S系统过程中所涉及到的相关内容&#xff0c;适用于初学者&#xff0c;在校毕业生&#xff0c…

ThingsBoard教程(三四):筛选规则节点 根据资产,设备,筛选,asset profile switch,device profile switch

前言 这是规则节点解析系列的第一篇,让我们先从Filter Nodes ,筛选节点类型开始。 筛选节点的作用主要是为了从筛选进入规则链的数据,根据一定的判断表达式来判断,数据向下游的那个分支流转。类似我们编程中的switch语句或if语句。 本篇主要讲解asset profile switch 与de…

每天一道算法练习题--Day13 第一章 --算法专题 --- ----------动态规划(重置版)

动态规划是一个从其他行业借鉴过来的词语。 它的大概意思先将一件事情分成若干阶段&#xff0c;然后通过阶段之间的转移达到目标。由于转移的方向通常是多个&#xff0c;因此这个时候就需要决策选择具体哪一个转移方向。 动态规划所要解决的事情通常是完成一个具体的目标&…

什么是渲染农场?我什么时候应该使用渲染农场?

网络上有关渲染农场的概念数不胜数&#xff0c;有一部分说法甚至让我们对渲染农场有了很大误解&#xff0c;究竟真正什么是渲染农场、渲染农场有多少种类型&#xff1f;我们怎么选择适合自己的渲染农场&#xff1f;这些都是各位小伙伴们近期比较关心的一些问题。 首先渲染农场是…

【C语言】基础语法7:文件操作

上一篇&#xff1a;字符串和字符处理 ❤️‍&#x1f525;前情提要❤️‍&#x1f525;   欢迎来到C语言基本语法教程   在本专栏结束后会将所有内容整理成思维导图&#xff08;结束换链接&#xff09;并免费提供给大家学习&#xff0c;希望大家纠错指正。本专栏将以基础出…

域内密码凭证获取

Volume Shadow Copy 活动目录数据库 ntds.dit&#xff1a;活动目录数据库&#xff0c;包括有关域用户、组和成员身份的 信息。它还包括域中所有用户的哈希值。 ntds.dit文件位置&#xff1a;%SystemRoot%\NTDS\NTDS.dit system文件位置&#xff1a;%SystemRoot%\System32\c…

好程序员:前端JavaScript全解析——Canvas绘制形状(下)

接着上一篇&#xff0c;好程序员继续讲解前端技术文章&#xff01; 绘制椭圆 ●canvas 也提供了绘制椭圆的 API ●语法 : 工具箱.ellipse( x, y, radiusX, radiusY, rotation, startAngle, endAngle, antiClockwise ) ○x : 椭圆中心点的 x 轴坐标 ○y : 椭圆中心点的 y 轴坐标…

Maven详解

一、什么是Maven Maven 是⼀个项⽬构建⼯具&#xff0c;创建的项⽬只要遵循 Maven 规范&#xff08;称为Maven项目&#xff09;&#xff0c;即可使用Maven 来进行&#xff1a;管理 jar 包、编译项目&#xff0c;打包项目等功能。 为什么学习 Servlet 之前要学 Maven&#xff1…

SAM(2023)-分割万物

文章目录 摘要算法数据引擎实验7.1 零样本单点生成mask7.2 零样本边缘检测7.3. 零样本目标Proposals7.4. 零样本实例分割7.5. 零样本文本生成Mask7.6. 消融实验 讨论限制&#xff1a;结论&#xff1a; 论文: 《Segment Anything》 github: https://github.com/facebookresear…

java获取类结构信息

package com.hspedu.reflection;import org.junit.jupiter.api.Test;import java.lang.annotation.Annotation; import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.Method;/*** author 韩顺平* version 1.0* 演示如何通过反射获…

初级算法-回溯算法

主要记录算法和数据结构学习笔记&#xff0c;新的一年更上一层楼&#xff01; 初级算法-回溯算法 一、组合二、电话号码的字母组合三、组合总和四、组合Ⅱ五、组合Ⅲ六、分割回文串七、复原IP地址八、子集问题九、子集Ⅱ十、递增子序列十一、重新安排行程十二、全排列十三、全…

CASAIM自动化精密尺寸测量设备全尺寸检测铸件自动化检测铸件

铸造作为现代装备制造工业的基础共性技术之一&#xff0c;铸件产品既是工业制造产品&#xff0c;也是大型机械的重要组成部分&#xff0c;被广泛运用在航空航天、工业船舶、机械电子和交通运输等行业。 铸件形状复杂&#xff0c;一般的三坐标或者卡尺圆规等工具难以获取多特征…

【基础算法】八大排序算法:直接插入排序,希尔排序,选择排序,堆排序,冒泡排序,快速排序(快排),归并排序,计数排序

文章目录 ✔️前言直接插入排序希尔排序选择排序1. 选择排序基础2. 选择排序优化3. 复杂度的分析 堆排序【⭐重点掌握⭐】1. 对堆的认识和数组建堆2. 对数组进行堆排序操作3. 复杂度的分析 冒泡排序快速排序【⭐重点掌握⭐】1. 霍尔法2. 挖坑法3. 前后指针法4. 快速排序优化&am…

每日一个小技巧:1招教你提取伴奏怎么做

伴奏是指在演唱或演奏时&#xff0c;用来衬托或补充主唱或乐器的音乐声音。而伴奏提取是一种技术&#xff0c;它可以帮助我们从歌曲中将人声和乐器分离出来。当我们听到一些喜欢的歌曲时&#xff0c;往往会被它的旋律深深吸引&#xff0c;想要将其作为自己的演唱曲目&#xff0…

国考只考一门?免试入学还好毕业的在职研究生专业有哪些

读同等学力申硕的同学想要拿学位证&#xff0c;那么首先要过的坎就是国考。修满学分和通过校考一般都不会很难&#xff0c;只要按时上课、根据院校安排的课程复习即可。而国考是全国统一命题、考试&#xff0c;大部分专业要考2门&#xff0c;对于有的同学来说&#xff0c;备考压…