【深度学习】LSTM、BiLSTM详解

news2024/12/25 10:08:10

文章目录

  • 1. LSTM简介:
  • 2. LSTM结构图:
  • 3. 单层LSTM详解
  • 4. 双层LSTM详解
  • 5. BiLSTM
  • 6. Pytorch实现LSTM示例
  • 7. nn.LSTM参数详解

1. LSTM简介:

    LSTM是一种循环神经网络,它可以处理和预测时间序列中间隔和延迟相对较长的重要事件。LSTM通过使用门控单元来控制信息的流动,从而缓解RNN中的梯度消失和梯度爆炸的问题。LSTM的核心是三个门:输入门遗忘门输出门

遗忘门: 遗忘门的作用是决定哪些信息从记忆单元中遗忘,它使用sigmoid激活函数,可以输出在0到1之间的值,可以理解为保留信息的比例。
输入门: 作用是决定哪些新信息被存储在记忆单元中
输出门: 输出门决定了下一个隐藏状态,即生成当前时间步的输出并传递到下一时间步
记忆单元:负责长期信息的存储,通过遗忘门和输入门的相互作用,记忆单元能够学习如何选择性地记住或忘记信息

2. LSTM结构图:

在这里插入图片描述

涉及到的计算公式如下:
在这里插入图片描述

3. 单层LSTM详解

(1)设定有3个字的序列【“早”“上”“好”】要经过LSTM处理,每个序列由20个元素组成的列向量构成,所以input size就为20。

(2)设定全连接层中有100个隐藏单元,LSTM的层数为1。

(3)因为是3个字的序列,所以LSTM需要3个时间步(即会自循环3次)才能处理完这个序列。

(4)nn.LSTM()每层也可以拆开写,这样每层的隐藏单元个数就可以分别设定。

在这里插入图片描述
    LSTM单元包含三个输入参数x、c、h;首先t1时刻作为第一个时间步,输入到第一个LSTM单元中,此时输入的初始从c(0)和h(0)都是0矩阵,计算完成后,第一个LSTM单元输出一组h(t1)\c(t1),作为本层LSTM的第二个时间步的输入参数;因此第二个时间步的输入就是h(t1),c(t1),x(t2),而输出是h(t2),c(t2);因此第三个时间步的输入就是h(t2),c(t2),x(t3),而输出是h(t3),c(t3)。

4. 双层LSTM详解

(1)设定有3个字的序列【“早”“上”“好”】要经过LSTM处理,每个序列由20个元素组成的列向量构成,所以input size就为20。

(2)设定全连接层中有100个隐藏单元,LSTM的层数为2。

(3)因为是3个字的序列,所以LSTM需要3个时间步(即会自循环3次)才能处理完这个序列。

(4)nn.LSTM()每层也可以拆开写,这样每层的隐藏单元个数就可以分别设定。

在这里插入图片描述
    第二层LSTM没有输入参数x(t1)、x(t2)、x(t3);所以我们将第一层LSTM输出的h(t1)、h(t2)、h(t3)作为第二层LSTM的输入x(t1)、x(t2)、x(t3)。第一个时间步输入的初始c(0)和h(0)都为0矩阵,计算完成后,第一个时间步输出新的一组h(t1)、c(t1),作为本层LSTM的第二个时间步的输入参数;因此第二个时间步的输入就是h(t1),c(t1),x(t2),而输出是h(t2),c(t2);因此第三个时间步的输入就是h(t2),c(t2),x(t3),而输出是h(t3),c(t3)。

5. BiLSTM

单层的BiLSTM其实就是2个LSTM,一个正向去处理序列,一个反向去处理序列,处理完后,两个LSTM的输出会拼接起来。
在这里插入图片描述

6. Pytorch实现LSTM示例

import torch 
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        super(LSTM, self).__init__()
        self.hidden_dim = hidden_dim  # 隐藏层维度
        self.num_layers = num_layers  # LSTM层的数量

        # LSTM网络层
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        # 全连接层,用于将LSTM的输出转换为最终的输出维度
        self.fc = nn.Linear(hidden_dim, output_dim)
        
        
    def forward(self, x):
        # 初始化隐藏状态和细胞状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)
        # 前向传播LSTM,返回输出和最新的隐藏状态与细胞状态
        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
        # 将LSTM的最后一个时间步的输出通过全连接层
        out = self.fc(out[:, -1, :])
        return out

7. nn.LSTM参数详解

pytorch官方定义:

CLASS torch.nn.LSTM(
        input_size,
        hidden_size,
        num_layers=1,
        bias=True,
        batch_first=False,
        dropout=0.0,
        bidirectional=False,
        proj_size=0,
        device=None,
        dtype=None
    )

input_size – 输入 x 中预期的特征数量
hidden_size – 隐藏状态 h 中的特征数量
num_layers – 循环层的数量。例如,设置 num_layers=2 表示将两个 LSTM 堆叠在一起形成一个 stacked LSTM,其中第二个 LSTM 接收第一个 LSTM 的输出并计算最终结果。默认值:1
bias – 如果 False,则该层不使用偏差权重 b_ih 和 b_hh。默认值:True
batch_first – 如果 True,则输入和输出张量将以 (batch, seq, feature) 而不是 (seq, batch, feature) 的形式提供。请注意,这并不适用于隐藏状态或单元状态。有关详细信息,请参见下面的输入/输出部分。默认值:False
dropout – 如果非零,则在除最后一层之外的每个 LSTM 层的输出上引入一个 Dropout 层,其 dropout 概率等于 dropout。默认值:0
bidirectional – 如果 True,则变为双向 LSTM。默认值:False
proj_size – 如果 > 0,则将使用具有相应大小的投影的 LSTM。默认值:0

对于输入序列每一个元素,每一层都会进行以下计算:
在这里插入图片描述
网络输入:
在这里插入图片描述

网络输出:
在这里插入图片描述

本文参考:https://blog.csdn.net/qq_34486832/article/details/134898868
https://pytorch.ac.cn/docs/stable/generated/torch.nn.LSTM.html#

LSTM每层的输出都要经过全连接层吗,还是直接对隐藏层进行输出?
通过在代码中对lstm的输出进行print输出:

import torch 
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        super(LSTM, self).__init__()
        self.hidden_dim = hidden_dim  # 隐藏层维度
        self.num_layers = num_layers  # LSTM层的数量

        # LSTM网络层
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        # 全连接层,用于将LSTM的输出转换为最终的输出维度
        self.fc = nn.Linear(hidden_dim, output_dim)
        
        
    def forward(self, x):
        # 初始化隐藏状态和细胞状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)
        # 前向传播LSTM,返回输出和最新的隐藏状态与细胞状态
        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
        print(out)
        print(hn)
        print(cn)
        # 将LSTM的最后一个时间步的输出通过全连接层
        out = self.fc(out[:, -1, :])
        return out
if __name__ == "__main__":
	input_dim = 3        # 输入特征的维度
	hidden_dim = 4       # 隐藏层的维度
	num_layers = 1       # LSTM 层的数量
	output_dim = 1       # 输出特征的维度
	    
	lstm = LSTM(input_dim, hidden_dim, num_layers, output_dim).to(device)
	batch_size = 1
	seq_length = 10
	input_tensor = torch.randn(batch_size, seq_length, input_dim).to(device)
	output = lstm(input_tensor)

通过对LSTM网络的输出我们可以看到,out的最后一层与最后一层隐藏层hn一致,说明并未经过全连接层,而是直接输出隐藏层
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2241606.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Queuing 表(buffer表)的优化实践 | OceanBase 性能优化实践

案例问题描述 该案例来自一个金融行业客户的问题:他们发现某个应用对一个数据量相对较小的表(仅包含数千条记录)访问时,频繁遇到性能下降的情况。为解决此问题,客户向我们求助进行分析。我们发现这张表有频繁的批量插…

【视觉SLAM】4b-特征点法估计相机运动之PnP 3D-2D

文章目录 1 问题引入2 求解P3P 1 问题引入 透视n点(Perspective-n-Point,PnP)问题是计算机视觉领域的经典问题,用于求解3D-2D的点运动。换句话说,当知道n个3D空间点坐标以及它们在图像上的投影点坐标时,可…

SpringBoot多环境+docker集成企业微信会话存档sdk

SpringBoot多环境docker集成企业微信会话存档sdk 文章来自于 https://developer.work.weixin.qq.com/community/article/detail?content_id16529801754907176021 SpringBoot多环境docker集成企业微信会话存档sdk 对于现在基本流行的springboot环境,官方文档真是比…

DAY64||dijkstra(堆优化版)精讲 ||Bellman_ford 算法精讲

dijkstra(堆优化版)精讲 题目如上题47. 参加科学大会(第六期模拟笔试) 邻接表 本题使用邻接表解决问题。 邻接表的优点: 对于稀疏图的存储,只需要存储边,空间利用率高遍历节点链接情况相对容…

在openi平台 基于华为顶级深度计算平台 openmind 动手实践

大家可能一直疑问,到底大模型在哪里有用。 本人从事的大模型有几个方向的业务。 基于生成式语言模型的海事航行警告结构化解析。 基于生成式语言模型的航空航行警告结构化解析。 基于生成式生物序列(蛋白质、有机物、rna、dna、mrna)的多模态…

Figma汉化:提升设计效率,降低沟通成本

在UI设计领域,Figma因其强大的功能而广受欢迎,但全英文界面对于国内设计师来说是一个不小的挑战。幸运的是,通过Figma汉化插件,我们可以克服语言障碍。以下是两种获取和安装Figma汉化插件的方法,旨在帮助国内的UI设计师…

深度学习-卷积神经网络CNN

案例-图像分类 网络结构: 卷积BN激活池化 数据集介绍 CIFAR-10数据集5万张训练图像、1万张测试图像、10个类别、每个类别有6k个图像,图像大小32323。下图列举了10个类,每一类随机展示了10张图片: 特征图计算 在卷积层和池化层结束后, 将特征…

关于adb shell登录开发板后terminal显示不完整

现象 今天有个同事跟我说,adb shell 登录开发板后,终端显示不完整,超出边界后就会出现奇怪的问题,比如字符覆盖显示等。如下图所示。 正常情况下应该如下图所示: 很明显,第一张图的显示区域只有完整区域…

【论文分享】三维景观格局如何影响城市居民的情绪

城市景观对居民情绪的影响是近些年来讨论的热门话题之一,现有的研究主要以遥感影像为数据来源,进行二维图像-数据分析,其量化结果精确度有限。本文引入了三维景观格局的研究模型,通过街景图片及网络发帖信息补充图像及数据来源&am…

ChatGPT学术专用版,一键润色纠错+中英互译+批量翻译PDF

ChatGPT academic项目是由中科院团队基于ChatGPT专属定制。论文润色、语法检查、中英互译、代码解释等可一键搞定,堪称科研神器。 功能介绍 我们以3.5版本为例,ChatGPT学术版总共分为五个区域:输入控制区、输出对话区、基础功能区、函数插件…

Go 语言已立足主流,编程语言排行榜24 年 11 月

Go语言概述 Go语言,简称Golang,是由Google的Robert Griesemer、Rob Pike和Ken Thompson在2007年设计,并于2009年11月正式宣布推出的静态类型、编译型开源编程语言。Go语言以其提高编程效率、软件构建速度和运行时性能的设计目标,…

一、HTML

一、基础概念 1、浏览器相关知识 这五个浏览器市场份额都非常大,且都有自己的内核。 什么是内核: 内核是浏览器的核心,用于处理浏览器所得到的各种资源。 例如,服务器发送图片、视频、音频的资源,浏览…

VRRP HSRP GLBP 三者区别

1. VRRP(Virtual Router Redundancy Protocol,虚拟路由冗余协议) 标准协议:VRRP 是一种开放标准协议(RFC 5798),因此支持的厂商较多,通常用于多种网络设备中。主备模式:…

Elasticsearch:管理和排除 Elasticsearch 内存故障

作者:来自 Elastic Stef Nestor 随着 Elastic Cloud 提供可观察性、安全性和搜索等解决方案,我们将使用 Elastic Cloud 的用户范围从完整的运营团队扩大到包括数据工程师、安全团队和顾问。作为 Elastic 支持代表,我很乐意与各种各样的用户和…

Java集合(Collection+Map)

Java集合&#xff08;CollectionMap&#xff09; 为什么要使用集合&#xff1f;泛型 <>集合框架单列集合CollectionCollection遍历方式List&#xff1a;有序、可重复、有索引ArrayListLinkedListVector&#xff08;已经淘汰&#xff0c;不会再用&#xff09; Set&#xf…

大数据如何助力干部选拔的公正性

随着社会的发展和进步&#xff0c;干部选拔成为组织管理中至关重要的一环。传统的选拔方式可能存在主观性、不公平性以及效率低下等问题。大数据技术的应用&#xff0c;为干部选拔提供了更加全面、精准、客观的信息支持&#xff0c;显著提升选拔工作的科学性和公正性。以下是大…

EHOME视频平台EasyCVR多品牌摄像机视频平台监控视频编码H.265与Smart 265的区别?

在视频监控领域&#xff0c;技术的不断进步推动着行业向更高效、更智能的方向发展。特别是在编码技术方面&#xff0c;Smart 265作为一种新型的视频编码技术&#xff0c;相较于传统的H.265&#xff0c;有明显优势。这种技术的优势在EasyCVR视频监控汇聚管理平台中得到了充分的体…

Docker:查看镜像里的文件

目录 背景步骤1、下载所需要的docker镜像2、创建并运行临时容器3、停止并删除临时容器 背景 在开发过程中&#xff0c;为了更好的理解和开发程序&#xff0c;有时需要确认镜像里的文件是否符合预期&#xff0c;这时就需要查看镜像内容 步骤 1、下载所需要的docker镜像 可以使…

【Vitepress报错】Error: [vitepress] 8 dead link(s) found.

原因 VitePress 在编译时&#xff0c;发现 死链接(dead links) 会构建失败&#xff01;具体在哪我也找不到… 解决方案 如图第一行蓝色提示信息&#xff0c;设置 Vitepress 属性 ignoredeadlinks 为 true 可忽略报错。 .vuepress/config.js export default defineConfig(…

HTB:Squashed[WriteUP]

目录 连接至HTB服务器并启动靶机 使用rustscan对靶机TCP端口进行开放扫描 使用nmap对靶机开放端口进行脚本、服务扫描 使用浏览器访问靶机80端口页面 使用showmount列出靶机上的NFS共享 新建一个test用户 使用Kali自带的PHP_REVERSE_SHELL并复制到一号挂载点 尝试使用c…