【图像分类】基于PyTorch搭建LSTM实现MNIST手写数字体识别(双向LSTM,附完整代码和数据集)

news2024/11/15 22:49:23

写在前面:
首先感谢兄弟们的关注和订阅,让我有创作的动力,在创作过程我会尽最大能力,保证作品的质量,如果有问题,可以私信我,让我们携手共进,共创辉煌。

在https://blog.csdn.net/AugustMe/article/details/128969138文章中,我们使用了基于PyTorch搭建LSTM实现MNIST手写数字体识别,LSTM是单向的,现在我们使用双向LSTM试一试效果,和之前的单向LSTM模型稍微有差别,请注意查看代码的变化。

1.导入依赖库

这些依赖库是必须导入的,用于后续代码的构建:

import torch
from torch import nn, optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

依赖库的版本信息:

torch: 1.8.0+cpu
numpy: 1.19.3
matplotlib: 3.2.1
pillow: 7.2.0

2.数据集

训练模型肯定少不了数据集,本教程使用我们以比较熟悉的 mnist 数据集,该数据集是手写数字数据集,每一张图片得大小为28×28,训练集60000张,测试集10000张,mnist数据集下载代码如下:

# 训练集
train_data = datasets.MNIST(root="./",    # 存放位置
                            train = True, # 载入训练集
                            transform=transforms.ToTensor(), # 把数据变成tensor类型
                            download = True    # 下载
                           )
# 测试集
test_data = datasets.MNIST(root="./",
                            train = False,
                            transform=transforms.ToTensor(),
                            download = True
                           )

这个mnist下载成功与否,还和你的网络有关系,有时候网络不好,可能会导致下载失败。如果你下载不下来,可以联系我,我将数据集打包发给你。

下载得到的数据集存放如下:

在这里插入图片描述

3.数据导入

数据下载成功后,加载下载得到的数据集,核心代码如下:

# 批次大小
batch_size = 32
# 装载训练集
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
# 装载测试集
test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False)

我们查看一下数据集中的图片,核心代码为:

# batch_size设为 1 时查看
for i, data in enumerate(train_loader):
    inputs, labels = data
    print(inputs.shape)
    print(labels.shape)
    img = inputs.view((28,28))
    print(img.shape)
    # plt.imshow(img)
    plt.imshow(img, cmap='gray')
    break

plt.imshow(img, cmap=‘gray’)
在这里插入图片描述
plt.imshow(img):

在这里插入图片描述

4.双向LSTM网络

Long Short-Term Memory (LSTM) 是一种特殊的循环神经网络,它能够处理较长的序列,并且能够记忆长期的依赖关系。LSTM 的结构包括输入门、输出门、遗忘门和记忆细胞,它们共同组成了一个“门控循环单元”,可以控制信息的流动,从而实现长期依赖关系的学习。LSTM 在自然语言处理、语音识别、机器翻译等领域有着广泛的应用。

基于pytorch深度学习框架搭建LSTM网络模型,使用了双向LSTM,一层:

这里面模型和之前的文章稍有不同,注意 output,(h_n,c_n)三个值的输出。

# 定义网络结构
class LSTM(nn.Module):
    def __init__(self):
        super(LSTM,self).__init__()   # 初始化
        self.lstm = nn.LSTM(
            input_size = 28,       # 表示输入特征的大小
            hidden_size = 64,      # 隐藏层的特征维度
            num_layers = 1,        # 表示lstm隐藏层的层数
            batch_first = True,    # lstm默认格式input(seq_len,batch,feature)
                                   # 等于True表示input和output变成(batch,seq_len,feature)
            bidirectional = True  # True则为双向lstm默认为False
        )
        self.out = torch.nn.Linear(in_features=64*2, out_features=10)
        self.softmax = torch.nn.Softmax(dim=1) # 映射到0-1之间
    
    def forward(self,x):
        # (batch, seq_len, feature)
        x = x.view(-1, 28, 28)
        # output:(batch,seq_len,hidden_size)包含每个序列的输出结果
        # 虽然lstm的batch_first为True,但是h_n,c_n的第0个维度还是num_layers
        # h_n :[num_layers,batch,hidden_size]只包含最后一个序列的输出结果
        # c_n:[num_layers,batch,hidden_size]只包含最后一个序列的输出结果
        output,(h_n,c_n) = self.lstm(x) # x输入到lstm
        output_in_last_timestep = output[:,-1,:] # 获取下一个输入
        x = self.out(output_in_last_timestep) # 输入到out
        x = self.softmax(x)  # 输入到softmax
        return x

特别说明:

LSTM中存在维度的变化,一定要注意,下面以实例进行讲解,请看下面的代码和注释。
h_n包含的是句子的最后一个单词的隐藏状态,c_n包含的是句子的最后一个单词的细胞状态,所以它们都与句子的长度seq_length无关。output[:,-1,:]与h_n是相等的,因为output[-1]包含的正是batch_size个句子中每一个句子的最后一个单词的隐藏状态,注意LSTM中的隐藏状态其实就是输出,cell state细胞状态才是LSTM中一直隐藏的,记录着信息,output与h_n的关系。

实验代码,仅供参考:

# -*- coding: utf-8 -*-
"""
Created on Fri Feb 10 15:25:40 2023

@author: augustqi

维度变化:
https://blog.csdn.net/qq_54867493/article/details/128790652
"""

import torch
import torch.nn as nn


input_x = torch.randn(1, 28, 28)  
print(input_x.shape)

input_x_ = input_x.view(-1, 28, 28)
print(input_x_.shape)


lstm = nn.LSTM(
            input_size = 28,       # 输入数据的特征维数,通常就是embedding_dim(词向量的维度)
            hidden_size = 64,      # 隐藏层的特征维度
            num_layers = 1,        # 表示lstm循环神经网络的层数
            batch_first = True,    # lstm默认格式input(seq_len,batch,feature)
                                   # 等于True表示input和output变成(batch,seq_len,feature)
            bidirectional = True  # True则为双向lstm默认为False
        )

linear = torch.nn.Linear(in_features=64*2, out_features=10)


softmax = torch.nn.Softmax(dim=1)


output, (h_n, c_n) = lstm(input_x_)


'''
output的维度:(batch, seq_len, num_directions*hidden_size)
hn的维度:(num_directions*num_layer, batch_size, hidden_size)
cn的维度:同hn
'''


print(output)
# 如果bidirectional=True, num_directions=2; 如果bidirectional=False, num_directions=1
print(output.shape)  # [seq_length, batch_size, num_directions * hidden_size]


print(output[:,-1,:])
print(output[:,-1,:].shape)

print(h_n)
print(h_n.shape) #  [num_directions * num_layers, batch, hidden_size]

print(c_n)
print(c_n.shape) # c_n.shape = h_n.shape


print(h_n[-1,:,:])
print(h_n[-1,:,:].shape) 

linear_out = linear(h_n[-1,:,:])

softmax_out = softmax(linear_out)


linear_out_2 = linear(output[:,-1,:])
softmax_out_2 = softmax(linear_out_2)

"""
h_n包含的是句子的最后一个单词的隐藏状态,c_n包含的是句子的最后一个单词的细胞状态,
所以它们都与句子的长度seq_length无关。
output[:,-1,:]与h_n是相等的,因为output[-1]包含的正是batch_size个句子中每一个句子的最后一个单词的隐藏状态,
注意LSTM中的隐藏状态其实就是输出,cell state细胞状态才是LSTM中一直隐藏的,记录着信息,output与h_n的关系。

"""

5.模型训练

训练代码如下,主要包括定义模型、定义损失函数、定义优化器,训练时的超参数,详情如下:

# 定义模型
model = LSTM()
# 定义代价函数
mse_loss = nn.CrossEntropyLoss()   # 交叉熵
# 定义优化器
optimizer = optim.Adam(model.parameters(),lr=0.001) # Adam

Epoch = 30
loss_train_list = []
loss_test_list = []
# 训练
for epoch in range(Epoch):
    # 模型的训练状态
    model.train()
    correct_train = 0
    loss_train = 0
    for i, data in enumerate(train_loader):
        # 获得一个批次的数据和标签
        inputs, labels = data
        # 获得模型预测结果(64,10)
        out = model(inputs)
        # 获得最大值,以及最大值所在的位置
        _, predicted = torch.max(out, 1)
        # 预测正确的数量
        correct_train += (predicted==labels).sum()
        # 交叉熵代价函数out(batch,C:类别的数量),labels(batch)
        loss = mse_loss(out, labels)
        loss_train += loss.item()  # loss.data, tensor(1.4612)
        # 梯度清零
        optimizer.zero_grad()
        # 计算梯度
        loss.backward()
        # 修改权值
        optimizer.step()     
    
    loss_train_list.append(loss_train/len(train_data))
    print("Epoch:{}/{}, Train acc:{:.4f}, Loss:{:.6f}".format(epoch+1, Epoch, (correct_train.item()/len(train_data)),  
          (loss_train/len(train_data))))

6.模型测试

每训练完一个epoch,就使用测试集测试一下模型,输出测试精度和损失情况:

# 模型的测试状态
model.eval()
correct_test = 0 # 测试集准确率
loss_test = 0
for i, data in enumerate(test_loader):
    # 获得一个批次的数据和标签
    inputs, labels = data
    # 获得模型预测结果(64,10)
    out = model(inputs)
    # 获得最大值,以及最大值所在的位置
    _,predicted = torch.max(out, 1)
    # 预测正确的数量
    correct_test += (predicted==labels).sum()
    loss = mse_loss(out, labels)
    loss_test += loss.item()  # loss.data, tensor(1.4612)

loss_test_list.append(loss_test/len(test_data))
print("Test acc:{:.4f}, Loss:{:.6f}".format(correct_test.item()/len(test_data), 
      loss_test/len(test_data)))

7.损失可视化

训练30个epoch,终端输出情况:

Epoch:1/30, Train acc:0.7438, Loss:0.054061
Test acc:0.8521, Loss:0.050427
Epoch:2/30, Train acc:0.8615, Loss:0.050059
Test acc:0.9322, Loss:0.047967
Epoch:3/30, Train acc:0.9387, Loss:0.047655
Test acc:0.9546, Loss:0.047182
Epoch:4/30, Train acc:0.9506, Loss:0.047248
Test acc:0.9618, Loss:0.046989
Epoch:5/30, Train acc:0.9620, Loss:0.046881
Test acc:0.9593, Loss:0.047013
Epoch:6/30, Train acc:0.9638, Loss:0.046818
Test acc:0.9630, Loss:0.046920
Epoch:7/30, Train acc:0.9647, Loss:0.046787
Test acc:0.9664, Loss:0.046818
Epoch:8/30, Train acc:0.9680, Loss:0.046681
Test acc:0.9700, Loss:0.046682
Epoch:9/30, Train acc:0.9698, Loss:0.046619
Test acc:0.9686, Loss:0.046729
Epoch:10/30, Train acc:0.9736, Loss:0.046505
Test acc:0.9710, Loss:0.046664
Epoch:11/30, Train acc:0.9761, Loss:0.046428
Test acc:0.9711, Loss:0.046657
Epoch:12/30, Train acc:0.9768, Loss:0.046398
Test acc:0.9771, Loss:0.046465
Epoch:13/30, Train acc:0.9784, Loss:0.046350
Test acc:0.9783, Loss:0.046434
Epoch:14/30, Train acc:0.9796, Loss:0.046312
Test acc:0.9773, Loss:0.046442
Epoch:15/30, Train acc:0.9809, Loss:0.046278
Test acc:0.9794, Loss:0.046393
Epoch:16/30, Train acc:0.9808, Loss:0.046270
Test acc:0.9789, Loss:0.046409
Epoch:17/30, Train acc:0.9807, Loss:0.046278
Test acc:0.9766, Loss:0.046474
Epoch:18/30, Train acc:0.9816, Loss:0.046243
Test acc:0.9793, Loss:0.046388
Epoch:19/30, Train acc:0.9840, Loss:0.046169
Test acc:0.9799, Loss:0.046367
Epoch:20/30, Train acc:0.9846, Loss:0.046152
Test acc:0.9823, Loss:0.046316
Epoch:21/30, Train acc:0.9853, Loss:0.046132
Test acc:0.9833, Loss:0.046268
Epoch:22/30, Train acc:0.9862, Loss:0.046103
Test acc:0.9814, Loss:0.046317
Epoch:23/30, Train acc:0.9850, Loss:0.046141
Test acc:0.9804, Loss:0.046343
Epoch:24/30, Train acc:0.9865, Loss:0.046091
Test acc:0.9815, Loss:0.046316
Epoch:25/30, Train acc:0.9873, Loss:0.046067
Test acc:0.9833, Loss:0.046262
Epoch:26/30, Train acc:0.9879, Loss:0.046048
Test acc:0.9813, Loss:0.046331
Epoch:27/30, Train acc:0.9870, Loss:0.046073
Test acc:0.9837, Loss:0.046250
Epoch:28/30, Train acc:0.9891, Loss:0.046014
Test acc:0.9830, Loss:0.046271
Epoch:29/30, Train acc:0.9875, Loss:0.046061
Test acc:0.9821, Loss:0.046299
Epoch:30/30, Train acc:0.9888, Loss:0.046023
Test acc:0.9815, Loss:0.046324

训练集上损失曲线图:

在这里插入图片描述

测试集上损失曲线图:

在这里插入图片描述

训练30个epoch后,模型在测试集上的精度达到98.15%,效果还不错。训练集上的损失和测试集上的损失都在下降并逐渐收敛。

参考资料

1.https://blog.csdn.net/AugustMe/article/details/128969138

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/342636.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【CSS】元素居中总结-水平居中、垂直居中、水平垂直居中

【CSS】元素居中一、 水平居中1.行内元素水平居中(1)text-align2.块级元素水平居中2.1 margin(1)margin2.2布局(1)flex justify-content(推荐)(2) flexmargin…

张驰咨询:关于六西格玛,有一些常见的疑惑!

​ 很多想要学习六西格玛的学员,经常会有这些困惑: 以前没有接触过六西格玛,需要什么基础吗?自学还是培训?哪些行业会用到六西格玛呢?学习六西格玛对以后的工作有哪些帮助?如何选择六西格玛培…

STM32配置读取双路24位模数转换(24bit ADC)芯片CS1238数据

STM32配置读取双路24位模数转换(24bit ADC)芯片CS1238数据 CS1238是一款国产双路24位ADC芯片,与CS1238对应的单路24位ADC芯片是CS1237,功能上相当于HX711和TM7711的组合。其功能如下所示: 市面上的模块: …

股票买卖接口怎么来的?

现在股票买卖接口主要是在线上研发,有专业的开发团队进行源码开发和完善,但是,常常会在开发过程中出现问题,也就是遇到一些特殊的情况需要及时处理,那么股票买卖接口怎么开发实现出来的?一、股票买卖接口开…

案例分享| 助力数字化转型:广州期货交易所全栈信创项目管理平台上线

广州期货交易所项目管理平台基于易趋(easytrack)进行实施,通过近半年的开发及试运行,现已成功交付上线、推广使用,取得了良好的应用效果。1. 关于广州期货交易所(以下简称广期所)广期所于2021年…

MySQL8.0安装教程

文章目录1.官网下载MySQL2.下载完记住解压的地址(一会用到)3.进入刚刚解压的文件夹下,创建data和my.ini在根目录下创建一个txt文件,名字叫my,文件后缀为ini,之后复制下面这个代码放在my.ini文件下&#xff…

华为手表开发:WATCH 3 Pro(4)创建项目 + 首页新建按钮,修改初始文本

华为手表开发:WATCH 3 Pro(4)创建项目 首页新建按钮,修改初始文本初环境与设备创建项目创建项目入口配置项目认识目录结构修改首页初始文本文件名:index.hml新建按钮 “ 按钮 ”index.hml初 鸿蒙可穿戴开发 希望能写…

直播预告 | 对谈谷歌云 DORA 布道师:聊聊最关键的四个 DevOps 表现指标

本期分享 DORA 的全称是 DevOps Research and Assessment,是一个致力于 DevOps 调研与研究的组织,2018 年加入 Google Cloud。自 2014 年起,DORA 每年会发布一份行业报告,基于对数千名从业者的调研,分析高效能团队与低…

联想K14电脑开机全屏变成绿色无法使用怎么U盘重装系统?

联想K14电脑开机全屏变成绿色无法使用怎么U盘重装系统?最近有用户使用联想K14电脑的时候,开机后桌面就变成了绿色的背景显示,无法进行任何的操作。而且通过强制重启之后还是会出现这个问题,那么这个情况如何去进行系统重装呢&…

PMP证书要怎么考,含金量怎么样?

很多朋友在对PMP不是了解的时候,会有些犹豫,PMP证书到底值不值得考。考下来有用吗? PMP证书当然有用,要含金量有含金量,要专业知识有专业知识,不过要是考了不用,久而久之忘了学习的内容&#x…

怿星科技校招礼盒:我想开了

校招礼盒大揭秘为了帮助2023届新同学快速了解怿星文化增强认同感经过1个多月的精心准备我们的校招大礼盒终于跟大家见面啦!!我们用了大量的公司IP形象-小怿通过各式各样的姿势和表情欢迎新同学的到来搭配着IP的蓝色色调传递出一种科幻与探索的感觉希望加…

计算机组成原理:1. 计算机系统概论

更好的阅读体验\huge{\color{red}{更好的阅读体验}}更好的阅读体验 文章目录1.1 计算机系统简介1.1.1 计算机软硬件概念1.1.2 计算机的层次1.1.3计算机组成和计算机体系结构1.2 计算机的基本组成1.2.1 冯诺伊曼计算机的特点1.2.2 计算机的硬件框图1.2.3 计算机的工作步骤1.3 计…

问卷调查会遇到哪些问题?怎么解决?

提到问卷调查我们并不陌生,它经常被用作调查市场、观察某类群体的行为特征等多种调查中。通过问卷调查得出的数据能够非常真实反映出是市场的现状和变化趋势,所以大家经常使用这个方法进行调查研究。不过,很多人在进行问卷调查的时候也会遇到…

JAVA八股、JAVA面经

还有三天面一个JAVA软件开发岗,之前完全没学过JAVA,整理一些面经...... 大佬整理的:Java面试必备八股文_-半度的博客-CSDN博客 另JAVA学习资料:Java | CS-Notes Java 基础Java 容器Java 并发Java 虚拟机Java IO目录 int和Inte…

电商新趋势来临!?解析Dtop 环球嘉年华电商是否值得加入!

近年来,电商平台的发展瞬息万变,加上疫情的推波助澜,让全球的电子商务来到前所未有的光景,营业销售额直达颠覆性的增长。 许多商家也因此纷涌而入,谋划分得电子商务的一杯羹。随着参与成为电商的商家日益剧增,商家们想从中谋利也不是件易事。再加上市场不断洗牌的形势下,传统电…

楔形文字的起源全2课-北京大学拱玉书 笔记

楔形文字的起源全2课-北京大学拱玉书 说明:以下图片素材均出自视频 【楔形文字的文源】 《吉尔伽美什史诗》记载的楔形文字起源: 学术界对楔形文字起源的总结: 【关于文字的演化理论发展史】 威廉.瓦尔伯顿提出文字“叙事图画”演变说…

浅谈 RBAC 权限模型

写作背景 工作两年半了,笔者一直在做 To B 的产品,像是后端管理系统、Saas 系统都有接触过,它们都有一个共同点:权限管理。我每天都在接触但只是从前端开发这个角色去理解,我对整个业务流程其实是比较模糊的&#xff…

第三部分:(主从)复合句——第一章:名词性从句

回顾:第二部分讲解的是并列句,即多件同等重要的事通过并列连词进行相连接,构成并列句 但是,现实生活中并不是许多事都是同等重要的,复合句就出现了,复合句全称为主从复合句 复合句 多件事不一样重要 主句…

GIS在地质灾害危险性评估与灾后重建中的实践技术应用及python机器学习灾害易发性评价模型建立与优化进阶

除滑坡灾害外,还包括崩塌、泥石流、地面沉降等各种地质灾害,具有类型多样、分布广泛、危害性大的特点。地质灾害危险性评价着重于根据多种影响因素和区域选择来评估在某个区域中某个阶段发生的地质灾害程度。以此预测和分析未来某个地形单位发生地质灾害…

社区买菜业务流程

前言 最近由于疫情的原因,很多城市的小区都不允许快递员上门送货了,用户只能到小区指定的地点进行取货。 多点、叮咚买菜、美菜、盒马等电商着实火了一把,每天的订单量都非常的多,他们都依托于超市或线下门店等进行接单、商品打…