动手学深度学习(Pytorch版)代码实践 -循环神经网络-57长短期记忆网络(LSTM)

news2024/9/25 9:30:18

57长短期记忆网络(LSTM

1.LSTM原理

LSTM是专为解决标准RNN的长时依赖问题而设计的。标准RNN在训练过程中,随着时间步的增加,梯度可能会消失或爆炸,导致模型难以学习和记忆长时间间隔的信息。LSTM通过引入一组称为门的机制来解决这个问题:

  1. 输入门(Input Gate):控制有多少新的信息可以传递到记忆单元中。
  2. 遗忘门(Forget Gate):控制当前记忆单元中有多少信息会被保留。
  3. 输出门(Output Gate):控制记忆单元的输出有多少被传递到下一步。

LSTM还引入了一个称为记忆单元(Cell State)的概念,用于携带长期信息。这些门的组合使得LSTM能够选择性地记住或遗忘信息,从而解决了长时依赖问题。
在这里插入图片描述
在这里插入图片描述

2.优点
  1. 解决梯度消失问题:通过门控机制,LSTM能够有效地传递梯度,避免了梯度消失和爆炸的问题。
  2. 捕捉长时依赖LSTM能够记住和利用长时间间隔的信息,这是标准RNN难以做到的。
  3. 灵活性LSTM适用于各种序列数据处理任务,如时间序列预测、语言建模和序列到序列的翻译等。
3.LSTMGRU的区别

GRU(门控循环单元)是另一种解决长时依赖问题的RNN变体。GRULSTM都引入了门控机制,但它们的具体实现有所不同。

  1. 结构简化GRU的结构比LSTM更简单,参数更少,计算效率更高。
  2. 性能对比:在一些任务上,GRULSTM的性能相当,但在某些情况下,GRU可能表现更好,特别是在较小的数据集或较短的序列上。
  3. 门的数量LSTM有三个门(输入门、遗忘门和输出门),而GRU只有两个门(更新门和重置门)。
4.LSTM代码实践
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

# 设置批量大小和序列步数
batch_size, num_steps = 32, 35
# 加载时间机器数据集
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

# 初始化LSTM模型参数
def get_lstm_params(vocab_size, num_hiddens, device):
    # 输入输出的维度大小
    num_inputs = num_outputs = vocab_size

    # 正态分布初始化权重
    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01

    # 三个权重参数(用于输入门、遗忘门、输出门和候选记忆元)
    def three():
        return (normal((num_inputs, num_hiddens)),  # 输入到隐藏状态的权重
                normal((num_hiddens, num_hiddens)),  # 隐藏状态到隐藏状态的权重
                torch.zeros(num_hiddens, device=device))  # 偏置

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))  # 隐藏状态到输出的权重
    b_q = torch.zeros(num_outputs, device=device)  # 输出偏置
    # 将所有参数附加到参数列表中
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)  # 设置参数需要梯度
    return params

# 初始化LSTM的隐藏状态
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),  # 隐藏状态
            torch.zeros((batch_size, num_hiddens), device=device))  # 记忆元

# LSTM前向传播
def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state  # 隐藏状态和记忆元
    outputs = []
    for X in inputs:
        # 输入门
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        # 遗忘门
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        # 输出门
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        # 候选记忆元
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        # 更新记忆元
        C = F * C + I * C_tilda
        # 更新隐藏状态
        H = O * torch.tanh(C)
        # 计算输出
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)  # 返回输出和状态

# 训练和预测模型
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
# 创建自定义的LSTM模型
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.3, 34433.0 tokens/sec on cuda:0
# 预测结果示例:time traveller conellace there wardeal that are almost us we hou

# 使用PyTorch的简洁实现
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)  # 创建LSTM层
model = d2l.RNNModel(lstm_layer, len(vocab))  # 创建模型
model = model.to(device)  # 将模型移动到GPU
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.0, 317323.7 tokens/sec on cuda:0
# 预测结果示例:time travelleryou can show black is white by argument said filby

自定义的LSTM模型:

在这里插入图片描述
简洁实现:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1914902.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

rk3588s 定制版 tc358775 调试 lvds 屏幕 (第一部分)

硬件: 3588s 没有 lvds 接口 , 所以使用的 东芝的 tc358774 (mipi ---> lvds芯片), 这个芯片是参考 3399 的 官方设计得来的,3399 的官方demo 板上应该是 使用到了 这颗芯片 参考资料: 1 网上的 GM8775C 转换芯片。 2 瑞芯微的 3588s 的资料 总体的逻辑: 1 3588s…

25届近5年中国民航大学自动化考研院校分析

中国民航大学 目录 一、学校学院专业简介 二、考试科目指定教材 三、近5年考研分数情况 四、近5年招生录取情况 五、最新一年分数段图表 六、初试大纲复试大纲 七、学费&奖学金&就业方向 一、学校学院专业简介 二、考试科目指定教材 1、考试科目介绍 2、指定教…

centos系统查找mysql的配置文件位置

执行命令查找mysql的安装目录: which mysql cd进入mysql的安装目录 cd /usr/bin 查找配置文件位置 ./mysql --help | grep "my.cnf" 定位配置文件 cd /etc 查找命令还可以用find命令 find / -name "my.cnf"

Docker 部署 ShardingSphere-Proxy 数据库中间件

文章目录 Github官网文档ShardingSphere-Proxymysql-connector-java 驱动下载conf 配置global.yamldatabase-sharding.yamldockerdocker-compose.yml Apache ShardingSphere 是一款分布式的数据库生态系统, 可以将任意数据库转换为分布式数据库,并通过数…

绿盟培训入侵排查

一、webshell 排查 1、文件特征 2、windows 3、linux 4、内存马 二、web 日志排查 1、日志排查 2、中间件报错排查 三、服务器失陷处置

Linux常用选项和指令

目录 Linux指令使用注意 用户创建与删除 ls指令 ls指令介绍 ls常见选项 ls选项组合使用 pwd指令 Linux文件系统结构 多叉树结构文件系统介绍 多叉树结构文件系统的特点 cd指令 绝对路径 相对路径 cd指令介绍 家户目录 最近访问的目录 touch指令 ​编辑mkdir指…

【HarmonyOS NEXT】鸿蒙 代码混淆

代码混淆简介 针对工程源码的混淆可以降低工程被破解攻击的风险,缩短代码的类与成员的名称,减小应用的大小。 DevEco Studio提供代码混淆的能力并默认开启,API 10及以上版本的Stage模型、编译模式为release时自动进行代码混淆。 使用约束 …

【中项第三版】系统集成项目管理工程师 | 第 10 章 启动过程组

前言 第10章对应的内容选择题和案例分析都会进行考查,这一章节属于10大管理的内容,学习要以教材为准。本章上午题分值预计在2分。 目录 10.1 制定项目章程 10.1.1 主要输入 10.1.2 主要输出 10.2 识别干系人 10.2.1 主要输入 10.2.2 主要工具与技…

解决:WPS,在一个表格中,按多次换行,无法换到下一页

现象:在一个表格里面,多次按下回车,始终无法到下一页 解决方法:右击—>表格属性—>选择行—>勾选 允许跨页断行 效果演示 对比展示

vulnhub-NOOB-1

确认靶机 扫描靶机发现ftp Anonymous 的A大小写都可以 查看文件 解密 登录网页 点击about us会下载一个压缩包 使用工具提取 steghide info 目标文件 //查看隐藏信息 steghide extract -sf 目标文件 //提取隐藏的文件 steghide embed -cf 隐藏信息的文件 -ef…

【AI大模型新型智算中心技术体系深度分析 2024】

文末有福利! ChatGPT 系 列 大 模 型 的 发 布, 不 仅 引 爆 全 球 科 技 圈, 更 加 夯 实 了 人 工 智 能(Artificial Intelligence, AI)在未来改变人类生产生活方式、引发社会文明和竞争力代际跃迁的战略性地位。当…

CephFS文件系统存储服务

目录 1.创建 CephFS 文件系统 MDS 接口 服务端操作 1.1 在管理节点创建 mds 服务 1.2 创建存储池,启用 ceph 文件系统 1.3 查看mds状态,一个up,其余两个待命,目前的工作的是node02上的mds服务 1.4 创建用户 客户端操作 1.5…

【割点 C++BFS】2556. 二进制矩阵中翻转最多一次使路径不连通

本文涉及知识点 割点 图论知识汇总 CBFS算法 LeetCode2556. 二进制矩阵中翻转最多一次使路径不连通 给你一个下标从 0 开始的 m x n 二进制 矩阵 grid 。你可以从一个格子 (row, col) 移动到格子 (row 1, col) 或者 (row, col 1) ,前提是前往的格子值为 1 。如…

【论文阅读】Characterization of Large Language Model Development in the Datacenter

26.Characterization of Large Language Model Development in the Datacenter 出处: NSDI-2024 数据中心中大型语言模型开发的表征InternLM/AcmeTrace (github.com) 摘要 大语言模型(LLMs)在许多任务中表现出色。然而,要高效利用大规模集…

深入了解代理IP常见协议:区别与选择

代理服务器在网络使用中扮演着重要的角色,是您设备和互联网之间的中间层。它不仅可以增强网络访问的安全性和隐私保护,还可以提供许多灵活的应用。使用代理时,不同的协议类型对数据交换具有不同的规则和特征。常见的代理协议包括HTTP代理、HT…

什么样的开放式耳机好用舒服?南卡、倍思、Oladance高人气质量绝佳产品力荐!

​开放式耳机在如今社会中已经迅速成为大家购买耳机的新趋势,深受喜欢听歌和热爱运动的人群欢迎。当大家谈到佩戴的稳固性时,开放式耳机都会收到一致好评。对于热爱运动的人士而言,高品质的开放式耳机无疑是理想之选。特别是在近年来的一些骑…

有什么语音转文字免费的方法?7个软件教你快速的转换文件

有什么语音转文字免费的方法?7个软件教你快速的转换文件 将语音转化为文字是一项常见的需求,尤其是在需要记录会议、采访或演讲内容时。以下是七款免费且实用的语音转文字软件,它们各具特色,适合不同需求和用户水平。 迅捷文字识…

【正点原子i.MX93开发板试用连载体验】简单的音频分类

本文最早发表于电子发烧友论坛: 今天测试的内容是进行简单的音频分类。我们要想进行语音控制,就需要构建和训练一个基本的自动语音识别 (ASR) 模型来识别不同的单词。如果想了解这方面的知识可以参考TensorFlow的官方文档:简单的音频识别&…

DDoS攻击详解

DDoS 攻击,其本质是通过操控大量的傀儡主机或者被其掌控的网络设备,向目标系统如潮水般地发送海量的请求或数据。这种行为的目的在于竭尽全力地耗尽目标系统的网络带宽、系统资源以及服务能力,从而致使目标系统无法正常地为合法用户提供其所应…

光学、SAR卫星影像助力洞庭湖决堤抢险(附带数据下载)

​​ 点击下方全系列课程学习 点击学习—>ArcGIS全系列实战视频教程——9个单一课程组合系列直播回放 点击学习——>遥感影像综合处理4大遥感软件ArcGISENVIErdaseCognition 7月5日下午,湖南岳阳市华容县团洲乡团北村团洲垸洞庭湖一线堤防发生决口&#xff0…