入门深度学习——基于全连接神经网络的手写数字识别案例(python代码实现)

news2024/11/26 9:32:25

入门深度学习——基于全连接神经网络的手写数字识别案例(python代码实现)

一、网络构建

1.1 问题导入

如图所示,数字五的图片作为输入,layer01层为输入层,layer02层为隐藏层,找出每列最大值对应索引为输出层。根据下图给出的网络结构搭建本案例用到的全连接神经网络
在这里插入图片描述

1.2 手写字数据集MINST

如图所示,MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片。数据集也被嵌入到sklearn和pytorch框架中可以直接调用。这里我们默认已经安装了pytorch框架。不会使用的这里简单介绍一下。
大家可以用按住win+R键,打开运行窗口,输入cmd。
在这里插入图片描述
输入cmd,回车后,会显示如下。
在这里插入图片描述
输入以下的命令,可以看看自己的电脑的显卡是不是NVIDIA。如果是AMD的,那么就安装cpu的吧,毕竟CUDA内核,只支持NVIDIA的显卡。

#AMD显卡
pip install pytorch-cpu
#NVIDIA显卡
pip install pytorch
#如果速度慢的话,可以加入清华源的链接
pip install pytorch-cpu -i https://pypi.tuna.tsinghua.edu.cn/simple/
#NVIDIA显卡
pip install pytorch -i https://pypi.tuna.tsinghua.edu.cn/simple/

这样就完成了,仍然存在问题的小伙伴,可以参考小程序员推荐的这个up主的教程pytorch保姆级教程。
这里我们输出几张图片和对应的标签。作为对数据集的了解,也方便我们针对性的设计网络结构,做到心中有数。
在这里插入图片描述

二、采用Pytorch框架编写全连接神经网络代码实现手写字识别

2.1 导入必要的包

import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from torchvision import datasets,transforms
from torch.utils.data import DataLoader

2.2 定义一些数据预处理操作

pipline=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])

2.3 下载数据集(训练集vs测试集)

train_dataset=datasets.MNIST('./data',train=True,transform=pipline,download=True)
test_dataset=datasets.MNIST('./data',train=False,transform=pipline,download=True)
print(len(train_dataset))
print(len(test_dataset))

60000
10000

2.4 分批加载训练集和测试集中的数据到内存里

train_loader=DataLoader(train_dataset,batch_size=32,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=32)

2.5 可视化数据集中的数据,做到心中有数

import matplotlib.pyplot as plt
examples=enumerate(train_loader)
_,(example_data,example_label)=next(examples)
print(example_data.shape)
for i in range(6):
    plt.subplot(2,3,i+1)
    plt.tight_layout()
    plt.imshow(example_data[i][0],cmap='gray')
#     plt.title('Ground Truth:{}'.format(example_label[i]))
    plt.title(f'Ground Truth:{example_label[i]}')

torch.Size([32, 1, 28, 28])
在这里插入图片描述

2.6 网络模型设计(有时也称为网络模型搭建)

class Net(nn.Module):
    def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
        super(Net,self).__init__()
        self.layer1=nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.ReLU(True))
        self.layer2=nn.Sequential(nn.Linear(n_hidden_1,n_hidden_2),nn.Sigmoid())
        self.layer3=nn.Linear(n_hidden_2,out_dim)    
        
    def forward(self,x):
        x=self.layer1(x)
        x=self.layer2(x)
        x=self.layer3(x)
        return x
model=Net(28*28,300,100,10)
model

以下结果来自Jupyter Notebook
Net(
(layer1): Sequential(
(0): Linear(in_features=784, out_features=300, bias=True)
(1): ReLU(inplace=True)
)
(layer2): Sequential(
(0): Linear(in_features=300, out_features=100, bias=True)
(1): Sigmoid()
)
(layer3): Linear(in_features=100, out_features=10, bias=True)
)

import torch.optim as optim
criterion=nn.CrossEntropyLoss()   #选用Pytorch中nn模块封装好的交叉熵损失函数
optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.5)  #选用随机梯度下降法(SGD)作为本模型的梯度下降法
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')   #确定代码运行设备究竟实在GPU还是CPU上跑
model.to(device)

2.7 训练网络模型

losses=[]
acces=[]

eval_losses=[]
eval_acces=[]

#训练轮数---epoch

for epoch in range(10):
    train_loss=0
    train_acc=0
    model.train()   #启用网络模型隐藏层中的dropout和BN(批归一化)操作
    
    if epoch%5==0:   #控制训练轮数间隔
        optimizer.param_groups[0]['lr']*=0.9    #动态调整学习率
        
    for img,label in train_loader:
        img=img.to(device)   #将训练图片写到设备里
        label=label.to(device)  #将图片类别写到设备里
        img=img.view(img.size(0),-1)
        
        out=model(img)   #调用前向传播函数得到预测值
        loss=criterion(out,label)   #计算预测值和真实值的损失
        
        optimizer.zero_grad()  #在新一轮反向传播开始前,清空上一轮反向传播得到的梯度
        loss.backward()  #把上一部得到的损失执行反向传播,得到新的网络模型参数(权值)
        optimizer.step()   #把上一部得到的新的权值更新到网络模型里
        
        #在前面前向传播和反向传播的额基础上,计算一些训练算法性能指标
        
        train_loss+=loss.item()  #记录反向传播每一轮得到的损失
        
        _,pred=out.max(1)   #得到图片的预测类别
        
        num_correct=(pred==label).sum().item()   #获取预测正确的样本数量
        acc=num_correct/img.shape[0]      #每一批次的正确率
        train_acc+=acc       #每一轮次的额正确率
        
    losses.append(train_loss/len(train_loader))    #所有轮次训练完之后总的损失
    acces.append(train_acc/len(train_loader))     #所有轮次训练完之后总的正确率

2.8 在测试集上测试网络模型,检验模型效果

eval_loss=0
eval_acc=0
model.eval()   #继续沿用BN操作,但是不再使用dropout操作

with torch.no_grad():
    for img,label in test_loader:
        img=img.to(device)
        label=label.to(device)
        
        img=img.view(img.size(0),-1)
        
        out=model(img)
        loss=criterion(out,label)
        
        eval_loss+=loss.item()   #记录每一批次的损失
        
        _,pred=out.max(1)
        
        num_correct=(pred==label).sum().item()
        acc=num_correct/img.shape[0]   #记录每一批次的准确率
        eval_acc+=acc     #记录每一轮的准确率
        

    eval_losses.append(eval_loss / len(test_loader))
    eval_acces.append(eval_acc / len(test_loader))
    print('epoch: {}, Train Loss: {:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f}'
      .format(epoch, train_loss / len(train_loader), train_acc / len(train_loader), 
                 eval_loss / len(test_loader), eval_acc / len(test_loader)))

epoch: 0, Train Loss: 1.1721, Train Acc: 0.6760, Test Loss: 0.4936, Test Acc: 0.8692
epoch: 1, Train Loss: 0.4093, Train Acc: 0.8866, Test Loss: 0.3368, Test Acc: 0.9020
epoch: 2, Train Loss: 0.3192, Train Acc: 0.9084, Test Loss: 0.2884, Test Acc: 0.9171
epoch: 3, Train Loss: 0.2755, Train Acc: 0.9194, Test Loss: 0.2552, Test Acc: 0.9271
epoch: 4, Train Loss: 0.2429, Train Acc: 0.9290, Test Loss: 0.2251, Test Acc: 0.9349
epoch: 5, Train Loss: 0.2160, Train Acc: 0.9367, Test Loss: 0.2001, Test Acc: 0.9405
epoch: 6, Train Loss: 0.1945, Train Acc: 0.9433, Test Loss: 0.1854, Test Acc: 0.9447
epoch: 7, Train Loss: 0.1761, Train Acc: 0.9494, Test Loss: 0.1716, Test Acc: 0.9504
epoch: 8, Train Loss: 0.1601, Train Acc: 0.9540, Test Loss: 0.1597, Test Acc: 0.9527
epoch: 9, Train Loss: 0.1468, Train Acc: 0.9572, Test Loss: 0.1434, Test Acc: 0.9567

2.10可视化训练及测试的损失值

plt.title('Train Loss')
plt.plot(np.arange(len(losses)),losses);
plt.legend(['Train Loss'],loc='upper right')                   

损失函数的结果:
在这里插入图片描述

三、代码文件

小程序员将代码文件和相关素材整理到了百度网盘里,因为文件大小基本不大,大家也不用担心限速问题。后期小程序员有能力的话,将在gitee或者github上上传相关素材。
链接:https://pan.baidu.com/s/1Ce14ZQYEYWJxhpNEP1ERhg?pwd=7mvf
提取码:7mvf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/345474.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

云原生周刊 | 开源领导者应该如何应对碎片化挑战?

Linux Fundation 发布了一份关于开源开发中的碎片化问题的报告《实现全球协作:开源领导者如何应对碎片化挑战》,该报告由华为在美国的研发部门 Futurewei 赞助。报告指出,虽然开源社区越来越国际化,但美国对开源共享和开发进行了过…

源码项目中常见设计模式及实现

原文https://mp.weixin.qq.com/s/K8yesHkTCerRhS0HfB0LeA 单例模式 单例模式是指一个类在一个进程中只有一个实例对象(但也不一定,比如Spring中的Bean的单例是指在一个容器中是单例的) 单例模式创建分为饿汉式和懒汉式,总共大概…

Linux内核驱动开发(一)

Linux内核初探 linux操作系统历史 开发模式 git 分布式管理git clone 获取git push 提交git pull 更新 邮件组 mailing list patch 内核代码组成 Makfile arch 体系系统架构相关 block 块设备 crypto 加密算法 drivers 驱动(85%) atm 通信bluet…

MAC文件误删怎么办?mac数据恢复,亲测很好用的方法

电脑文件误删,应该很多人都经历过。之前分享了很多关于Windows电脑文件误删如何恢复的方法,那么MAC电脑文件误删该怎么办?有什么好方法可以使得mac数据恢复回来吗?下面就给大家分享一些亲测好用的方法! 一、MAC电脑的文…

使用Proxifier+burp抓包总结

一、微信小程序&网页抓包 1. Proxifier简介 Proxifier是一款功能非常强大的socks5客户端,可以让不支持通过代理服务器工作的网络程序能通过HTTPS或SOCKS代理或代理链。 2. 使用Proxifier代理抓包 原理:让微信相关流量先走127.0.0.1:80到burp。具体…

Final Cut Pro 10.6.5

软件介绍Final Cut Pro 10.6.5 已通过小编安装运行测试 100%可以使用。Final Cut Pro 10.6.5 破解版启用了全新的矩形图标,与最新的macOS Ventura设计风格统一,支持最新的macOS 13 文图拉系统,支持Apple M1/M2芯片。经过完整而彻底的重新设计…

数据结构之单链表

一、链表的组成 链表是由一个一个的节点组成的,节点又是一个一个的对象, 相邻的节点之间产生联系,形成一条链表。 例子:假如现在有两个人,A和B,A保存了B的联系方式,这俩人之间就有了联系。 A和…

HashMap底层实现原理概述

原文https://blog.csdn.net/fedorafrog/article/details/115478407 hashMap结构 常见问题 在理解了HashMap的整体架构的基础上,我们可以试着回答一下下面的几个问题,如果对其中的某几个问题还有疑惑,那就说明我们还需要深入代码&#xff0c…

ubuntu 20.04 安装 flameshot截图工具

ubuntu 20.04 安装 flameshot截图工具安装命令使用命令设置快捷键效果图安装命令 sudo apt-get install flameshot安装日志 $ sudo apt-get install flameshot [sudo] password for huifeimao: Reading package lists… Done Building dependency tree Reading state informat…

【零基础入门前端系列】—表格(五)

【零基础入门前端系列】—表格(五) 一、表格 表格在数据展示方面非常简单,并且表现优秀,通过与CSS的结合,可以让数据变得更加美观和整齐。 单元格的特点:同行等高、同列等宽。 表格的基本语法&#xff1…

性能测试之tomcat+nginx负载均衡

nginx tomcat 配置准备工作:两个tomcat 执行命令 cp -r apache-tomcat-8.5.56 apache-tomcat-8.5.56_2修改被复制的tomcat2下conf的server.xml 的端口号,不能与tomcat1的端口号重复,不然会启动报错 ,一台电脑上想要启动多个tomcat&#xff0c…

自定义bean 加载到spring IOC容器中

自定义bean加载到spring容器中的两种方式: 1.在类上添加注解Controller、RestController(本质是Controller)、Service、Repository、Component2.使用Configuration和Bean 这篇文章主要介绍第二种方式原理(因为在实际使用中&#…

SteaLinG:一款针对社工的开源安全渗透测试框架

关于SteaLinG SteaLinG是一款功能强大的开源渗透测试框架,该框架专为社会工程学研究人员设计,可以帮助广大研究人员或组织内的安全专家测试目标设备的安全性。该工具基于Python开发,因此具备良好的跨平台特性。在使用时,我们只需…

2023软考纸质证书领取通知来了!

不少同学都在关注2022下半年软考证书领取时间,截止至目前,上海、湖北、江苏、南京、安徽、山东、浙江、宁波、江西、贵州、云南、辽宁、大连、吉林、广西地区的纸质证书可以领取了。将持续更新2022下半年软考纸质证书领取时间,请同学们在证书…

信息安全保障

信息安全保障信息安全保障基础信息安全保障背景信息安全保障概念与模型基于时间的PDR模型PPDR模型(时间)IATF模型--深度防御保障模型(空间)信息安全保障实践我国信息安全保障实践各国信息安全保障我国信息安全保障体系信息安全保障…

SpringColud第四讲 Nacos的Windows安装方式和Linux的安装方式

在Nacos的GitHub页面,提供有下载链接,可以下载编译好的Nacos服务端或者源代码: 目录 1.Windows安装Nacos 1.1.下载 1.2.解压 1.3.修改相关配置: 1.4.启动: 1.5.登录: 2.Linux的安装方式Nacos 2.1.…

python cartopy手动导入地图数据绘制底图/python地图上绘制散点图:Downloading:warnings/散点图添加图里标签

……开学回所,打开电脑spyder一看一脸懵逼,简直不敢相信这些都是我自己用过的代码,想把以前的自己喊过来科研了() 废话少说,最近写小综述论文,需要绘制一个地图底图+散点图&#xff…

Cortex-M0存储器系统

目录1.概述2.存储器映射3.程序存储器、Boot Loader和存储器重映射4.数据存储器5.支持小端和大端数据类型数据对齐访问非法地址多寄存器加载和存储指令的使用6.存储器属性1.概述 Cortex-M0处理器具有32位系统总线接口,以及32位地址线(4GB的地址空间&…

TongWeb8数据源相关问题

问题一:数据源连接不足当TongWeb数据源连接用完时,除了监控中看到连接占用高以外,日志中会有如下提示信息。2023-02-14 10:24:43 [WARN] - com.tongweb.web.jdbc.pool.PoolExhaustedException: [TW-0.0.0.0-8088-3] Timeout: Pool empty. Una…

Hadoop高可用搭建(一)

目录 创建多台虚拟机 修改计算机名称 快速生效 修改网络信息 重启网络服务 关闭和禁用每台机的防火墙 同步时间 安装ntpdate 定时更新时间 启动定时任务 设置集群中每台机器的/etc/hosts 把hosts拷贝发送到每一台虚拟机 配置免密登陆 将本机的公钥拷贝到要免密登…