pytorch- RNN循环神经网络

news2025/1/23 11:59:05

目录

  • 1. why RNN
  • 2. RNN
  • 3. pytorch RNN layer
    • 3.1 基本单元
    • 3.2 nn.RNN
      • 3.2.1 函数说明
      • 3.2.2 单层pytorch实现
      • 3.2.3 多层pytorch实现
    • 3.3 nn.RNNCell
      • 3.3.1 函数说明
      • 3.3.2 单层pytorch实现
      • 3.3.3 多层pytorch实现
  • 4.完整代码

1. why RNN

以淘宝的评论为例,判断评论是正面还是负面的,如下图:
在这里插入图片描述
上图中每个单词用一个线性层来表示,最后再聚合,每个单词都有一个单独的w和b。
这种方法的问题:

  • 对于长句子甚至是一段文章来说,就很难表示了,因为要用很多线性层和参数表示
  • 没有语境信息
    比如:
    我不喜欢数学,如果没看到不,只看到喜欢,理解的意思就完全不一样了,因此对于一个句子来说,必须有一个语境信息,才能正确理解句子的意思。

为了解决上述问题,RNN增加了权值共享和一个用于保存语境信息的memory h

2. RNN

如下图:
第一个单词不仅考虑到了x输入还考虑到了初始化输入,通过这两个输入产生了一个语境信息h1,第二个单词不仅考虑当前单词的输入还要考虑上一个单词的语境信息h1,以此类推。
在这里插入图片描述
在这里插入图片描述
RNN的核心就是有个语境信息ht,这个语境信息根据当前的输入和上次的语境信息ht-1不断更新自我,并不断向前传。
展开图如下:
在这里插入图片描述

3. pytorch RNN layer

3.1 基本单元

下图展示了ht的计算过程,假设句子长度为5,batch是3,每个单词用100维向量表示,h0初始值用20维表示,最终得到h(t+1)维度为[3,20]
在这里插入图片描述
在这里插入图片描述
上图中rnn=nn.RNN(100,10),100是feature len,10表示hidden len。
输出参数中rnn.weight_hh_10.shape=》[hidden len, hidden len]
rnn.weight_ih_10.shape=》[hidden len, feature len]

3.2 nn.RNN

3.2.1 函数说明

在这里插入图片描述
input_size-输入x的维度
hidden_size-h的维度
num_layers-有几次,默认1
在这里插入图片描述
上图中forward函数的返回值中
ht[num layers, b, h dim]=》是最后时间戳所有memory(h)的状态
out[seq len, b, h dim]=》是所有时间错最后一个memory(h)的状态

3.2.2 单层pytorch实现

在这里插入图片描述

3.2.3 多层pytorch实现

在这里插入图片描述
上图为2层RNN,h变由1层的[1,3,20]变为][2,3,20]([num_layer,b, h dim]),out和1层一样是[10,3,20]
在这里插入图片描述
下图为4层RNN,pytorch代码实现,注意一下输出shape的变化
在这里插入图片描述

3.3 nn.RNNCell

3.3.1 函数说明

nn.RNNCell与nn.RNN的初始化参数是完全一致
在这里插入图片描述
但是输入输出就不一样了,如下图:
在这里插入图片描述

3.3.2 单层pytorch实现

从pytorch代码可以看出,nn.RNNCell是循环处理每个单词,每次自更新h1
在这里插入图片描述

3.3.3 多层pytorch实现

下图为2层nn.RNNCell的pytorch代码,注意1层的h dim与2层的input dim必须一致,下图都是30
从代码中也可以看出第1层的h1作为第2层的输入参与更新h2。
在这里插入图片描述

4.完整代码

import  torch
from    torch import nn
from    torch import optim
from    torch.nn import functional as F


def main():


    rnn = nn.RNN(input_size=100, hidden_size=20, num_layers=1)
    print(rnn)
    x = torch.randn(10, 3, 100)
    out, h = rnn(x, torch.zeros(1, 3, 20))
    print(out.shape, h.shape)

    rnn = nn.RNN(input_size=100, hidden_size=20, num_layers=4)
    print(rnn)
    x = torch.randn(10, 3, 100)
    out, h = rnn(x, torch.zeros(4, 3, 20))
    print(out.shape, h.shape)
    # print(vars(rnn))

    print('rnn by cell')

    cell1 = nn.RNNCell(100, 20)
    h1 = torch.zeros(3, 20)
    for xt in x:
        h1 = cell1(xt, h1)
    print(h1.shape)


    cell1 = nn.RNNCell(100, 30)
    cell2 = nn.RNNCell(30, 20)
    h1 = torch.zeros(3, 30)
    h2 = torch.zeros(3, 20)
    for xt in x:
        h1 = cell1(xt, h1)
        h2 = cell2(h1, h2)
    print(h2.shape)

    print('Lstm')
    lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)
    print(lstm)
    x = torch.randn(10, 3, 100)
    out, (h, c) = lstm(x)
    print(out.shape, h.shape, c.shape)

    print('one layer lstm')
    cell = nn.LSTMCell(input_size=100, hidden_size=20)
    h = torch.zeros(3, 20)
    c = torch.zeros(3, 20)
    for xt in x:
        h, c = cell(xt, [h, c])
    print(h.shape, c.shape)


    print('two layer lstm')
    cell1 = nn.LSTMCell(input_size=100, hidden_size=30)
    cell2 = nn.LSTMCell(input_size=30, hidden_size=20)
    h1 = torch.zeros(3, 30)
    c1 = torch.zeros(3, 30)
    h2 = torch.zeros(3, 20)
    c2 = torch.zeros(3, 20)
    for xt in x:
        h1, c1 = cell1(xt, [h1, c1])
        h2, c2 = cell2(h1, [h2, c2])
    print(h2.shape, c2.shape)






if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1909595.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

登录谷歌时系统提示“找不到您的Google账号”,原因通常有5个

时不时有朋友问我,说他明明后了谷歌账号,但是在登录谷歌时(有时是自己登录谷歌的网站或者APP,也有的是登录ourplay等加速器、虚拟机等第三方应用),输入了账号后系统却提示说“找不到您的Google账号”&#…

如何在 Odoo 16 中创建名称搜索功能

我们精通 Odoo,了解如何使用关系字段来建立不同模型之间的联系。为了填充这些关系字段,我们经常需要从一整套可用记录中搜索并找到特定值。Odoo 的名称搜索功能简化了此搜索过程,提供了一种根据我们的特定需求自定义搜索的便捷方式。 深入探…

图像识别和目标检测在超市电子秤上的应用

目录 前言深度学习的目标检测图像识别技术视觉秤的优势其他应用场景中的技术应用未来展望 前言 随着科技的不断发展,电子秤在生鲜超市中的应用也在不断升级。传统的电子秤需要打秤人员手动输入秤码,这不仅耗时费力,还需要大量的培训以记住各…

WMS海外仓系统应用:如何改善海外仓的12个核心业务流程

现代化跨境电商的发展依赖海外仓的高效运转,从货物入仓到订单拣货再到最后的货物出库,全部流程都需要海外仓可以顺畅应对。 作为海外仓,则需要借助诸如WMS海外仓系统这样的智能化管理方式,才能适应日益复杂的客户需求。今天我们就…

数据融合工具(4)正方形矩形图幅分幅计算

一、需求背景 对于工程方面需要的局部地区的大比例尺地形图、平面图和中小比例尺挂图和地图集,常使用矩形分幅。 二、矩形分幅 矩形分幅是按平面直接坐标系的横纵坐标线来划分的,图幅的上、下以坐标横轴为界,左、右以坐标纵轴为界。 供各种工…

KIVY ScreenManager 使用案例常见错误总结

# 导入Kivy的App类,它是所有kivy应用的基类 from kivy.app import App # kivy内置了丰富的控件widget 如 按钮button 复选框checkbox 标签label 输入框textinput 滚动容器scrollable container等 from kivy.uix.button import Button # 引入BoxLayout 布局 from ki…

「AI绘画Stable Diffusion 零基础入门 」文生图教程:什么是提示词?万字长文详解提示词的使用,建议收藏!

大家好,我是画画的小强 AI 绘画的一个必不可少的环节就是告诉 AI 描述画面的 Prompt(提示词),但是这种很长很乱、穿插着各种奇怪的数字符号、高深莫测的提示词,究竟在说着什么?难道真的是咒语吗&#xff1…

【力扣高频题】042.接雨水问题

上一篇我们通过采用 双指针 的方法解决了 经典 容器盛水 问题 ,本文我们接着来学习一道在面试中极大概率会被考到的经典题目:接雨水 问题 。 42. 接雨水 给定 n 个非负整数,表示每个宽度为 1 的柱子的高度图,计算按此排列的柱子…

跨平台桌面应用开发工具:electron的优缺点

跨平台桌面应用开发工具Electron是一个由GitHub开发和维护的开源框架,它允许开发者使用HTML、CSS和JavaScript等Web技术来构建跨平台的桌面应用程序。以下是关于Electron的详细介绍: 一、Electron概述 定义:Electron是一个基于Chromium和Nod…

【《无主之地3》风格角色渲染在Unity URP下的实现_角色渲染(第四篇) 】

文章目录 概要描边问题外秒变分叉解决办法1:测试效果如下:外秒变分叉解决办法2:URP管线下PBR渲染源码关键词解释:完整shader代码如下:URP管线下二次元皮肤渲染源码URP管线下二次元头发渲染源码简要介绍文章的目的、主要内容和读者将获得的知识。 概要 提示:《无主之地3》…

Apache中使用SSI设置

先停服务在修改httpd.conf,备份下 Apache\Apache24\conf 设置httpd.conf LoadModule ssl_module modules/mod_ssl.so 取消该命令前的注释符# AddType text/html .shtml AddOutputFilter INCLUDES .shtml 取消该命令前的注释符# 加入html 搜索Options Indexes …

【STM32标准库】DMA双缓冲模式

1.双缓冲模式简介 设置DMA_SxCR寄存器的DBM位为1可启动双缓冲传输模式,并自动激活循环模式,所以设置普通模式或者循环模式都可以。 双缓冲不应用与存储器到存储器的传输。可以应用在从存储器到外设或者外设到存储器。 双缓冲模式下, 两个存…

pbootCMS 数据库sqlite转mysql数据库

前言 pbootCMS默认使用 sqlite数据库 ,那么什么是sqlite数据库呢? SQLite,是一款轻型的数据库,是遵守ACID的关系型数据库管理系统,它包含在一个相对小的C库中。它是D.RichardHipp建立的公有领域项目。它的设计目标是嵌…

叉车指纹锁有规定要装吗

叉车作为工业运输的重要工具,其安全性能一直备受关注。在这个信息化、智能化的时代,对于叉车这类高风险的设备,安全性措施显得尤为重要。而叉车指纹锁作为一种高科技安全设备,其在叉车管理中的应用逐渐受到重视。 那么&#xff0c…

探展2024世界人工智能大会之合合信息扫描黑科技~

文章目录 ⭐️ 前言⭐️ AIGC古籍修复文化遗产焕新⭐️ 高效的文档图像处理解决方案⭐️ AIGC扫描黑科技一键全搞定⭐️ 行业级的大模型加速器⭐️ 结语 ⭐️ 前言 大家好,我是 哈哥(哈哥撩编程) ,这次非常荣幸受邀作为专业观众参…

IP-GUARD如何禁止电脑自带摄像头

IP-GUARD可以通过设备管理模块禁止USB接口,所以USB外置摄像头很容易就可以禁止了。 但是笔记本自带摄像头无法禁止,配置客户端策略如下: device_control_unknown_mode1 device_control_unphysical_mode3

PMP–知识卡片--Scrum角色

Scrum 角色 Scrum 团队由 5 到 9 个(72)团队成员组成。有三种类型角色: 产品负责人(PO):产品负责人定义项目愿景、需求和优先级,对产品成功负责。Scrum Master:负责团队&#xff0c…

Unity海面效果——5、水沫和海平线

Unity引擎制作海面效果 大家好,我是阿赵。 继续做海面效果,上次做完了漫反射颜色和水波动画,还有法线和高光效果。 原则上来说,这个海面已经基本能看了,从性能的考虑,到这里差不多可以停止了。不过有些细节…

SpringCloud跨微服务的远程调用,如何发起网络请求,RestTemplate

在我们的业务流程之中不一定都会是自己模块查询自己模块的信息,有些时候就需要去结合其他模块的信息来进行一些查询完成相应的业务流程,但是在SpringCloud每个模块都相对独立,数据库也有数据隔离。所以当我们需要其他微服务模块的信息的时候&…

HackTheBox--IClean

IClean测试过程 1 信息收集 NMAP端口扫描 80端口测试 echo "10.10.11.12 capiclean.htb" | sudo tee -a /etc/hosts检查页面功能,除了 login 页面无其他可能利用点,可以尝试进行目录爆破和子域名扫描 目录扫描 ./gobuster dir -u http://c…