循环神经网络-简洁实现

news2024/9/30 21:24:17

参考:
https://zh-v2.d2l.ai/chapter_recurrent-neural-networks/rnn-concise.html
https://pytorch.org/docs/stable/generated/torch.nn.RNN.html?highlight=rnn#torch.nn.RNN

RNN

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

batch_size, num_steps = 32, 35  # num_steps: sequence length
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps) #  vocab:Vocab 26

# 1 定义模型
# 构造一个具有256个隐藏层的循环神经网络 rnn_layer
# 此处先仅设计一层循环神经网络,以后讨论多层神经网络
num_hiddens = 256
rnn_layer = nn.RNN(len(vocab),num_hiddens) # RNN(28,256)
"""input_size – The number of expected features in the input x
hidden_size – The number of features in the hidden state h
num_layers – Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final results. Default: 1
nonlinearity – The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh'
bias – If False, then the layer does not use bias weights b_ih and b_hh. Default: True
batch_first – If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature). Note that this does not apply to hidden or cell states. See the Inputs/Outputs sections below for details. Default: False
dropout – If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. Default: 0
bidirectional – If True, becomes a bidirectional RNN. Default: False
"""
# 2.我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)
state = torch.zeros((1,batch_size,num_hiddens))
print(state.shape)  #(torch.size([1,32,256]))

#3. 通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。
# 需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。
X=torch.rand(size=(num_steps,batch_size,len(vocab)))  #torch.Size([35, 32, 28])   # (L,N,H(in)) L:sequence length  N batch size Hin: input_size
Y,state_new = rnn_layer(X,state)
print(Y.shape,state_new.shape) #torch.Size([35, 32, 256]) torch.Size([1, 32, 256])

class RNNModel(nn.Module):
    """循环神经网络"""
    def __init__(self,rnn_layer,vocab_size,**kwargs):
        super(RNNModel,self).__init__(**kwargs)
        self.rnn = rnn_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.rnn.hidden_size
        # 如果RNN是双向的,num_directions 应该是2,否则应该是1
        if not self.rnn.bidirectional:
            self.num_directions = 1
            self.linear = nn.Linear(self.num_hiddens,self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens*2,self.vocab_size)

    def forward(self,inputs,state):
        X = F.one_hot(inputs.T.long(),self.vocab_size)
        X = X.to(torch.float32)
        Y,state = self.rnn(X,state)

        # 全连接首层将Y的形状改为(时间步数*批量大小,隐藏单元数)
        output = self.linear(Y.reshape((-1,Y.shape[-1])))
        return output,state

    def begin_state(self, device, batch_size=1):
        if not isinstance(self.rnn, nn.LSTM):
            # nn.GRU以张量作为隐状态
            return  torch.zeros((self.num_directions * self.rnn.num_layers,
                                 batch_size, self.num_hiddens),
                                device=device)
        else:
            # nn.LSTM以元组作为隐状态
            return (torch.zeros((
                self.num_directions * self.rnn.num_layers,
                batch_size, self.num_hiddens), device=device),
                    torch.zeros((
                        self.num_directions * self.rnn.num_layers,
                        batch_size, self.num_hiddens), device=device))

# 训练
device = d2l.try_gpu()
net = RNNModel(rnn_layer,vocab_size=len(vocab))
net = net.to(device)
num_epochs ,lr = 500,1
d2l.train_ch8(net,train_iter,vocab,lr,num_epochs,device)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1025195.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

安防监控视频云存储平台EasyNVR对接EasyNVS时,一直不上线该如何解决?

视频安防监控平台EasyNVR可支持设备通过RTSP/Onvif协议接入,并能对接入的视频流进行处理与多端分发,包括RTSP、RTMP、HTTP-FLV、WS-FLV、HLS、WebRTC等多种格式。 近期有用户在使用安防视频平台EasyNVR对接上级平台EasyNVS时,出现了一直不上线…

文档丢失怎么找回?学会这3个方法就足够!

场景1:“不是吧!我辛辛苦苦写的文档好像忘记保存就退出了!谁能救救我!帮我找回丢失的文档?” 场景2:“电脑里的文档太多了,每次在清理时都容易误删。有什么方法可以找回我丢失的文档吗&#xff…

iPhone密码忘了怎么办?这3招已足矣

很急!之前改了手机密码,现在完全想不起来。该试的数字也都试过了,根本没用,求一个能解锁iPhone手机的方法!感谢! iPhone手机的锁屏密码是一个保护用户隐私的功能。如果没有锁屏密码给手机上一道“锁”&…

JavaScript系列从入门到精通系列第二篇:JavaScript书写位置、注释和结束符

文章目录 一:JavaScript书写位置 1:CSS书写位置 (一):行内样式表 (二):内部样式表 (三):外部样式表 2:Js书写位置 (一):行内样式表 (二):内部样式表 (三):外部样…

Cesium 地球(1)-概览

​ 参考: CesiumJS 2022^ 源码解读[4] - 最复杂的地球皮肤 影像与地形的渲染与下载过程 Cesium 地球(1)-概览 相关类的从属关系: 地球由 影像数据,和地形数据共同组成。 流程概览: // Scene.jsfunction render() {// ① 更新影像图层的可见性globe.update();/…

java-decompiler

Java Decompiler GitHub F:\Document_JD-GUI\jd-gui-windows-1.4.0

73家央国企专场培训|第38期信创专业人员-精华班在京成功举办

9月8日-10日,由太极计算机股份有限公司-太极信创研习院(以下简称“太极股份”)主办,北京慧点科技有限公司协办的“信息技术应用创新专业人员(ITAIP)-第38期信创精华班(央国企专场培训)”在北京市…

[游戏开发][Shader]ShaderToy通用模板转Unity-CG语言

这个通用模板貌似是Candcat写的,漏了几个宏定义,我这给补一下,例如: #define iTime _Time.y #define atan atan2 对照表如下 代码如下 Shader "Shadertoy/Template" {Properties{iMouse("Mouse Pos", Vec…

恩智浦为稳固地位,将扩大投资4国家 | 百能云芯

车用芯片制造商恩智浦,今天宣布了一项重大计划,旨在进一步深耕欧洲市场。该公司将利用欧洲微电子和通信技术共同利益重点计划(IPCEI ME/CT)的支持,加强其在奥地利、德国、荷兰和罗马尼亚的研发能力,并将根据…

MySQL数据库详解 五:用户管理

文章目录 1. 数据库的用户管理1.1 新建用户1.2 重命名用户1.3 删除用户1.4 修改用户密码1.5 忘记用户密码的解决方法1.6 数据库用户授权1.6.1 授权用户权限类别1.6.2 添加权限1.6.2 撤销权限 2. mysql命令 1. 数据库的用户管理 1.1 新建用户 create user 用户名来源地址 [ide…

性能测试必备知识-使用MySQL存储过程构造大量数据:实例解析

在软件开发过程中,测试是一个不可或缺的环节。通过测试,我们可以发现并修复软件中的各种问题,提高软件的质量和稳定性。然而,手动编写大量的测试用例是一项耗时且容易出错的任务。为了解决这个问题,我们需要学会使用批…

一文了解线上展厅设计与搭建要点,线上展厅有哪些应用

引言: 线上展厅已经成为了现代营销领域中不可或缺的一部分。通过巧妙的设计与搭建,企业可以与潜在客户建立更深入的联系,提高品牌知名度,从而提高商务成交量。 一、线上展厅设计要点 线上展厅的设计是关键的一步,因为…

架构师面试必备:高并发限流算法全攻略

Hello大家好,我是小米!今天我要和大家聊一聊一个在技术面试中经常被问到的问题——高并发限流算法!这个话题非常有趣,也是我们在日常工作中经常会碰到的挑战之一。在本文中,我将详细介绍一些常见的高并发限流算法&…

无涯教程-JavaScript - SUMIF函数

描述 您可以使用SUMIF函数对满足指定条件的范围内的值求和。 语法 SUMIF (range, criteria, [sum_range])争论 Argument描述Required/Optionalrange 您要通过条件判断的单元格范围。 每个范围中的单元格必须是数字或包含数字的名称,数组或引用。 空白和文本值将被忽略。 所…

AMEYA360:村田土壤传感器新增功能

村田制作所新增了土壤传感器功能,除了以前的普通土壤外,还可对人工培养土岩棉、椰糠进行测量。 近年来,对番茄、草莓等农作物广泛使用配制营养土岩棉及椰糠等人工培养土。相较普通培养土,此类培养土的保水力非常高,且难…

面试官:Vue3.0 所采用的 Composition Api 与 Vue2.x 使用的 Options Api 有什么不同?

🎬 岸边的风:个人主页 🔥 个人专栏 :《 VUE 》 《 javaScript 》 ⛺️ 生活的理想,就是为了理想的生活 ! 目录 开始之前 正文 一、Options Api 二、Composition Api 三、对比 逻辑组织 Options API Compostion API 逻辑…

vue select联动 设置filterable坑

需求: 平台改变 获取服务类目List 服务类目改变 获取模板标题List 模板标题改变 获取关键词List 由于模板标题List数据条数较多,因此需要设置可搜索选择 问题:由于模板标题加了【filterable】属性。当服务类目改变时,模板标题要么…

肖sir__mysql之索引__010

mysql之索引 一、什么是索引? 索引是一种数据结构设计 一个索引是存储的表中数据结构; 索引是建立在表字段上, 索引包含了一列值,这个值保存在一个数据结构中 二、索引作用 1、保证数据记录的唯一性 2、实现表与表之间的参照性 3…

openGauss学习笔记-72 openGauss 数据库管理-创建和管理分区表

文章目录 openGauss学习笔记-72 openGauss 数据库管理-创建和管理分区表72.1 背景信息72.2 操作步骤72.2.1 使用默认表空间72.2.1.1 创建分区表(假设用户已创建tpcds schema)72.2.1.2 插入数据72.2.1.3 修改分区表行迁移属性72.2.1.4 删除分区72.2.1.5 增…

java学习--day6(数组)

文章目录 day5作业今天的内容1.数组1.1开发中为啥要有数组1.2在Java中如何定义数组1.3对第二种声明方式进行赋值1.4对数组进行取值1.5二维数组【了解】1.6数组可以当成一个方法的参数【重点】1.7数组可以当成一个方法的返回值1.8数组在内存中如何分配的【了解】 2.数组方法循环…