【PyTorch API】 nn.RNN 和 nn.LSTM 介绍和代码详解

news2024/7/6 18:16:41

文章目录

    • 1. nn.RNN 构建单向 RNN
    • 2. nn.LSTM 构建单向 LSTM
    • 3. 推荐参考资料

1. nn.RNN 构建单向 RNN

torch.nn.RNN 的 PyTorch 链接:torch.nn.RNN(*args, **kwargs)

nn.RNN 的用法和输入输出参数的介绍直接看代码:

import torch
import torch.nn as nn

# 单层单向 RNN
embed_dim = 5    # 每个输入元素的特征维度,如每个单词用长度为 5 的特征向量表示
hidden_dim = 6   # 隐状态的特征维度,如每个单词在隐藏层中用长度为 6 的特征向量表示
rnn_layers = 4   # 循环层数
rnn = nn.RNN(input_size=embed_dim, hidden_size=hidden_dim, num_layers=rnn_layers, batch_first=True)

# 输入
batch_size = 2
sequence_length = 3    # 输入的序列长度,如 i love you 的序列长度为 3,每个单词用长度为 feature_num 的特征向量表示
input = torch.randn(batch_size, sequence_length, embed_dim)
h0 = torch.randn(rnn_layers, batch_size, hidden_dim)

# output 表示隐藏层在各个 time step 上计算并输出的隐状态
# hn 表示所有掩藏层的在最后一个 time step 隐状态, 即单词 you 的隐状态
output, hn = rnn(input, h0)

print(f"output = {output}")
print(f"hn = {hn}")
print(f"output.shape = {output.shape}")     # torch.Size([2, 3, 6])    [batch_size, sequence_length, hidden_dim]
print(f"hn.shape = {hn.shape}")             # torch.Size([4, 2, 6])    [rnn_layers, batch_size, hidden_dim]


"""
output = tensor([[[-0.3727, -0.2137, -0.3619, -0.6116, -0.1483,  0.8292],
         [ 0.1138, -0.6310, -0.3897, -0.5275,  0.2012,  0.3399],
         [-0.0522, -0.5991, -0.3114, -0.7089,  0.3824,  0.1903]],

        [[ 0.1370, -0.6037,  0.3906, -0.5222,  0.8498,  0.8887],
         [-0.3463, -0.3293, -0.1874, -0.7746,  0.2287,  0.1343],
         [-0.2588, -0.4145, -0.2608, -0.3799,  0.4464,  0.1960]]],
       grad_fn=<TransposeBackward1>)
       
hn = tensor([[[-0.2892,  0.7568,  0.4635, -0.2106, -0.0123, -0.7278],
         [ 0.3492, -0.3639, -0.4249, -0.6626,  0.7551,  0.9312]],

        [[ 0.0154,  0.0190,  0.3580, -0.1975, -0.1185,  0.3622],
         [ 0.0905,  0.6483, -0.1252,  0.3903,  0.0359, -0.3011]],

        [[-0.2833, -0.3383,  0.2421, -0.2168, -0.6694, -0.5462],
         [ 0.2976,  0.0724, -0.0116, -0.1295, -0.6324, -0.0302]],

        [[-0.0522, -0.5991, -0.3114, -0.7089,  0.3824,  0.1903],
         [-0.2588, -0.4145, -0.2608, -0.3799,  0.4464,  0.1960]]],
       grad_fn=<StackBackward0>)

output.shape = torch.Size([2, 3, 6])
hn.shape = torch.Size([4, 2, 6])
"""

需要特别注意的是 nn.RNN 的第二个输出 hn 表示所有掩藏层的在最后一个 time step 隐状态,听起来很难理解,看下面的红色方框内的数据就懂了。即 output[:, -1, :] = hn[-1, : , :]

这里 hn 保存了四次循环中最后一个 time step 隐状态的数值,以输入 i love you 为了,hn 保存的是单词 you 的隐状态。

在这里插入图片描述

2. nn.LSTM 构建单向 LSTM

torch.nn.RNN 的 PyTorch 链接:torch.nn.LSTM(*args, **kwargs)

nn.LSTM 的用法和输入输出参数的介绍直接看代码:

import torch
import torch.nn as nn


batch_size = 4
seq_len = 3      # 输入的序列长度
embed_dim = 5    # 每个输入元素的特征维度
hidden_size = 5 * 2    # 隐状态的特征维度,根据工程经验可取 hidden_size = embed_dim * 2
num_layers = 2    # LSTM 的层数,一般设置为 1-4 层;多层 LSTM 的介绍可以参考 https://blog.csdn.net/weixin_41041772/article/details/88032093

lstm = nn.LSTM(input_size=embed_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)

# h0 可以缺省
input = torch.randn(batch_size, seq_len, embed_dim)

"""
output 表示隐藏层在各个 time step 上计算并输出的隐状态
hn 表示所有隐藏层的在最后一个 time step 隐状态, 即单词 you 的隐状态;所以 hn 与句子长度 seq_len 无关
hn[-1] 表示最后一个隐藏层的在最后一个 time step 隐状态,即 LSTM 的输出
cn 表示句子的最后一个单词的细胞状态;所以 cn 与句子长度 seq_len 无关
其中 output[:, -1, :] = hn[-1,:,:]
"""
output, (hn, cn) = lstm(input)


print(f"output.shape = {output.shape}")   # torch.Size([4, 3, 10])
print(f"hn.shape = {hn.shape}")           # torch.Size([2, 4, 10])
print(f"cn.shape = {cn.shape}")           # torch.Size([2, 4, 10])

print(f"output = {output}")
print(f"hn = {hn}")
print(f"output[:, -1, :] = {output[:, -1, :]}")
print(f"hn[-1,:,:] = {hn[-1,:,:]}")


"""
output = tensor([[[ 0.0447,  0.0111,  0.0292,  0.0692, -0.0547, -0.0120, -0.0202, -0.0243,  0.1216,  0.0643],
                  [ 0.0780,  0.0279,  0.0231,  0.1061, -0.0819, -0.0027, -0.0269, -0.0509,  0.1800,  0.0921],
                  [ 0.0993,  0.0160,  0.0516,  0.1402, -0.1146, -0.0177, -0.0607, -0.0715,  0.2110,  0.0954]],

                 [[ 0.0542, -0.0053,  0.0415,  0.0899, -0.0561, -0.0376, -0.0327, -0.0276,  0.1159,  0.0545],
                  [ 0.0819, -0.0015,  0.0640,  0.1263, -0.1021, -0.0502, -0.0495, -0.0464,  0.1814,  0.0750],
                  [ 0.0914,  0.0034,  0.0558,  0.1418, -0.1327, -0.0643, -0.0616, -0.0674,  0.2195,  0.0886]],

                 [[ 0.0552, -0.0006,  0.0351,  0.0864, -0.0486, -0.0192, -0.0305, -0.0289,  0.1103,  0.0554],
                  [ 0.0835, -0.0099,  0.0415,  0.1396, -0.0758, -0.0829, -0.0616, -0.0604,  0.1740,  0.0828],
                  [ 0.1202, -0.0113,  0.0570,  0.1608, -0.0836, -0.0801, -0.0792, -0.0874,  0.1923,  0.0829]],

                 [[ 0.0115, -0.0026,  0.0267,  0.0747, -0.0867, -0.0250, -0.0199, -0.0154,  0.1158,  0.0649],
                  [ 0.0628,  0.0003,  0.0297,  0.1191, -0.1028, -0.0342, -0.0509, -0.0496,  0.1759,  0.0831],
                  [ 0.0569,  0.0105,  0.0158,  0.1300, -0.1367, -0.0207, -0.0514, -0.0629,  0.2029,  0.1042]]], grad_fn=<TransposeBackward0>)
            
hn = tensor([[[-0.1933, -0.0058, -0.1237,  0.0348, -0.1394,  0.2403,  0.1591, -0.1143,  0.1211, -0.1971],
              [-0.2387,  0.0433, -0.0296,  0.0877, -0.1198,  0.1919,  0.0832, 0.0738,  0.1907, -0.1807],
              [-0.2174,  0.0721, -0.0447,  0.1081, -0.0520,  0.2519,  0.4040, -0.0033,  0.1378, -0.2930],
              [-0.2130, -0.0404, -0.0588, -0.1346, -0.1865,  0.1032, -0.0269, 0.0265, -0.0664, -0.1800]],

             [[ 0.0993,  0.0160,  0.0516,  0.1402, -0.1146, -0.0177, -0.0607, -0.0715,  0.2110,  0.0954],
              [ 0.0914,  0.0034,  0.0558,  0.1418, -0.1327, -0.0643, -0.0616, -0.0674,  0.2195,  0.0886],
              [ 0.1202, -0.0113,  0.0570,  0.1608, -0.0836, -0.0801, -0.0792, -0.0874,  0.1923,  0.0829],
              [ 0.0569,  0.0105,  0.0158,  0.1300, -0.1367, -0.0207, -0.0514, -0.0629,  0.2029,  0.1042]]], grad_fn=<StackBackward0>)


验证 output[:, -1, :] = hn[-1,:,:]
output[:, -1, :] = tensor([[ 0.0993,  0.0160,  0.0516,  0.1402, -0.1146, -0.0177, -0.0607, -0.0715,0.2110,  0.0954],
                           [ 0.0914,  0.0034,  0.0558,  0.1418, -0.1327, -0.0643, -0.0616, -0.0674,  0.2195,  0.0886],
                           [ 0.1202, -0.0113,  0.0570,  0.1608, -0.0836, -0.0801, -0.0792, -0.0874, 0.1923,  0.0829],
                           [ 0.0569,  0.0105,  0.0158,  0.1300, -0.1367, -0.0207, -0.0514, -0.0629, 0.2029,  0.1042]], grad_fn=<SliceBackward0>)
          
hn[-1,:,:] = tensor([[ 0.0993,  0.0160,  0.0516,  0.1402, -0.1146, -0.0177, -0.0607, -0.0715,0.2110,  0.0954],
                     [ 0.0914,  0.0034,  0.0558,  0.1418, -0.1327, -0.0643, -0.0616, -0.0674,  0.2195,  0.0886],
                     [ 0.1202, -0.0113,  0.0570,  0.1608, -0.0836, -0.0801, -0.0792, -0.0874,  0.1923,  0.0829],
                     [ 0.0569,  0.0105,  0.0158,  0.1300, -0.1367, -0.0207, -0.0514, -0.0629, 0.2029,  0.1042]], grad_fn=<SliceBackward0>)
"""

3. 推荐参考资料

多层 LSTM 的介绍可以参考博客 RNN之多层LSTM理解:输入,输出,时间步,隐藏节点数,层数

RNN 的原理和 PyTorch 源码复现可以参考视频 PyTorch RNN的原理及其手写复现

LSTM 的原理和 PyTorch 源码复现可以参考视频 PyTorch LSTM和LSTMP的原理及其手写复现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/725789.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

商业模式画布

商业模式画布给了创业者一个思考的框架&#xff0c;在行动之前充分思考和演练。 文章目录 认识商业模式画布九个组成部分&#xff08;以Zoom为例拆解&#xff09;收入成本 九个组成部分的关系总结 认识商业模式画布 九个组成部分&#xff08;以Zoom为例拆解&#xff09; 收入 成…

七月创作之星挑战赛开始咯~

活动火热进行中&#xff01; 欢迎各位大佬积极参与~ 大家请加入卡奥斯开源社区官方社群哦&#xff0c;最新活动实时更新&#xff01; 还有专属群内福利&#xff08;蛋糕券、购物卡、周边礼品&#xff09;等你来拿~ 礼品详情

高效管理工作任务,推荐优秀任务管理软件助力工作效率提升

任务管理软件是一种用于组织任务、将任务分配给个人并监控其进展的软件。该软件可以帮助确保任务在预算内按时完成。它在协同工作环境中特别有用&#xff0c;在这种环境中&#xff0c;多人在处理需要跟踪和监视的任务。 任务管理软件可以帮助简化分配任务和监控任务进度的过程。…

ModaHub魔搭社区:基于阿里云 ACK 搭建开源向量数据库 Milvus

目录 一、准备资源 二、集群创建&#xff1a; 本集群基于Terway网络构建 二、连接刚刚创建的ACK集群 三、部署Milvus数据库 四、优化Milvus配置 简介&#xff1a; 生成式 AI&#xff08;Generative AI&#xff09;引爆了向量数据库&#xff08;Vector Database&#xff0…

STM8低门槛快速入门,类似Arduino封装库模式开发介绍

STM8低门槛快速入门&#xff0c;类似Arduino封装库模式开发介绍 &#x1f4cc;STM8外设封装库原项目开源地址&#xff1a;https://github.com/gicking/STM8_templates&#x1f4cd;个人整理过的项目地址&#xff1a;https://github.com/perseverance51/STM8-Templates &#x1…

前端开发常用Nginx设置说明

前端部署常用到Nginx&#xff0c;作为前端开发常用的配置不多&#xff0c;担也需要掌握 常见配置说明&#xff0c;这里只列表server模块的核心代码 server {listen 9015; # 端口号server_name 172.16.101.191; # 浏览器访问域名&#xff0c;不配置默认为本服务器地址index in…

redhat6安装mysql8.0.33

1、下载mysql 官网地址&#xff1a;https://downloads.mysql.com/archives/community/ 下载步骤&#xff1a; 过滤操作系统版本 下载后&#xff0c;上传到服务器Downloads目录 2、安装mysql8 解压压缩包 tar -xvf mysql-8.0.31-1.el9.x86_64.rpm-bundle.tar [rootrhel64 …

node搭建一个简单的脚手架

一、什么是脚手架 脚手架&#xff08;Scaffold&#xff09;是指在软件开发过程中为提高开发效率而提供的一套基础代码结构、组织规范、开发工具和工程化配置的工具。脚手架可以帮助开发团队快速搭建项目的基础框架&#xff0c;规范项目的开发流程&#xff0c;并提供一些常用的…

指针函数与函数指针

指针函数 指针函数&#xff1a;指针函数是一个函数&#xff0c;返回值是一个指针。 int *fun; //fun是指针变量 int *fun(x,y); //fun是指针函数; #include<iostream> using namespace std;char* day_name() {return("Monday"); //返回地址 }int main() {char…

堆排序选择排序

选择排序 选择排序&#xff08;Selection sort&#xff09;是一种简单直观的排序算法。它的工作原理如下。首先在未排序序列中找到最小&#xff08;大&#xff09;元素&#xff0c;存放到排序序列的起始位置&#xff0c;然后&#xff0c;再从剩余未排序元素中继续寻找最小&…

Linux系统下 - [linux命令]查找包含指定内容的文件

格式1&#xff1a;grep -r “指定内容” 目录 eg:输出包含"指定内容"的文件列表以及简要信息 查找当前目录下的 CONFIG_ESP_SMARTCONFIG_TYPE grep -r "CONFIG_ESP_SMARTCONFIG_TYPE" .格式2&#xff1a;grep -r -l “指定内容” 目录 eg:仅输出包含&q…

模拟Toast 自定义提示框

模拟Toast 自定义提示框 前言 为满足产品需求&#xff0c;发现现在的ToastUtils不是太重就是不太满足需求&#xff0c;这边写个简单易用的工具&#xff0c;几十行代码解决的问题,还要啥轮子。 功能如下&#xff1a; 自动消失相对锚点位置 可配置&#xff0c;正中间&#x…

刷题日记06《回溯算法》

问题描述 力扣https://leetcode.cn/problems/Ygoe9J/ 给定一个无重复元素的正整数数组 candidates 和一个正整数 target &#xff0c;找出 candidates 中所有可以使数字和为目标数 target 的唯一组合。 candidates 中的数字可以无限制重复被选取。如果至少一个所选数字数量不同…

计算机体系结构基础知识介绍之缓存性能的十大进阶优化之非阻塞缓存(四)

优化四&#xff1a;非阻塞缓存&#xff0c;提高缓存带宽 对于允许乱序执行的流水线计算机&#xff0c;处理器不需要因数据高速缓存未命中而停止。 例如&#xff0c;处理器可以继续从指令高速缓存获取指令&#xff0c;同时等待数据高速缓存返回丢失的数据。 非阻塞高速缓存或无…

23家企业推出昇腾AI系列新品 覆盖云、边、端智能硬件

[中国&#xff0c;上海&#xff0c;2023年7月6日] 昇腾人工智能产业高峰论坛在上海举办。论坛现场&#xff0c;大模型联合创新启动&#xff0c;26家行业领军企业、科研院所与华为将共同基于昇腾AI进行基础大模型与行业大模型应用创新。同时&#xff0c;华为携手伙伴联合发布昇腾…

Java虚拟机(JVM)、垃圾回收器

一、Java简介 1、Java开发及运行版本 JRE(Java Runtime Environment&#xff0c;运行环境) 所有的程序都要在JRE下才能够运行。包括JVM和Java核心类库和支持文件。JDK(Java Development Kit&#xff0c;开发工具包) 用来编译、调试Java程序的开发工具包。包括Java工具(javac/…

【LNMP】架构及应用部署 搭建电影网站

准备环境 一台虚拟机192.168.108.67 关闭防火墙 systemctl stop firewalld iptables -F setenforce 0 检查光盘 查看yum仓库 安装nginx依赖 [rootlocalhost ~]# yum -y install pcre-devel zlib-devel 创建管理nginx用户&#xff08;用来运行nginx&#xff09; [rootlocalh…

picard安装时报错“Exception in thread “main“ java.lang.UnsupportedClassVersionError”

最近在通过GATK所介绍的best practice流程来call SNP流程 1.流程 1.1 BWA比对&#xff0c;获得sam文件 1.2 准备用picard来压缩排序sam文件为bam文件&#xff0c;并对bam文件进行去重复&#xff08;duplicates marking&#xff09; 这是就需要用到picard软件 按照教程网页上…

go-zero的rpc服务案例解析

go-zero的远程调用服务是基于gRpc的gRPC教程与应用。 zero使用使用gRpc需要安装protoc插件&#xff0c;因为gRpc基于protoc插件使用protocol buffers文件生成rpc服务器和api的代码的。 gRPC 的代码生成还依赖 protoc-gen-go&#xff0c;protoc-gen-go-grpc 插件来配合生成 Go…

机器学习笔记 - 局部敏感哈希简介

一、算法简述 局部敏感散列 (LSH) 技术,可显著加快对数据的邻居搜索或近似重复检测。例如,这些技术可用于以惊人的速度过滤掉抓取网页的重复项,或者从地理空间数据集中对附近点执行近恒定时间查找。 让我们快速回顾一下其他类型的哈希函数,哈希函数的传统用途是…