[PyTorch][chapter 44][RNN]

news2024/11/23 23:46:10

简介

            循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的递归神经网络(recursive neural network) [1]  。

            对循环神经网络的研究始于二十世纪80-90年代,并在二十一世纪初发展为深度学习(deep learning)算法之一 [2]  ,其中双向循环神经网络(Bidirectional RNN, Bi-RNN)和长短期记忆网络(Long Short-Term Memory networks,LSTM)是常见的循环神经网络 [3]  。


目录:

  1.  模型
  2. Forward
  3. Backward
  4. nn.RNN
  5. nn.RNNCell

一  模型

    

                   x_t: t 时刻样本输入\sim R^{n,1}

                   h_t: t 时刻样本隐藏状态\sim R^{m,1}

                   o_t: t时刻输出\sim R^{k,1}

                  \hat{y_t}:  t时刻样本预测类别(只有分类算法才有)\sim R^{k,1}

                  L_t: t 时刻损失函数


二  RNN 前向传播算法 Forward

     2.1   t 时刻隐藏值h_t 更新

             z_t=Ux_t+Wh_{t-1}+b

             h_t=\sigma(z_t)

            其中激活函数\sigma通常用tanh

   2.2   t 时刻输出

           o_t=Vh_t+c

           \hat{y_t}=\sigma(o_t)

           其中激活函数\sigma 为softmax


三 RNN 反向传播算法 BPTT(back-propagation through time)

      3.1 输出层参数v,c梯度

             \frac{\partial L}{\partial v}=\sum_{t=1}^{T}(\hat{y_t}-y_t)h_t^T

             \frac{\partial L}{\partial c}=\sum_{t=1}^{T}(\hat{y_t}-y_t)

      

     3.2  隐藏层参数更新

             定义

               \delta_t=\frac{\partial L}{\partial h_t}

                   =V^T(\hat{y_t}-y_t)+W^Tdiag(1-h_{t+1}^2)\delta_{t+1}

               证明:

                     \delta_{t}=\frac{\partial L_t}{\partial h_t}+(\frac{\partial h_{t+1}}{\partial h_t})^T\frac{\partial L}{\partial h_{t+1}}

                            =(\frac{\partial o_t}{\partial L_t})^T\frac{\partial L_t}{\partial o_t}+(\frac{\partial h_{t+1}}{\partial h_t})^T\delta_{t+1}

                             =V^T(\hat{y_t}-y_t)+(diag(1-h_{t+1}^2)W)^T\delta_{t+1}

                            =V^T(\hat{y_t}-y_t)+W^Tdiag(1-h_{t+1}^2)\delta_{t+1}

                  对于最后一个时刻T

                   \delta_T=V^T(\hat{y_T}-y_T)

          3.3 计算权重系数U,W,b

                   \frac{\partial L}{\partial W}=\sum_t diag(1-h_t^2)\delta_t h_{t-1}^T

                   \frac{\partial L}{\partial U}=\sum_t diag(1-h_t^2)\delta_t x_t^T

                    \frac{\partial L}{\partial b}=\sum_t diag(1-h_t^2)\delta_t


四 nn.RNN 

   这里面介绍PyTorch 使用RNN 类

                       

    4.1 更新规则:

                      h_t= tanh(W_{ih}h_t+W_{ih}b_{ih}+W_{hh}h_{t-1}+b_{hh})

                    

                

参数说明
L时间序列长度T  or 句子长度为 L
Nbatch_size 
d输入特征维度

                  

                     

# -*- coding: utf-8 -*-
"""
Created on Wed Jul 19 15:30:01 2023

@author: chengxf2
"""

import torch
import torch.nn as nn


rnn = nn.RNN(input_size=100, hidden_size=5)
param = rnn._parameters

print("\n 权重系数",param.keys())

print(rnn.weight_ih_l0.shape)

输出:

 

 RNN参数说明:

参数

说明

input_size =d

 输入维度

hidden_size=h

隐藏层维度

num_layers

RNN默认是 1 层。该参数大于 1 ,会形成 Stacked RNN,又称多层RNN或深度RNN

nonlinearity

非线性激活函数。可以选择 tanh relu

bias

即偏置。默认启用

batch_first

选择让 batch_size=N 作为输入的形状中的第一个参数。默认是 False,L × N × d 形状

batch_first=True 时, N × L × d

dropout

即是否启用 dropout。如要启用,则应设置 dropout 的概率,此时除最后一层外,RNN的每一层后面都会加上一个dropout层。默认是 0,即不启用

bidirectional

即是否启用双向RNN,默认关闭

 4.2 单层例子

import torch.nn as nn
import torch

rnn = nn.RNN(input_size= 100, hidden_size=20, num_layers=1)

X = torch.randn(10,3,100)

h_0 = torch.zeros(1,3,20)

out,h = rnn(X,h_0)


print("\n out.shape",out.shape)

print("\n h.shape",h.shape)
      

          out: 包含每个时刻的 隐藏值h_t

           h :    最后一个时刻的隐藏值h_T

  4.3  多层RNN

    

    把当前的隐藏层输出,作为下一层的输入

 第一个隐藏层输出:

               h_t^1= tanh(x_tW_{ih}^1+h_{t-1}^1W_{hh}^1)

第二个隐藏层输出

            h_t^2=tanh(h_t^1W_{ih}^2+h_{t-1}^2W_{hh}^2)

# -*- coding: utf-8 -*-
"""
Created on Mon Jul 24 11:43:30 2023

@author: chengxf2
"""

import torch.nn as nn
import torch
rnn = nn.RNN(input_size=100,  hidden_size=20, num_layers=2)
print(rnn)

x = torch.randn(10,3,100) #默认是[L,N,d]结构

out,h =rnn(x)

print(out.shape, h.shape)


5  nn.RNNCell

     nn.RNN封装了整个RNN实现的过程, PyTorch 还提供了 nn.RNNCell 可以

自己实现RNN 

                x_t\sim [N, dim]

                h_{t-1} \sim [layers, N, dim]

                h_t=rnnCell(x_t,h_{t-1})

               

   5.1  单层RNN

          

# -*- coding: utf-8 -*-
"""
Created on Mon Jul 24 11:43:30 2023

@author: chengxf2
"""
import torch
from torch import nn

def  main():
    model = nn.RNNCell(input_size=10, hidden_size=20)
    
    h1= torch.zeros(3,20)
    
    trainData = torch.randn(8,3,10)
    
    for xt in trainData:
        
         h1= model(xt,h1)
         
    print(h1.shape)


if __name__ == "__main__":
    
    main()
    

 6.2 多层RNN

# -*- coding: utf-8 -*-
"""
Created on Mon Jul 24 11:43:30 2023

@author: chengxf2
"""
import torch
from torch import nn



def  main():
    layer1 = nn.RNNCell(input_size=40, hidden_size=30)
    layer2 = nn.RNNCell(input_size=30, hidden_size=20)
    
    h1= torch.zeros(3,30)
    h2= torch.zeros(3,20)
    
    
    trainData = torch.randn(8,3,40)
    
    for xt in trainData:
        
         h1=  layer1(xt,h1)
         h2 = layer2(h1,h2)
    
    print(h1.shape)
    print(h2.shape)


if __name__ == "__main__":
    
    main()
    

     

参考:

Pytorch 循环神经网络 nn.RNN() nn.RNNCell() nn.Parameter()不同方法实现_老光头_ME2CS的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/786734.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

leetcode 面试题 判定是否互为字符重排

⭐️ 题目描述 🌟 leetcode链接:判定是否互为字符重排 思路: 两个字符串的每个字母和数量都相等。那么 s2 一定可以排成 s1 字符串。 代码: bool CheckPermutation(char* s1, char* s2){char hash1[26] {0};char hash2[26] {…

Python深度学习“四大名著”之一【赠书活动|第二期《Python机器学习:基于PyTorch和Scikit-Learn》】

近年来,机器学习方法凭借其理解海量数据和自主决策的能力,已在医疗保健、 机器人、生物学、物理学、大众消费和互联网服务等行业得到了广泛的应用。自从AlexNet模型在2012年ImageNet大赛被提出以来,机器学习和深度学习迅猛发展,取…

不知道零基础小白拥有一个黑客梦有没有机会能够实现

01.简单了解一下网络安全 说白了,网络安全就是指网络系统中的数据受到保护不被破坏。而我们从事网络信息安全工作的安全工程师,主要工作当然是设计程序来维护网络安全了。 网络安全工程师是一个统称,还包含很多职位,像安全产品工…

【代码随想录day19】从前序与中序遍历序列构造二叉树

题目 思路 使用递归建树,流程如下: 取出后序节点创建新树的节点 找到新树的节点在中序中的索引 分割中序序列 分割后序序列 继续递归建立整颗新树 # Definition for a binary tree node. # class TreeNode: # def __init__(self, val0, leftN…

spring-cloud-alibaba——nacos-server搭建

前言:组件版本关系,官方:组件版本关系 1,nacos-server搭建(windows环境),下载地址nacos 选择对应的版本,这里以目前最新版2.2.3为例子,下载后解压 单机模式 修改\nacos-server-2.2.3\nacos\bin\startup.c…

【ribbon】Ribbon的使用与原理

负载均衡介绍 负载均衡(Load Balance),其含义就是指将负载(工作任务)进行平衡、分摊到多个操作单元上进行运行,例如FTP服务器、Web服务器、企业核心应用服务器和其它主要任务服务器等,从而协同…

【全方位解析】如何获取客户端/服务端真实 IP

一、应用场景 1.比如在投票系统开发中,为了防止刷票,我们需要限制每个 IP 地址只能投票一次 2.当网站受到诸如 DDoS(Distributed Denial of Service,分布式拒绝服务攻击)等攻击时,我们需要快速定位攻击者…

星火汇聚丨高效行动,决胜2023!

7月6日至10日,数字韧性领域标杆企业同创永益营销体系大会在长沙召开,闪耀在全国各地的销售之星集结汇聚。本次大会历时四天,包含营销全员大会、各行业业委会专场会议、销售大比武等业务实践议程,以及飞盘竞技、走进毛泽东故居韶山…

【收藏】用Vue.js来构建你的Web3应用,就像开发 Web2 一样熟悉

作为一名涉足去中心化网络的前端 JavaScript 开发人员,您可能遇到过许多 Web3 开发解决方案。但是,这些解决方案通常侧重于钱包集成和交易执行,这就造成了学习曲线,偏离了熟悉的 Web2 开发体验。 但不用担心!有一种解…

ip、域名、DNS、CDN概念

1、概念 ip地址 在网络世界里, 一台服务器或者说一台网络设备对应着一个ip地址, 如果我们需要访问指定的网络设备的资源, 那么我们就需要知道这个ip地址, 然后才能去访问它. 这就好像, 我想去朋友家里, 我必须先知道他家的住址, 才能去拜访它. 在互联网世界中, 所有的通信都是…

Docker数据管理与Dockerfile

目录 Docker 的数据管理 1.数据卷 2.数据卷容器 端口映射 容器互联(使用centos镜像) Docker 镜像的创建 1.基于现有镜像创建 2.基于本地模板创建 3.基于Dockerfile 创建 联合文件系统…

亚马逊会员日过后站内站外怎么做?

在亚马逊的会员日活动中,众多品牌商家都参与了进来,通过优惠力度和活动策划提高了销售额。但是,会员日过后,如何保持销售增长和用户粘性,需要品牌商家在站内和站外进行策略优化。 一、站内优化 1、提高产品质量的同时…

【Nodejs】Express模板使用

1.Express脚手架的安装 安装Express脚手架有两种方式: 使用express-generator安装 使用命令行进入项目目录,依次执行: cnpm i -g express-generator可通过express -h查看命令行的指令含义 express -hUsage: express [options] [dir] Optio…

28.JavaWeb-Elasticsearch

1.Elasticsearch概述 Elasticsearch 是一个分布式的全文检索引擎。采用Java语言开发,基于Apache协议的开源项目,具有实时搜索,稳定,可靠,快速的特点。 1.1 全文检索引擎 分为通用搜索引擎(百度、谷歌&…

苹果发布安全更新,修复了今年第11个零日漏洞!

苹果公司发布了安全更新,修复针对 iPhone、Mac 和 iPad 的零日漏洞。 苹果公司在一份公告中描述了一个 WebKit 漏洞,该漏洞被标记为 CVE-2023-37450,已在本月初的新一轮快速安全响应 (RSR) 更新中得到解决。 本次修补的另一个零日漏洞是一个…

CAD中让时间日期自动填写的方法

图纸的图签中,通常会有一栏是出图日期。有的单位,也会叫做版本号。即哪天出的图。一般情况下,出图日期就是打图当天。 在这样的前期下,图纸由于存在频繁修改,所以出图日期也会存在变化。还有一种情况,就是出…

(四)FLUX语法

以下内容来自 尚硅谷,写这一系列的文章,主要是为了方便后续自己的查看,不用带着个PDF找来找去的,太麻烦! 第 4 章 FLUX语法 4.1 认识FLUX语言 1、Flux是一种函数式的数据脚本语言,它旨在将查询、处理、分…

Docker 网络和资源限制

Docker 网络 一、Docker 网络的概念1、Docker 网络实现原理2、查看容器的输出和日志信息3、Docker 的网络模式:4、容器的网络模式 二、网络模式详解1、host模式2、container模式3、none模式4、bridge模式5、自定义网络 三、资源控制1、CPU 资源控制(1&am…

MQ - 闲聊MQ一二事儿 (Kafka、RocketMQ 、Pulsar )

文章目录 MQ的发展史阶段一:追求解耦阶段二:追求吞吐量与一致性阶段三:追求平台化 MQ的通用架构主题topic、生产者producer、消费者consumer分区partition MQ 存储KafkaGood Design ---> 磁盘顺序写盘Poor Impact---> topic 数量不能过…

Java Spring和Spring集成Mybatis

0目录 1.Spring 2.Spring集成Mybatis 1.Spring 特性 IOC:控制反转 AOP:面向切面 Spring组成部分 在SMM中起到的作用(粘合剂) Spring理念 OOP核心思想【万物皆对象】 Spring核心思想【万物皆Bean组件】 Spring优势 低侵入式 …