numpy反向传播算法示例

news2024/11/24 1:03:20

numpy反向传播算法示例

数据

通过 scikit-learn 库提供的便捷工具生成 2000 个线性不可分的 2 分类数据集
按着7: 3比例切分训练集和测试集
在这里插入图片描述

backpropagation.py

#!/usr/bin/env python
# encoding: utf-8
"""
@desc:  反向传播算法
"""

import pickle 


import time 

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split

plt.rcParams['font.size'] = 16
# plt.rcParams['font.family'] = ['STKaiti']
plt.rcParams['axes.unicode_minus'] = False


def load_dataset():
    # 采样点数
    N_SAMPLES = 2000
    # 测试数量比率
    TEST_SIZE = 0.3
    # 利用工具函数直接生成数据集
    X, y = make_moons(n_samples=N_SAMPLES, noise=0.2, random_state=100)
    # 将 2000 个点按着 7:3 分割为训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=42)
    return X, y, X_train, X_test, y_train, y_test


def make_plot(X, y, plot_name, XX=None, YY=None, preds=None, dark=False):
    # 绘制数据集的分布, X 为 2D 坐标, y 为数据点的标签
    if (dark):
        plt.style.use('dark_background')
    else:
        sns.set_style("whitegrid")
    plt.figure(figsize=(16, 12))
    axes = plt.gca()
    axes.set(xlabel="$x_1$", ylabel="$x_2$")
    plt.title(plot_name, fontsize=30)
    plt.subplots_adjust(left=0.20)
    plt.subplots_adjust(right=0.80)
    if XX is not None and YY is not None and preds is not None:
        plt.contourf(XX, YY, preds.reshape(XX.shape), 25, alpha=1, cmap=plt.cm.Spectral)
        plt.contour(XX, YY, preds.reshape(XX.shape), levels=[.5], cmap="Greys", vmin=0, vmax=.6)
    # 绘制散点图,根据标签区分颜色
    plt.scatter(X[:, 0], X[:, 1], c=y.ravel(), s=40, cmap=plt.cm.Spectral, edgecolors='none')
    # plt.savefig('数据集分布.svg')
    # plt.close()
    plt.show()


class Layer:
    # 全连接网络层
    def __init__(self, n_input, n_neurons, activation=None, weights=None,
                 bias=None):
        """
        :param int n_input: 输入节点数
        :param int n_neurons: 输出节点数
        :param str activation: 激活函数类型
        :param weights: 权值张量,默认类内部生成
        :param bias: 偏置,默认类内部生成
        """
        # 通过正态分布初始化网络权值,初始化非常重要,不合适的初始化将导致网络不收敛
        self.weights = weights if weights is not None else np.random.randn(n_input, n_neurons) * np.sqrt(1 / n_neurons)
        self.bias = bias if bias is not None else np.random.rand(n_neurons) * 0.1
        self.activation = activation  # 激活函数类型,如’sigmoid’
        self.last_activation = None  # 激活函数的输出值o
        self.error = None  # 用于计算当前层的delta 变量的中间变量
        self.delta = None  # 记录当前层的delta 变量,用于计算梯度

    # 网络层的前向传播函数实现如下,其中last_activation 变量用于保存当前层的输出值:
    def activate(self, x):
        # 前向传播函数
        r = np.dot(x, self.weights) + self.bias  # X@W+b
        # 通过激活函数,得到全连接层的输出o
        self.last_activation = self._apply_activation(r)
        return self.last_activation

    # 上述代码中的self._apply_activation 函数实现了不同类型的激活函数的前向计算过程,
    # 尽管此处我们只使用Sigmoid 激活函数一种。代码如下:
    def _apply_activation(self, r):
        # 计算激活函数的输出
        if self.activation is None:
            return r  # 无激活函数,直接返回
        # ReLU 激活函数
        elif self.activation == 'relu':
            return np.maximum(r, 0)
        # tanh 激活函数
        elif self.activation == 'tanh':
            return np.tanh(r)
        # sigmoid 激活函数
        elif self.activation == 'sigmoid':
            return 1 / (1 + np.exp(-r))
        return r

    # 针对于不同类型的激活函数,它们的导数计算实现如下:
    def apply_activation_derivative(self, r):
        # 计算激活函数的导数
        # 无激活函数,导数为1
        if self.activation is None:
            return np.ones_like(r)
        # ReLU 函数的导数实现
        elif self.activation == 'relu':
            grad = np.array(r, copy=True)
            grad[r > 0] = 1.
            grad[r <= 0] = 0.
            return grad
        # tanh 函数的导数实现
        elif self.activation == 'tanh':
            return 1 - r ** 2
        # Sigmoid 函数的导数实现
        elif self.activation == 'sigmoid':
            return r * (1 - r)
        return r


# 神经网络模型
class NeuralNetwork:
    def __init__(self):
        self._layers = []  # 网络层对象列表

    def add_layer(self, layer):
        # 追加网络层
        self._layers.append(layer)

    # 网络的前向传播只需要循环调各个网络层对象的前向计算函数即可,代码如下:
    # 前向传播
    def feed_forward(self, X):
        for layer in self._layers:
            # 依次通过各个网络层
            X = layer.activate(X)
        return X

    def backpropagation(self, X, y, learning_rate):
        # 反向传播算法实现
        # 前向计算,得到输出值
        output = self.feed_forward(X)
        for i in reversed(range(len(self._layers))):  # 反向循环
            layer = self._layers[i]  # 得到当前层对象
            # 如果是输出层
            if layer == self._layers[-1]:  # 对于输出层
                layer.error = y - output  # 计算2 分类任务的均方差的导数
                # 关键步骤:计算最后一层的delta,参考输出层的梯度公式
                layer.delta = layer.error * layer.apply_activation_derivative(output)
            else:  # 如果是隐藏层
                next_layer = self._layers[i + 1]  # 得到下一层对象
                layer.error = np.dot(next_layer.weights, next_layer.delta)
                # 关键步骤:计算隐藏层的delta,参考隐藏层的梯度公式
                layer.delta = layer.error * layer.apply_activation_derivative(layer.last_activation)

        # 循环更新权值
        for i in range(len(self._layers)):
            layer = self._layers[i]
            # o_i 为上一网络层的输出
            o_i = np.atleast_2d(X if i == 0 else self._layers[i - 1].last_activation)
            # 梯度下降算法,delta 是公式中的负数,故这里用加号
            layer.weights += layer.delta * o_i.T * learning_rate

    def train(self, X_train, X_test, y_train, y_test, learning_rate, max_epochs):
        # 网络训练函数
        # one-hot 编码
        y_onehot = np.zeros((y_train.shape[0], 2))
        y_onehot[np.arange(y_train.shape[0]), y_train] = 1

        # 将One-hot 编码后的真实标签与网络的输出计算均方误差,并调用反向传播函数更新网络参数,循环迭代训练集1000 遍即可
        mses = []
        accuracys = []
        for i in range(max_epochs + 1):  # 训练1000 个epoch
            for j in range(len(X_train)):  # 一次训练一个样本
                self.backpropagation(X_train[j], y_onehot[j], learning_rate)
            if i % 10 == 0:
                # 打印出MSE Loss
                mse = np.mean(np.square(y_onehot - self.feed_forward(X_train)))
                mses.append(mse)
                accuracy = self.accuracy(self.predict(X_test), y_test.flatten())
                accuracys.append(accuracy)
                print('Epoch: #%s, MSE: %f' % (i, float(mse)))
                # 统计并打印准确率
                print('Accuracy: %.2f%%' % (accuracy * 100))
        return mses, accuracys

    def predict(self, X):
        return self.feed_forward(X)

    def accuracy(self, X, y):
        return np.sum(np.equal(np.argmax(X, axis=1), y)) / y.shape[0]


def main():
    X, y, X_train, X_test, y_train, y_test = load_dataset()
    # 调用 make_plot 函数绘制数据的分布,其中 X 为 2D 坐标, y 为标签
    # make_plot(X, y, "Classification Dataset Visualization ")

    mses = None
    accuracys = None

    ## make True to train  or  False to load trained modle
    train = False
    
    if (train) :
        nn = NeuralNetwork()  # 实例化网络类
        nn.add_layer(Layer(2, 25, 'sigmoid'))  # 隐藏层 1, 2=>25
        nn.add_layer(Layer(25, 50, 'sigmoid'))  # 隐藏层 2, 25=>50
        nn.add_layer(Layer(50, 25, 'sigmoid'))  # 隐藏层 3, 50=>25
        nn.add_layer(Layer(25, 2, 'sigmoid'))  # 输出层, 25=>2
        time1  = time.perf_counter()
        mses, accuracys = nn.train(X_train, X_test, y_train, y_test, 0.01, 1000)
        time2  = time.perf_counter() 

        print (f"train time : {time2-time1} s ")

        #--- save trained modle
        with open('trained_model.pickle', 'wb') as f:
            pickle.dump(nn, f)

        # 绘制MES曲线
        x = [i for i in range(0, 101, 10)]
        ax = plt.subplot(1,2,1)
        ax.set_title("MES Loss")
        plt.plot(x, mses[:11], color='blue')
        plt.xlabel('Epoch')
        plt.ylabel('MSE')
        # plt.savefig('训练误差曲线.svg')
        # plt.close()

        # 绘制Accuracy曲线
        ax = plt.subplot(1,2,2)
        ax.set_title("Accuracy")
        plt.plot(x, accuracys[:11], color='blue')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        # plt.savefig('网络测试准确率.svg')
        # plt.close()

        plt.savefig('训练曲线.svg')
    else :
    #--- load modle from pickle
        modle = None
        with open('trained_model.pickle', 'rb') as f:
            modle = pickle.load(f)
        # 预测
        out = modle.predict(X_test)
        # 预测标签
        pred = np.argmax(out, axis=1) 

        # 正确率
        accuracy = modle.accuracy(out, y_test.flatten())
        print(f"accuracy : {accuracy}")

        # 绘制散点图,根据标签区分颜色
        ax = plt.subplot(1,2,1)
        ax.set_title("X_test Classification by y_test")
        plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test.ravel(), s=40, cmap=plt.cm.Spectral, edgecolors='#356')

        # 绘制散点图,根据模型预测标签,正确的绿色/错误的红色
        ax = plt.subplot(1,2,2)
        ax.set_title("X_test Classification by pred ,error mark red,correct mark green")

        t_x = X_test[:, 0]
        t_y = X_test[:, 1]
        acc_points = []
        err_points = []
        for i in range(0,y_test.shape[0]):
            if np.equal(pred[i],y_test[i]):
                # print(f"pred[{i}]:{pred[i]}  \t y_test[{i}]: {y_test[i]}")
                acc_points.append(i)
            else :
                err_points.append(i)
                # print(f"pred[{i}]:{pred[i]}  \t y_test[{i}]: {y_test[i]}")

        plt.scatter(t_x[acc_points],t_y[acc_points],s=40,c="green" ,  edgecolors='#356')
        plt.scatter(t_x[err_points],t_y[err_points],s=80,c="red" ,  edgecolors='#789')


    
    plt.show()


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/703271.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

leetcode电话号码的字母组合C++实现教程

链接: 电话号码的字母组合 class Solution {char* PNumStr[10] {"","","abc","def","ghi","jkl","mno","pqrs","tuv","wxyz"};//因为每次都需要执行拷贝&#xff0c;…

了解嵌入式系统的不同细分领域:探索嵌入式BSP的定义

嵌入式BSP是指嵌入式系统中的板级支持软件。它是针对特定硬件平台的软件包&#xff0c;提供了操作系统和硬件之间的抽象层&#xff0c;以便开发人员可以更方便地使用硬件功能和编写应用程序。 嵌入式BSP的功能包括&#xff1a; 设备初始化和配置&#xff1a;BSP负责初始化硬件设…

学做测试平台开发-Vuetify 框架

Vuetify 是 Vue 的语义化组件框架&#xff0c;旨在提供整洁、语义化和可重用的组件&#xff0c;使得构建应用程序更方便。 Vuetify 核心是为了提供各种可重复使用的&#xff0c;即插即用并且适合各种项目规格的组件。 Vue 的语义成分。利用 Vue 的功能组件&#xff0c;所有基…

kubectl port-forward 指令

背景&#xff1a; 当K8s的Service类型为ClusterIP&#xff0c;不是NodePort&#xff0c;就只能集群内部访问&#xff0c;想在外部访问可以执行kubectl port-forward&#xff0c;将一个或多个本地端口转发到 Pod或者Service 作用&#xff1a; 做转发&#xff0c;将本地端口转发…

使用 fitter 拟合数据分布

一、简介 前面的文章中通过假设对比来检验样本是否服从泊松分布。得出的结论是总体分布不服从泊松分布&#xff0c;那么如何找到与总体分布最接近的分布呢&#xff1f;不可能一个个分布去验证。这里便可以用到 fitter 这个库。 fitter 是一个小型的第三方库&#xff0c;提供了…

【粉丝投稿】一文带你了解MySQL的左连接与右连接

前言&#xff1a; 昨天粉丝问了一个问题&#xff0c;因此本篇文章主要讲解MySQL的左连接和右连接的知识。该专栏比较适合刚入坑Java的小白以及准备秋招的大佬阅读。 如果文章有什么需要改进的地方欢迎大佬提出&#xff0c;对大佬有帮助希望可以支持下哦~ 小威在此先感谢各位小…

LeetCode——从上到下打印二叉树 II

题目来源 剑指 Offer 32 - II. 从上到下打印二叉树 II - 力扣&#xff08;LeetCode&#xff09; 题目 从上到下按层打印二叉树&#xff0c;同一层的节点按从左到右的顺序打印&#xff0c;每一层打印到一行。 示例 给定二叉树&#xff1a;[3,9,20,null,null,15,7] 返回其层…

vuex2实现时间列表选择器

目录 一、效果展示 二、代码分析 2.1、区域确定与坐标获取 2.2、单个点击与一次性点击 一、效果展示 主要借助自定义指令实现。在表格的"td们"上面进行移动框选&#xff0c;有一次性框选和单个框选&#xff0c;去掉自定义指令里的clearTargetNodes()会连续td,连…

Hudi学习5:Hudi的helloworld-编译源码

hudi是使用java代码编写的 部署hudi 1. 下载源码 Download | Apache Hudi https://dlcdn.apache.org/hudi/0.13.1/hudi-0.13.1.src.tgz 2.编译 安装maven 首先要先有JDK java8以上 配置镜像源 执行编译 测试

Yolov5小目标性能提升方案介绍

目录 1.小目标检测介绍 1.1 小目标定义 1.2 难点 2.小目标难点解决方案 2.1注意力提升小目标检测精度 2.1.1 上下文信息CAM 2.1.2 ConvNeXt 2.1.3 ECVBlock 2.1.4 多头上下文集成&#xff08;Context Aggregation&#xff09;的广义构建模块 2.2 多头检测头 2.3 loss优化…

spring 在容器中一个bean依赖另一个bean 需要通过ref方式注入进去 通过构造器 或property

spring 在容器中一个bean依赖另一个bean 需要通过ref方式注入进去 通过构造器 或property

Go语言net模块TCP和UDP编程实践

任何一种语言TCP和UDP网络编程总是必须的&#xff0c;接下来就将go语言中使用net标准库进行TCP和UDP编程进行梳理总结。 目录 1.代码结构 2.TCP通信 3.UDP通信 4.TCP模拟HTTP协议通信 5.利用TCP扫描那些端口被占用 1.代码结构 2.TCP通信 server.go package mainimport …

【正点原子STM32连载】 第四十八章 内存管理实验 摘自【正点原子】STM32F103 战舰开发指南V1.2

1&#xff09;实验平台&#xff1a;正点原子stm32f103战舰开发板V4 2&#xff09;平台购买地址&#xff1a;https://detail.tmall.com/item.htm?id609294757420 3&#xff09;全套实验源码手册视频下载地址&#xff1a; http://www.openedv.com/thread-340252-1-1.html# 第四…

全志V3S嵌入式驱动开发(spi-nand image制作)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing @163.com】 上一篇文章,我们说到了spi-nor image的制作和输入。相比较spi-nor,spi-nand虽然在稳定性上面差一点,但是价格上面有很大的优势。举例来说,一般32M的spi-nor大约在6-7元左右,但…

Springboot链接Redis实现AOP防止重复提交

目录 安装redis Springboot链接Redis 1.创建springboot项目 如果spring boot启动报Error creating bean with name redisUtil/redisTemplate 2.新建application.yml配置 3.redis配置类-直接用 4.redis工具类-直接用 5.写Controller测试 6.启动、测试 整合AOP&#xff…

内网横线移动—WmiSmbCrackMapExecProxyChainsImpacket

内网横线移动—Wmi&Smb&CrackMapExec&ProxyChains&Impacket 1. 前置环境准备2. wmic介绍2.1. wmic操作演示2.1.1. 受控主机上线2.1.1.1. 内网存活探测2.1.1.2. 密码抓取 2.1.2. 横向移动2.1.2.1. 上传文件2.1.2.2. 文件上传目标主机2.1.2.3. 执行木马 2.2. wmi…

JAVA 安全-JWT 安全及预编译 CASE 注入等(40)

在各种语言脚本的环境下&#xff0c;也会产生一些新的漏洞&#xff0c;如果是java又能产生那些漏洞&#xff0c;思维导图里面常规漏洞之前都有&#xff1b; java的访问控制&#xff0c;jwt令牌&#xff08;php几乎没有&#xff09;组件安全&#xff0c;这些都是java特有的 #综…

C#核心知识回顾——3.继承构造、拆装箱、多态

1.继承中的构造函数&#xff1a; 特点&#xff1a; 当申明一个子类对象时 先执行父类的构造函数&#xff0c;再执行子类的构造函数注意&#xff01;&#xff01;&#xff1a; 1.父类的无参构造很重要 2.子类可以通过base关键字代表父类调用父类构造 public class Mot…

【单片机】STM32单片机,定时器的输入捕获,基于捕获的频率计,STM32F103

文章目录 简单介绍外部计数频率计TIM5 频率计 简单介绍 下面的定时器都具有输入捕获能力&#xff1a; 外部计数频率计 查看另一篇文章&#xff1a;https://qq742971636.blog.csdn.net/article/details/131471539 外部计数频率计的缺点&#xff1a;需要两个定时器配合&#x…

控制 显示、隐藏

1、使用 v-if < v-if"isShow"></> data(){ return:{ isShow:false } } if (concepts && concepts.length >0){ this.isShow nv.concepts.includes(snap) } 确认 数组中有某个 字段&#xff0c;用includes 有&…