RNN实战具体跑的代码

news2024/9/24 15:27:52

一、首先先上代码:这个是接pytorch API的调用代码

import torch
import torch.nn as nn
bs, T = 2,3#批次大小,输入序列长度
input_size,hidden_size=2,3#输入特征大小,隐含层特征大小
input = torch.randn(bs,T, input_size)
h_prev = torch.zeros(bs, hidden_size)#初始隐含状态

#step1:调用pytorch API
rnn = nn.RNN(input_size,hidden_size,batch_first=True)
rnn_output,state_final = rnn(input,h_prev.unsqueeze(0))
print(rnn_output)
print(state_final)

参数解释一下:

1、bs(batch_size):就每个批次输入的数据量

2、T(也就是序列长度):

如果是纯时间序列预测的模型,那么在纯时间序列模型中,T。时间步数代表着历史数据中的每个时间点,而模型通过学习这些历史数据来预测未来的结果。

假设我们有一段时间内的房价数据,每个月记录一个时间点的房价。我们想要根据过去几个月的房价数据来预测下个月的房价。在这种情况下,时间步数就是过去几个月的数量。

举个例子,假设我们有过去 12 个月的房价数据,现在想要预测下个月的房价。那么在这个例子中,时间步数就是 12。我们可以用过去 12 个月的房价数据作为输入特征,然后预测下个月的房价作为模型的输出。

        如果是NLP模型的话:

举个例子,假设我们要用RNN模型来预测一个句子中的下一个单词。我们可以将句子分割成单词,并按照单词的顺序依次输入到模型中。在这个例子中,每个时间步对应着句子中的一个单词。

考虑句子:"The cat is sitting on the mat.",如果我们按照单词的顺序将其输入到RNN模型中,则每个单词对应一个时间步。例如:

  • 时间步 1:输入单词 "The"
  • 时间步 2:输入单词 "cat"
  • 时间步 3:输入单词 "is"
  • 时间步 4:输入单词 "sitting"
  • 时间步 5:输入单词 "on"
  • 时间步 6:输入单词 "the"
  • 时间步 7:输入单词 "mat"

在这个例子中,整个句子被分成了7个时间步,每个时间步上都有相应的输入数据。模型在每个时间步上接收输入并产生输出,从而对整个序列进行处理。

3、layer 层

  • input_size输入特征的维度大小。

  • hidden_size:隐含层(隐藏层)的特征维度大小。

  • input:输入数据,形状为(batch_size, sequence_length, input_size)。

  • h_prev:初始隐含状态,形状为(batch_size, hidden_size)。在这里,初始隐含状态被初始化为全零张量。

  • rnn_output:这是 RNN 模型的输出结果。它是一个张量,包含了每个时间步的隐含状态的输出。形状为 (batch_size, sequence_length, hidden_size),其中 batch_size 表示批次大小,sequence_length 表示序列长度,hidden_size 表示隐含状态的特征维度大小。

  • state_final:这是 RNN 模型的最终隐含状态。它表示 RNN 模型在整个序列上的最后一个时间步的隐含状态。形状为 (num_layers, batch_size, hidden_size),其中 num_layers 表示 RNN 模型的层数,通常为 1。

以上解释还是有些抽象,我举个例子:

import torch
import torch.nn as nn

# 输入数据 input,形状为 (3, 5, 1),表示 3 个样本,每个样本序列长度为 5,每个时间步的输入特征维度为 1
input = torch.randn(3, 5, 1)

# 初始隐含状态 h_prev,形状为 (3, 10),表示 3 个样本,隐含状态的特征维度为 10
h_prev = torch.zeros(3, 10)

# 定义 RNN 模型
rnn = nn.RNN(input_size=1, hidden_size=10, batch_first=True)

# 将输入数据和初始隐含状态输入到 RNN 模型中
rnn_output, state_final = rnn(input, h_prev.unsqueeze(0))

# 输出结果 rnn_output 的形状为 (3, 5, 10),表示 3 个样本,每个样本序列长度为 5,每个时间步的隐含状态的特征维度为 10
# 最终的隐含状态 state_final 的形状为 (1, 3, 10),表示 1 层 RNN 模型,3 个样本,隐含状态的特征维度为 10

这里的 rnn_output 包含了每个时间步的隐含状态的输出,state_final 是整个序列上最后一个时间步的隐含状态。h_prev.unsqueeze(0)这个什么意思?

h_prev.unsqueeze(0)表示对初始隐含状态 h_prev 进行操作,使用 unsqueeze(0) 方法在第 0 维度上增加一个维度。

具体来说,假设 h_prev 的形状是 (3, 10),其中 3 是批次大小,10 是隐含状态的特征维度。那么使用 unsqueeze(0) 方法之后,h_prev 的形状将变为 (1, 3, 10)

这个操作是为了符合 RNN 模型接受初始隐含状态的要求。在 PyTorch 中,RNN 模型期望初始隐含状态的形状为 (num_layers, batch_size, hidden_size),其中 num_layers 是 RNN 模型的层数,通常为 1。因此,我们需要对初始隐含状态进行相应的维度调整,以适应模型的输入要求。

所以,h_prev.unsqueeze(0) 的作用是将原始的初始隐含状态 h_prev 转换为符合 RNN 模型输入要求的形状。

二、手写一个rnn_forward函数,实现RNN计算原理

解释一下这里的ht到底是什么意思在这个公式中,ht​ 表示RNN模型在时间步 t 的隐含状态,通常也被称为隐藏状态或者记忆状态。RNN(循环神经网络)的核心是根据当前的输入xt​ 和前一个时间步的隐含状态 ℎt−1来计算当前时间步的隐含状态 ht​。

具体来说:会有一个叠加的过程

  • xt​ 是当前时间步的输入,通常是一个特征向量。
  • Wih​ 是输入到隐含层的权重矩阵。
  • bih​ 是输入到隐含层的偏置向量。
  • Whh​ 是隐含层到隐含层的权重矩阵。
  • bhh​ 是隐含层到隐含层的偏置向量。
  • tanh 是双曲正切函数,用于加入非线性。
    #step2: 手写一个RNN forward,实现RNN的计算原理
    def run_forward(input,weight_ih,weitht_hh,bias_ih,bias_hh,h_prev):
        bs, T, input_size = input.shape
        h_dim = weight_ih.shape[0]
        h_out = torch.zeros(bs, T, h_dim)#初始化一个输出(状态)矩阵
    
        for t in range(T):
            x = input[:,t,:].unsqueeze(2)#获取当前时刻输入特征,bs*input_size*1
            w_ih_batch = weight_ih.unsqueeze(0).tile(bs,1,1)#ba*h_dim*input_size
            w_hh_batch = weight_hh.unsqueeze(0),tile(bs,1,1)#bs*h_dim
    
            w_times_x = torch.bmm(w_ih_batch,x).squeeze(-1)#bs*h_dim
            w_times_h = torch.bmm(w_hh_batch,h_prev.unsqueeze(2).squeeze(-1))#bs*h_dim
            h_prev = torch.tanh(w_times_h+bias_ih+weitht_hh)
    
            h_out[:,t,:] = h_prev
        return h_out,h_prev.unsqueeze(0)

    最后得到的就是ht

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1430418.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

css新手教程

css新手教程 课程:14、盒子模型及边框使用_哔哩哔哩_bilibili 一.什么是CSS 1.什么是CSS Cascading Style Sheet 层叠样式表。 CSS:表现(美化网页) 字体,颜色,边距,高度,宽度&am…

用户体验优化:HubSpot的秘密武器

在当今数字化市场中,提升用户体验已经成为企业成功的关键因素之一。HubSpot,作为一款领先的营销自动化工具,不仅在推动销售业绩上表现出色,同时通过其独特的策略也致力于提升用户体验。运营坛将深入探讨HubSpot是如何通过个性化推…

Leetcode92:反转链表II(区间反转链表)

一、题目 给你单链表的头指针 head 和两个整数 left 和 right &#xff0c;其中 left < right 。请你反转从位置 left 到位置 right 的链表节点&#xff0c;返回 反转后的链表 。 示例&#xff1a; 输入&#xff1a;head [1,2,3,4,5], left 2, right 4 输出&#xff1a…

Redisson看门狗机制

一、背景 网上redis分布式锁的工具方法&#xff0c;大都满足互斥、防止死锁的特性&#xff0c;有些工具方法会满足可重入特性。如果只满足上述3种特性会有哪些隐患呢&#xff1f;redis分布式锁无法自动续期&#xff0c;比如&#xff0c;一个锁设置了1分钟超时释放&#xff0c;…

Servlet+Ajax实现对数据的列表展示(极简入门)

目录 1.准备工作 1.数据库源&#xff08;这里以Mysql为例&#xff09; 2.映射实体类 3.模拟三层架构&#xff08;Dao、Service、Controller&#xff09; Dao接口 Dao实现 Service实现&#xff08;这里省略Service接口&#xff09; Controller层&#xff08;或叫Servlet层…

Vulnhub billu b0x

0x01 环境搭建 1. 从官方下载靶机环境&#xff0c;解压到本地&#xff0c;双击OVF文件直接导入到vmware虚拟机里面。2. 将虚拟机的网络适配器调成NAT模式&#xff0c;然后开机即可进行操作了。 0x02 主机发现 nmap -sn 192.168.2.0/24 成功获取靶机IP为192.168.2.129。 0x0…

网络时间协议NTP工作模式

单播服务器/客户端模式 单播服务器/客户端模式运行在同步子网中层数较高层上。这种模式下,需要预先知道服务器的IP地址。 客户端:运行在客户端模式的主机(简称客户端)定期向服务器端发送报文,报文中的Mode字段设置为3(客户端模式)。当客户端接收到应答报文时,客户端会…

指针2 1月31日学习笔记

一、strncpy、strncmp、strncat函数 strncpy函数用于将一个字符串的一部分拷贝到另一个字符串中。 char* strncpy(char *dest, const char *src, size_t n){size_t i;for (i 0; i < n && src[i] ! \0; i)dest[i] src[i];for ( ; i < n; i)dest[i] \0;return …

2024牛客寒假算法基础集训营1部分题解

// 能力有限&#xff0c;做多少发多少。 A-DFS搜索 题目描述 最近&#xff0c;fried-chicken完全学明白了DFS搜索&#xff08;如上图所示&#xff09;&#xff01;于是学弟向他请教DFS搜索&#xff0c;fried-chicken热心的进行了讲解&#xff1a; 所谓DFS搜索&#xff0c;就…

上海纽约大学信息技术部高级主任常潘:解密大数据引领的未来教育革命

大数据产业创新服务媒体 ——聚焦数据 改变商业 在数字化时代&#xff0c;大数据技术的应用已经深刻地改变着各行各业。特别是在教育领域&#xff0c;智慧校园建设作为现代化校园的代名词&#xff0c;正迎来大数据技术的巨大机遇。 1月17日&#xff0c;上海纽约大学信息技术部…

嵌入式软件工程师面试题——嵌入式专题(五十二)

说明&#xff1a; 面试群&#xff0c;群号&#xff1a; 228447240面试题来源于网络书籍&#xff0c;公司题目以及博主原创或修改&#xff08;题目大部分来源于各种公司&#xff09;&#xff1b;文中很多题目&#xff0c;或许大家直接编译器写完&#xff0c;1分钟就出结果了。但…

EtherCAT转ModbusTCP网关

一、功能概述 1.1设备简介 本产品是EtherCAT和Modbus TCP网关&#xff0c;使用数据映射方式工作。 本产品在EtherCAT侧作为EtherCAT从站&#xff0c;接TwinCAT、CodeSYS、PLC等&#xff1b;在ModbusTCP侧做为ModbusTCP主站&#xff08;Client&#xff09;或从站&#xff08;…

【蓝桥杯冲冲冲】[NOIP2003 普及组] 数字游戏

蓝桥杯备赛 | 洛谷做题打卡day25 文章目录 蓝桥杯备赛 | 洛谷做题打卡day25[NOIP2003 普及组] 数字游戏题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 提示思路 题解代码我的一些话 [NOIP2003 普及组] 数字游戏 题目描述 丁丁最近沉迷于一个数字游戏之中。这个游戏看…

MySQL for update锁表还是锁行校验

select * from user where id 1 for update ; 1. for update作用 在MySQL中&#xff0c;使用for update子句可以对查询结果集进行行级锁定&#xff0c;以便在事务中对这些行进行更新或者防止其他事务对这些行进行修改。 当使用for update时&#xff0c;锁定行的方式取决于wh…

【初中生讲机器学习】4. 支持向量机算法怎么用?一个实例带你看懂!

创建时间&#xff1a;2024-02-02 最后编辑时间&#xff1a;2024-02-03 作者&#xff1a;Geeker_LStar 你好呀~这里是 Geeker_LStar 的人工智能学习专栏&#xff0c;很高兴遇见你~ 我是 Geeker_LStar&#xff0c;一名初三学生&#xff0c;热爱计算机和数学&#xff0c;我们一起加…

【八大排序】冒泡排序 | 快速排序 + 图文详解!!

&#x1f4f7; 江池俊&#xff1a; 个人主页 &#x1f525;个人专栏&#xff1a; ✅数据结构冒险记 ✅C语言进阶之路 &#x1f305; 有航道的人&#xff0c;再渺小也不会迷途。 文章目录 交换排序一、冒泡排序1.1 算法步骤 动图演示1.2 冒泡排序的效率分析1.3 代码实现1.4 …

HSRP配置指南

实验大纲 第 1 部分&#xff1a;验证连通性 步骤 1&#xff1a;追踪从 PC-A 到 Web 服务器的路径 步骤 2&#xff1a;追踪从 PC-B 到 Web 服务器的路径 步骤 3&#xff1a;观察当 R3 不可用时&#xff0c;网络的行为 第 2 部分&#xff1a;配置 HSRP 主用和 备用路由器 步…

【Crypto | CTF】BUUCTF rsarsa1

天命&#xff1a;第二题RSA解密啦&#xff0c;这题比较正宗 先来看看RSA算法 这道题给出了 p&#xff0c;q&#xff0c;E&#xff0c;就是给了两个质数和公钥 有这三个东西&#xff0c;那就可以得出私钥了 最后把私钥和质数放进去解密即可得到解密后的明文 from gmpy2 impor…

公交最短距离-算法

题目 给定一个一维数组&#xff0c;其中每一个元素表示相邻公交站之间的距离&#xff0c;比如有四个公交站A,B,C,D&#xff0c;对应的距离数组为&#xff0c;1,2,3,4&#xff0c;如下图示 给定目标站X和Y&#xff0c;求他们之间最短的距离 解题 遍历一次整个数组&#xff0c;…

考研/计算机二级数据结构刷题之顺序表

目录 第一题 顺序表的初始化&#xff0c;销毁&#xff0c;头插&#xff0c;尾插&#xff0c;头删&#xff0c;尾删&#xff0c;指定位置插入&#xff0c;指定删除以及打印 第二题 移除元素 题目链接&#xff1a; OJ链接 题目详解&#xff1a;移除元素 第三题&#xff1a;删…