《动手学深度学习 Pytorch版》 8.6 循环神经网络的简洁实现

news2024/10/5 16:22:54
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

8.6.1 定义模型

num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)
state = torch.zeros((1, batch_size, num_hiddens))
state.shape  # (隐藏层数,批量大小,隐藏单元数)
torch.Size([1, 32, 256])

通过一个隐状态和一个输入可以用更新后的隐状态计算输出。

需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算:它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape
(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))
#@save
class RNNModel(nn.Module):
    """循环神经网络模型"""
    def __init__(self, rnn_layer, vocab_size, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        self.rnn = rnn_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.rnn.hidden_size
        # 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1
        if not self.rnn.bidirectional:
            self.num_directions = 1
            self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)

    def forward(self, inputs, state):
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        X = X.to(torch.float32)
        Y, state = self.rnn(X, state)
        # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)
        # 它的输出形状是(时间步数*批量大小,词表大小)。
        output = self.linear(Y.reshape((-1, Y.shape[-1])))
        return output, state

    def begin_state(self, device, batch_size=1):
        if not isinstance(self.rnn, nn.LSTM):
            # nn.GRU以张量作为隐状态
            return  torch.zeros((self.num_directions * self.rnn.num_layers,
                                 batch_size, self.num_hiddens),
                                device=device)
        else:
            # nn.LSTM以元组作为隐状态
            return (torch.zeros((
                self.num_directions * self.rnn.num_layers,
                batch_size, self.num_hiddens), device=device),
                    torch.zeros((
                        self.num_directions * self.rnn.num_layers,
                        batch_size, self.num_hiddens), device=device))

8.6.2 训练与预测

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)
'time travellerffffffffff'
num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)  # 比自己写的跑得快
perplexity 1.3, 213489.4 tokens/sec on cuda:0
time traveller held in his han so withtre scon the thin one mige
travellericho for the prof read haly and hes it nople hat d

在这里插入图片描述

练习

(1)尝试使用高级API,能使循环神经网络模型过拟合吗?

略。


(2)如果在循环神经网络模型中增加隐藏层的数量会发生什么?能使模型正常工作吗?

num_hiddens1 = 1024
rnn_layer1 = nn.RNN(len(vocab), num_hiddens1)

net1 = RNNModel(rnn_layer1, vocab_size=len(vocab))
net1 = net1.to(device)

num_epochs, lr = 500, 1
d2l.train_ch8(net1, train_iter, vocab, lr, num_epochs, device)  # 效果更好了,但是曲线没那么平滑了
perplexity 1.0, 97329.8 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby

在这里插入图片描述


(3)尝试使用循环神经网络实现 8.1 节的自回归模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1093523.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

1 两数之和

解题思路&#xff1a; \qquad 对每个数nums[i]&#xff0c;仅需在数组中搜索target-nums[i]是否存在。 优化思路&#xff1a; \qquad 首先能想到&#xff0c;利用哈希表O(1)查询target-nums[i]。 \qquad 建立map<int, vector<int>>的表能够处理重复元素&#x…

基于Eigen的位姿转换

位姿中姿态的表示形式有很多种&#xff0c;比如&#xff1a;旋转矩阵、四元数、欧拉角、旋转向量等等。这里基于Eigen实现四种数学形式的相互转换功能。本文利用Eigen实现上述四种形式的相互转换。我这里给出一个SE3&#xff08;4*4&#xff09;(先平移、再旋转)的构建方法&…

Ubuntu - 安装Docker

在Ubuntu上安装Docker分为以下几个步骤&#xff1a; 更新包列表&#xff1a; sudo apt update 安装依赖包&#xff0c;以便允许apt使用HTTPS&#xff1a; sudo apt install apt-transport-https ca-certificates curl software-properties-common 添加Docker官方GPG密钥&a…

在命令行下使用Apache Ant

Apache Ant的帮助文档 离线帮助文档 在<ant的安装目录>/manual下是离线帮助文档 双击index.html可以看到帮助文档的内容&#xff1a; 在线帮助文档 最新发布版本的帮助文档https://ant.apache.org/manual/index.html Apache Ant的命令 ant命令行格式 ant [opt…

在 Windows 平台上启动 MATLAB

目录 在 Windows 平台上启动 MATLAB 选择 MATLAB 图标 从 Windows 系统命令行调用 matlab 从 MATLAB 命令提示符调用 matlab 打开与 MATLAB 相关联的文件 从 Windows 资源管理器工具中选择 MATLAB 可执行文件 在 Windows 平台上启动 MATLAB 选择以下一种方式启动 MATLAB…

6.串口、时钟

预备知识 CC2530在正常运行的时候需要一个高频时钟信号和一个低频的时钟信号 高频时钟信号&#xff0c;主要供给CPU&#xff0c;保证程序的运行。 低频时钟信号&#xff0c;主要供给看门狗、睡眠定时器等偏上外设。 CC2530时钟信号的来源&#xff1a; 高频信号有2个&#xff0…

【C++】:初阶模板

朋友们、伙计们&#xff0c;我们又见面了&#xff0c;本期来给大家解读一下有关Linux的基础知识点&#xff0c;如果看完之后对你有一定的启发&#xff0c;那么请留下你的三连&#xff0c;祝大家心想事成&#xff01; C 语 言 专 栏&#xff1a;C语言&#xff1a;从入门到精通 数…

Stirling-PDF:一款优秀的开源PDF处理工具

最近我的朋友大雄需要将一个PDF转换为Word文档。于是他在网上尝试了多个PDF转换的在线工具&#xff0c;但要么需要会员&#xff0c;要么需要登录等繁琐操作&#xff0c;而且我们的文件也存在泄漏等安全隐患。因此&#xff0c;他向我咨询是否有可私有化部署且易于使用的PDF在线工…

字符函数和字符串函数2(C语言进阶)

字符函数和字符串函数2 三.长度受限制的字符串函数介绍1.strncpy2.strncat3.strncmp 四.字符串查找1.strstr2.strtok 五.错误信息报告1.strerror 六.字符操作七.内存操作函数1.memcpy2.memmove3.memset4.memcmp 三.长度受限制的字符串函数介绍 1.strncpy char * strncpy ( ch…

8.简易无线通信

预备知识 Zigbee无线通信&#xff0c;需要高频的载波来提供发射效率&#xff0c;Zigbee模块之间要可以正常的收发&#xff0c;接收模块必须把接收频率设置和发射模块的载波频率一致。Zigbee有27个载波可以进行通信&#xff0c;载波叫做信道&#xff08;无线通信的通道&#xf…

UE4 EQS环境查询 学习笔记

EQS环境查询对应Actor的范围 EQS环境查询查询对应的类 查询到即有一个蓝色的球在Actor上&#xff0c;里面有位置信息等等 在行为树运行EQS&#xff0c;按键&#xff08;‘&#xff09;可以看到Player的位置已经被标记 运行对应的EQS在这里放如EQS就可以了 Generated Point&…

2023年中国分布式光纤传感产量、需求量及行业市场规模分析[图]

分布式光纤传感器中的光纤能够集传感、传输功能于一体&#xff0c;能够完成在整条光纤长度上环境参量的空间、时间多维连续测量&#xff0c;具有结构简单、易于布设、性价比高、易实现长距离等独特优点&#xff0c;常用的分布式光纤传感器有光时域反射仪、布里渊分析仪、喇曼反…

【AI视野·今日Robot 机器人论文速览 第五十四期】Fri, 13 Oct 2023

AI视野今日CS.Robotics 机器人学论文速览 Fri, 13 Oct 2023 Totally 45 papers &#x1f449;上期速览✈更多精彩请移步主页 Interesting: &#x1f4da;AI与机器人安全, 从攻击界面、伦理法律和人机交互层面进行了论述。(from 密西西比大学) &#x1f4da;机器人与图机器学…

Windows端口号被占用的查看方法及解决办法

Windows端口号被占用的查看方法及解决办法 Error starting ApplicationContext. To display the conditions report re-run your application with debug enabled. 2023-10-14 22:58:32.069 ERROR 6488 --- [ main] o.s.b.d.LoggingFailureAnalysisReporter : ***…

Qt 布局(QLayout 类QStackedWidget 类) 总结

一、QLayout类(基本布局) QLayout类是Qt框架中用于管理和排列QWidget控件的布局类。它提供了一种方便而灵活的方式来自动布局QWidget控件。QLayout类允许您以一种简单的方式指定如何安排控件&#xff0c;并能够自动处理控件的位置和大小&#xff0c;以使其适应更改的父窗口的大…

【HCIA】静态路由综合实验

实验要求&#xff1a; 1、R6为ISP&#xff0c;接口IP地址均为公有地址&#xff0c;该设备只能配置IP地址之后不能再对其进行任何配置 2、R1-R5为局域网&#xff0c;私有IP地址192.168.1.0/24&#xff0c;请合理分配 3、R1、R2、R4&#xff0c;各有两个环回IP地址;R5,R6各有一…

基于 Kubernetes 的 Serverless PaaS 稳定性建设万字总结

作者&#xff1a;许成铭&#xff08;竞霄&#xff09; 数字经济的今天&#xff0c;云计算俨然已经作为基础设施融入到人们的日常生活中&#xff0c;稳定性作为云产品的基本要求&#xff0c;研发人员的技术底线&#xff0c;其不仅仅是文档里承诺的几个九的 SLA 数字&#xff0c…

MyBatis的缓存,一级缓存,二级缓存

10、MyBatis的缓存 10.1、MyBatis的一级缓存 一级缓存是SqlSession级别的&#xff0c;通过同一个SqlSession对象 查询的结果数据会被缓存&#xff0c;下次执行相同的查询语句&#xff0c;就 会从缓存中&#xff08;缓存在内存里&#xff09;直接获取&#xff0c;不会重新访问…

c++string类的赋值问题

来看问题&#xff1a; 为什么呢&#xff1f;是因为定义string a""时候a没有占用空间&#xff0c;所以没有a[0],a[1],a[3]。如果说string a"hhhhhh"&#xff0c;那么图中a[0],a[1],a[3]就有效了。正确的做法是用连接&#xff0c;或者是定义时写成string a(6…

爬虫 | 正则、Xpath、BeautifulSoup示例学习

文章目录 &#x1f4da;import requests&#x1f4da;import re&#x1f4da;from lxml import etree&#x1f4da;from bs4 import BeautifulSoup&#x1f4da;小结 契机是课程项目需要爬取一份数据&#xff0c;于是在CSDN搜了搜相关的教程。在博主【朦胧的雨梦】主页学到很多…