一文读懂官方给出torch.nn.RNN API的参数及手写RNN API复现

news2025/1/18 14:41:33

理论部分 

官方给出的文档解释: 

 计算公式:

该公式对应的结构框图:

        其中 xt 表示当前 t 时刻的输入,Wih表示 “输入层” 到 “隐藏层” 的权重矩阵。即 Wih 会将输入映射至隐藏层。bih表示输入层到隐藏层的偏置,ht-1表示 t-1 时刻的状态,ht 表示t时刻的隐藏状态,Whh表示隐藏层到隐藏层的权重矩阵,bhh表示隐藏层到隐藏层的偏置。
        `torch.nn.RNN` 是PyTorch中的一个循环神经网络(RNN)模块,用于构建RNN模型。它具有许多参数,以下是一些常用参数的解释:

参数: 

input_size (int):输入数据的特征大小(特征维度),即每个时间步的输入向量 xt 的维度。例如,如果你的输入数据是单词的嵌入向量,每个单词表示为一个100维的向量,那么`input_size`将等于100。
hidden_size (int):隐藏层的特征大小,即每个时间步的隐藏状态向量ht的维度。它决定了模型的表示能力和记忆能力。较大的`hidden_size`通常允许模型学习更复杂的模式,但也需要更多的计算资源。

num_layers (int,可选):RNN的层数,用于堆叠多个RNN层,默认值为1。当层数大于1时,RNN会变为多层RNN。多层RNN可以捕捉更复杂的时间依赖关系,但也会增加模型的复杂性。

nonlinearity (str,可选):指定激活函数,默认值为'tanh'。可选值有'tanh'和'relu'。
bias (bool,可选):如果设置为True,则在RNN中添加偏置项。默认值为True。偏差项通常有助于模型更好地拟合数据。

batch_first (bool,可选):一个布尔值,确定输入数据的维度顺序。如果设置为True,则输入数据的形状为(batch_size, seq_len, input_size)。否则,默认输入数据的形状为(seq_len, batch_size, input_size)。默认值为False。

dropout (float,可选):如果非零,则在除最后一层之外的每个RNN层之间添加dropout层,其丢弃概率为dropout。默认值为0。这有助于防止过拟合。
bidirectional (bool,可选):一个布尔值,确定是否使用双向RNN。如果设置为True,RNN将同时在时间步的正向和反向方向上运行(则使用双向RNN),以捕捉前后的上下文信息。默认值为False。

        一旦你创建了`torch.nn.RNN`模块,你可以将输入数据传递给它,然后使用输出来进行进一步的处理或者连接其他层。模型的隐藏状态也可以在需要时访问,以进行序列数据的持久化记忆或其他操作。

input : 对于没有批次的输入,即没有batch的输入,其形状大小为一个二维的张量,尺寸为:序列长度(总时间步数)✖input_size(输入数据的特征大小,即每个时间步的输入向量xt的维度)。
当batch_first = False,那么input的尺寸大小为:序列长度✖batch_size(批次大小)✖input_size。
当batch_first = True,那么input的尺寸大小为:batch_size(批次大小)✖序列长度✖input_size。

------------------------------------------------RNN网络input维度解释-------------------------------------------------

        xt是一个n维向量,RNN(递归神经网络)的输入将是一整个序列,也就是X = [ x1,…,xt-1,xt,xt+1,…xT ],T表示序列的长度,这也就是说RNN中一个样本就是一个序列。
        对于语言模型,每一个xt将代表一个词向量,一整个序列就代表一句话(也就是一个样本),T就是这句话包含的单词数量。又由于在神经网络中,我们的输入通常是多个样本作为一个批次的,所以在RNN中数据通常是三维的,也就是[ batch_size, seq_len, input_size ]或者[ seq_len, batch_size, input_size ],其中,batch_size 表示批次大小,也就是一个批次含有多少个序列(句子);seq_len表示一个序列(句子)的长度,也就是按时间序列展开每个样本有多少个可见的RNN cell;input_size 表示某时刻输入数据的维度,即xt的维度,也就是这个输入数据的特征数目(features)。

一个句子:一个sample

一个句子由n个词:n个timestep(seq_len)

一个词是k维的词向量:k个feature

        例如:一个数据集包含五句话(天气真好)(你是谁啊)(我是小明)(明天打球)(武汉加油)。数据集的维度就是(batch_size, time_step, feature_dim)= (5, 4, word_embedding)。

对于这样一个数据集,输入RNN的时候是什么情况?

        RNN是每个time_step输入一次数据,那么for循环time_step1时,进入网络的数据就是(天,你,我,明, 武)每句话的第一个字进入网络,然后依次往后,这里我们最简单的理解就是同时有batch_size个RNN在处理数据,每个RNN处理一个字,那么time_step1的输出就是(batch_size, hidden_size),整个batch处理完输出为(batch_size, time_step, hidden_size)。

---------------------------------------------------------------------------------------------------------------------------------

h_0: 对于没有批次的输入,其形状大小为一个二维的张量,尺寸为:D(单层RNN网络D取1,双层RNN网络D取2)*RNN层数✖hidden_size(隐藏层的特征大小,即每个时间步的隐藏状态向量ht的维度)。
对于有批次的输入,其形状大小为一个三维的张量,尺寸为:D(单层RNN网络D取1,双层RNN网络D取2)*RNN层数✖batch_size✖hidden_size(隐藏层的特征大小,即每个时间步的隐藏状态向量ht的维度)。
如果不提供初始隐藏状态h_0,默认情况下取0。

RNN的输出变量

output 总输出

        `torch.nn.RNN`模块的主要输出是一个包含RNN模型在每个时间步的输出的张量,通常称为`output`。这个输出张量的形状取决于输入数据的形状、RNN的参数设置以及输入序列的长度。

        具体来说,`output`张量的形状为`( batch_size, sequence_length, num_directions * hidden_size)`,其中:

- `sequence_length`是输入序列的长度,即时间步的数量。
- `batch_size`是输入数据的批次大小,即一次处理多少个序列。
- `num_directions`是一个可选的参数,如果你的RNN是双向的(`bidirectional=True`),则`num_directions`为2;如果是单向的,则为1。
- `hidden_size`是RNN的隐藏状态大小,即每个时间步的隐藏状态向量的维度。

        这个`output`张量包含了RNN在每个时间步的输出。通常,你可以在训练后对这些输出进行进一步处理,例如用于分类、回归或序列到序列的任务,或者用于获得序列中的某些信息。

h_n 最后一个时间步的输出 

        此外,`torch.nn.RNN`还返回一个包含最后一个时间步的隐藏状态的张量,通常称为`h_n`。这个张量的形状为`(num_layers * num_directions, batch_size, hidden_size)`,其中:

- `num_layers`是RNN模型的层数。
- `num_directions`是一个可选的参数,如果你的RNN是双向的,则`num_directions`为2;如果是单向的,则为1。
- `hidden_size`是RNN的隐藏状态大小,即每个时间步的隐藏状态向量的维度。

        `h_n`张量包含了每个层的最后一个时间步的隐藏状态,可以用于进行额外的处理或者作为下一个时间步的初始隐藏状态。

        综上所述,`output`和`h_n`是`torch.nn.RNN`模块的两个主要输出,它们提供了RNN在输入序列上的输出信息和最终的隐藏状态信息。

RNN网路的权重参数

weight_ih_l[k] : 输入层到隐藏层的权重矩阵,这是输入到隐藏层[k]的权重参数,其中k表示层的索引。当k=0时,形状为一个二维张量,尺寸大小为:hidden_size✖input_size。

weight_hh_l[k] : 隐藏层到隐藏层的权重矩阵,用于控制前一个时间步的隐藏状态如何影响当前时间步的隐藏状态。形状为一个二维张量,尺寸大小为:hidden_size✖hidden_size。

bias_ih_l[k] : 这是输入到隐藏层[k]的偏差(bias)参数,用于偏置输入数据对隐藏状态的影响。形状为一个一维张量,尺寸大小:hidden_size。

bias_hh_l[k] : 这是隐藏层[k]到自身的偏差(bias)参数,用于偏置前一个时间步的隐藏状态对当前时间步的隐藏状态的影响。形状为一个一维张量,尺寸大小:hidden_size。


        注意,如果你的RNN模型有多个层(`num_layers`大于1),那么每个层都会有一组权重参数。通常,这些权重参数是在模型初始化时随机初始化的,然后通过反向传播进行训练。
        这些权重参数是RNN模型的核心组成部分,它们决定了模型如何处理输入序列并学习时间依赖关系。在训练过程中,这些参数会根据损失函数的梯度进行更新,以逐渐提高模型的性能。

       

代码部分 

 调用官方给出的单层RNN API函数

import torch
import torch.nn as nn
# 单向RNN逐行实现
batch_size, T = 2, 3 # 批大小,输入序列长度
input_size, hidden_size = 2, 3 # input_size:输入数据的特征大小,即每个时间步的输入向量xt的维度。
                               # hidden_size:隐藏层的特征大小,即每个时间步的隐藏状态向量ht的维度。
input = torch.randn(batch_size, T, input_size) # 随机初始化一个输入
h_prev = torch.zeros(batch_size, hidden_size) # 初始隐藏状态,形状大小为batch_size*hidden_size
# step1 调用pytorch RNN API
rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, batch_first=True)
print("rnn API output: ")
rnn_output, state_final = rnn(input, h_prev.unsqueeze(0)) # 使用unsqueeze在h_prev的第0维扩充一维
print("rnn_output:", rnn_output)
print("state_final:", state_final)

手写单层RNN网络

# step2 手写RNN函数,实现RNN计算原理
def rnn_forward(input, weight_ih, bias_ih, weight_hh, bias_hh, h_prev):
    """
    :param input:输入。
    :param weight_ih: 输入层到隐藏层的权重矩阵。weight_ih的大小:h_dim*input_size
    :param bias_ih:输入层到隐藏层的偏置。bias_ih的大小:batch_size*hidden_size
    :param weight_hh:隐藏层到隐藏层的权重矩阵。weight_hh的大小:hidden_size✖hidden_size
    :param bias_hh:隐藏层到隐藏层的偏置。bias_hh的大小:batch_size*hidden_size
    :param h_prev:隐藏状态。h_prev的大小:batch_size*hidden_size
    :return: h_out:总输出; h_prev.unsqueeze(0):最后时刻的输出
    """

    batch_size, T, input_size = input.shape # 将input的形状拆解出来。batch_size表示批次大小,T表示序列长度,或时间步数,input_size表示输入xt的维度。

    h_dim = weight_ih.shape[0] # 隐藏状态的维度
    h_out = torch.zeros(batch_size, T, h_dim) # 初始化一个输出(状态)矩阵

    for t in range(T):
        x = input[:, t, :].unsqueeze(2) # 获取当前时刻输入. x的大小:batch_size*input_size*1
        # 由于input是三维张量:
        # 第一维度是batch_size,全部拿;
        # 第二维度是时间,我们拿当前第 t 时刻的输入;
        # 第三维度是特征维度,全部拿。
        w_ih_batch = weight_ih.unsqueeze(0).tile(batch_size, 1, 1) # w_ih_batch的大小:batch_size*h_dim*input_size
        w_hh_batch = weight_hh.unsqueeze(0).tile(batch_size, 1, 1) # w_hh_batch的大小:batch_size*h_dim*h_dim

        w_times_x = torch.bmm(w_ih_batch, x).squeeze(-1) # batch_size*h_dim
        w_times_h = torch.bmm(w_hh_batch, h_prev.unsqueeze(2)).squeeze(-1) # batch_size*h_dim
        h_prev = torch.tanh(w_times_x+bias_ih+w_times_h+bias_hh)
        h_out[:, t, :] = h_prev
    return h_out, h_prev.unsqueeze(0)

        用官方RNN API的参数喂入到我们自己写的rnn_forward函数中来验证我们的函数输出的结果与官方API输出的结果是否一致。 

# 验证rnn_forward的正确性
# for p,n in rnn.named_parameters():
#     print(p, n) # p表示参数名称,n表示参数取值
custom_rnn_out, custom_state_final = \
    rnn_forward(input, rnn.weight_ih_l0, rnn.bias_ih_l0,
                rnn.weight_hh_l0, rnn.bias_hh_l0, h_prev)
print("rnn_forward function output: ")
print("custom_rnn_out:", custom_rnn_out)
print("custom_state_final:", custom_state_final)

打印结果:

rnn API output: 
rnn_output: tensor([[[ 0.5228, -0.8081, -0.5678],
         [ 0.6968, -0.9421, -0.9379],
         [ 0.1090, -0.7488, -0.9370]],

        [[-0.3539, -0.1466, -0.3064],
         [-0.0056, -0.8178, -0.7956],
         [ 0.7735,  0.1453,  0.1368]]], grad_fn=<TransposeBackward1>)
state_final: tensor([[[ 0.1090, -0.7488, -0.9370],
         [ 0.7735,  0.1453,  0.1368]]], grad_fn=<StackBackward0>)
rnn_forward function output: 
custom_rnn_out: tensor([[[ 0.5228, -0.8081, -0.5678],
         [ 0.6968, -0.9421, -0.9379],
         [ 0.1090, -0.7488, -0.9370]],

        [[-0.3539, -0.1466, -0.3064],
         [-0.0056, -0.8178, -0.7956],
         [ 0.7735,  0.1453,  0.1368]]], grad_fn=<CopySlices>)
custom_state_final: tensor([[[ 0.1090, -0.7488, -0.9370],
         [ 0.7735,  0.1453,  0.1368]]], grad_fn=<UnsqueezeBackward0>)


发现结果一致。

代码部分可参考自:29、PyTorch RNN的原理及其手写复现_哔哩哔哩_bilibili本期直播视频主要讲解序列建模中RNN模型的原理、PyTorch API讲解以及如何逐行实现RNN算法。如果大家觉得本期视频有收获,欢迎支持或转发。直播深度学习算法原理与项目源码讲解,欢迎关注我的直播间:https://live.bilibili.com/14297368?spm_id_from=333.999.0.0, 视频播放量 24437、弹幕量 312、点赞数 650、投硬币枚数 573、收藏人数 981、转发人数 75, 视频作者 deep_thoughts, 作者简介 在有限的生命里怎么样把握住时间专注做点自己喜欢做的同时对别人也有价值的事情,是我们应该时常自查反省的(纯公益分享不接任何广告或合作),相关视频:55 循环神经网络 RNN 的实现【动手学深度学习v2】,【基于pytorch的】循环神经网络和LSTM的基本原理讲解与代码实现!,pytorch-LSTM原理及代码,54 循环神经网络 RNN【动手学深度学习v2】,【循环神经网络】5分钟搞懂RNN,3D动画深入浅出,30、PyTorch LSTM和LSTMP的原理及其手写复现,PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】,20分钟掌握RNN与LSTM原理及其结构应用(Seq2Seq & Attention),57 长短期记忆网络(LSTM)【动手学深度学习v2】,真-极度易懂的LSTM (代码)icon-default.png?t=N7T8https://www.bilibili.com/video/BV13i4y1R7jB/?spm_id_from=333.788&vd_source=fb7bfda367c76676e2483b9b60485e57

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/991898.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【我的第一千篇文章】

作为一名Java开发者&#xff0c;我很自豪地宣布&#xff0c;这里是我输出的第一千篇文章。在过去的六年里&#xff0c;我一直坚持每月输出优质内容&#xff0c;并将其分享给了全世界的读者们。这一千篇文章中&#xff0c;有很多关于Java编程的技巧、经验分享、优秀实践示例、案…

16 “count(*)“ 和 “count(1)“ 和 “count(field1)“ 的差异

前言 经常会有面试题看到这样的问题 “ select count(*) ”, “ select count(field1) ”, “ select count(1) ” 的效率差异啥的 然后 我们这里 就来探索一下 这个问题 我们这里从比较复杂的 select count(field1) 开始看, 因为 较为复杂的处理过程 会留一下一些关键的调试…

C4D国潮场景3D模型合集

110个国潮场景3D模型&#xff0c;C4D源文件&#xff0c;部分效果图如下&#xff1a; 微信扫描下方二维码 回复关键字获取 100004

神经网络输出中间特征图

在进行神经网络的训练过程中&#xff0c;会生成不同的特征图信息&#xff0c;这些特征图中包含大量图像信息&#xff0c;如轮廓信息&#xff0c;细节信息等&#xff0c;然而&#xff0c;我们一般只获取最终的输出结果&#xff0c;至于中间的特征图则很少关注。 前两天师弟突然…

第24章 互斥锁实验(iTOP-RK3568开发板驱动开发指南 )

在上一章节中对信号量进行了学习&#xff0c;而本章节要学习的互斥锁可以说是“量值”为 1 的信号量&#xff0c;最终实现的效果相同&#xff0c;既然有了信号量&#xff0c;那为什么还要有互斥锁呢&#xff0c;带着疑问&#xff0c;让我们来进行本章节的学习吧&#xff01; 2…

古尔曼表示不服?郭明錤:苹果可能不会在10月发布M3芯片的机型

9月9日消息&#xff0c;据天风证券分析师郭明錤所言&#xff0c;苹果可能不会在今年发布搭载M3芯片的MacBook Air/Pro机型。这一说法与此前彭博社的马克古尔曼所透露的消息有所不同。根据古尔曼的消息&#xff0c;苹果最快在10月会发布M3款苹果MacBook Air和Pro电脑。他表示&am…

美国封锁激励中国制造业数字化转型的崛起 | 百能云芯

上海在近日公布了第二批工赋链主培育企业名单&#xff0c;共有15家企业入选。这些被称为“链主”的企业在上海制造业数字化转型的过程中扮演着关键角色&#xff0c;类似于领头大雁&#xff0c;它们是上海制造业的数字化网络中的关键节点。 中新社的报道指出&#xff0c;“数字技…

软件源码开发,网络中的“摄像头”:运维监控系统

在日常生活中&#xff0c;我们不管是在大街小巷&#xff0c;还是在商场大厦都可以见到一个圆形或是方形带有镜片的“小盒子”&#xff0c;这个“小盒子”就是摄像头&#xff0c;摄像头作为一个能实时录制记录它能照到范围内的视频图像的工具&#xff0c;可以在丢失物品、抓捕坏…

判断动物知识竞猜答案正误

判断动物知识竞猜答案正误 教学目标 1&#xff0e; 知识与技能&#xff1a; 结合实例&#xff0c;理解选择结构。掌握if语句的基本格式&#xff0c;掌握关系运算符。 过程与方法&#xff1a; 学会使用if编程解决实际生活中的一些问题。 情感态度与价值观&#xff1a; 教…

通讯软件018——分分钟学会UaExpert OPC UA Client配置

本文介绍如何配置UaExpert OPC UA Client&#xff0c;通过本文可以对OPC UA的基本概念有所了解&#xff0c;掌握OPC UA的本质。相关软件请登录网信智汇(wangxinzhihui.com)。 创建OPC UA 连接 这里需要掌握一下OPC UA的安全机制。 1&#xff09;安全模式&#xff1a; OPC UA安…

史上最详细的PyCharm安装教程,小白建议收藏!

前言&#xff1a;Hello大家好&#xff0c;我是小哥谈。PyCharm是由JetBrains公司开发的一款Python开发工具&#xff0c;在Windows、Mac OS和Linux操作系统中都可以使用&#xff0c;它具有语法高亮显示、Project&#xff08;项目&#xff09;管理、代码跳转、智能提示、自动完成…

初识集合框架 -Java

目录 一、集合框架的概念 二、集合框架的重要性 三、涉及的数据结构和算法 3.1 什么是数据结构 3.2 集合框架&#xff08;容器&#xff09;背后对应的数据结构 3.3 相关的Java知识 3.4 什么是算法 3.5 如何学好数据结构和算法 一、集合框架的概念 Java 集合框架&#xff0c;…

【图卷积神经网络】1-入门篇:为什么使用图神经网络(下)

为什么使用图神经网络? 在本书中,我们将重点介绍图学习技术中的深度学习家族,通常称为图神经网络。GNNs是一种新的深度学习架构类别,专门设计用于处理图结构化数据。与主要用于文本和图像的传统深度学习算法不同,GNNs明确地用于处理和分析图数据集(见图1.4)。 图1.4 - …

Vue3+Ts+Vite项目(第一篇)——使用Vite创建Vue3项目

概述 保姆级详解&#xff0c;带你使用 Vite 创建 Vue3 项目&#xff0c;全程cv即可 文章目录 概述一、 安装 Vite二、 创建项目2.1 运行上述命令后&#xff0c;会让我们输入项目名称。可以写一个 vue3-study2.2 选择项目模板&#xff0c;此处选择 Vue&#xff0c;然后回车确定…

无涯教程-JavaScript - IMPOWER函数

描述 IMPOWER函数以x yi或x yj文本格式返回加到幂的复数。求幂的复数的计算方法如下- $$(x yi)^ n r ^ ne ^ {n \theta} r ^ n \cos n \theta ir ^ n sin n \theta $$ 哪里- $$r \sqrt {x ^ 2 y ^ 2} \:\:和\:\:\theta \tan ^ {-1} \left(\frac {y} {x} \right)\:…

雅思写作 三小时浓缩学习顾家北 笔记总结(四)

目录 The company should provide maternity leave and other assistance to female employees with children. Community redevelopment provides opportunities for offenders to acquire vocational skills. The law should classify drunk driving as a criminal offens…

JavaScript对象方法

在 JavaScript 中&#xff0c;对象可以包含方法&#xff0c;即函数作为它的属性。这些被称为对象函数或方法。 例如&#xff1a; const ITshareArray {firstname: "张三",secondname: "二愣子",birthYear: "1996",job: "程序员",fri…

多元共进|科技促进艺术发展,助力文化传承

科技发展助力文化和艺术的传播 融合传统与创新&#xff0c;碰撞独特魅力 一起来了解 2023 Google 开发者大会上 谷歌如何依托科技创新 推动艺术与文化连接 传承和弘扬传统文化 自 2011 年成立以来&#xff0c;谷歌艺术与文化致力于提供体验艺术和文化的新方式&#xff0c;从生成…

SpringAOP面向切面编程

文章目录 一. AOP是什么&#xff1f;二. AOP相关概念三. SpringAOP的简单演示四. SpringAOP实现原理 一. AOP是什么&#xff1f; AOP&#xff08;Aspect Oriented Programming&#xff09;&#xff1a;面向切面编程&#xff0c;它是一种编程思想&#xff0c;是对某一类事情的集…

热迁移技术-QEMU

社区有言Talk is cheep, show me the code&#xff0c;我们尽量低纬度描述技术。 代码和版本&#xff1a; Qemu-5.0 #热迁移技术的实现者 Kernel-4.19 #提供kvm实现 热迁移的演进 Qemu有加载保存vm的功能&#xff0c;这是两个互补的操作。保存状态就是为每个vm中运行的设备保存…