80. 循环神经网络的简洁实现

news2024/11/25 16:00:46

虽然从零开始实现循环神经网络对了解循环神经网络的实现方式具有指导意义,但并不方便。 本节将展示如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。 我们仍然从读取时光机器数据集开始。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

1. 定义模型

高级API提供了循环神经网络的实现。 我们构造一个具有256个隐藏单元的单隐藏层的循环神经网络层rnn_layer。 事实上,我们还没有讨论多层循环神经网络的意义。 现在仅需要将多层理解为一层循环神经网络的输出被用作下一层循环神经网络的输入就足够了。

num_hiddens = 256
# pytorch定义的RNN中,输入输出是len(vocab),隐藏层有num_hiddens个隐藏单元
rnn_layer = nn.RNN(len(vocab), num_hiddens)

我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)

state = torch.zeros((1, batch_size, num_hiddens))
state.shape

运行结果:

在这里插入图片描述

通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
# ps:这里Y并不是输出,而是最后一个隐藏层,所以这里的维度是256,而不是len(vocab)
# Y的形状是(num_steps,batch_size,num_hiddens)
# 是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。
Y, state_new = rnn_layer(X, state)
# state的形状是(隐藏层数,批量大小,隐藏单元数)
Y.shape, state_new.shape

与从零实现类似, 我们为一个完整的循环神经网络模型定义了一个RNNModel类。 注意,rnn_layer只包含隐藏的循环层,我们还需要创建一个单独的输出层

class RNNModel(nn.Module):
    """循环神经网络模型"""
    def __init__(self, rnn_layer, vocab_size, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        self.rnn = rnn_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.rnn.hidden_size
        # 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1
        if not self.rnn.bidirectional:
            self.num_directions = 1
            # 创造单独的输出层
            self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)

    def forward(self, inputs, state):
        X = F.one_hot(inputs.T.long(), self.vocab_size) # one-hot编码
        X = X.to(torch.float32)
        # 得到的Y是中间的隐藏状态,形状是(时间步数,批量大小,隐藏单元数)
        Y, state = self.rnn(X, state)
        # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)
        # output的输出形状是(时间步数*批量大小,词表大小)。
        output = self.linear(Y.reshape((-1, Y.shape[-1])))
        return output, state

    def begin_state(self, device, batch_size=1):
        if not isinstance(self.rnn, nn.LSTM):
            # nn.GRU以张量作为隐状态
            return  torch.zeros((self.num_directions * self.rnn.num_layers,
                                 batch_size, self.num_hiddens),
                                device=device)
        else:
            # nn.LSTM以元组作为隐状态
            return (torch.zeros((
                self.num_directions * self.rnn.num_layers,
                batch_size, self.num_hiddens), device=device),
                    torch.zeros((
                        self.num_directions * self.rnn.num_layers,
                        batch_size, self.num_hiddens), device=device))

2. 训练与预测

在训练模型之前,让我们基于一个具有随机权重的模型进行预测

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)

运行结果:

在这里插入图片描述

很明显,这种模型根本不能输出好的结果。 接下来,我们使用从零实现的代码中定义的超参数调用train_ch8,并且使用高级API训练模型

num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

运行结果:

在这里插入图片描述

可以看出,速度更快,原因是:从零开始实现的时候是一堆小矩阵的乘法做的。而用框架实现可以把小矩阵乘法变成大矩阵乘法,就把小矩阵concat成大矩阵,然后做一次矩阵乘法就出去了。因为在同样的计算量下,多次小矩阵乘法的开销要大于一次大矩阵乘法。

与上一节相比,由于深度学习框架的高级API对代码进行了更多的优化, 该模型在较短的时间内达到了较低的困惑度。

3. Q&A

Q1: num_steps是什么?

A1:给定一个长为num_steps的序列,要去一次预测之后每一个长为num_steps的序列。一个序列长度为num_steps,所以要做num_steps次分类。

Q2:为什么是批量大小 x 时间长度?

A2: 因为每一个批量的每一个样本的长度是T(也就是时间长度),然后实际上是要做T次分类。如果从多分类的角度来讲,给一个小批量,其实要做的分类次数是批量大小 x 时间长度,就是对任何一个样本在任何一个时间点都要做一次分类,所以等价于一个小批量中有 批量大小 x 时间长度个样本需要进行分类。

Q3:H是每个step都在变化吗?

A3:是的,H是在每个时间维度变化的。如果一个batch里面,有batch_size个样本,一个样本长度是T的话,那么在batch里面H会被更新T x batch_size次。但是根据采样方式的不一样,当前batch的H要不要丢给下一个batch,还是直接砍断了,需要根据下一个batch和前一个batch的序列是否接在一起来判断。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/169209.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【SpringCloud13】SpringCloud Config分布式配置中心

1.概述 1.1 分布式系统面临的配置问题 微服务意味着要将单体应用中的业务拆分成一个个子服务,每个服务的粒度相对较小,因此系统中会出现大量的服务。由于每个服务都需要必要的配置信息才能运行,所以一套集中式的、动态的配置管理设施是必不…

PointNext论文解读

论文地址:https://arxiv.org/abs/2206.04670 github地址:GitHub - guochengqian/PointNeXt: [NeurIPS22] PointNeXt: Revisiting PointNet with Improved Training and Scaling Strategies 本文主要提出优化PointNet的两大关键点. 1) 好的训练策略 2…

如何搭建一个专业的知识库

当客户跟你达成合作关系后,需要持续的关系维护,在一定的销售点,定期和客户沟通,据调查,赢得一个新客户的成本可能是保留一个现有客户的5到25倍,作为营销策略,客户服务支持必须满足他们的期望。建…

[BJDCTF2020]Easy MD5(浅谈PHP弱类型hash比较缺陷)

目录 信息收集 构造payload PHP弱类型hash比较缺陷 0e碰撞 数组MD5 总结 信息收集 看题目应该和MD5加密相关 select * from admin where passwordmd5($pass,true) PHP的MD5函数 string必需。规定要计算的字符串。raw 可选。规定十六进制或二进制输出格式: …

2023-01-17 PostgreSQL 并行查询概述

简介: 大数据时代,人们使用数据库系统处理的数据量越来越大,请求越来越复杂,对数据库系统的大数据处理能力和混合负载能力提出更高的要求。PostgreSQL 作为世界上最先进的开源数据库,在大数据处理方面做了很多工作&…

详谈ORB-SLAM2的单目初始化器Initializer

单目初始化器Initializer类,这个类只用于单目初始化,因为这是ORB-SLAM里遗留的一个类,也是祖传代码,双目和RGBD相机只需要一帧就能初始化,因为双目和RGBD相机拍到的点都是有信息的,但是单目相机就不一定了&…

六种方法在云平台和远程桌面中使用Kali

一、说明 本篇主要介绍方便在云服务器,或者以远程桌面(GUI)形式使用kali配置教程,帮助渗透更加方便顺利。 二、方法 2.1 方法一 云服务提供商预装 备注:预算充足,可以首考虑此方法 优点: 云服…

java 探花交友项目实战 day3 完善个人信息 阿里云OSS文件存储 百度人脸识别

完善用户信息 业务概述 阿里云OSS Data ConfigurationProperties(prefix "tanhua.oss") public class OssProperties { private String accessKey; private String secret; private String bucketName; private String url; //域名 private Strin…

微分方程的特征值解法:斯图姆-刘维尔方程

一.基础概念 前置:福克斯定理和奇点理论 常点的级数解 奇异点的级数解 则至少存在一个如下形式的解(弗罗贝尼乌斯级数): 19世纪中期,常微分方程的研究到了新的阶段,存在定理和斯图姆-刘维尔理论都假设微分方程区域内含解析函数或至少包含连续函数,而另一方面,以前研究…

东莞注塑MES管理系统具有哪些功能

伴随着人们对于物质生活的品质要求越来越高,日用品、医疗保健、汽车工业、电子行业、新能源、家电、包装行业以及建筑等行业对注塑产品的需求量日益突出。注塑企业提供的各种各样的塑料产品已渗透到经济生活的各个领域,为国家经济的各个部门包括轻工业和…

ARM SD卡启动详解

一、主流的外存设备介绍 内存和外存的区别:一般是把这种 RAM(random access memory,随机访问存储器,特点是任意字节读写,掉电丢失)叫内存,把 ROM(read only memory,只读存储器,类似…

15子空间投影

子空间投影 从向量的投影入手,延伸到高维投影,并将投影使用矩阵形式给出。做投影也即向另一个向量上做垂线。上一章讨论的Axb无解时的最优解求解时,并没有解释这个最优解为何“最优”,本节课给出相应的解释。相对简单的二维空间的…

MyBatis -- resultType 和 resultMap

MyBatis -- resultType 和 resultMap一、返回类型&#xff1a;resultType二、返回字典映射&#xff1a;resultMap一、返回类型&#xff1a;resultType 绝⼤数查询场景可以使用 resultType 进⾏返回&#xff0c;如下代码所示&#xff1a; <select id"getNameById"…

企业如何借助制造业ERP系统,做好生产排产管理?

随着市场竞争越来越激烈&#xff0c;生产制造行业订单零碎化趋势越发突出。面对品种多&#xff0c;数量小&#xff0c;批次多&#xff0c;个性化需求也多的生产方式&#xff0c;PMC生产排产管理变得非常困难&#xff1b;同时生产过程还会有各种不确定的临时性因素出现&#xff…

详解pandas的read_csv函数

一、官网参数 pandas官网参数网址&#xff1a;pandas.read_csv — pandas 1.5.2 documentation 如下所示&#xff1a; 二、常用参数详解 1、filepath_or_buffer(文件) 一般指读取文件的路径。比如读取csv文件。【必须指定】 import pandas as pddf_1 pd.read_csv(r"C:…

Xilinx FPGA电源设计与注意事项

1 引言随着半导体和芯片技术的飞速发展&#xff0c;现在的FPGA集成了越来越多的可配置逻辑资源、各种各样的外部总线接口以及丰富的内部RAM资源&#xff0c;使其在国防、医疗、消费电子等领域得到了越来越广泛的应用。当采用FPGA进行设计电路时&#xff0c;大多数FPGA对上电的电…

软件测试复习06:基于经验的测试

作者&#xff1a;非妃是公主 专栏&#xff1a;《软件测试》 个性签&#xff1a;顺境不惰&#xff0c;逆境不馁&#xff0c;以心制境&#xff0c;万事可成。——曾国藩 文章目录软件缺陷基于缺陷分类的测试缺陷模式探索性测试软件缺陷 主要由以下几种原因造成&#xff1a; 疏…

Redux相关知识(什么是redux、redux的工作原理、redux的核心概念、redux的基本使用)(十一)

系列文章目录 第一章&#xff1a;React基础知识&#xff08;React基本使用、JSX语法、React模块化与组件化&#xff09;&#xff08;一&#xff09; 第二章&#xff1a;React基础知识&#xff08;组件实例三大核心属性state、props、refs&#xff09;&#xff08;二&#xff0…

Arduino 开发ESP8266(ESP12F)模块

①ESP12F模块的硬件说明如上图所示&#xff0c;其他引脚均引出。②准备好硬件之后就是要下载Arduino IDE&#xff0c;目前版本为2.0.3&#xff0c;下载地址为&#xff1a;https://www.arduino.cc/en/software&#xff0c;如下图所示③安装Arduino IDE较为简单&#xff0c;安装之…

aws cloudformation 在堆栈中使用 waitcondition 协调资源创建和相关操作

参考资料 https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/using-cfn-waitcondition.htmlhttps://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-properties-waitcondition.html 本文介绍cloudformation的waitcondition条件&#xff0c;wait…