人工智能应用-实验5-BP 神经网络分类手写数据集

news2024/11/18 4:03:44

文章目录

    • 🧡🧡实验内容🧡🧡
    • 🧡🧡代码🧡🧡
    • 🧡🧡分析结果🧡🧡
    • 🧡🧡实验总结🧡🧡

🧡🧡实验内容🧡🧡

编写 BP 神经网络分类, 实现对 MNIST 数据集分类的操作。


🧡🧡代码🧡🧡

需要配置torch。由于是小demo。为了提高效率,我采用的是google的colab进行实验编码,省去配环境的烦恼。

import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from time import time
from torchvision import datasets, transforms
from torch import nn, optim

#@title 加载
transform = transforms.Compose([
                transforms.ToTensor(), # 转为张量,同时如果是图片(uint8)类型,会自动进行归一化到(0,1)
                transforms.Normalize( (0.5, ) , (0.5, ) ) # 转为std=0.5、mean=0.5的分布, 灰色图像,通道只有一个  将值域(0,1)再次转为(-1,1)
                ])
train_set = datasets.MNIST('train_set', # 下载到该文件夹下
              download=not os.path.exists('train_set'), # 是否下载,如果下载过,则不重复下载
              train=True, # 是否为训练集
              transform=transform # 要对图片做的transform
              )
test_set = datasets.MNIST('test_set',
              download=not os.path.exists('test_set'),
              train=False,
              transform=transform
              )
test_set
# train_set[0][0]
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)

dataiter = iter(train_loader)
images, labels = next(iter(dataiter))
print(images.shape)
print(labels.shape)


#@title Bp net
class BP_Net(nn.Module):
    def __init__(self):
        super().__init__()
        """
        定义第一个线性层,
        输入为图片(28x28),
        输出为第一个隐层的输入,大小为128。
        """
        self.linear1 = nn.Linear(28 * 28, 128)
        self.relu1 = nn.ReLU() # 在第一个隐层使用ReLU激活函数
        """
        定义第二个线性层,
        输入是第一个隐层的输出,
        输出为第二个隐层的输入,大小为64。
        """
        self.linear2 = nn.Linear(128, 64)
        self.relu2 = nn.ReLU() # 在第二个隐层使用ReLU激活函数
        """
        定义第三个线性层,
        输入是第二个隐层的输出,
        输出为输出层,大小为10
        """
        self.linear3 = nn.Linear(64, 10)
        self.softmax = nn.LogSoftmax(dim=1) # 最终的输出经过softmax进行归一化

    def forward(self, x):
        """
        定义神经网络的前向传播
        x: 输入的图片数据, shape为(64, 1, 28, 28)
        """
        x = x.view(x.shape[0], -1) # 首先将x的shape转为(64, 784)

        # 进行前向传播
        x = self.linear1(x)
        x = self.relu1(x)
        x = self.linear2(x)
        x = self.relu2(x)
        x = self.linear3(x)
        x = self.softmax(x)

        return x
model = BP_Net()
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9)

#@title 评估
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve
model.eval() # 将模型设置为评估模式

correct_count, all_count = 0, 0
predictions = [] # 预测结果列表
true_labels = [] # 真实标签列表

for images,labels in test_loader: # 从test_loader中一批一批加载图片
    for i in range(len(labels)):
        logps = model(images[i])  # 进行前向传播,获取预测值
        probab = list(logps.detach().numpy()[0]) # 将预测结果转为概率列表。[0]是取第一张照片的10个数字的概率列表(因为一次只预测一张照片)
        pred_label = probab.index(max(probab)) # 取最大的index作为预测结果
        true_label = labels.numpy()[i]
        if(true_label == pred_label): # 判断是否预测正确
            correct_count += 1
        all_count += 1
        predictions.append(pred_label)
        true_labels.append(true_label)

# 准确率
print("Number Of Images Tested =", all_count)
print("Model Accuracy =", (correct_count/all_count))

# 混淆矩阵
def plot_confusion_matrix(cm, classes):
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)
    thresh = cm.max() / 2
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'), ha="center", va="center",
                     color="white" if cm[i, j] > thresh else "black")
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.show()

cm = confusion_matrix(true_labels, predictions)
classes = [str(i) for i in range(10)]
plot_confusion_matrix(cm, classes)

#@title 验证
model.train() # 切回训练模式

## 验证本地图片
import cv2
from PIL import Image
for num in range(0,10):
    img = cv2.imread('./myImg/{}.jpg'.format(num), 0)  # 以灰度图的方式读取要预测的图片
    img = cv2.resize(img, (28, 28))
    height, width = img.shape
    dst = np.zeros((height, width), np.uint8)
    for i in range(height):
        for j in range(width):
            dst[i, j] = 255 - img[i, j]
    dst= dst / 255.0 #归一化
    dst = (dst - 0.5) / 0.5  # 标准化到[-1, 1]
    img = dst
    # print(img)
    img = np.array(img).astype(np.float32)
    img = np.expand_dims(img, 0)  # 扩展后,为[1,28,28]
    img = np.expand_dims(img, 0)  # 扩展后,为[1,1,28,28]
    img = torch.from_numpy(img)
    # print(img.shape)
    with torch.no_grad():
        output=model(img)
    # print(output.data)
    print(output.data.max(1)[1])


🧡🧡分析结果🧡🧡

数据预处理

  • 加载数据集:
    加载torch自带的minst数据集
  • 转换数据:
    先转为tensor变量(相当于直接除255归一化到值域为(0,1))
    在这里插入图片描述
    然后根据std=0.5,mean=0.5,再将值域标准化到(-1,1)
    在这里插入图片描述

设置基本参数:
在这里插入图片描述

构建BP神经网络:
如下,输入为一张2828图片,拆解成2828=784个特征,最终经过三个线性层(784,128)、(128、64)、(64,10),输出为10个特征(对应10个类),归一化这10个特征,它们的大小即认为它属于哪张图片的概率值,取出概率最大的特征对应的类别作为最终预测类别。
在这里插入图片描述

模型训练:
在这里插入图片描述
在这里插入图片描述

模型评估:
准确率:达到97.69%
在这里插入图片描述
混淆矩阵
在这里插入图片描述

接下来,分析网络层数对分类准确率的影响。
被对照试验:隐藏层数目改为2,神经元数目分别为128、64
准确率为:97.69%
对照实验1:隐藏层数目改为3,神经元数目分别为256、128、64
在这里插入图片描述
Loss图:
在这里插入图片描述
准确率和混淆矩阵如下:97.55%
在这里插入图片描述
对照实验2:隐藏层数目改为5,神经元数目分别为512、256、128、64、32
在这里插入图片描述
Loss图:
在这里插入图片描述
准确率和混淆矩阵:97.85%
在这里插入图片描述
总结结果如下表:
在这里插入图片描述
分析可知:

  • 运行时间:从实验结果来看,在增加隐藏层数的情况下,运行时间明显增加。
  • 准确率:实验结果显示,在增加隐藏层数的情况下,准确率大体上有所提升,但是总体变化幅度并不大,可能是因为epochs或者随机梯度下降等参数已经设为较优值,使得准确率已经接近最优效果,从而导致增加网络层数的提优空间并不明显。
    综合来看,增加隐藏层数对于提高分类准确率有一定的帮助,但是也会明显增加运行时间。其次,需要注意的是,若增加隐藏层数并非一定能够带来准确率的提升,过多的隐藏层可能会导致过拟合等问题。

🧡🧡实验总结🧡🧡

在完成基础实验上,我自己画了几张数字图,以对模型进行验证
在这里插入图片描述
结果如下,可以看到,对数字1和数字5分类错误(分布预测成了5和8),其余均分类正确,大体上效果良好。考虑原因,可能是因为minst的数据集是“黑底白字”,而我手画的图片则为“黑字白底”,导致了一些误差。
在这里插入图片描述
理论理解:
通过本次实验,大体上掌握了BP神经网络的定义和结构,总的来说,BP神经网络可以理解为一个黑盒子,通过不断根据loss进行反向传播,最终目的就是得到线性参数w和b,从而根据Y=wx+b 对输入的新x进行预测分类。
代码实践:
一开始想用纯numpy进行BP网络的编写,但是在编写后向传播时,可能是线代和高数知识有些遗忘,求导数时琢磨了很久。后面还是选择直接使用pytorch进行编写,也容易调参,方便进行实验。对我而言,代码中比较纠结的是shape的转换和传入,因此最好多查看中间过程的shape,以便更好理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1716233.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

服务器内存与CPU要占用多少才合理?

一 通常服务器内存占用多少合理?cpu占用多少才合理? 1 通常配置范围建议: 建议CPU使用率不高于80%;内存使用率不高于80%; 注意:具体情况还需要根据服务器的实际负载和应用场景来判断。 2 内存使用率&…

【方法】如何禁止查看压缩包里的内容?

使用压缩文件,可以让文件更方便存储和传输,那对于重要的文件,如何防止随意查看压缩包的内容呢?我们可以试试以下两个方法。 方法1: 最常见的便是给压缩包设置“打开密码”,这样只有通过密码才能查看文件内…

MyBatis系统学习 - 使用Mybatis完成查询单条,多条数据,模糊查询,动态设置表名,获取自增主键

上篇博客我们围绕Mybatis链接数据库进行了相关概述,并对Mybatis的配置文件进行详细的描述,本篇博客也是建立在上篇博客之上进行的,在上面博客搭建的框架基础上,我们对MyBatis实现简单的增删改查操作进行重点概述,在MyB…

产品推荐 | 基于Xilinx Zynq-7015 FPGA的MYC-C7Z015核心板

一、产品概述 基于 Xilinx Zynq-7015,双Cortex-A9FPGA全可编程处理器;PS部分(ARM)与PL部分(FPGA)之间采用AXI高速片上总线通信,吉比特级带宽,突破传统ARMFPGA架构的通信瓶颈,通过PL部分(FPGA)灵活配置丰富的外设接口&…

windows 安装 使用 nginx

windows 安装 使用 nginx nginx官网下载地址:https://nginx.org/en/download.html 下载稳定版本即可 下载压缩包解压到即可 进入文件夹中,打开命令行窗口,执行启动命令 start nginx.exe验证(默认是80端口)&#x…

产品经理-原型绘制(五)

1. 概念 用线条、图形描绘出的产品框架,也称为线框图,是需求和功能的具体化表现 2. 常用工具 Axure 3. 类别 3.1 草图原型 手绘图稿,修改方便,规划的早期使用 3.2 低保真原型 简单交互,无设计图,无需…

【Docker】2、配置SSL证书远程访问Docker

1、使用 openssl 生成 ca 1、创建文件夹 mkdir -p /root/dockercd /root/docker2、创建 RSA 私钥 会提示 2 次输入证书密码,至少 4 位,创建后会生成一个 ca-key.pem 文件 openssl genrsa -aes256 -out ca-key.pem 4096得到 ca-key.pem 文件 3、创建…

桌面上怎么记工作任务更加合理?能设置桌面提醒的便签软件

在快节奏的现代工作中,电脑已成为我们处理工作的主要工具。每天,我们都要面对电脑屏幕,处理大量的工作任务。为了更好地管理这些琐碎却重要的工作,将工作任务直接记录在桌面上,随时查看和调整,无疑是一种高…

什么是边缘计算网关?工业方向应用有哪些?天拓四方

在数字化时代,信息的传输与处理变得愈发重要,而其中的关键节点之一便是边缘计算网关。这一先进的网络设备,不仅扩展了云端功能至本地边缘设备,还使得边缘设备能够自主、快速地响应本地事件,提供了低延时、低成本、隐私…

【FPGA】Verilog语言从零到精通

接触fpga一段时间,也能写点跑点吧……试试系统地康康呢~这个需要耐心但是回报巨大的工作。正原子&&小梅哥 15_语法篇:Verilog高级知识点_哔哩哔哩_bilibili 1Verilog基础 Verilog程序框架:模块的结构 类比:c语言的基础…

5款ai文案自动生成器,让你写作爆款文案不犯难!

现如今,无论是用于社交媒体、广告宣传、网站内容还是其他各种领域,优秀的文案都能吸引更多的关注和流量。但是,对于许多创作者来说,想要创作出高质量的文案并非易事,常常会面临灵感枯竭、思路卡顿等问题。而现在有了一…

Python开发:简单的密码爆破工具

当我们进行在线密码破解时,使用 BurpSuite 以及 wfuzz 足以应对大部分网站应用场景。但是在遇到一些特殊情况时我们还是需要自己来开发密码爆破工具,本文将介绍如何使用Python开发一款简单的密码爆破工具。 0x01 背景介绍 密码破解 记得有大佬曾经说过…

企业网络的“瑞士军刀”:探索“一端多能”设备的多面性

在数字化时代,企业网络需求的复杂性和多样性不断增长,传统的单一功能网络设备已难以满足这些需求。企业需要一种集多种功能于一身的“一端多能”网络设备,以应对各种网络环境和业务需求,就像是一把多功能、灵活、可靠的瑞士军刀&a…

函数编程实际应用-异步任务

背景 常见的函数式接口,就是对函数编程的应用Runnable 没有返回值的函数式接口Callable 有返回值的函数式接口 使用线程池 一般来说,很少使用new Thread(函数对象)这种方式来直接 创建线程,更多的时候使用的线程成来集…

SOL 交易机器人基本知识

有没有可以盈利的机器人? 是的,各行各业都有许多盈利机器人。在金融领域,交易机器人被广泛用于自动化投资策略并根据预定义的算法执行交易。这些机器人可以分析市场趋势并做出快速决策,从而可能带来可观的回报。同样,在…

英飞凌24GHz毫米波雷达-BGT24LTR11N16家用机器人应用

BGT24LTR11N16基础描述: 关于BGT24LTR11N16,它是一款用于信号生成和接收的硅锗雷达MMlC,工作频率为24.00GHz至24.25GHz ISM频段。它基于24GHz基本电压控制振荡器(VCO)。 这颗芯片是属于1T1R,也就是一发一收…

linux开发之设备树五、设备树描述中断实践

设备树是基于设备总线模型的(platform) 1、添加节点 假设中断引脚为:GPIO0_B5 下面使用设备树来描述它 1、写节点,起节点名字 这里用了ft5x06的触摸芯片,然后I2C的地址为38 2、为节点添加属性 首先添加compatible…

【CALayer-时钟练习-CADisplayLink Objective-C语言】

一、我们接着来看,这个CADisplayLink啊, 1.刚才我们说那个时间呢,差上1秒钟的样子,然后呢,我们现在呢,用这个叫做CADisplayLink的东西,来解决,用这个类,来解决啊, 我们说,NSTimer,是跑到这儿了以后,一秒钟以后, 它才会执行,这个timeChange方法,真正的时间,不…

【NumPy】深入理解NumPy的dot函数:矩阵乘法与点积运算详解

🧑 博主简介:阿里巴巴嵌入式技术专家,深耕嵌入式人工智能领域,具备多年的嵌入式硬件产品研发管理经验。 📒 博客介绍:分享嵌入式开发领域的相关知识、经验、思考和感悟,欢迎关注。提供嵌入式方向…

MySQL8找不到my.ini配置文件以及报sql_mode=only_full_group_by解决方案

一、找不到my.ini配置文件 MySQL 8 安装或启动过程中,如果系统找不到my.ini文件,通常意味着 MySQL服务器没有找到其配置文件。在Windows系统上,MySQL 8 预期使用my.ini作为配置文件,而不是在某些情况下用到的my.cnf文件。 通过 …