深度学习基础--LSTM学习笔记(李沐《动手学习深度学习》)

news2025/1/22 3:54:00

前言

  • LSTM是RNN模型的升级版,神经网络模型较为复杂,这里是学习笔记的记录;
  • LSTM比较复杂,可以先看:
    • 深度学习基础–一文搞懂RNN
    • 深度学习基础–GRU学习笔记(李沐《动手学习深度学习》)
  • RNN:RNN讲解
  • 参考:李沐动手学习深度学习;
  • 欢迎收藏加关注,本人将会持续更新。

    文章目录

        • 长距离依赖问题
        • LSTM的核心思想
        • LSTM门简介
          • 三个输入门
          • 候选记忆单元
          • 记忆状态
          • 隐状态
          • 总结
        • Pytorch实践
        • 参考资料

LSTM也称为长短期记忆网络,他说RNN、GRU的升级版,它能够学到长期依赖,说白了,RNN是理解一句话,但是LSTM就是理解一段话.

长距离依赖问题

RNN模型中,核心的是有一个隐藏层,这个隐藏层记录之前的信息,但是这个隐藏层的每次更新,权重都是一样的,但是我们生活中不是所有信息都是等价的,[知乎大佬一个案例](LSTM - 长短期记忆递归神经网络 - 知乎):

在这里插入图片描述

我们看到这句话,核心就是几个关键词:“纸好”、“没味道”、“便宜”、“质量好”,我们看完这句话其实和看到这几个关键词没什么大得区别,从这来看,这里也可以得出两点:

  • 在一个时间序列中,前后信息不是所有都是等效的,“关键词”往往最核心,也有一些词“没有啥效果”;
  • 我们在从左到右阅读的时候,脑子自动会帮我们过滤掉一些无用的信息,只留下一些“关键词”的理解,并且能够利用之前的信息去理解后面的信息,这也是我们熟悉的“上下文”;

LSTM也称“长短期记忆网络”,他的核心就是**“记忆”**,有点像我们大脑一样,对于过去的一些信息,有些“忘记”,有些“记得牢”,也有些“只是有个印象”。

LSTM的核心思想

相比于RNN,LSTM的核心就是,除了有隐藏态ht 之外,还有Ct, Ct代表T这个时刻的记忆,从Ct-1计算得来,用于信息的赛选,对重要信息进行保留,如图:

在这里插入图片描述

那怎么进行保留呢?对上一层的信息Ct-1保留,无非就是全部保留,全部不保留,或者保留一部分,这样的话就需要输入一个[0, 1]之间的值,而这个在神经网络中,有一个激活函数可以很好的做到,叫做:sigmoid,记忆保留过程,如图:

输入0,全部不要;输入1,全部保留;输入(0, 1),保留部分信息。

LSTM门简介

LSTM有三个门,分别是:

  • 忘记门(遗忘门):将朝着0减少
  • 输入门:决定是不是要忽略输入数据
  • 输出门:决定是不是要使用隐状态

👀 提示:一下数学公式组合成一块,我感觉就不是那么容易理解了,但是能大概理解即可,后面在案例中实践学习。

三个输入门

首先数据经过输入、输出、遗忘门,这三个门第一步都是做线性运算+激活函数进行非线性运算,由于是RNN的升级版,故都会吸取前面的特征Ht-1

在这里插入图片描述

候选记忆单元

候选记忆单元经过先进过线性计算,在经过激活函数tanh的作用,将函数值映射到[-1,1]之间,这个的作用需要结合记忆状态更新来看,结合隐藏层更新公式,可以发现,这个其实的作用可以理解为:对当前的输入信息“记忆多少”

在这里插入图片描述

记忆状态

记忆状态:这是LSTM的核心,看公式有两部分组成,第一部分是遗忘门的更新,决定对之前的记忆信息“吸取多少”,第二个是结合候选记忆单元结合输入门数据,这个我感觉就是代表者说是对当前的数据输入“吸收多少信息”,用于下一个数据的更新。

这个极端情况下,数据范围是[-2, 2]。

在这里插入图片描述

隐状态

在这里,对当前的记忆**Ct**再一次进行了tanh激活函数的作用,他的用处是将记忆单元数据映射到[-1, 1],然后再结合当前输入,这样当前的输入结合了之前的记忆做了更新,然后输出。

在这里插入图片描述
blog.csdnimg.cn/direct/29dd0a2418c04d2d98866a31dccd52d3.png#pic_center)

总结

LSTM原理具体细节确实复杂,但是我感觉可以结合实践慢慢理解,毕竟小编还是本科生🤠🤠.

在这里插入图片描述

Pytorch实践

pytorchAPI

class torch.nn.LSTM(
    input_size, 
    hidden_size, 
    num_layers=1, 
    bias=True, 
    batch_first=False, 
    dropout=0, 
    bidirectional=False, 
    proj_size=0
)
  • nput_size: 输入特征的数量。
  • hidden_size: 隐藏状态(或输出)特征的数量。
  • num_layers: LSTM 层的数量。默认是 1。
  • bias: 如果为 False,则不会使用偏置项。默认是 True。
  • batch_first: 如果为 True,则输入和输出张量提供给模块的形式为 (batch, seq, feature)。默认是 False,即 (seq, batch, feature)。
  • dropout: 如果非零,则在除了最后一层之外的所有RNN层之后引入一个Dropout层。默认是 0。
  • bidirectional: 如果为 True,则会变成双向LSTM。默认是 False。
  • proj_size: 如果 > 0,则将 LSTM 的隐藏状态投影到这个大小。这有助于减少内存消耗。默认是 0,表示没有投影。

下面将用这个API接口进行搭建一个简单的LSTM网络结构

import torch  
import torch.nn as nn 

class SimpleLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super().__init__()
        # 定义LSTM
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        # 定义线性层
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        list_out, (hh, cn) = self.lstm(x)
        
        # 取最后一个时间步输出
        out = list_out[:, -1, :]
        
        output = self.fc(out)
        
        return output.view(-1, 1, 1)  # 保持维度
    
# 设置参数
input_size = 10 # 输入维度
hidden_size = 20  # 隐藏层维度
num_layers = 2  # LSTM层数
output_size = 1 # 输出维度

# 实例化模型
model = SimpleLSTM(input_size, hidden_size, num_layers, output_size)

# 随机生产数据
# 示例输入(batch_size, seq_len, input_size)
x = torch.randn(5, 15, input_size)

# 模型
out = model(x)

print(out.shape)

输出:

torch.Size([5, 1, 1])

参考资料

LSTM - 长短期记忆递归神经网络 - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2280137.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电气防火保护器为高校学生宿舍提供安全保障

摘 要:3月2日,清华大学紫荆学生公寓发生火情,无人员伤亡。推断起火原因系中厅内通电电器发生故障引燃周边可燃物所致。2月27日,贵州某高校女生宿舍发生火灾,现场明火得到有效控制,无人员受伤。2月19日&…

每打开一个chrome页面都会【自动打开F12开发者模式】,原因是 使用HBuilderX会影响谷歌浏览器的浏览模式

打开 HBuilderX,点击 运行 -> 运行到浏览器 -> 设置web服务器 -> 添加chrome浏览器安装路径 chrome谷歌浏览器插件 B站视频下载助手插件: 参考地址:Chrome插件 - B站下载助手(轻松下载bilibili哔哩哔哩视频&#xff09…

C#使用WMI获取控制面板中安装的所有程序列表

C#使用WMI获取控制面板中安装的所有程序列表 WMI 全称Windows Management Instrumentation,Windows Management Instrumentation是Windows中用于提供共同的界面和对象模式以便访问有关操作系统、设备、应用程序和服务的管理信息。如果此服务被终止,多数基于 Windo…

企业级流程架构设计思路-基于价值链的流程架构

获取更多企业流程资料 纸上得来终觉浅,绝知此事要躬行 一.企业流程分级规则定义 1.流程分类分级的总体原则 2.完整的流程体系需要体现出流程的分类分级 03.通用的流程分级方法 04.流程分级的标准 二.企业流程架构设计原则 1.流程架构设计原则 流程框架是流程体…

PyTorch使用教程(8)-一文了解torchvision

一、什么是torchvision torchvision提供了丰富的功能,主要包括数据集、模型、转换工具和实用方法四大模块。数据集模块内置了多种广泛使用的图像和视频数据集,如ImageNet、CIFAR-10、MNIST等,方便开发者进行训练和评估。模型模块封装了大量经…

如何将自己本地项目开源到github上?

环境: LLMB项目 问题描述: 如何将自己本地项目开源到github上? 解决方案: 步骤 1: 准备本地项目 确保项目整洁 确认所有的文件都在合适的位置,并且项目的 README.md 文件已经完善。检查是否有敏感信息&#xff0…

ConvBERT:通过基于跨度的动态卷积改进BERT

摘要 像BERT及其变体这样的预训练语言模型最近在各种自然语言理解任务中取得了令人印象深刻的性能。然而,BERT严重依赖于全局自注意力机制,因此存在较大的内存占用和计算成本。尽管所有的注意力头都从全局角度查询整个输入序列以生成注意力图&#xff0…

2025web建议

随便收集的信息 新手入门路线推荐 第一步:Web安全相关概念 建议学习时间:2周 学习内容如下: 1、熟悉基本概念(SQL注入、上传、XSS、CSRF、一句话木马等)。 2、通过关键字(SQL注入、上传、XSS、CSRF、一句话木马等)进行Google。 3、阅读《Web…

用JAVA实现人工智能:采用框架Spring AI Java

Spring AI 集成人工智能,为Java项目添加AI功能指南 本文主旨是用实际的可操作的代码,介绍Java怎么通过spring ai 接入大模型。 例子使用spring ai alibaba QWen千问api完成,你可以跑通以后换自己的实现。QWen目前有100万免费Token额度&…

【JDBC】数据库连接的艺术:深入解析数据库连接池、Apache-DBUtils与BasicDAO

文章目录 前言🌍 一.连接池❄️1. 传统获取Conntion问题分析❄️2. 数据库连接池❄️3.连接池之C3P0技术🍁3.1关键特性🍁3.2配置选项🍁3.3使用示例 ❄️4. 连接池之Druid技术🍁 4.1主要特性🍁 4.2 配置选项…

canvas 图片组合并进行下载

运行图片&#xff1a; 思路&#xff1a;先画一个背景图片&#xff0c;再画一个二维码定位到你想要的位置&#xff0c;最后直接下载即可&#xff0c;可以扩散一下思维&#xff0c;画简单的海报的时候&#xff0c;也可以的 源代码 <!DOCTYPE html> <html lang"en&q…

记一次升级请求创建报错问题的调查过程(Windchill)

问题现象描述&#xff1a; ​ 新建申请请求单&#xff0c;在选择某些物料时会报此错误&#xff0c;选另外的物料时又可以正常创建&#xff0c;不报此错误。 问题原因分析&#xff1a; ​ 1.分析后台日志 —没有任何进展&#xff0c;此报错应该是前端的报错 ​ 2.从前端下手…

自旋锁与CAS

上文我们认识了许许多多的锁&#xff0c;此篇我们的CAS就是从上文的锁策略开展的新概念&#xff0c;我们来一探究竟吧 1. 什么是CAS&#xff1f; CAS: 全称Compare and swap&#xff0c;字⾯意思:“比较并交换”&#xff0c;⼀个CAS涉及到以下操作&#xff1a; 我们假设内存中…

线程池遇到未处理的异常会崩溃吗?

线程池中的 execute 和 submit 方法详解 目录 引言execute 方法 使用示例代码 submit 方法 2.1 提交 Callable 任务2.2 提交 Runnable 任务 遇到未处理异常 3.1 execute 方法遇到未处理异常3.2 submit 方法遇到未处理异常 小结 引言 在多线程编程中&#xff0c;线程池是提高性…

2024年第十五届蓝桥杯青少组国赛(c++)真题—快速分解质因数

快速分解质因数 完整题目和在线测评可点击下方链接前往&#xff1a; 快速分解质因数_C_少儿编程题库学习中心-嗨信奥https://www.hixinao.com/tiku/cpp/show-3781.htmlhttps://www.hixinao.com/tiku/cpp/show-3781.html 若如其他赛事真题可自行前往题库中心查找&#xff0c;题…

Linux内核编程(二十一)USB驱动开发

一、驱动类型 USB 驱动开发主要分为两种&#xff1a;主机侧的驱动程序和设备侧的驱动程序。一般我们编写的都是主机侧的USB驱动程序。 主机侧驱动程序用于控制插入到主机中的 USB 设备&#xff0c;而设备侧驱动程序则负责控制 USB 设备如何与主机通信。由于设备侧驱动程序通常与…

AI Agent:深度解析与未来展望

一、AI Agent的前世&#xff1a;从概念到萌芽 &#xff08;一&#xff09;早期探索 AI Agent的概念可以追溯到20世纪50年代&#xff0c;早期的AI研究主要集中在简单的规则系统上&#xff0c;这些系统的行为是确定性的&#xff0c;输出由输入决定。随着时间的推移&#xff0c;…

SuperMap iClient3D for WebGL选中抬升特效

在大屏展示系统中&#xff0c;对行政区划数据制作了立体效果&#xff0c;如果希望选中某一行政区划进行重点介绍&#xff0c;目前常见的方式是通过修改选中对象色彩、边线等方式进行实现&#xff1b;这里提供另外一种偏移动效的思路&#xff0c;并提供下钻功能&#xff0c;让地…

领域算法 - 字符串匹配算法

字符串匹配算法 文章目录 字符串匹配算法一&#xff1a;KMP算法1&#xff1a;算法概述2&#xff1a;部分匹配表3&#xff1a;算法实现 二&#xff1a;Moore算法1&#xff1a;算法概述2&#xff1a;代码实现3&#xff1a;完整实现 三&#xff1a;马拉车算法1&#xff1a;算法概述…

小红书用户作品列表 API 系列,返回值说明

item_search_shop_video-获得某书用户作品列表 公共参数 名称类型必须描述keyString是调用key&#xff08;必须以GET方式拼接在URL中&#xff09;secretString是调用密钥api_nameString是API接口名称&#xff08;包括在请求地址中&#xff09;[item_search,item_get,item_sea…