深度学习基础练习:从pytorch API出发复现LSTM与LSTMP

news2025/1/2 3:56:04

2024/11/5-2024/11/7:

前置知识:

[译] 理解 LSTM(Long Short-Term Memory, LSTM) 网络 - wangduo - 博客园

【官方双语】LSTM(长短期记忆神经网络)StatQuest_哔哩哔哩_bilibili

大部分思路来自于:

PyTorch LSTM和LSTMP的原理及其手写复现_哔哩哔哩_bilibiliicon-default.png?t=O83Ahttps://www.bilibili.com/video/BV1zq4y1m7aH/?spm_id_from=333.880.my_history.page.click&vd_source=db0d5acc929b82408b1040d67f2b1dde

部分常量设置与官方api使用:

        其实在实现RNN之后可以发现,lstm基本是同样的套路。 在看完上面的前置知识之后,理解三个门的作用即可对lstm有一个具体的认识,这里不再赘述。

         关于输入设置这方面,参考如下:

# 定义常量
bs, T, i_size, h_size = 2, 3, 4, 5
# 输入序列
input = torch.randn(bs, T, i_size)
# 初始值,不需要训练
c0 = torch.randn(bs, h_size)
h0 = torch.randn(bs, h_size)

        将定义的常量输入官方api:

# 调用官方api
lstm_layer = nn.LSTM(i_size, h_size, batch_first=True)
# 单层单项lstm,h0与c0的第0维度为 D(是否双向)*num_layers 故增加0维,维度为1
output, (h_final, c_final) = lstm_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)))
print(output.shape)
print(output)

for k, v in lstm_layer.named_parameters():
    print(k)
    print(v.shape)

        输出如下:

torch.Size([2, 3, 5])
tensor([[[ 0.1134, -0.1032,  0.1496,  0.1853, -0.3758],
         [ 0.1831,  0.0223,  0.0377,  0.0867, -0.1090],
         [ 0.1233,  0.1121,  0.0574, -0.0401, -0.1576]],

        [[-0.2761,  0.3259,  0.1687, -0.0632,  0.2046],
         [ 0.1796,  0.3110,  0.0974,  0.0294,  0.0220],
         [ 0.1205,  0.1815,  0.0840, -0.1714, -0.1216]]],
       grad_fn=<TransposeBackward0>)
weight_ih_l0
torch.Size([20, 4])
weight_hh_l0
torch.Size([20, 5])
bias_ih_l0
torch.Size([20])
bias_hh_l0
torch.Size([20])

        可以看到LSTM的内置参数有 weight_ih_l0、weight_hh_l0、bias_ih_l0、bias_hh_l0,将关于三个门的知识结合在一起看差不多就明白接下来应该怎么做了: 

        h和c都是统一经过*weight+bias的操作,加在一起后经过tahn或者sigmoid激活函数,最后或点乘或加在h或者c上进行对参数的更新。只要不把维度的对应关系搞混还是比较好复现的。

        需要注意的是:三个门中的四个weight和bias(遗忘门一个,输入门两个,输出门一个)全部都按照第0维度拼在了一起方便同时进行矩阵运算,所以我们可以看到这些权重和偏置的第0维度的大小为4*h_size。一开始这一点也带给了我比较大的困惑。

代码复现与验证:

        代码较为简单,跟上次实现RNN的思路也差不多,基本是照着官方api那给的公式一步一步来的:

# 代码复现
def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh):
    h0, c0 = initial_states
    bs, T, i_size = input.shape
    h_size = w_ih.shape[0] // 4

    prev_h = h0     # [bs, h_size]
    prev_c = c0     # [bs, h_size]
    """
    w_ih    # 4*h_size, i_size
    w_hh    # 4*h_size, h_size
    """
    # 输出序列
    output_size = h_size
    output = torch.zeros(bs, T, output_size)

    for t in range(T):
        x = input[:, t, :]  # 当前时刻输入向量, [bs, i_size]
        w_times_x = torch.matmul(w_ih, x.unsqueeze(-1)).squeeze(-1)         # [bs, 4*h_size]
        w_times_h_prev = torch.matmul(w_hh, prev_h.unsqueeze(-1)).squeeze(-1)    # [bs, 4*h_size]

        # 分别计算输入门(i),遗忘门(f),输出门(o),cell(g)
        i_t = torch.sigmoid(w_times_x[:, : h_size] + w_times_h_prev[:, : h_size] + b_ih[: h_size] + b_hh[: h_size])
        f_t = torch.sigmoid(w_times_x[:, h_size: 2*h_size] + w_times_h_prev[:, h_size: 2*h_size] +
                            b_ih[h_size: 2*h_size] + b_hh[h_size: 2*h_size])
        g_t = torch.tanh(w_times_x[:, 2*h_size: 3*h_size] + w_times_h_prev[:, 2*h_size: 3*h_size] +
                         b_ih[2*h_size: 3*h_size] + b_hh[2*h_size: 3*h_size])
        o_t = torch.sigmoid(w_times_x[:, 3*h_size:] + w_times_h_prev[:, 3*h_size:] + b_ih[3*h_size:] + b_hh[3*h_size:])
        # 更新流
        prev_c = f_t * prev_c + i_t * g_t
        prev_h = o_t * torch.tanh(prev_c)

        output[:, t, :] = prev_h
    return output, (prev_h, prev_c)

        输出结果对比验证:

# 调用官方api
lstm_layer = nn.LSTM(i_size, h_size, batch_first=True)
# 单层单项lstm,h0与c0的第0维度为 D(是否双向)*num_layers 故增加0维,维度为1
output, (h_final, c_final) = lstm_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)))
print(output.shape)
print(output)

for k, v in lstm_layer.named_parameters():
    print(k, v.shape)
output, (h_final, c_final) = lstm_forward(input, (h0, c0), lstm_layer.weight_ih_l0,
                                          lstm_layer.weight_hh_l0, lstm_layer.bias_ih_l0, lstm_layer.bias_hh_l0)

print(output)

        结果如下:

torch.Size([2, 3, 5])
tensor([[[-0.6394, -0.1796,  0.0831,  0.0816, -0.0620],
         [-0.5798, -0.2235,  0.0539, -0.0120, -0.0272],
         [-0.4229, -0.0798, -0.0762, -0.0030, -0.0668]],

        [[ 0.0294,  0.3240, -0.4318,  0.5005, -0.0223],
         [-0.1458,  0.0472, -0.1115,  0.3445,  0.3558],
         [-0.2922, -0.1013, -0.1755,  0.3065,  0.1130]]],
       grad_fn=<TransposeBackward0>)
weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 5])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])
tensor([[[-0.6394, -0.1796,  0.0831,  0.0816, -0.0620],
         [-0.5798, -0.2235,  0.0539, -0.0120, -0.0272],
         [-0.4229, -0.0798, -0.0762, -0.0030, -0.0668]],

        [[ 0.0294,  0.3240, -0.4318,  0.5005, -0.0223],
         [-0.1458,  0.0472, -0.1115,  0.3445,  0.3558],
         [-0.2922, -0.1013, -0.1755,  0.3065,  0.1130]]], grad_fn=<CopySlices>)

         复现成功。

appendix:

        这里放下LSTMP的参数设置:

# lstmp对h_size进行压缩
proj_size = 3
# h0的h_size也改为proj_size,而c0不变
h0 = torch.randn(bs, proj_size)

# 调用官方api
lstmp_layer = nn.LSTM(i_size, h_size, batch_first=True, proj_size=proj_size)
# 单层单项lstm,h0与c0的第0维度为 D(是否双向)*num_layers 故增加0维,维度为1
output, (h, c) = lstmp_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)), )

print(output)
print(output.shape)
print(h.shape)
print(c.shape)
for k, v in lstmp_layer.named_parameters():
    print(k, v.shape)

tensor([[[-0.0492,  0.0265,  0.0883],
         [-0.1028, -0.0327, -0.0542],
         [ 0.0250, -0.0231, -0.1199]],

        [[-0.2417, -0.1737, -0.0755],
         [-0.2351, -0.0837, -0.0376],
         [-0.2527, -0.0258, -0.0236]]], grad_fn=<TransposeBackward0>)
torch.Size([2, 3, 3])
torch.Size([1, 2, 3])
torch.Size([1, 2, 5])
weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 3])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])
weight_hr_l0 torch.Size([3, 5])

         其实LSTMP就多出了个weight_hr_l0对h进行压缩,但是不对cell压缩,目的是减少lstm的参数量,在小一点的sequence上基本没啥区别。若要支持lstmp,在前面的代码上改动几行即可:

def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh, w_hr=None):
    h0, c0 = initial_states
    bs, T, i_size = input.shape
    h_size = w_ih.shape[0] // 4

    prev_h = h0     # [bs, h_size]
    prev_c = c0     # [bs, h_size]
    """
    w_ih    # 4*h_size, i_size
    w_hh    # 4*h_size, h_size
    """

    if w_hr is not None:
        # 输出压缩至p_size
        p_size = w_hr.shape[0]
        output_size = p_size
    else:
        output_size = h_size

    output = torch.zeros(bs, T, output_size)

    for t in range(T):
        x = input[:, t, :]  # 当前时刻输入向量, [bs, i_size]
        w_times_x = torch.matmul(w_ih, x.unsqueeze(-1)).squeeze(-1)         # [bs, 4*h_size]
        w_times_h_prev = torch.matmul(w_hh, prev_h.unsqueeze(-1)).squeeze(-1)    # [bs, 4*h_size]

        # 分别计算输入门(i),遗忘门(f),输出门(o),cell(g)
        i_t = torch.sigmoid(w_times_x[:, : h_size] + w_times_h_prev[:, : h_size] + b_ih[: h_size] + b_hh[: h_size])
        f_t = torch.sigmoid(w_times_x[:, h_size: 2*h_size] + w_times_h_prev[:, h_size: 2*h_size] +
                            b_ih[h_size: 2*h_size] + b_hh[h_size: 2*h_size])
        g_t = torch.tanh(w_times_x[:, 2*h_size: 3*h_size] + w_times_h_prev[:, 2*h_size: 3*h_size] +
                         b_ih[2*h_size: 3*h_size] + b_hh[2*h_size: 3*h_size])
        o_t = torch.sigmoid(w_times_x[:, 3*h_size:] + w_times_h_prev[:, 3*h_size:] + b_ih[3*h_size:] + b_hh[3*h_size:])
        # 更新流
        prev_c = f_t * prev_c + i_t * g_t
        prev_h = o_t * torch.tanh(prev_c)   # [bs, h_size]

        if w_hr is not None:
            prev_h = torch.matmul(w_hr, prev_h.unsqueeze(-1)).squeeze(-1)   # [bs, p_size]

        output[:, t, :] = prev_h
    return output, (prev_h, prev_c)

        经过验证,复现成功。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2238083.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【芯智雲城】Sigmastar星宸科技图传编/解码方案

一、图传技术简介 图传是指将图像或媒体内容从一个设备传输到另外一个设备的技术&#xff0c;传输的媒介可以是无线电波、光纤、以太网等。图传系统主要由图像采集设备、传输设备和接收设备组成&#xff0c;图像采集设备负责采集实时图像&#xff0c;传输设备将采集到的图像转…

JavaFX史上最全教程 - Shape - JavaFX矩形椭圆

avaFX Shape类定义了常见的形状&#xff0c;如线&#xff0c;矩形&#xff0c;圆&#xff0c;Arc&#xff0c;CubicCurve&#xff0c;Ellipse和QuadCurve。 在场景图上绘制矩形需要宽度&#xff0c;高度和左上角的&#xff08;x&#xff0c;y&#xff09;位置。 要在JavaFX中…

【Windows修改Docker Desktop(WSL2)内存分配大小】

记录一下遇到使用Docker Desktop占用内存居高不下的问题 自从使用了Docker Desktop&#xff0c;电脑基本每天都需要重启&#xff0c;内存完全不够用&#xff0c;从16g扩展到24&#xff0c;然后到40G&#xff0c;还是不够用&#xff1b;打开Docker Desktop 运行时间一长&#x…

使用QLoRA和自定义数据集微调大模型

大家好&#xff0c;大语言模型&#xff08;LLMs&#xff09;对自然语言处理&#xff08;NLP&#xff09;的影响是非常深远的&#xff0c;不仅提高了任务效率&#xff0c;还催生出新能力&#xff0c;推动了模型架构和训练方法的创新。尽管如此强大&#xff0c;但LLMs也有局限&am…

Mac M1 Docker创建Rocketmq集群并接入Springboot项目

文章目录 前言Docker创建rocketmq集群创建rocketmq目录创建docker-compose.yml新增broker.conf文件启动容器 Springboot 接入 rocketmq配置maven依赖修改appplication.yml新增消息生产者新增消费者测试发送消息 总结 前言 最近公司给配置了一台mac&#xff0c;正好有时间给装一…

golang分布式缓存项目 Day2

注&#xff1a;该项目原作者&#xff1a;https://geektutu.com/post/geecache-day1.html。本文旨在记录本人做该项目时的一些疑惑解答以及部分的测试样例以便于本人复习。 支持并发读写 接下来我们使用 sync.Mutex 封装 LRU 的几个方法&#xff0c;使之支持并发的读写。在这之…

abap 可配置通用报表字段级日志监控

文章目录 1.功能需求描述1.1 功能1.2 效果展示2.数据库表解释2.1 表介绍3.数据库表及字段3.1.应用日志数据库抬头表:ZLOG_TAB_H3.2.应用日志数据库明细表:ZLOG_TAB_P3.3.应用日志维护字段配置表:ZLOG_TAB_F4.日志封装类5.代码6.调用方式代码7.调用案例程序demo1.功能需求描述 …

材质(三)——材质参数集和材质函数

a.之前是针对材质在材质蓝图里面 类似 于静态更改的方法&#xff0c; b.材质参数集 &#xff0c;对外开放参数&#xff0c;可以手动更改&#xff0c;已然是一种封闭的静态更改方法 c.那么材质函数&#xff0c;将参数集对外开放&#xff0c;可以在关卡蓝图 通过程序 算法 去动…

随机采样之接受拒绝采样

之前提到的逆变换采样&#xff08;Inverse Transform Sampling&#xff09;是一种生成随机样本的方法&#xff0c;它利用累积分布函数&#xff08;CDF&#xff09;的逆函数来生成具有特定分布的随机变量。以下是逆变换采样的缺点&#xff1a; 计算复杂性&#xff1a;对于某些分…

软件设计师:排序算法总结

一、直接插入 排序方式&#xff1a;从第一个数开始&#xff0c;拿两个数比较&#xff0c;把后面一位跟前面的数比较&#xff0c;把较小的数放在前面一位 二、希尔 排序方式&#xff1a;按“增量序列&#xff08;步长&#xff09;”分组比较&#xff0c;组内元素比较交换 假设…

信息安全工程师(78)网络安全应急响应技术与常见工具

前言 网络安全应急响应是指为应对网络安全事件&#xff0c;相关人员或组织机构对网络安全事件进行监测、预警、分析、响应和恢复等工作。 一、网络安全应急响应技术 网络安全应急响应组织 构成&#xff1a;网络安全应急响应组织主要由应急领导组和应急技术支撑组构成。领导组负…

Kafka 的一些问题,夺命15连问

kafka-中的组成员 kafka四大核心 生产者API 允许应用程序发布记录流至一个或者多个kafka的主题&#xff08;topics&#xff09;。 消费者API 允许应用程序订阅一个或者多个主题&#xff0c;并处理这些主题接收到的记录流 StreamsAPI 允许应用程序充当流处理器&#xff08;s…

精选5款小程序设计工具,助力设计之路璀璨前行

在当今数字化浪潮中&#xff0c;小程序的重要性日益凸显&#xff0c;无论是电商、社交还是服务领域&#xff0c;小程序都成为连接用户与品牌的关键桥梁。而一款优秀的小程序离不开精心的设计&#xff0c;以下 5 款小程序设计工具将成为你设计事业的得力助手。 一、即时设计 即…

亚马逊评论爬虫+数据分析

爬取评论 做分析首先得有数据&#xff0c;数据是核心&#xff0c;而且要准确&#xff01; 1、爬虫必要步骤&#xff0c;选好框架 2、开发所需数据 3、最后测试流程 这里我所选框架是seleniumrequest&#xff0c;很多人觉得selenium慢&#xff0c;确实不快&#xff0c;仅针对此…

量子计算及其在密码学中的应用

&#x1f493; 博客主页&#xff1a;瑕疵的CSDN主页 &#x1f4dd; Gitee主页&#xff1a;瑕疵的gitee主页 ⏩ 文章专栏&#xff1a;《热点资讯》 量子计算及其在密码学中的应用 量子计算及其在密码学中的应用 量子计算及其在密码学中的应用 引言 量子计算概述 定义与原理 发展…

论文笔记:no pose,no problem-基于dust3r输出GS参数实现unpose稀疏重建

1.摘要 我们引入了 NoPoSplat&#xff0c;这是一种前馈模型&#xff0c;能够从未设置的稀疏多视图图像中重建由 3D 高斯参数化的 3D 场景。 我们的模型专门使用光度损失进行训练&#xff0c;在推理过程中实现了实时 3D 高斯重建。 为了消除重建过程中对准确pose的需要&#xff…

godot--自定义边框/选中时样式 StyleBoxTexture

前提知识&#xff1a; stylebox就像一个贴图&#xff0c;把图案贴到控件是。多个stylebox同时生效的话&#xff0c;那当然也有层级之分&#xff0c;上层覆盖下层&#xff08;可以设置透明度来显示下层&#xff09; 关于主题的概念&#xff1a; godot——主题、Theme、StyleB…

ReactPress 安装指南:从 MySQL 安装到项目启动

ReactPress Github项目地址&#xff1a;https://github.com/fecommunity/reactpress 欢迎Star。 ReactPress 是一个基于 React 的开源发布平台&#xff0c;适用于搭建博客、网站或内容管理系统&#xff08;CMS&#xff09;。本文将详细介绍如何安装 ReactPress&#xff0c;包括…

BMC运维管理:IPMI实现服务器远控制

IPMI实现服务器远控制 实操一、使用IPMI重置BMC用户密码实操二、使用IPMI配置BMC的静态IP实操三、IPMI实现BMC和主机控制操作实操四、ipmitool查看服务器基本信息实操五、ipmitool实现问题定位BMC(Baseboard Management Controller,基板管理控制器)是服务器硬件的一个独立管…

手机上用什么方法可以切换ip

手机上用什么方法可以切换IP&#xff1f;在某些特定情境下&#xff0c;用户可能需要切换手机的IP地址&#xff0c;以满足网络安全、隐私保护或绕过地域限制等需求。下面以华为手机为例&#xff0c;将详细介绍手机IP地址切换的几种方法&#xff0c;帮助用户轻松实现这一目标。 一…