一文读懂LSTM及手写LSTM结构

news2025/1/9 15:32:36

        `torch.nn.LSTM`是PyTorch中用于创建长短时记忆网络(Long Short-Term Memory)的类。LSTM是一种用于处理序列数据的循环神经网络(Recurrent Neural Network,RNN)变体。

官方给出的LSTM API 文档 

 以下是 `torch.nn.LSTM` 的主要参数(用于配置和定制 LSTM 层的行为):

1. `input_size`(必需参数):输入数据的特征维度大小。这是输入序列的特征向量的维度。

2. `hidden_size`(必需参数):LSTM 单元的隐藏状态的维度大小。这决定了 LSTM 层的输出和内部隐藏状态的维度。

3. `num_layers`(可选参数,默认为 1):LSTM 层的堆叠层数。你可以将多个 LSTM 层叠加在一起,以增加模型的容量和表示能力。

4. `bias`(可选参数,默认为 True):一个布尔值,确定是否在 LSTM 单元中包含偏置项。

5. `batch_first`(可选参数,默认为 False):一个布尔值,指定输入数据的形状。如果设置为 True,输入数据的形状应为 `(batch_size, sequence_length, input_size)`,否则为 `(sequence_length, batch_size, input_size)`。

6. `dropout`(可选参数,默认为 0.0):应用于除最后一层外的每个 LSTM 层的丢弃率。这有助于防止过拟合。

7. `bidirectional`(可选参数,默认为 False):一个布尔值,指定是否使用双向 LSTM。如果设置为 True,LSTM 将具有前向和后向的隐藏状态,以更好地捕捉序列的上下文信息。

8. `batch_first`(可选参数,默认为 False):一个布尔值,用于指定输入数据的形状。如果设置为 True,则输入数据应为 `(batch_size, sequence_length, input_size)`,否则为 `(sequence_length, batch_size, input_size)`。

9. `device`(可选参数):指定要在哪个设备上创建 LSTM 层,例如 CPU 或 GPU。

10. `dtype`(可选参数):指定数据类型,例如 `torch.float32` 或 `torch.float64`。

11. `return_sequences`(可选参数,默认为 False):一个布尔值,指定是否返回每个时间步的输出序列。如果设置为 True,则返回完整的输出序列;否则,只返回最后一个时间步的输出。

        这些参数允许你根据具体的任务和模型架构来配置 LSTM 层。根据你的需求,你可以灵活地选择不同的参数值来构建不同的 LSTM 模型。

LSTM的输入

 

`torch.nn.LSTM` 层的输入通常是一个包含两个元素的元组 `(input, (h_0, c_0))`,调用方法为:

output, (h_n, c_n) = torch.nn.LSTM(input, (h_0,c_0))

其中:

(1)        

        input 通常是一个三维张量,具体形状取决于是否设置了 `batch_first` 参数。输入张量包括以下维度:

1. 批量维度(Batch Dimension):这是数据中的样本数量。如果 `batch_first` 设置为 True,那么批量维度将是第一个维度;否则,批量维度将是第二个维度。

2. 序列长度维度(Sequence Length Dimension):这是时间步的数量,也是序列的长度。它是输入序列中数据点的数量。

3. 特征维度(Feature Dimension):这是输入数据点的特征数量。它表示每个时间步的输入特征向量 xt 的维度。

根据上述描述,以下是两种常见的输入形状:

- 如果 `batch_first` 为 True:
    - 输入张量的形状为 `(batch_size, sequence_length, input_size)`。
    - `batch_size` 是批量大小,表示同时处理的样本数量。
    - `sequence_length` 是序列的长度,即时间步的数量。
    - `input_size` 是输入特征向量的维度。

- 如果 `batch_first` 为 False:
    - 输入张量的形状为 `(sequence_length, batch_size, input_size)`。
    - `sequence_length` 是序列的长度,即时间步的数量。
    - `batch_size` 是批量大小,表示同时处理的样本数量。
    - `input_size` 是输入特征向量的维度。

        要注意的是,这只是输入的形状,LSTM 层的参数(例如 `input_size` 和 `hidden_size`)必须与输入形状相匹配。根据你的具体任务和数据,你需要将输入数据整理成适当形状的张量,然后将其传递给 `torch.nn.LSTM` 层以进行前向传播。

(2)

  `(h_0, c_0)`:是包含初始隐藏状态和初始细胞状态的元组。
   - `h_0`:是初始隐藏状态,其形状为 `(num_layers * num_directions, batch_size, hidden_size)`。`num_layers` 是 LSTM 层的堆叠层数,`num_directions` 是 1 或 2,取决于是否使用双向 LSTM。
   - `c_0`:是初始细胞状态,其形状也为 `(num_layers * num_directions, batch_size, hidden_size)`。

LSTM的输出

`torch.nn.LSTM` 层的输出通常是一个包含两个元素的元组 `(output, (h_n, c_n))`,其中:

1. `output`:是一个包含每个时间步的 LSTM 输出的张量。其形状为 `(batch_size, sequence_length, num_directions * hidden_size)`【batch_first = True的情况下】,其中:
   - `sequence_length` 是序列的长度,即时间步的数量。
   - `batch_size` 是批量大小,表示同时处理的样本数量。
   - `num_directions` 是 1 或 2,取决于是否使用双向(bidirectional)LSTM。
   - `hidden_size` 是 LSTM 单元的隐藏状态的维度大小。

2. `(h_n, c_n)`:是包含最后一个时间步的隐藏状态和细胞状态的元组。
   - `h_n`:是最后一个时间步的隐藏状态,其形状为 `(num_layers * num_directions, batch_size, hidden_size)`。`num_layers` 是 LSTM 层的堆叠层数,`num_directions` 是 1 或 2,取决于是否使用双向 LSTM。
   - `c_n`:是最后一个时间步的细胞状态,其形状也为 `(num_layers * num_directions, batch_size, hidden_size)`。

        你可以选择是否要使用输出中的全部时间步的输出,或者只使用最后一个时间步的输出,具体取决于你的任务需求。

        通常,如果你只关心最终的输出,你可以使用 `output[-1]` 或 `h_n`。如果你需要完整的时间步输出序列,可以使用 `output`。这些输出可以传递到其他层或用于任务的后续处理。

 LSTM的权重参数

`torch.nn.LSTM`具有以下主要的权重参数(用于捕捉序列中的长期依赖关系):

1. `weight_ih_l[k]`:这是输入到LSTM单元的权重参数,其中k表示LSTM层的索引。`weight_ih_l[k]`的维度是(4 * hidden_size,input_size),其中hidden_size是LSTM隐藏状态的大小,input_size是输入数据的特征维度。这个权重参数控制着输入数据如何影响LSTM单元的状态。

2. `weight_hh_l[k]`:这是隐藏状态到LSTM单元的权重参数,其中k表示LSTM层的索引。`weight_hh_l[k]`的维度是(4 * hidden_size,hidden_size)。这个权重参数控制着前一个时间步的隐藏状态如何影响当前时间步的隐藏状态。

3. `bias_ih_l[k]`和`bias_hh_l[k]`:这是输入到LSTM单元和隐藏状态到LSTM单元的偏置参数,其中k表示LSTM层的索引。`bias_ih_l[k]`的维度是(4 * hidden_size),`bias_hh_l[k]`的维度也是(4 * hidden_size)。这些偏置参数用于调整输入和隐藏状态的影响。

以上权重参数中的4表示LSTM单元的门控机制,通常被称为输入门(input gate)、遗忘门(forget gate)、输出门(output gate)和细胞状态(cell state)。LSTM使用这些门来控制信息的流动,以捕捉长期依赖关系。

        要访问和修改这些权重参数,您可以使用`state_dict`属性来获取或设置模型的权重。例如,如果您有一个名为`lstm_model`的`torch.nn.LSTM`模型,您可以使用以下代码来获取权重参数的字典:lstm_weights = lstm_model.state_dict()。然后,您可以从`lstm_weights`字典中提取和修改特定的权重参数。请注意,修改权重参数可能会影响模型的性能,因此需要谨慎操作。

你还可以使用:

for k, v in lstm_model.named_parameters():
    print(k, v) # 打印权重参数名称及数值

方法得到模型的权重参数。

代码部分

        下述代码包括了官方API以及手写的LSTM源码。 


# 视频链接:
# https://www.bilibili.com/video/BV1zq4y1m7aH/?spm_id_from=333.788&vd_source=fb7bfda367c76676e2483b9b60485e57

# 实现LSTM 源码
# 定义常量
import torch
import torch.nn as nn
batch_size, T, input_size, hidden_size = 2, 3, 4, 5


input = torch.randn(batch_size, T, input_size)
c_0 = torch.randn(batch_size, hidden_size) # 初始细胞单元,不参与网络训练
h_0 = torch.randn(batch_size, hidden_size) # 初始隐藏状态

# 调用官方API
lstm_layer = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)
output, (h_n, c_n) = lstm_layer(input, (h_0.unsqueeze(0), c_0.unsqueeze(0)))
print("LSTM API")
print("output:\n", output)
print("h_n:\n", h_n)
print("c_n:\n", c_n)

# for k, v in lstm_layer.named_parameters():
#     print(k, v)
lstm_weight = lstm_layer.state_dict() # 使用`state_dict`属性来获取或设置模型的权重
print("lstm_weight:\n", lstm_weight)

# 自己写一个LSTM模型
def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh):
    """

    :param input:
    :param initial_states:
    :param w_ih:
    :param w_hh:
    :param b_ih:
    :param b_hh:
    :return:
    """
    h_0, c_0 = initial_states # 初始状态
    batch_size, T, input_size = input.shape
    hidden_size = w_ih.shape[0] // 4
    prev_h = h_0
    prev_c = c_0

    batch_w_ih = w_ih.unsqueeze(0).tile(batch_size, 1, 1) # [batch_size, 4*hidden_size, input_size]
    batch_w_hh = w_hh.unsqueeze(0).tile(batch_size, 1, 1) # [batch_size, 4*hidden_size, hidden_size]
    output_size = hidden_size
    output = torch.zeros(batch_size, T, output_size) # 输出序列

    for t in range(T):
        x = input[:, t, :] # 当前时刻的输入向量,[batch_size*input_size]
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1)) # [batch_size, 4*hidden_size, 1]
        w_times_x = w_times_x.squeeze(-1) # [batch_size, 4*hidden_size]

        w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1)) # [batch_size, 4*hidden_size, 1]
        w_times_h_prev = w_times_h_prev.squeeze(-1)  # [batch_size, 4*hidden_size]

        # 分别计算输入门(i)、遗忘门(f)、cell(g)、输出门(o)
        i_t = torch.sigmoid(w_times_x[:, :hidden_size] + w_times_h_prev[:, :hidden_size]
                            +b_ih[ :hidden_size] + b_hh[ :hidden_size])
        f_t = torch.sigmoid(w_times_x[:, hidden_size:2*hidden_size] + w_times_h_prev[:, hidden_size:2*hidden_size]
                            + b_ih[hidden_size:2*hidden_size] + b_hh[hidden_size:2*hidden_size])
        g_t = torch.tanh(w_times_x[:, 2*hidden_size:3*hidden_size] + w_times_h_prev[:, 2*hidden_size:3*hidden_size]
                            + b_ih[2*hidden_size:3*hidden_size] + b_hh[2*hidden_size:3*hidden_size])
        o_t = torch.sigmoid(w_times_x[:, 3*hidden_size:4*hidden_size] + w_times_h_prev[:, 3*hidden_size:4*hidden_size]
                            + b_ih[3*hidden_size:4*hidden_size] + b_hh[3*hidden_size:4*hidden_size])
        prev_c = f_t * prev_c + i_t * g_t
        prev_h = o_t * torch.tanh(prev_c)

        output[:, t, :] = prev_h

    return output, (prev_h, prev_c)


output_custom, (h_final_custom, c_final_custom) = lstm_forward(input=input, initial_states = (h_0, c_0), w_ih=lstm_layer.weight_ih_l0,
             w_hh=lstm_layer.weight_hh_l0, b_ih=lstm_layer.bias_ih_l0, b_hh=lstm_layer.bias_hh_l0)

print("LSTM custom")
print("output_custom:\n", output_custom)
print("h_final_custom:\n", h_final_custom)
print("c_final_custom:\n", c_final_custom)

LSTM模型输入输出可视化理解

 

 

图文来自:pytorch中LSTM参数详解(一张图帮你更好的理解每一个参数)_pytorch lstm 参数一图_xjtuwfj的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/993801.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LORA项目源码解读

大模型fineturn技术中类似于核武器的LORA,简单而又高效。其理论基础为:在将通用大模型迁移到具体专业领域时,仅需要对其高维参数的低秩子空间进行更新。基于该朴素的逻辑,LORA降低大模型的fineturn门槛,模型训练时不需…

Redis-带你深入学习数据类型list

目录 1、list列表 2、list相关命令 2.1、添加相关命令:rpush、lpush、linsert 2.2、查找相关命令:lrange、lindex、llen 2.3、删除相关命令:lpop、rpop、lrem、ltrim 2.4、修改相关命令:lset 2.5、阻塞相关命令&#xff1a…

deepin V23通过flathub安装steam畅玩游戏

deepin V23缺少32位库,在星火商店安装的steam,打开报错,无法使用! 通过flathub网站安装steam,可以正常使用,详细教程如下: flathub网址:主页 | Flathub 注意:flathub下载速度慢,只…

【笔试强训选择题】Day38.习题(错题)解析

作者简介:大家好,我是未央; 博客首页:未央.303 系列专栏:笔试强训选择题 每日一句:人的一生,可以有所作为的时机只有一次,那就是现在!! 文章目录 前言一、Day…

ChatGPT实战与私有化大模型落地

文章目录 大模型现状baseline底座选择数据构造迁移方法评价思考 领域大模型训练技巧Tokenizer分布式深度学习数据并行管道并行向量并行分布式框架——Megatron-LM分布式深度学习框架——Colossal-AI分布式深度学习框架——DeepSpeedP-tuning 微调 资源消耗模型推理加速模型推理…

基于SSM的学院实验中心管理系统

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

从数据页的角度看 B+Tree

InnoDB 是如何存储数据的? MySQL支持多种存储引擎,不同的存储引擎,存储数据的方式也不相同,我们最常使用的是 InnoDB 存储引擎。 在数据库中的记录是按照行来存储的,但是数据库的读取并不是按照 [ 行] 为单位&#x…

MySQL进阶 —— 超详细操作演示!!!(上)

MySQL进阶 —— 超详细操作演示!!!(上) 一、存储引擎1.1 MySQL 体系结构1.2 存储引擎介绍1.3 存储引擎特点1.4 存储引擎选择 二、索引2.1 索引概述2.2 索引结构2.3 索引分类2.4 索引语法2.5 SQL 性能分析2.6 索引使用2…

BUUCTF rip 1

使用linux的file命令查看基本信息 64位 使用IDA64位进行反编译 看到gets就肯定有栈溢出 能看到有一个 _system函数,改函数能执行系统命令 既然反编译有这个函数说明有地方调用了他 果然在一个fun函数中有调用,执行的命令是 /bin/sh 也就是一个后门函数&…

【C++ • STL • 力扣】详解string相关OJ

文章目录 1、仅仅翻转字母2、字符串中的第一个唯一字符3、字符串里最后一个单词的长度4、验证一个字符串是否是回文5、字符串相加总结 ヾ(๑╹◡╹)ノ" 人总要为过去的懒惰而付出代价 ヾ(๑╹◡╹)ノ" 1、仅仅翻转字母 力扣链接 代码1展示&…

【Spring Cloud系列】 雪花算法原理及实现

【Spring Cloud系列】 雪花算法原理及实现 文章目录 【Spring Cloud系列】 雪花算法原理及实现一、概述二、生成ID规则部分硬性要求三、ID号生成系统可用性要求四、解决分布式ID通用方案4.1 UUID4.2 数据库自增主键4.3 基于Redis生成全局id策略 五、SnowFlake(雪花算…

数据结构与算法-----顺序表(链表篇)

目录 前言 顺序表 链表 概念 与数组的不同 单链表 1. 创建节点 2.插入节点 尾插节点(形成链表结构) 向指定位置插入节点(链表已有) ​编辑 3.遍历链表数据 4.获取链表长度 5.删除节点 删除尾节点 删除指定节点 …

51单片机项目(10)——基于51单片机的电压计

本次设计的电压计,使用ADC0832芯片,测到电压后,将电压信息发送到串口进行显示。仿真功能正常,能够运行。(工程文件和代码放在最后) 电路图如下: 运行过程如下: ADC0832介绍&#xff…

linux下检测CPU性能的mpstat命令安装与用法

1、安装命令 $ sudo apt-get install sysstat sysstat安装包还包括了检测设备其它状态的命令&#xff0c;查看命令如下&#xff1a; 2、检测CPU命令语法 $ mpstat --h //查看mpstat的语法 Usage: mpstat [ options ] [ <interval> [ <count> ] ] Options are: …

设计模式之访问器模式(Visitor)的C++实现

1、访问器模式的提出 在软件开发过程中&#xff0c;早已发布的软件版本&#xff0c;由于需求的变化&#xff0c;需要给某个类层次结构增加新的方法。如果在该基类和子类中都添加新的行为方法&#xff0c;将给代码原有的结构带来破坏&#xff0c;同时&#xff0c;也违反了修改封…

D. Sorting By Multiplication

Problem - D - Codeforces 思路&#xff1a;我们首先考虑当只能乘以正数时&#xff0c;那么变为单调增的方法就是找所有w[i]>w[i1]的对数&#xff0c;因为如果存在一个w[i]>w[i1]&#xff0c;那么我们一定至少需要进行一次操作&#xff0c;并且我们还知道我们进行一次操…

Redis经典问题:缓存穿透

&#xff08;笔记总结自《黑马点评》项目&#xff09; 一、产生原因 用户请求的数据在缓存中和数据库中都不存在&#xff0c;不断发起这样的请求&#xff0c;给数据库带来巨大压力。 常见的解决方式有缓存空对象和布隆过滤器。 二、缓存空对象 思路&#xff1a;当我们客户…

JP《乡村振兴振兴战略下传统村落文化旅游设计》许少辉书香续,山水长

JP《乡村振兴振兴战略下传统村落文化旅游设计》许少辉书香续&#xff0c;山水长

MySQL--MySQL表的增删改查(基础)

排序&#xff1a;ORDER BY 语法&#xff1a; – ASC 为升序&#xff08;从小到大&#xff09; – DESC 为降序&#xff08;从大到小&#xff09; – 默认为 ASC SELECT … FROM table_name [WHERE …] ORDER BY column [ASC|DESC], […]; *** update

【数据结构--顺序表】合并两个有序数组

题目描述&#xff1a; 代码实现&#xff1a; void merge(int* nums1, int nums1Size, int m, int* nums2, int nums2Size, int n){int x0;if(m0)//如果nums1为空&#xff0c;而nums2不为空&#xff0c;则将nums2拷贝至nums1{while(nums1Size--){nums1[x]nums2[x];x;}}if(n0)//…