基于BP神经网络对MNIST数据集检测识别(numpy版本)

news2025/1/31 2:51:07

基于BP神经网络对MNIST数据集检测识别

  • 1.作者介绍
  • 2.BP神经网络介绍
    • 2.1 BP神经网络
  • 3.BP神经网络对MNIST数据集检测实验
    • 3.1 读取数据集
    • 3.2 前向传播
    • 3.3 损失函数
    • 3.4 构建神经网络
    • 3.5 训练
    • 3.6 模型推理
  • 4.完整代码

1.作者介绍

王凯,男,西安工程大学电子信息学院,2022级研究生
研究方向:机器视觉与人工智能
电子邮件:1794240761@qq.com

张思怡,女,西安工程大学电子信息学院,2022级研究生,张宏伟人工智能课题组
研究方向:机器视觉与人工智能
电子邮件:981664791@qq.com

2.BP神经网络介绍

2.1 BP神经网络

搭建一个两层(两个权重矩阵,一个隐藏层)的神经网络,其中输入节点和输出节点的个数是确定的,分别为 784 和 10。而隐藏层节点的个数还未确定,并没有明确要求隐藏层的节点个数,所以在这里取50个。现在神经网络的结构已经确定了,再看一下里面是怎么样的,这里画出了对一个数据的运算过程:
在这里插入图片描述
数学公式:
在这里插入图片描述

3.BP神经网络对MNIST数据集检测实验

3.1 读取数据集

安装numpy :pip install numpy
安装matplotlib pip install matplotlib
mnist是一个包含各种手写数字图片的数据集:其中有60000个训练数据和10000个测试时局,即60000个 train_img 和与之对应的 train_label,10000个 test_img 和 与之对应的test_label。
在这里插入图片描述
其中的 train_img 和 test_img 就是这种图片的形式,train_img 是为了训练神经网络算法的训练数据,test_img 是为了测试神经网络算法的测试数据,每一张图片为2828,将图片转换为2828=784个像素点,每个像素点的值为0到255,像素点值的大小代表灰度,从而构成一个1784的矩阵,作为神经网络的输入,而神经网络的输出形式为110的矩阵,个:eg:[0.01,0.01,0.01,0.04,0.8,0.01,0.1,0.01,0.01,0.01],矩阵里的数字代表神经网络预测值的概率,比如0.8代表第五个数的预测值概率。
其中 train_label 和 test_label 是 对应训练数据和测试数据的标签,可以理解为一个1*10的矩阵,用one-hot-vectors(只有正确解表示为1)表示,one_hot_label为True的情况下,标签作为one-hot数组返回,one-hot数组 例:[0,0,0,0,1,0,0,0,0,0],即矩阵里的数字1代表第五个数为True,也就是这个标签代表数字5。
数据集的读取:
load_mnist(normalize=True, flatten=True, one_hot_label=False):中,
normalize : 是否将图像的像素值正规化为0.0~1.0(将像素值正规化有利于提高精度)。flatten : 是否将图像展开为一维数组。 one_hot_label:是否采用one-hot表示。
在这里插入图片描述
完整代码及数据集下载:https://gitee.com/wang-kai-ya/bp.git

3.2 前向传播

前向传播时,我们可以构造一个函数,输入数据,输出预测。

def predict(self, x):
    w1, w2 = self.dict['w1'], self.dict['w2']
    b1, b2 = self.dict['b1'], self.dict['b2']

    a1 = np.dot(x, w1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, w2) + b2
    y = softmax(a2)

3.3 损失函数

求出神经网络对一组数据的预测值,是一个1*10的矩阵。
在这里插入图片描述

其中,Yk表示的是第k个节点的预测值,Tk表示标签中第k个节点的one-hot值,举前面的eg:(手写数字5的图片预测值和5的标签)
Yk=[0.01,0.01,0.01,0.04,0.8,0.01,0.1,0.01,0.01,0.01]
Tk=[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]
值得一提的是,在交叉熵误差函数中,Tk的值只有一个1,其余为0,所以对于这个数据的交叉熵误差就为 E = -1(log0.8)。
在这里选用交叉熵误差作为损失函数,代码实现如下:

def loss(self, y, t):
    t = t.argmax(axis=1)
    num = y.shape[0]
    s = y[np.arange(num), t]

    return -np.sum(np.log(s)) / num

3.4 构建神经网络

前面我们定义了预测值predict, 损失函数loss, 识别精度accuracy, 梯度grad,下面构建一个神经网络的类,把这些方法添加到神经网络的类中:

for i in range(epoch):
    batch_mask = np.random.choice(train_size, batch_size)  # 从0到60000 随机选100个数
    x_batch = x_train[batch_mask]
    y_batch = net.predict(x_batch)
    t_batch = t_train[batch_mask]
    grad = net.gradient(x_batch, t_batch)

    for key in ('w1', 'b1', 'w2', 'b2'):
        net.dict[key] -= lr * grad[key]
    loss = net.loss(y_batch, t_batch)
    train_loss_list.append(loss)

    # 每批数据记录一次精度和当前的损失值
    if i % iter_per_epoch == 0:
        train_acc = net.accuracy(x_train, t_train)
        test_acc = net.accuracy(x_test, t_test)
        train_acc_list.append(train_acc)
        test_acc_list.append(test_acc)
        print(
            '第' + str(i/600) + '次迭代''train_acc, test_acc, loss :| ' + str(train_acc) + ", " + str(test_acc) + ',' + str(
                loss))

3.5 训练

import numpy as np
import matplotlib.pyplot as plt
from TwoLayerNet import TwoLayerNet
from mnist import load_mnist

(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
net = TwoLayerNet(input_size=784, hidden_size=50, output_size=10, weight_init_std=0.01)

epoch = 20400
batch_size = 100
lr = 0.1

train_size = x_train.shape[0]  # 60000
iter_per_epoch = max(train_size / batch_size, 1)  # 600

train_loss_list = []
train_acc_list = []
test_acc_list = []

保存权重:

np.save('w1.npy', net.dict['w1'])
np.save('b1.npy', net.dict['b1'])
np.save('w2.npy', net.dict['w2'])
np.save('b2.npy', net.dict['b2'])

结果可视化:
在这里插入图片描述
在这里插入图片描述

3.6 模型推理

import numpy as np
from mnist import load_mnist
from functions import sigmoid, softmax
import cv2
######################################数据的预处理
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
batch_mask = np.random.choice(100,1)  # 从0到60000 随机选100个数
#print(batch_mask)
x_batch = x_train[batch_mask]

#####################################转成图片
arr = x_batch.reshape(28,28)
cv2.imshow('wk',arr)
key = cv2.waitKey(10000)
#np.savetxt('batch_mask.txt',arr)
#print(x_batch)
#train_size = x_batch.shape[0]
#print(train_size)
########################################进入模型预测
w1 = np.load('w1.npy')
b1 = np.load('b1.npy')
w2 = np.load('w2.npy')
b2 = np.load('b2.npy')

a1 = np.dot(x_batch,w1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1,w2) + b2
y = softmax(a2)
p = np.argmax(y, axis=1)

print(p)

运行python reasoning.py
可以看到模型拥有较高的准确率。

4.完整代码

训练

import numpy as np
import matplotlib.pyplot as plt
from TwoLayerNet import TwoLayerNet
from mnist import load_mnist

(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
net = TwoLayerNet(input_size=784, hidden_size=50, output_size=10, weight_init_std=0.01)

epoch = 20400
batch_size = 100
lr = 0.1

train_size = x_train.shape[0]  # 60000
iter_per_epoch = max(train_size / batch_size, 1)  # 600

train_loss_list = []
train_acc_list = []
test_acc_list = []

for i in range(epoch):
    batch_mask = np.random.choice(train_size, batch_size)  # 从0到60000 随机选100个数
    x_batch = x_train[batch_mask]
    y_batch = net.predict(x_batch)
    t_batch = t_train[batch_mask]
    grad = net.gradient(x_batch, t_batch)

    for key in ('w1', 'b1', 'w2', 'b2'):
        net.dict[key] -= lr * grad[key]
    loss = net.loss(y_batch, t_batch)
    train_loss_list.append(loss)

    # 每批数据记录一次精度和当前的损失值
    if i % iter_per_epoch == 0:
        train_acc = net.accuracy(x_train, t_train)
        test_acc = net.accuracy(x_test, t_test)
        train_acc_list.append(train_acc)
        test_acc_list.append(test_acc)
        print(
            '第' + str(i/600) + '次迭代''train_acc, test_acc, loss :| ' + str(train_acc) + ", " + str(test_acc) + ',' + str(
                loss))

np.save('w1.npy', net.dict['w1'])
np.save('b1.npy', net.dict['b1'])
np.save('w2.npy', net.dict['w2'])
np.save('b2.npy', net.dict['b2'])

markers = {'train': 'o', 'test': 's'}
x = np.arange(len(train_acc_list))
plt.plot(x, train_acc_list, label='train acc')
plt.plot(x, test_acc_list, label='test acc', linestyle='--')
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()

测试

import numpy as np
from mnist import load_mnist
from functions import sigmoid, softmax
import cv2
######################################数据的预处理
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
batch_mask = np.random.choice(100,1)  # 从0到60000 随机选100个数
#print(batch_mask)
x_batch = x_train[batch_mask]

#####################################转成图片
arr = x_batch.reshape(28,28)
cv2.imshow('wk',arr)
key = cv2.waitKey(10000)
#np.savetxt('batch_mask.txt',arr)
#print(x_batch)
#train_size = x_batch.shape[0]
#print(train_size)
########################################进入模型预测
w1 = np.load('w1.npy')
b1 = np.load('b1.npy')
w2 = np.load('w2.npy')
b2 = np.load('b2.npy')

a1 = np.dot(x_batch,w1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1,w2) + b2
y = softmax(a2)
p = np.argmax(y, axis=1)

print(p)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/631774.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

EasyX实现简易贪吃蛇

📝个人主页:认真写博客的夏目浅石. 📣系列专栏:夏目的C语言宝藏 文章目录 前言一、头文件包含二、创建蛇与食物的结构体三、游戏的初始化四、游戏的绘画事件五、蛇的移动事件六、输入方向七、生成食物八、吃食物九、游戏失败的判定…

【Python开发】FastAPI 09:middleware 中间件及跨域

FastAPI 提供了一些中间件来增强它的功能,类似于 Spring 的切面编程,中间件可以在请求处理前或处理后执行一些操作,例如记录日志、添加请求头、鉴权等,跨域也是 FastAPI 中间件的一部分。 目录 1 中间件 1.1 创建中间件 1.2 使…

IDEA日常配置和操作小结

✅作者简介:大家好,我是Cisyam,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Cisyam-Shark的博客 💞当前专栏: IDEA专栏 ✨特色专栏&…

算法刷题-数组-长度最小的子数组

209.长度最小的子数组 力扣题目链接 给定一个含有 n 个正整数的数组和一个正整数 s ,找出该数组中满足其和 ≥ s 的长度最小的 连续 子数组,并返回其长度。如果不存在符合条件的子数组,返回 0。 示例: 输入:s 7, …

电子元器件解析之电容(一)——定义与性能参数

下篇文章:电子元器件解析之电容(二)——电容分类与应用场景:https://blog.csdn.net/weixin_42837669/article/details/131142767 摘要 电容是最基本的电子元器件之一,本文介绍了电容的定义,并总结了电容的各个性能参数&#xff0c…

攻防世界—easyupload

2023.6.10 又试了一下 他好像会检测php字段,所以一开始的aaa.jpg怎么也传不上去,用a.jpg那种就可以了 根据题目和页面我们可以猜到这应该是一个利用文件上传漏洞,利用一句话木马来获取服务器上保存的flag保险起见右键查看源代码 可以发现毫无…

基础实验篇 | RflySim底层飞行控制算法开发系列课程总体介绍

本讲主要介绍多旋翼的特点及选用多旋翼作为实验平台的原因、对于无人系统教育的一些新需求、RflySim平台对于飞控的底层控制算法的开发优势、本期平台课程的设置、以及如何开发自驾仪系统。 相较于固定翼和直升机,多旋翼具有机械结构简单、 易维护的优点。以四旋翼…

使用 Pycharm 调试远程代码

文章目录 背景同步远程代码Interpreter注意点 背景 工作机是一台 Windows 电脑,而很多时候需要在 Mac 电脑上编码、配合 iPhone 模拟器,所以我以前是用 VNC 或者向日葵来远程 Mac 来编程,其实还能接受,但是最让我不舒服的是快捷键…

MySQL表操作:提高数据处理效率的秘诀(进阶)(2)

💕“学习难免有坎坷,重要的是你能尽力而为,持之以恒。”💕 🐼作者:不能再留遗憾了🐼 🎆专栏:MySQL学习🎆 🚗本文章主要内容:MySQL表操…

K Shortest Paths算法之Yen algorithm

Yen’s算法是一种在图论中用于计算单源K最短无环路径的算法,该算法由Jin Y. Yen在1971年提出。这个算法的时间复杂度和空间复杂度都取决于用于计算偏离路径的最短路径算法。如果使用Dijkstra算法,那么时间复杂度为O(KN3),采用Fibonacci堆计算…

Elasticsearch:Explain API - 如何计算分数

你想了解你的文档为何获得该分数吗? 文档 让我们通过一组示例文档来了解 Explain API。 就我而言,我将使用一小部分电影名言。 POST _bulk { "index" : { "_index" : "movie_quotes" } } { "title" : "T…

从malloc到跑路

当我还是一个懵懂无知的少年时,内心也曾升起这样的疑问这内存咋来的捏❓ 有个帅气的小哥哥给我甩了一篇博客:对,就是这篇,看完后总感觉意犹未尽,似乎少了点什么的样子。。。 还是从malloc开始 如果申请的内存小于64B…

“配置DHCP服务器和DHCP中继的网络自动配置实验“

"配置DHCP服务器和DHCP中继的网络自动配置实验" 【实验目的】 部署DHCP服务器。熟悉DHCP中继的配置方法。验证拓扑。 【实验拓扑】 实验拓扑如图所示。 设备参数如下表所示。 设备 接口 IP地址 子网掩码 默认网关 DHCPSERVE S0/3/0 192.168.10.1 255.255.…

丰富上下文的超高分辨率分割:一种新的基准

文章目录 Ultra-High Resolution Segmentation with Ultra-Rich Context: A Novel Benchmark摘要数据集Dataset SummaryData Collection and Pre-processing 数据标注数据统计 WSDNet实验结果 Ultra-High Resolution Segmentation with Ultra-Rich Context: A Novel Benchmark …

SSM整合快速入门案例(一)

文章目录 前言一、设计数据库表二、创建工程三、SSM技术整合四、功能模块开发五、接口测试总结 前言 为了巩固所学的知识,作者尝试着开始发布一些学习笔记类的博客,方便日后回顾。当然,如果能帮到一些萌新进行新技术的学习那也是极好的。作者…

不认识docker,怎么好意思说自己是干IT的

1.Docker是什么 一款产品从开发到上线,从操作系统,到运行环境,再到应用配置。作为开发运维之间的协作我们需要关心很多东西,这也是很多互联网公司都不得不面对的问题,特别是各种版本的迭代之后,不同版本环…

用户行为数据分析

文章目录 用户行为数据分析1 项目描述2 项目需求3 数据准备1、创建user_data数据表用于导入user_data.csv中的数据2、加载user_data.csv中的数据到user_data表3、接下来进行数据清洗,包括:删除重复值,时间戳格式化,删除异常值。 4…

OpenGL光照之基础光照

文章目录 环境光照漫反射光照计算漫反射光照镜面光照代码 现实世界的光照是极其复杂的,而且会受到诸多因素的影响,这是我们有限的计算能力所无法模拟的。因此OpenGL的光照使用的是简化的模型,对现实的情况进行近似,这样处理起来会…

MyBatis-Plus(2.0)

ActiveRecord ActiveRecord(简称AR)一直广受动态语言(PHP、Ruby等)的喜爱,而java作为准静态语言,对于ActiveRecord往往只能感叹器优雅 什么是ActiveRecord? ActiveRecord也属于ORM(对象关系映射)层,由Rail…

视频|人人能看懂的苹果visionOS空间设计课程

本周的重磅消息无疑是苹果Vision Pro以及对应的visionOS,考虑到苹果头显硬件上当前以第一方App为主,因此本届WWDC的一个重点就是释放visionOS和相关能力给开发者,让开发者尽快打造出更多、更优质的第三方App阵容。 与此同时,苹果也…