AlgoC++第六课:BP反向传播算法

news2024/9/19 10:49:59

目录

  • BP反向传播算法
    • 前言
    • 1. MNIST
    • 2. 感知机
      • 2.1 前言
      • 2.2 感知机-矩阵表示
      • 2.3 感知机-矩阵表示-多个样本
      • 2.4 感知机-增加偏置
      • 2.5 感知机-多个输出
      • 2.6 总结
      • 2.7 关于广播
    • 3. BP
    • 4. 动量SGD
    • 5. BP示例代码
    • 总结

BP反向传播算法

前言

手写AI推出的全新面向AI算法的C++课程 Algo C++,链接。记录下个人学习笔记,仅供自己参考。

本次课程主要是讲解BP反向传播算法

课程大纲可看下面的思维导图

在这里插入图片描述

1. MNIST

MNIST是一个手写数字识别数据集,它包含60,000个训练图像和10,000个测试图像。每个图像是28x28像素大小的灰度图像,表示0到9之间的一个数字。该数据集最初由美国国家标准与技术研究所(NIST)在20世纪80年代收集和创建,后来由MNIST数据库维护。该数据集已经成为机器学习和计算机视觉领域的标准基准数据集之一。(from chatGPT)

MNIST官网有关于手写数字数据集及其格式的详细介绍

MNIST数据集的格式为二进制格式,分为训练集和测试集两部分,每个部分都由一个图像文件和一个标签文件组成。训练集包含60,000个图像,测试集包含10,000个图像。图像文件和标签文件的名称分别为:

  • 训练集图像文件:train-images-idx3-ubyte.gz
  • 训练集标签文件:train-labels-idx1-ubyte.gz
  • 测试集图像文件:t10k-images-idx3-ubyte.gz
  • 测试集标签文件:t10k-labels-idx1-ubyte.gz

图像文件包含原始的灰度图像数据,标签文件包含每个图像对应的数字标签。

具体来说,图像文件的格式如下:

  • 前4个字节:magic number(MSB first)
  • 4个字节:数据集中图像的数量
  • 4个字节:每个图像的行数
  • 4个字节:每个图像的列数
  • 后续的字节:表示每个图像的像素值,按照行优先的顺序排列

标签文件的格式如下:

  • 前4个字节:magic number(MSB first)
  • 4个字节:数据集中标签的数量
  • 后续的字节:表示每个图像对应的标签,取值范围为0到9

因此,MNIST数据集中的每个图像都是一个28x28像素的灰度图像,由0到255之间的整数值表示每个像素的灰度级别。标签是一个0到9之间的整数值,表示每个图像对应的数字。

下面的示例代码用于加载MNIST数据集并返回对应的图像和标签文件:

#include <stdio.h>
#include <string.h>
#include <iostream>
#include <vector>

struct __attribute__((packed)) mnist_labels_header_t{
    unsigned int magic_number;
    unsigned int number_of_items;
};

struct __attribute__((packed)) mnist_images_header_t{
    unsigned int magic_number;
    unsigned int number_of_images;
    unsigned int number_of_rows;
    unsigned int number_of_columns;
};

unsigned int inverse_byte(unsigned int v){
    unsigned char* p = (unsigned char*)&v;
    std::swap(p[0], p[3]);
    std::swap(p[1], p[2]);
    return v;
}

int main(){

    FILE* f = fopen("mnist/train-labels.idx1-ubyte", "rb");
    mnist_labels_header_t labels_header{0};
    fread(&labels_header, 1, sizeof(labels_header), f);
    printf("labels_header.magic_number = %X, number_of_items = %d\n", 
        inverse_byte(labels_header.magic_number), inverse_byte(labels_header.number_of_items));
    
    unsigned char label = 0;
    fread(&label, 1, sizeof(label), f);
    printf("First label is: %d\n", label);
    fclose(f);

    f = fopen("mnist/train-images.idx3-ubyte", "rb");
    mnist_images_header_t images_header{0};
    fread(&images_header, 1, sizeof(images_header), f);
    printf("images_header.magic_number = %X, number_of_images = %d, number_of_rows = %d, number_of_columns = %d\n", 
        inverse_byte(images_header.magic_number), 
        inverse_byte(images_header.number_of_images),
        inverse_byte(images_header.number_of_rows),
        inverse_byte(images_header.number_of_columns)
    );
    
    std::vector<unsigned char> image(inverse_byte(images_header.number_of_rows) * inverse_byte(images_header.number_of_columns));
    fread(image.data(), 1, image.size(), f);
    for(int i = 0;i < image.size(); ++i){
        if(image[i] == 0)
            printf("--- ");
        else
            printf("%03d ", image[i]);

        if((i + 1) % inverse_byte(images_header.number_of_columns) == 0)
            printf("\n");
    }
    fclose(f);
    return 0;
}

上述示例代码用于解析 MNIST 数据集中的标签和图像数据。代码首先定义了两个结构体:mnist_labels_header_tmnist_images_header_t,分别表示标签和图像的头部信息,其中使用 packed 属性来确保结构体不会被编译器优化对齐。

接下来定义了一个 inverse_byte 函数,用于将4字节整数类型的大小端转换

主函数首先打开标签文件,读取标签头部信息并打印。然后读取第一个标签,并打印;接着打开图像文件,读取图像头部信息并打印;然后读取第一张图像,将其打印出来;最后关闭文件。

运行效果如下图所示:

在这里插入图片描述

2. 感知机

2.1 前言

感知机可以简单用下图表示:

在这里插入图片描述

上图表示了:

  • 第一点: c = a × W 1 + b × W 2 c = a\times W_1 + b\times W_2 c=a×W1+b×W2
  • 第二点: d = a c t i v a t i o n ( c ) d = activation(c) d=activation(c) 这里的 activation 是一个激活函数
  • 激活函数通常是为了非线性映射,例如 s i g m o i d = 1 1 + e − x sigmoid = \frac{1}{1+e^{-x}} sigmoid=1+ex1 或者 r e l u = m a x ( 0 , x ) relu = max(0,x) relu=max(0,x)

重点

  • 任何两个节点的连接线是具有权重值的,例如 W 1 W_1 W1 W 2 W_2 W2
  • 多个节点连接到一个节点,指这多个节点值加权求和后,经过激活函数的结果即: d = a c t i v a t i o n ( a × W 1 + b × W 2 ) d = activation(a \times W_1 + b \times W_2) d=activation(a×W1+b×W2)

2.2 感知机-矩阵表示

我们在理解感知机可以通过用矩阵的方式来理解,定义 A = { a b } A=\left\{\begin{array}{cc}a&b\end{array}\right\} A={ab} W = { w 1 w 2 } W=\left\{\begin{array}{c}w1\\ w2\end{array}\right\} W={w1w2} 则输出为
c = { a b } × { w 1 w 2 } = A W c=\left\{a\quad b\right\}\times\left\{\begin{array}{c}w1\\ w2\end{array}\right\}=AW c={ab}×{w1w2}=AW

2.3 感知机-矩阵表示-多个样本

我们来看下多个样本的情况,当增加了一个样本后其实是在矩阵 A A A 中新增了一行(注意权重是同一组只是样本换了)

在这里插入图片描述

定义 A = { a b } A=\left\{\begin{array}{cc}a&b\end{array}\right\} A={ab} W = { w 1 w 2 } W=\left\{\begin{array}{c}w1\\ w2\end{array}\right\} W={w1w2} 那么输出为
{ c f } = { a b x y } × { w 1 w 2 } = A W \left\{\begin{array}{c}c\\ f\end{array}\right\}=\left\{\begin{array}{c}a&b\\ x&y\end{array}\right\}\times\left\{\begin{array}{c}w1\\ w2\end{array}\right\}=AW {cf}={axby}×{w1w2}=AW

2.4 感知机-增加偏置

关于偏置的存在,考虑 y = k x + b y = kx + b y=kx+b 直线公式,若 b = 0 b = 0 b=0,则退化为 y = k x y = kx y=kx,此时表达的直接必定过零点,无法表达不过零点的直线,所以偏置在这里非常重要,感知机增加偏置后的图如下所示:

在这里插入图片描述

c = { a b } × { w 1 w 2 } + b i a s = A W + b i a s c=\left\{\begin{array}{c}a&b\end{array}\right\}\times\left\{\begin{array}{c}w1\\ w2\end{array}\right\}+bias=AW+bias c={ab}×{w1w2}+bias=AW+bias

2.5 感知机-多个输出

当感知机存在多个输出时,如下图所示,

在这里插入图片描述

定义 c = { a b } × { w 11 w 12 } + b i a s 1 c=\left\{a\quad b\right\}\times\left\{\begin{array}{c}w11\\ w12\end{array}\right\}+bias1 c={ab}×{w11w12}+bias1 q = { a b } × { w 21 w 22 } + b i a s 2 q=\left\{\begin{array}{c}a&b\end{array}\right\}\times\left\{\begin{array}{c}w21\\ w22\end{array}\right\}+bias2 q={ab}×{w21w22}+bias2

将输出 c c c q q q 合并为一个矩阵,输出如下:
{ c q } = { a b } × { w 11 w 21 w 12 w 22 } + { b i a s 1 b i a s 2 } = A W + B \begin{aligned} \left\{\begin{array}{cc}c&q\end{array}\right\}=\left\{\begin{array}{cc}a&b\end{array}\right\}\times\left\{\begin{array}{cc}w11&w21\\ w12&w22\end{array}\right\}+\left\{\begin{array}{cc}bias1&bias2\end{array}\right\}=AW+B \end{aligned} {cq}={ab}×{w11w12w21w22}+{bias1bias2}=AW+B

2.6 总结

  • 新增一个样本, A A A 增加一行
  • 新增一个输出, W W W 增加一列
  • A A A 的行数是样本数, A A A 的列数是特征数
  • W W W 的行数是输入特征数, W W W 的列数是输出特征数
  • 可以认为 A A A 经过 W W W 映射为新的特征

2.7 关于广播

广播机制是非常重要的一种特性,它可以使得不同形状的矩阵在一些条件下能够进行数学运算。当运算中两个矩阵地形状不一致时,会自动扩展,以满足运算条件,这个过程就称为广播

对于矩阵 A A A B B B 的元素操作(如点乘、点加、点除等等),广播约定了假设 A A A 1 × 5 1\times 5 1×5 B B B 3 × 5 3\times 5 3×5,则约定把 A A A 在行方向复制 3 份后,再与 B B B 进行元素操作。同理可以发生在列上,发生在 B B B 上。

对于
{ c q f m } = { a b x y } × { w 11 w 21 w 12 w 22 } + { b i a s 1 b i a s 2 } = A W + B \left\{\begin{array}{cc}c&q\\f&m\end{array}\right\}=\left\{\begin{array}{cc}a&b\\x&y\end{array}\right\}\times\left\{\begin{array}{cc}w11&w21\\w12&w22\end{array}\right\}+\left\{\begin{array}{cc}bias1&bias2\end{array}\right\}=AW+B {cfqm}={axby}×{w11w12w21w22}+{bias1bias2}=AW+B
等价于
{ c q f m } = { a b x y } × { w 11 w 21 w 12 w 22 } + { b i a s 1 b i a s 2 b i a s 1 b i a s 2 } = A W + B \left\{\begin{array}{cc}c&q\\f&m\end{array}\right\}=\left\{\begin{array}{cc}a&b\\x&y\end{array}\right\}\times\left\{\begin{array}{cc}w11&w21\\w12&w22\end{array}\right\}+\left\{\begin{array}{cc}bias1&bias2\\bias1&bias2\end{array}\right\}=AW+B {cfqm}={axby}×{w11w12w21w22}+{bias1bias1bias2bias2}=AW+B

3. BP

BP(Back Propagation)误差反向传播算法,使用反向传播算法的多层感知器又称为 BP 神经网络。BP 是当前人工智能主要采用的算法,例如所知道的 CNN、GAN、NLP的Bert、Transformer,都是 BP 体系下的算法框架。

理解 BP 对于理解网络如何训练很重要

在这里我们采用最简单的思路理解BP。确保能够理解并且复现

在这里插入图片描述

使用BP的训练流程

  • 1.计算隐藏层的输出: H = r e l u ( X W 1 + B 1 ) H = relu(XW_1+B_1) H=relu(XW1+B1)
  • 2.计算输出层的预测概率: P = s i g m o i d ( H W 2 + B 2 ) P = sigmoid(HW_2+B_2) P=sigmoid(HW2+B2)
  • 3.计算损失: L = B i n a r y C r o s s E n t r o p y L o s s ( P , Y ) L = BinaryCrossEntropyLoss(P,Y) L=BinaryCrossEntropyLoss(P,Y)
  • 4.计算 L L L W 2 W_2 W2 B 2 B_2 B2 的梯度: ∂ L ∂ W 2 = H T ( P − Y ) \frac{\partial L}{\partial W_2}=H^{T}(P-Y) W2L=HT(PY) ∂ L ∂ B 2 = r e d u c e _ s u m ( P − Y ) \frac{\partial L}{\partial B_{2}}= reduce\_sum(P-Y) B2L=reduce_sum(PY)
  • 5.计算 L L L W 1 W_1 W1 B 1 B_1 B1 的梯度: ∂ L ∂ W 1 = X T ∂ L ∂ ( X W 1 + B 1 ) \frac{\partial L}{\partial W_{1}}=X^{T}\frac{\partial L}{\partial(X W_{1}+B_{1})} W1L=XT(XW1+B1)L ∂ L ∂ B 1 = r e d u c e _ s u m ∂ L ∂ ( X W 1 + B 1 ) \frac{\partial L}{\partial B_{1}}=reduce\_sum\frac{\partial L}{\partial(X W_{1}+B_{1})} B1L=reduce_sum(XW1+B1)L
  • 6.拿到梯度后,对每一个参数应用优化器进行更新迭代

部分核心代码如下

// 开始循环所有的batch
for(int ibatch = 0; ibatch < num_batch_per_epoch; ++ibatch, ++total_batch){
    
    // 前向传播
    auto x           = data::choice_rows(trainimages, image_indexs, ibatch * batch_size, batch_size);
    auto y           = data::choice_rows(trainlabels, image_indexs, ibatch * batch_size, batch_size);
    auto hidden      = x.gemm(input_to_hidden) + hidden_bias;
    auto hidden_act  = nn::relu(hidden);
    auto output      = hidden_act.gemm(hidden_to_output) + output_bias;
    auto probability = nn::sigmoid(output);
    float loss 		 = nn::compute_loss(probability, y);
    
    // 反向传播
    // C = AB
    // dA =  G * BT
    // dB = AT * G
    // loss部分求导,loss对output求导
    auto doutput           = (probability - y) * (1 / (float)batch_size);
    
    // 第二个Linear求导
    auto doutput_bias      = data::reduce_sum_by_row(output);
    auto dhidden_to_output = hidden_act.gemm(doutput, true);
    auto hidden_act        = doutput.gemm(hidden_to_output, false, true);
    
    // 第一个Linear输出求导
    auto dhidden           = nn::drelu(dhidden_act, hidden);
    
    // 第一个Linear求导
    auto dinput_to_hidden  = x.gemm(dhidden, true);
    auto dhidden_bias      = data::reduce_sum_by_row(dhidden);
    
    // 调用优化器来调整更新参数
    optim.update_params(
    	(&input_to_hidden,  &hidden_bias,  &hidden_to_output,  &output_bias),
        (&dinput_to_hidden, &dhidden_bias, &dhidden_to_output, $doutput_bias),
    	lr, momentum
    );
}

4. 动量SGD

对于参数更新方向等于 − l r ∗ g r a d -lr * grad lrgrad,我们定义 D = − l r ∗ g r a d D = -lr * grad D=lrgrad

而梯度下降时,我们有: θ + = θ + D \theta^+=\theta+D θ+=θ+D

假设梯度方向固定沿着右边取值相同,则每个时刻的推进都是均匀的,如下图所示:

在这里插入图片描述

对于动量 Momentum,则是基于物理上的惯性设计,定义动量系数 m m m

定义 t 1 t_1 t1 时刻的累计梯度量: A = D t 0 ⋅ m + D t 1 A = D_{t0}\cdot m+ D_{t1} A=Dt0m+Dt1 其中 D t 0 = 0 D_{t0} = 0 Dt0=0

A A A 就是动量 SGD 的参数更新方向 θ + = θ + A \theta^+=\theta+A θ+=θ+A

假设梯度方向固定沿着右边取值相同,则每个时刻的推进都有惯性作用,也可以连续下降的区域,会具有更快的下降速度。若在梯度方向不同时,也会存在正负抵消,从而更小心翼翼的前进,如下图所示:

在这里插入图片描述

5. BP示例代码

main.cpp

#include <vector>
#include <string>
#include <iostream>
#include <fstream>
#include <cmath>
#include <tuple>
#include <iomanip>
#include <stdarg.h>
#include <memory.h>
#include <random>
#include <algorithm>
#include <chrono>
#include "matrix.hpp"

using namespace std;

namespace Application{

    static default_random_engine global_random_engine;

    namespace logger{

        #define INFO(...)  Application::logger::__printf(__FILE__, __LINE__, __VA_ARGS__)

        void __printf(const char* file, int line, const char* fmt, ...){

            va_list vl;
            va_start(vl, fmt);

            // None   = 0,     // 无颜色配置
            // Black  = 30,    // 黑色
            // Red    = 31,    // 红色
            // Green  = 32,    // 绿色
            // Yellow = 33,    // 黄色
            // Blue   = 34,    // 蓝色
            // Rosein = 35,    // 品红
            // Cyan   = 36,    // 青色
            // White  = 37     // 白色
            /* 格式是: \e[颜色号m文字\e[0m   */
            printf("\e[32m[%s:%d]:\e[0m ", file, line);
            vprintf(fmt, vl);
            printf("\n");
        }
    };

    namespace io{

        struct __attribute__((packed)) mnist_labels_header_t{
            unsigned int magic_number;
            unsigned int number_of_items;
        };

        struct __attribute__((packed)) mnist_images_header_t{
            unsigned int magic_number;
            unsigned int number_of_images;
            unsigned int number_of_rows;
            unsigned int number_of_columns;
        };

        unsigned int inverse_byte(unsigned int v){
            unsigned char* p = (unsigned char*)&v;
            std::swap(p[0], p[3]);
            std::swap(p[1], p[2]);
            return v;
        }

        /* 加载mnist数据集 */
        tuple<Matrix, Matrix> load_data(const string& image_file, const string& label_file){

            Matrix images, labels;
            fstream fimage(image_file, ios::binary | ios::in);
            fstream flabel(label_file, ios::binary | ios::in);

            mnist_images_header_t images_header;
            mnist_labels_header_t labels_header;
            fimage.read((char*)&images_header, sizeof(images_header));
            flabel.read((char*)&labels_header, sizeof(labels_header));

            images_header.number_of_images = inverse_byte(images_header.number_of_images);
            labels_header.number_of_items  = inverse_byte(labels_header.number_of_items);

            images.resize(images_header.number_of_images, 28 * 28);
            labels.resize(labels_header.number_of_items, 10);

            std::vector<unsigned char> buffer(images.rows() * images.cols());
            fimage.read((char*)buffer.data(), buffer.size());

            for(int i = 0; i < buffer.size(); ++i)
                images.ptr()[i] = (buffer[i] / 255.0f - 0.1307f) / 0.3081f;
                //images.ptr()[i] = (buffer[i] - 127.5f) / 127.5f;

            buffer.resize(labels.rows());
            flabel.read((char*)buffer.data(), buffer.size());
            for(int i = 0; i < buffer.size(); ++i)
                labels.ptr(i)[buffer[i]] = 1;   // onehot
            return make_tuple(images, labels);
        }

        void print_image(const float* ptr, int rows, int cols){

            for(int i = 0;i < rows * cols; ++i){

                //int pixel = ptr[i] * 127.5 + 127.5;
                int pixel = (ptr[i] * 0.3081f + 0.1307f) * 255.0f;
                if(pixel < 20)
                    printf("--- ");
                else
                    printf("%03d ", pixel);

                if((i + 1) % cols == 0)
                    printf("\n");
            }
        }

        bool save_model(const string& file, const vector<Matrix>& model){

            ofstream out(file, ios::binary | ios::out);
            if(!out.is_open()){
                INFO("Open %s failed.", file.c_str());
                return false;
            }

            unsigned int header_file[] = {0x3355FF11, model.size()};
            out.write((char*)header_file, sizeof(header_file));

            for(auto& m : model){
                int header[] = {m.rows(), m.cols()};
                out.write((char*)header, sizeof(header));
                out.write((char*)m.ptr(), m.numel() * sizeof(float));
            }
            return out.good();
        }

        bool load_model(const string& file, vector<Matrix>& model){

            ifstream in(file, ios::binary | ios::in);
            if(!in.is_open()){
                INFO("Open %s failed.", file.c_str());
                return false;
            }

            unsigned int header_file[2];
            in.read((char*)header_file, sizeof(header_file));

            if(header_file[0] != 0x3355FF11){
                INFO("Invalid model file: %s", file.c_str());
                return false;
            }

            model.resize(header_file[1]);
            for(int i = 0; i < model.size(); ++i){
                auto& m = model[i];
                int header[2];
                in.read((char*)header, sizeof(header));
                m.resize(header[0], header[1]);
                in.read((char*)m.ptr(), m.numel() * sizeof(float));
            }
            return in.good();
        }
    };

    namespace data{

        int argmax(float* ptr, int size){
            return std::max_element(ptr, ptr + size) - ptr;
        }

        Matrix choice_rows(const Matrix& m, const vector<int>& indexs, int begin=0, int size=-1){

            if(size == -1) size = indexs.size();
            Matrix out(size, m.cols());
            for(int i = 0; i < size; ++i){
                int mrow = indexs[i + begin];
                int orow = i;
                memcpy(out.ptr(orow), m.ptr(mrow), sizeof(float) * m.cols());
            }
            return out;
        }

        Matrix reduce_sum_by_row(const Matrix& value){
            Matrix out(1, value.cols());
            auto optr = out.ptr();
            auto vptr = value.ptr();
            for(int i = 0; i < value.numel(); ++i, ++vptr)
                optr[i % value.cols()] += *vptr;
            return out;
        }
    };

    namespace tools{

        vector<int> range(int end){
            vector<int> out(end);
            for(int i = 0; i < end; ++i)
                out[i] = i;
            return out;
        }

        double timenow(){
            return chrono::duration_cast<chrono::microseconds>(chrono::system_clock::now().time_since_epoch()).count() / 1000.0;
        }
    };

    namespace nn{

        Matrix relu(const Matrix& input){
            Matrix out(input);
            for(int i = 0; i < out.numel(); ++i)
                out.ptr()[i] = std::max(0.0f, out.ptr()[i]);
            return out;
        }

        Matrix drelu(const Matrix& grad, const Matrix& x){
            Matrix out = grad;
            auto optr = out.ptr();
            auto xptr = x.ptr();
            for(int i = 0; i < out.numel(); ++i, ++optr, ++xptr){
                if(*xptr <= 0)
                    *optr = 0;
            }
            return out;
        }

        Matrix sigmoid(const Matrix& input){
            Matrix out(input);
            float eps = 1e-5;
            for(int i = 0; i < out.numel(); ++i){
                float& x = out.ptr()[i];

                /* 处理sigmoid数值稳定性问题 */
                if(x < 0){
                    x = exp(x) / (1 + exp(x));
                }else{
                    x = 1 / (1 + exp(-x));
                }

                /* 保证x不会等于0或者等于1 */
                x = std::max(0.0f + eps, std::min(1.0f - eps, x));
            }
            return out;
        }

        float compute_loss(const Matrix& probability, const Matrix& onehot_labels){

            float eps = 1e-5;
            float sum_loss  = 0;
            auto pred_ptr   = probability.ptr();
            auto onehot_ptr = onehot_labels.ptr();
            int numel       = probability.numel();
            for(int i = 0; i < numel; ++i, ++pred_ptr, ++onehot_ptr){
                auto y = *onehot_ptr;
                auto p = *pred_ptr;
                p = max(min(p, 1 - eps), eps);
                sum_loss += -(y * log(p) + (1 - y) * log(1 - p));
            }
            return sum_loss / probability.rows();
        }

        float eval_test_accuracy(const Matrix& probability, const Matrix& labels){

            int success = 0;
            for(int i = 0; i < probability.rows(); ++i){
                auto row_ptr = probability.ptr(i);
                int predict_label = std::max_element(row_ptr, row_ptr + probability.cols()) - row_ptr;
                if(labels.ptr(i)[predict_label] == 1)
                    success++;
            }
            return success / (float)probability.rows();
        }
    };

    namespace random{

        Matrix create_normal_distribution_matrix(int rows, int cols, float mean=0.0f, float stddev=1.0f){

            normal_distribution<float> norm(mean, stddev);
            Matrix out(rows, cols);
            for(int i = 0; i < rows; ++i){
                for(int j = 0; j < cols; ++j)
                    out.ptr(i)[j] = norm(global_random_engine);
            }
            return out;
        }
    };

    namespace optimizer{

        struct SGDMomentum{
            vector<Matrix> delta_momentums;

            // 提供对应的参数params,和对应的梯度grads,进行参数的更新
            void update_params(const vector<Matrix*>& params, const vector<Matrix*>& grads, float lr, float momentum=0.9){

                if(delta_momentums.size() != params.size())
                    delta_momentums.resize(params.size());

                for(int i =0 ; i < params.size(); ++i){
                    auto& delta_momentum = delta_momentums[i];
                    auto& param          = *params[i];
                    auto& grad           = *grads[i];

                    if(delta_momentum.numel() == 0)
                        delta_momentum.resize(param.rows(), param.cols());
                    
                    delta_momentum = momentum * delta_momentum - lr * grad;
                    param          = param + delta_momentum;
                }
            }
        };
    };
    
    int run(){

        Matrix trainimages, trainlabels;
        Matrix valimage, vallabels;
        tie(trainimages, trainlabels) = io::load_data("mnist/train-images.idx3-ubyte", "mnist/train-labels.idx1-ubyte");
        tie(valimage, vallabels)      = io::load_data("mnist/t10k-images.idx3-ubyte",  "mnist/t10k-labels.idx1-ubyte");
        
        int num_images  = trainimages.rows();
        int num_input   = trainimages.cols();
        int num_hidden  = 1024;
        int num_output  = 10;
        int num_epoch   = 10;
        float lr        = 1e-1;
        int batch_size  = 256;
        float momentum  = 0.9f;
        int num_batch_per_epoch = num_images / batch_size;
        auto image_indexs       = tools::range(num_images);

        // 凯明初始化,fan_in + fan_out
        // W1 B1
        Matrix input_to_hidden  = random::create_normal_distribution_matrix(num_input,  num_hidden, 0, 2.0f / sqrt((float)(num_input + num_hidden)));
        Matrix hidden_bias(1, num_hidden);

        // W2 B2
        Matrix hidden_to_output = random::create_normal_distribution_matrix(num_hidden, num_output, 0, 1.0f / sqrt((float)(num_hidden + num_output)));
        Matrix output_bias(1, num_output);

        optimizer::SGDMomentum optim;
        auto t0 = tools::timenow();
        int total_batch = 0;
        for(int epoch = 0; epoch < num_epoch; ++epoch){

            if(epoch == 8){
                lr *= 0.1;
            }

            // 打乱索引
            // 0, 1, 2, 3, 4, 5 ... 59999
            // 199, 20, 1, 9, 10, 6, ..., 111
            std::shuffle(image_indexs.begin(), image_indexs.end(), global_random_engine);
            
            // 开始循环所有的batch
            for(int ibatch = 0; ibatch < num_batch_per_epoch; ++ibatch, ++total_batch){

                // 前向过程
                // trainimages -> X(60000, 784)
                // idx = image_indexs[0:256] -> 乱的
                // X = trainimages[idx]
                auto x           = data::choice_rows(trainimages,   image_indexs, ibatch * batch_size, batch_size);
                auto y           = data::choice_rows(trainlabels,   image_indexs, ibatch * batch_size, batch_size);
                auto hidden      = x.gemm(input_to_hidden) + hidden_bias;
                auto hidden_act  = nn::relu(hidden);
                auto output      = hidden_act.gemm(hidden_to_output) + output_bias;
                auto probability = nn::sigmoid(output);
                float loss       = nn::compute_loss(probability, y);

                // 反向过程
                // C = AB
                // dA = G * BT
                // dB = AT * G
                // loss部分求导,loss对output求导
                auto doutput           = (probability - y) * (1 / (float)batch_size);

                // 第二个Linear求导
                auto doutput_bias      = data::reduce_sum_by_row(doutput);
                auto dhidden_to_output = hidden_act.gemm(doutput, true);
                auto dhidden_act       = doutput.gemm(hidden_to_output, false, true);

                // 第一个Linear输出求导
                auto dhidden           = nn::drelu(dhidden_act, hidden);

                // 第一个Linear求导
                auto dinput_to_hidden  = x.gemm(dhidden, true);
                auto dhidden_bias      = data::reduce_sum_by_row(dhidden);

                // 调用优化器来调整更新参数
                optim.update_params(
                    {&input_to_hidden,  &hidden_bias,  &hidden_to_output,  &output_bias},
                    {&dinput_to_hidden, &dhidden_bias, &dhidden_to_output, &doutput_bias},
                    lr, momentum
                );

                if((total_batch + 1) % 50 == 0){
                    auto t1 = tools::timenow();
                    auto batchs_time = t1 - t0;
                    t0 = t1;
                    INFO("Epoch %.2f / %d, Loss: %f, LR: %f [ %.2f ms / 50 batch ]", epoch + ibatch / (float)num_batch_per_epoch, num_epoch, loss, lr, batchs_time);
                }
            }

            // 模型对测试集进行测试,并打印精度
            auto test_hidden      = nn::relu(valimage.gemm(input_to_hidden) + hidden_bias);
            auto test_probability = nn::sigmoid(test_hidden.gemm(hidden_to_output) + output_bias);
            float accuracy        = nn::eval_test_accuracy(test_probability, vallabels);
            float test_loss       = nn::compute_loss(test_probability, vallabels);
            INFO("Test Accuracy: %.2f %%, Loss: %f", accuracy * 100, test_loss);
        }

        INFO("Save to model.bin .");
        io::save_model("model.bin", {input_to_hidden, hidden_bias, hidden_to_output, output_bias});

        for(int i = 0; i < valimage.rows(); ++i){

            auto input = data::choice_rows(valimage, {i});
            auto test_hidden      = nn::relu(input.gemm(input_to_hidden) + hidden_bias);
            auto test_probability = nn::sigmoid(test_hidden.gemm(hidden_to_output) + output_bias);

            int ilabel = data::argmax(test_probability.ptr(), test_probability.cols());
            float prob = test_probability.ptr()[ilabel];

            io::print_image(input.ptr(), 28, 28);
            INFO("Predict %d, Confidence = %f", ilabel, prob);

            printf("Pass [Enter] to next.");
            getchar();
        }
        return 0;
    }
};

int main(){
    return Application::run();
}

上述反向传播算法可以分为以下几个部分:(from chatGPT)

1.数据加载与预处理:通过 io 命名空间下的函数加载 MNIST 数据集,将数据集转换成矩阵形式,同时进行数据的预处理。值得注意的点有

  • 加载数据的格式和路径
  • 数据的维度和大小
  • 数据类型的一致性
  • 数据的归一化和预处理

2.神经网络的参数初始化:对于三层神经网络,我们需要初始化输入层到隐藏层之间的权重矩阵、隐藏层到输出层之间的权重矩阵以及对应的偏置向量。值得注意的点有

  • 网络结构的定义,包括层数、神经元个数等
  • 权重和偏置的初始化方法,以及如何保存和加载
  • 激活函数的选择

3.前向传播:使用矩阵乘法和激活函数 ReLU、Sigmoid 等操作完成前向传播,并计算出模型在当前参数下的预测结果。值得注意的点有

  • 矩阵运算的实现,包括矩阵乘法、加法等
  • 激活函数的实现
  • 矩阵的广播和重复

4.损失函数计算:计算模型在当前参数下的预测结果与真实标签之间的损失值,使用交叉熵作为损失函数。

5.反向传播:根据链式法则,依次计算各个参数的梯度,并进行参数更新。具体的过程是,首先计算输出层对输入的梯度,然后计算输出层参数的梯度,接着计算隐藏层的梯度,最后计算输入层参数的梯度。值得注意的点有

  • 梯度的计算方法和公式

  • 优化器的定义和实现

  • 学习率和动量的设置

  • 参数的更新和梯度的累加

6.模型训练:对于给定的训练集,将其分为多个 batch,每个 batch 中包含多个样本,通过反向传播更新参数,并计算损失值和准确率。值得注意的点有

  • 数据的划分和打乱
  • 批量训练的实现
  • 模型的保存和加载

7.模型测试:使用训练得到的模型对测试集进行测试,并计算准确率和损失值。

8.结果可视化:将模型对单张图像的预测结果可视化,包括输出结果和置信度。

运行效果如下:

在这里插入图片描述

总结

本次课程从最简单的感知机出发学习了 BP 反向传播算法,我们应该从矩阵的角度来分析,优化器方面基于传统的 SGD 学习了带动量的 SGD,它基于物理上的惯性设计,其效果更好。最后将 BP 反向传播的 C++ 示例代码跑了一遍,只进行了简单的分析,后续再跟着杜老师手写吧😄

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/463102.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PHP、一:概述

1.概念 2.wampsever安装 百度搜索直接下载 下图是解压后目录&#xff0c;所写文件必须写在www文件夹下。 例&#xff1a;www文件夹下新建1.php&#xff0c;phpinfo()查看当前版本等信息。 使用localhost访问 php版本切换&#xff1a; 鼠标左键点击wampserver&#xff0c;切…

git rebase

git rebase rebase 是一个……我觉得很麻烦的指令&#xff0c;不过没办法&#xff0c;公司算是有个软规定必须要使用 rebase。 rebase 的功能和 merge 很像&#xff0c;不过它能够保持一个相对干净的历史&#xff0c;继续举个例子&#xff0c;假设现在有一个新的功能开发好了…

Golang Gin HTTP 请求和参数解析

gin 网络请求与路由处理 我们介绍了Gin框架&#xff0c;并做了Gin框架的安装&#xff0c;完成了第一个Gin工程的创建。 创建Engine 在gin框架中&#xff0c;Engine被定义成为一个结构体&#xff0c;Engine代表gin框架的一个结构体定义&#xff0c;其中包含了路由组、中间件、…

26- OCR 基于PP-OCRv3的液晶屏读数识别

要点&#xff1a; 液晶屏识别示例github 地址 1. 简介 本项目基于PaddleOCR开源套件&#xff0c;以PP-OCRv3检测和识别模型为基础&#xff0c;针对液晶屏读数识别场景进行优化。主要是针对各种仪表进行识别&#xff1a; 2 安装环境 安装Git&#xff1a;Git 详细安装教程 # 首…

YOLOv8 Bug及解决方案汇总

Traceback (most recent call last): File “D:\Anaconda\Scripts\yolo-script.py”, line 33, in sys.exit(load_entry_point(‘ultralytics==8.0.83’, ‘console_scripts’, ‘yolo’)()) self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbos…

基于 Python 的 Meta AI —— SAM

Segment Anything Model&#xff08;SAM&#xff09;是 Facebook 的一个 AI 模型&#xff0c;旨在推广分割技术。在我们之前的文章中&#xff0c;我们讨论了 SAM 的一般信息&#xff0c;现在让我们深入了解其技术细节。SAM 模型的结构如下图所示&#xff0c;图像经过编码器得到…

【致敬未来的攻城狮计划】— 连续打卡第十三天:FSP固件库开发启动文件详解

系列文章目录 1.连续打卡第一天&#xff1a;提前对CPK_RA2E1是瑞萨RA系列开发板的初体验&#xff0c;了解一下 2.开发环境的选择和调试&#xff08;从零开始&#xff0c;加油&#xff09; 3.欲速则不达&#xff0c;今天是对RA2E1 基础知识的补充学习。 4.e2 studio 使用教程 5.…

《Spring MVC》 第八章 拦截器实现权限验证、异常处理

前言 Spring 提供了Interceptor 拦截器&#xff0c;可用于实现权限验证、异常处理等 1、拦截器 对用户请求进行拦截&#xff0c;并在请求进入控制器&#xff08;Controller&#xff09;之前、控制器处理完请求后、甚至是渲染视图后&#xff0c;执行一些指定的操作 1.1、定义…

UGUI中点击判断的原理

首选需要理解 EventSystem 中的代码结构&#xff0c;EventSystem 目录下包含4个子文件夹&#xff0c;分别是 EventData、InputModules&#xff0c;Raycasters 和 UIElements&#xff0c;UIElements 下是 UI Toolkit 相关代码&#xff0c;这里不做研究&#xff0c;主要关注其他三…

linux文件及文件内容查找命令总结

在linux环境下&#xff0c;我们经常要查找一个文件或者文件的内容&#xff0c;但搜索的命令有很多&#xff0c;这些命令都有什么区别&#xff0c;应该怎么选择和使用呢&#xff1f; 下面总结了一些常见的文件查找、内容查找的命令&#xff0c;收藏起来备用吧。 文件查找 where…

每日学术速递4.25

CV - 计算机视觉 | ML - 机器学习 | RL - 强化学习 | NLP 自然语言处理 Subjects: cs.CV 1.Long-Term Photometric Consistent Novel View Synthesis with Diffusion Models 标题&#xff1a;具有扩散模型的长期光度一致的新视图合成 作者&#xff1a;Jason J. Yu, Feresh…

Python入门教程+项目实战-11.3节: 元组的操作方法

目录 11.3.1 元组的常用操作方法 11.3.2 元组的查找 11.3.3 知识要点 11.3.4 系统学习python 11.3.1 元组的常用操作方法 元组类型是一种抽象数据类型&#xff0c;抽象数据类型定义了数据类型的操作方法&#xff0c;在本节的内容中&#xff0c;着重介绍元组类型的操作方法…

hive udf, tried to access method org.bouncycastle.math.ec.ECPoint$AbstractFp

在hive中添加加密udf,测试报错&#xff1a; select encrypt_sm2("aa","04AD9356466C7A505B3B2E18F2484E1F096108FA19C0F61C707A808EDF7C132BC3CE33E63D2CC6D77FB0A172004F8F5282CEADE22ED9628A02FE8FD85AF1EFE8B3"); Error: Error while compiling statem…

从0搭建Vue3组件库(九):VitePress 搭建部署组件库文档

VitePress 搭建组件库文档 当我们组件库完成的时候,一个详细的使用文档是必不可少的。本篇文章将介绍如何使用 VitePress 快速搭建一个组件库文档站点并部署到GitHub上 安装 首先新建 site 文件夹,并执行pnpm init,然后安装vitepress和vue pnpm install -D vitepress vue安…

什么是分库分表?为什么需要分表?什么时候分库分表

不急于上手实战 ShardingSphere 框架&#xff0c;先来复习下分库分表的基础概念&#xff0c;技术名词大多晦涩难懂&#xff0c;不要死记硬背理解最重要&#xff0c;当你捅破那层窗户纸&#xff0c;发现其实它也就那么回事。 什么是分库分表 分库分表是在海量数据下&#xff0…

“星河杯”隐私计算大赛新闻发布会在京召开

4月24日下午&#xff0c;“星河杯”隐私计算大赛新闻发布会在京召开。本次大赛由中国信通院、中国通信学会、隐私计算联盟共同主办&#xff0c;中移动信息技术有限公司、联通数字科技有限公司、天翼电子商务有限公司、中国通信标准化协会大数据技术标准推进委员会联合协办&…

微信小程序 | 基于高德地图+ChatGPT实现旅游规划小程序

&#x1f388;&#x1f388;效果预览&#x1f388;&#x1f388; ❤ 路劲规划 ❤ 功能总览 ❤ ChatGPT交互 一、需求背景 五一假期即即将到来&#xff0c;在大家都阳过之后&#xff0c;截止到目前这应该是最安全的一个假期。所以出去旅游想必是大多数人的选择。 然后&#x…

Activity中startForResult的原理分析

前言&#xff1a; 如果使用androidX支持库中的ComponentActivity&#xff0c;会推荐使用registerForActivityResult的方式。但是对于不支持androidX的项目&#xff0c;或者就是继承自Activity的页面来说&#xff0c;startActivityForResult仍然是唯一的选择。 如果想了解andr…

虹科教您 | 虹科RELY-TSN-KIT操作指南(3)——基于Linux系统进行TSN协议测试

随着技术的变革和实际生产业务需求的推动&#xff0c;工厂内部互联架构逐渐趋于扁平化&#xff08;IT/OT融合&#xff09;&#xff0c;而TSN则是在这一背景下发展起来的新兴技术&#xff0c;旨在为以太网协议建立“通用”的时间敏感机制&#xff0c;以确保网络数据传输的时间确…

云计算服务安全评估办法

云计算服务安全评估办法 2019-07-22 14:46 来源&#xff1a; 网信办网站【字体&#xff1a;大 中 小】打印 国家互联网信息办公室 国家发展和改革委员会 工业和信息化部 财政部关于发布《云计算服务安全评估办法》的公告 2019年 第2号 为提高党政机关、关键信息基础设施运营者…