深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别

news2024/11/20 18:48:17

深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别

  • 一、前言
  • 二、网络结构
  • 二、可解释性
  • 三、记忆主线
  • 四、遗忘门
  • 五、输入门
  • 六、输出门
  • 七、手写数字识别实战
    • 7.1 引入依赖库
    • 7.2 加载数据
    • 7.3 迭代训练
    • 7.4 数据验证
  • 八、参考资料

一、前言

基本的RNN存在梯度消失和梯度爆炸问题,会忘记它在较长序列中以前看到的内容,只具有短时记忆。得到比较广泛应用的是LSTM(Long Short Term Memory)——长短期记忆网络,它在一定程度上解决了这两个问题。

二、网络结构

我们来看一下LSTM网络的结构图:
在这里插入图片描述
咱们放大看看,由于网上找不到清晰版的示例图,亲绘了一幅:
在这里插入图片描述
LSTM包含遗忘门、输入门、输出门。分别用于LSTM的三个步骤:旧记忆的遗忘、新记忆的输入、最终结果的输出。

二、可解释性

为什么要这么设计LSTM网络呢?我们打个比方:

小明上次考了数学,留下的大部分是数学的知识记忆 C t − 1 C_{t-1} Ct1;这次考生物,一些数学知识用不到,部分复杂的公式自然而然地被遗忘 f t ⊙ C t − 1 f_t\odot{C}_{t-1} ftCt1;复习生物知识一本书 C ~ t \tilde{C}_t C~t,大概记得八成 i t ⊙ C ~ t i_t\odot\tilde{C}_t itC~t,那么当前的记忆 C t = f t ⊙ C t − 1 + i t ⊙ C ~ t C_t=f_t\odot{C}_{t-1}+i_t\odot\tilde{C}_t Ct=ftCt1+itC~t;考试时,成绩受到考题和当前记忆 C t C_t Ct的影响 h t = O t ⊙ tanh ⁡ C t h_t=O_t\odot\tanh{C_t} ht=OttanhCt

注: ⊙ \odot 是矩阵的点乘符号,即两个矩阵对应元素相乘

三、记忆主线

在这里插入图片描述
如上图所示,原有记忆是 C t − 1 C_{t-1} Ct1,经过遗忘(用矩阵参数进行点乘)、添加新记忆(加上新的记忆矩阵),当前最新的记忆就变成了 C t C_{t} Ct,如此循环,不重要的记忆就会忘记、重要的记忆就会一直流传下去。

四、遗忘门

第一步,我们会遗忘部分原有的记忆。
在这里插入图片描述
如上图所示, f t = σ ( W x f x t + W h f h t − 1 + b f ) f_t=\sigma(W_{xf}x_t+W_{hf} h_{t-1}+b_f) ft=σ(Wxfxt+Whfht1+bf)
σ \sigma σ代表sigmoid函数。原有记忆是 C t − 1 C_{t-1} Ct1,遗忘后为 f t ⊙ C t − 1 f_t\odot{C}_{t-1} ftCt1

五、输入门

第二步,我们会新增部分新的记忆。我们要确定,哪些新信息要保留到记忆细胞里。
在这里插入图片描述
如上图所示,
C ~ t = t a n h ( W x c x t + W h c h t − 1 + b c ) i t = σ ( W x i x t + W h i h t − 1 + b i ) \begin{aligned} \tilde{C}_t&=tanh(W_{xc}x_t+W_{hc}h_{t-1} +b_c)\\ i_t&=\sigma(W_{xi}x_t+W_{hi} h_{t-1}+b_i) \end{aligned} C~tit=tanh(Wxcxt+Whcht1+bc)=σ(Wxixt+Whiht1+bi)

C ~ t \tilde{C}_t C~t表示所有的输入信息,但我们不是所有的都记得, i t i_t it控制记忆程度, i t ⊙ C ~ t i_t\odot\tilde{C}_t itC~t是本次输入所记住的信息。
遗忘后的记忆是 f t ⊙ C t − 1 f_t\odot{C}_{t-1} ftCt1,输入新的记忆后, C t = f t ⊙ C t − 1 + i t ⊙ C ~ t C_t=f_t\odot{C}_{t-1}+i_t\odot\tilde{C}_t Ct=ftCt1+itC~t

六、输出门

第三步,我们要根据现有记忆 C t C_t Ct,输出我们需要的内容。
在这里插入图片描述
如上图所示,
O t = σ ( W x o x t + W h o h t − 1 + b o ) h t = O t ⊙ tanh ⁡ ( C t ) \begin{aligned} O_t&=\sigma(W_{xo}x_t+W_{ho} h_{t-1}+b_o)\\ h_t&=O_t\odot\tanh(C_t) \end{aligned} Otht=σ(Wxoxt+Whoht1+bo)=Ottanh(Ct)

这就是LSTM网络的思想原理,接下来我们将用于手写数字识别实战。

七、手写数字识别实战

7.1 引入依赖库

import torch
import torch.nn as nn
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

7.2 加载数据

train_data = datasets.MNIST(root="./data",train=True,transform=transforms.ToTensor(),download=False)
batch_size=64

train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)

test_data = datasets.MNIST(root="./data",train=False,transform=transforms.ToTensor(),download=False)
test_x = test_data.data.type(torch.FloatTensor)[:2000]/255.   #取2000个样本数据并将其缩放为0~1范围
test_y = test_data.targets[:2000]

print(train_data.data.shape)


torch.Size([60000, 28, 28])

7.3 迭代训练


#迭代次数
epochs=1

#学习率
learning_rate=0.01

plt_epoch=[]
plt_loss=[]

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.rnn = nn.LSTM(     # LSTM 效果要比 nn.RNN() 好多了
            input_size=28,      # 图片每行的数据像素点
            hidden_size=64,     # rnn hidden unit
            num_layers=1,       # 有几层 RNN layers
            batch_first=True,   # input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size)
        )

        self.out = nn.Linear(64, 10)    # 输出层

    def forward(self, x):
        # x shape (batch, time_step, input_size)
        # r_out shape (batch, time_step, output_size)
        # h_n shape (n_layers, batch, hidden_size)   LSTM 有两个 hidden states, h_n 是分线, h_c 是主线
        # h_c shape (n_layers, batch, hidden_size)
        r_out, (h_n, h_c) = self.rnn(x, None)   # None 表示 hidden state 会用全0的 state

        # 选取最后一个时间点的 r_out 输出
        # 这里 r_out[:, -1, :] 的值也是 h_n 的值
        out = self.out(r_out[:, -1, :])
        return out

model = MyModel()

#损失函数
cost=nn.CrossEntropyLoss()
#迭代优化器
optmizer=torch.optim.Adam(model.parameters(),lr=learning_rate)

for epoch in range(epochs):

    for step, (images, labels) in enumerate(train_loader):

        images=images.view(-1,28,28)

        #预测结果
        output=model(images) #调用__call__函数

        #计算损失值
        loss=cost(output,labels)

        #在反向传播前先把梯度清零
        optmizer.zero_grad()

        #反向传播,计算各参数对于损失loss的梯度
        loss.backward()

        #根据刚刚反向传播得到的梯度更新模型参数
        optmizer.step()
    
        plt_epoch.append(step+1)
        plt_loss.append(loss.item())
        
        #打印损失值
        if step % 50 == 0:
            pred_y = model(test_x)
            pred_y=pred_y.argmax(dim=1) #返回最大值的下标
            print(f"step:{step},loss:{loss.item():.4f},accuracy: {(torch.sum(pred_y == test_y)/test_y.size()[0]) * 100:.2f}%")


# 保存模型
torch.save(model, 'LSTM_Digits.pt')

#绘制迭代次数与损失函数的关系
plt.plot(plt_epoch,plt_loss)
step:0,loss:2.3081,accuracy: 8.75%
step:50,loss:1.0913,accuracy: 59.40%
step:100,loss:0.7879,accuracy: 70.30%
step:150,loss:0.7618,accuracy: 73.75%
step:200,loss:0.4271,accuracy: 86.70%
step:250,loss:0.3963,accuracy: 90.65%
step:300,loss:0.2965,accuracy: 91.85%
step:350,loss:0.3396,accuracy: 94.15%
step:400,loss:0.2283,accuracy: 92.30%
step:450,loss:0.4932,accuracy: 94.05%
step:500,loss:0.2487,accuracy: 93.25%
step:550,loss:0.1460,accuracy: 94.20%
step:600,loss:0.1908,accuracy: 94.70%
step:650,loss:0.1521,accuracy: 92.35%
step:700,loss:0.1530,accuracy: 94.80%
step:750,loss:0.1192,accuracy: 94.65%
step:800,loss:0.0478,accuracy: 95.30%
step:850,loss:0.0535,accuracy: 95.70%
step:900,loss:0.1174,accuracy: 95.45%

在这里插入图片描述

7.4 数据验证

#加载模型
model=torch.load('LSTM_Digits.pt')

#预测结果
pred_y=model(test_x)
#计算损失值
loss=cost(pred_y,test_y)

print('loss:',loss.detach().item())

pred_y=pred_y.argmax(dim=1) #返回最大值的下标
print(f"准确率: {(torch.sum(pred_y == test_y)/test_y.size()[0]) * 100}%")

# 打印10个预测结果
pred_y = model(test_x[:10].view(-1, 28, 28))
pred_y=pred_y.argmax(dim=1) #返回最大值的下标
print('预测数字',pred_y)
print( '真实数字',test_y[:10])
loss: 0.11265470087528229
准确率: 96.45000457763672%
预测数字 tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])
真实数字 tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])

八、参考资料

《如何从RNN起步,一步一步通俗理解LSTM》
《大白话讲解LSTM长短期记忆网络 如何缓解梯度消失,手把手公式推导反向传播》
《Understanding LSTM Networks》
《【Pytorch教程】:RNN 循环神经网络 (分类)》

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/62340.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

毕业设计-基于机器视觉的深蹲检测识别-TensorFlow-opencv

目录 前言 课题背景和意义 实现技术思路 实现效果图样例 前言 📅大四是整个大学期间最忙碌的时光,一边要忙着备考或实习为毕业后面临的就业升学做准备,一边要为毕业设计耗费大量精力。近几年各个学校要求的毕设项目越来越难,有不少课题是研究生级别难度的,对本科…

LeetCode刷题复盘笔记—一文搞懂完全背包之377. 组合总和 Ⅳ问题(动态规划系列第十二篇)

今日主要总结一下动态规划完全背包的一道题目,377. 组合总和 Ⅳ 题目:377. 组合总和 Ⅳ Leetcode题目地址 题目描述: 给你一个由 不同 整数组成的数组 nums ,和一个目标整数 target 。请你从 nums 中找出并返回总和为 target 的…

[附源码]计算机毕业设计基于web的羽毛球管理系统

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

Python软件编程等级考试一级——20220915

Python软件编程等级考试一级——20220915理论单选题判断题实操第一题第二题理论 单选题 1、表达式len(“学史明理增信 ,读史终生受益”) > len(" reading history will benefit you ")的结果是? A、0 B、True C、False D、1 2、表达…

SLMi333国内首款兼容光耦带DESAT保护功能的隔离式栅极驱动器

SLMi333国内首款兼容光耦带DESAT保护功能的隔离式栅极驱动器,内置快速去饱和(DESAT)故障检测功能,米勒钳位功能,漏极开路故障反馈,软关断功能以及可选择的自恢复模式,兼容光耦隔离驱动器,一款高…

安装mongodb6

一、安装mongodb6.0.2 1.官网下载社区版 https://www.mongodb.com/ 2.双击下载的文件,按步骤安装 选择custom 自定义安装 改一下安装地址,路径最好不要带空格 Install MongoD as a Service 作为服务方式安装 Run the service as Network Service…

SuperMap iClient for Leaflet对EPSG:4509图加载滑动查询

作者:John SuperMap iClient for Leaflet对EPSG:4509地图加载&滑动查询 在WebGIS开发使用中,我们会遇到地图显示不了,以及查询到数据显示不出的问题,因此本文就以EPSG:4509为例介绍该坐标系地图加载和查询。 1、EPSG:4509地图…

数据分析案例:基于水色图像的水质识别

大数据分析课程、大数据分析班、大数据案例等,围绕大数据展开讲解。 数据分析案例:基于水色图像的水质识别,通过学习本案例,可以掌握图像切割、特征提取、模型构建和模型评价的主要方法和技能,并为后续相关课程学习及将…

蚂蚁面试官:Zookeeper 的选举流程是怎样的?我当场懵逼了

​ 编辑切换为居中 添加图片注释,不超过 140 字(可选) 面试经常会遇到面试官问 Zookeeper 的选举原理,我心想,问这些有啥用吗?又不要我造火箭! 每次面试也只知道个大概,并没有深究…

分布式电源接入对配电网影响的研究(Matlab代码实现)

👨‍🎓个人主页:研学社的博客 💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜…

vue3 + element 从0到1搭建前端基础框架

一、框架搭建 框架代码 个人博客 往往从0到1开发项目时发觉无从下手,或者很可能一步一个坑,因为大多基础框架公司已经搭建完毕的,新加入的成员也都是在此基础上进行功能模块的拓展。网上也鲜有详尽的全流程参考,多是某个局部功能的…

【Vue】webpack的基本使用

✍️ 作者简介: 前端新手学习中。 💂 作者主页: 作者主页查看更多前端教学 🎓 专栏分享:css重难点教学 Node.js教学 从头开始学习 ajax学习 文章目录webpack的学习目标前端工程化 小白眼中的前端开发 vs 实际的前端开发 什么是前端工程…

CISP考试大纲/范围

CISP考试主要是考CISP知识体系大纲,分别为信息安全保障、信息安全技术、信息安全管理、信息安全工程和信息安全标准法规这五大知识类,每个知识类根据其逻辑划分为多个知识体,每个知识体包含多个知识域,每个知识域由一个或多个知识…

Java项目:SSM失物招领管理系统

作者主页:源码空间站2022 简介:Java领域优质创作者、Java项目、学习资料、技术互助 文末获取源码 项目介绍 主要功能包括: 用户发布失物,或者招领失物,管理员对用户,失物信息进行增删改查。 环境需要 1…

新课程教学杂志新课程教学杂志社新课程教学编辑部2022年第19期目录

核心素养 核心素养视域下的历史教学设计——以“清朝君主专制的强化”为例 王威; 1-3 新中考背景下文本分析能力与核心素养的培育 黄嫄; 4-5《新课程教学》投稿:cn7kantougao163.com 基于核心素养的物理教学评价改良 李红; 6-7 初中语文综合性学习的…

Metabase学习教程:系统管理-6

Metabase可扩展性 扩展Metabase以支持更多人和数据库的最佳实践。 Metabase是一个可扩展的、经过实战的软件,被成千上万的公司用来提供高质量的自助服务分析。它通过水平扩展支持高可用性,而且它是开箱即用的高效工具:一台拥有4gb内存的单核…

vue.js axios 数据不刷新

getServerList(){axios.get(/server/showList).then(function(response){this.servers response.data // 不刷新console.log(response.data)}).catch(function (error) {console.log(error);}); } 打印this:this不是vue对象修改为:getServerList(){axi…

Mysql各种缓冲区的功能及之间的联系

buffer poolmysql数据存放在磁盘里面,如果每次查询都直接从磁盘里面查询,会影响性能,因此需要内存态缓存池。另外缓存池的淘汰机制不是基础LRU,而是是改进版LRU,防止大量临时缓存挤出热点数据。buffer pool读缓存分为老…

代码随想录算法训练营第五十三天| LeetCode1143. 最长公共子序列、LeetCode1035. 不相交的线、LeetCode53. 最大子数组和

一、LeetCode1143. 最长公共子序列 1:题目描述(1143. 最长公共子序列) 给定两个字符串 text1 和 text2,返回这两个字符串的最长 公共子序列 的长度。如果不存在 公共子序列 ,返回 0 。 一个字符串的 子序列 是指这样一…

Leetcode 1687. 从仓库到码头运输箱子 [四种解法] 动态规划 从朴素出发详细剖析优化步骤

你有一辆货运卡车,你需要用这一辆车把一些箱子从仓库运送到码头。这辆卡车每次运输有 箱子数目的限制 和 总重量的限制 。给你一个箱子数组 boxes 和三个整数 portsCount, maxBoxes 和 maxWeight ,其中 boxes[i] [ports​​i​, weighti] 。ports​​i …