《动手学深度学习(PyTorch版)》笔记8.6

news2025/1/23 21:13:52

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter8 Recurrent Neural Networks

8.6 Concise Implementation of RNN

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
import matplotlib.pyplot as plt

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

state = torch.zeros((1, batch_size, num_hiddens))
print(state.shape)

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)#Y不涉及输出层的计算
print(Y.shape, state_new.shape)

class RNNModel(nn.Module):#@save
    """循环神经网络模型"""
    def __init__(self, rnn_layer, vocab_size, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        self.rnn = rnn_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.rnn.hidden_size
        # 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1
        if not self.rnn.bidirectional:
            self.num_directions = 1
            self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)

    def forward(self, inputs, state):
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        X = X.to(torch.float32)
        Y, state = self.rnn(X, state)
        # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数),它的输出形状是(时间步数*批量大小,词表大小)。
        output = self.linear(Y.reshape((-1, Y.shape[-1])))
        return output, state

    def begin_state(self, device, batch_size=1):
        if not isinstance(self.rnn, nn.LSTM):
            #nn.GRU以张量作为隐状态
            #GRU为门控循环单元(Gated Recurrent Unit),是一种流行的循环神经网络变体。
            #GRU使用了一组门控机制来控制信息的流动,包括更新门(update gate)和重置门(reset gate),以更好地捕捉长期依赖关系
            return  torch.zeros((self.num_directions * self.rnn.num_layers,
                                batch_size, self.num_hiddens),
                                device=device)
        else:
            #nn.LSTM以元组作为隐状态
            #LSTM代表长短期记忆网络(Long Short-Term Memory),是另一种常用的循环神经网络类型。
            #相比于简单的循环神经网络,LSTM引入了三个门控单元:输入门(input gate)、遗忘门(forget gate)和输出门(output gate),以及一个记忆单元(cell state),可以更有效地处理长期依赖性。
            return (torch.zeros((
                self.num_directions * self.rnn.num_layers,
                batch_size, self.num_hiddens), device=device),
                    torch.zeros((
                        self.num_directions * self.rnn.num_layers,
                        batch_size, self.num_hiddens), device=device))
            
device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)

num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)
plt.show()

训练结果:
在这里插入图片描述

与上一节相比,由于pytorch的高级API对代码进行了更多的优化,该模型在较短的时间内达到了较低的困惑度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1446359.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LayoutInflater源码解析及常见相关报错分析

在日常Android开发中,最经常使用的RecyclerView控件是大家都绕不开的,而编写其Adapter时更离不开LayoutInflater的调用。当然,如果你做这一行有些时日了,相信你对其使用一定是炉火纯青了。即使如此,我觉得LayoutInflat…

【C++】STL之string 超详解

目录 1.string概述 2.string使用 1.构造初始化 2.成员函数 1.迭代器 2.容量操作 1.size和length 返回字符串长度 2.resize 调整字符串大小 3.capacity 获得字符串容量 4.reserve 调整容量 5.clear 清除 6.empty 判空 3.string插入、追加 、拼接 1.运算…

LeetCode:67.二进制求和

67. 二进制求和 - 力扣(LeetCode) 又是一道求和题,% / 在求和的用途了解了些, 目录 题目: 思路分析: 博主代码: 官方代码: 每日表情包: 题目: 思路分析&#xf…

第五课[lmdeploy]作业 +第六课[OpenCompass评测]作业

第五课基础作业 如下图,采用api_server部署,并转发端口通过curl提交内容。 第六课基础作业 完了捏?

【Java程序设计】【C00252】基于Springboot的实习管理系统(有论文)

基于Springboot的实习管理系统(有论文) 项目简介项目获取开发环境项目技术运行截图 项目简介 这是一个基于Springboot的实习管理系统 本系统分为前台功能模块、管理员功能模块、教师功能模块、学生功能模块以及实习单位功能模块。 前台功能模块&#xf…

【每日一题】LeetCode——反转链表

📚博客主页:爱敲代码的小杨. ✨专栏:《Java SE语法》 | 《数据结构与算法》 | 《C生万物》 ❤️感谢大家点赞👍🏻收藏⭐评论✍🏻,您的三连就是我持续更新的动力❤️ 🙏小杨水平有…

网络协议与攻击模拟_17HTTPS 协议

HTTPShttpssl/tls 1、加密算法 2、PKI(公钥基础设施) 3、证书 4、部署HTTPS服务器 部署CA证书服务器 5、分析HTTPS流量 分析TLS的交互过程 一、HTTPS协议 在http的通道上增加了安全性,传输过程通过加密和身份认证来确保传输安全性 1、TLS …

C++ 广度优先搜索(bfs)(五十四)【第一篇】

今天我们来学习一下一个新的搜索,广度优先搜索。 1.广度优先搜索的前提 队列(queue) 是一种 操作受限制 的线性表,其限制: 只允许从表的前端(front)进行删除操作; 只允许在表的后端…

基于无线传感器网络的LC-DANSE波束形成算法matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 4.1LC-DANSE算法原理 4.2 LCMV算法原理 5.完整程序 1.程序功能描述 在无线传感器网络中,通过MATLAB对比LC-DANSE波束形成算法和LCMV波束形成算法。对比SNR,mse等指标…

【go语言】一个简单HTTP服务的例子

一、Go语言安装 Go语言(又称Golang)的安装过程相对简单,下面是在不同操作系统上安装Go语言的步骤: 在Windows上安装Go语言: 访问Go语言的官方网站(golang.org)或者使用国内镜像站点&#xff0…

肯尼斯·里科《C和指针》第13章 高级指针话题(2)函数指针

我们不会每天都使用函数指针。但是,它们的确有用武之地,最常见的两个用途是转换表(jump table)和作为参数传递给另一个函数。本节将探索这两方面的一些技巧。但是,首先容我指出一个常见的错误,这是非常重要的。 简单声明一个函数指…

Linux基础I/O(三)——缓冲区和文件系统

文章目录 什么是C语言的缓冲区理解文件系统理解软硬链接 什么是C语言的缓冲区 C语言的缓冲区其实就是一部分内存 那么它的作用是什么? 下面有一个例子: 你在陕西,你远在山东的同学要过生日了,你打算送给他一份生日礼物。你有两种方…

视觉开发板—K210自学笔记(五)

本期我们来遵循其他单片机的学习路线开始去用板子上的按键控制点亮LED。那么第一步还是先知道K210里面的硬件电路是怎么连接的,需要查看第二节的文档,看看开发板原理图到底是按键是跟哪个IO连在一起。然后再建立输入按键和GPIO的映射就可以开始变成了。 …

VTK 常用坐标系 坐标系 转换

1.VTK 常用坐标系 计算机图形学里常用的坐标系统主要有四种,分别是:Model坐标系统、World坐标系统、View坐标系统和Display坐标系统 在VTK里,Model坐标系统用得比较少,其他三种坐标系统经常使用。它们之间的变换则是由类vtkCoord…

Docker容器输入汉字触发自动补全

一、描述 输入汉字自动触发补全: Display all 952 possibilities? (y or n)是因为容器中没有中文字符集和中文字体导致的,安装中文字体,并设置字符集即可。 二、解决 1、安装字符集 (1)查看系统支持的字符集 lo…

甘肃旅游服务平台:技术驱动的创新实践

✍✍计算机编程指导师 ⭐⭐个人介绍:自己非常喜欢研究技术问题!专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目:有源码或者技术上的问题欢迎在评论区一起讨论交流! ⚡⚡ Java实战 |…

【Java程序设计】【C00260】基于Springboot的企业客户信息反馈平台(有论文)

基于Springboot的企业客户信息反馈平台(有论文) 项目简介项目获取开发环境项目技术运行截图 项目简介 这是一个基于Springboot的企业客户信息反馈平台 本系统分为平台功能模块、管理员功能模块以及客户功能模块。 平台功能模块:在平台首页可…

计算机网络——网络安全

计算机网络——网络安全 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家, [跳转到网站](https://www.captainbed.cn/qianqiu) 小程一言专栏链接: [link](http://t.csdnimg.cn/ZUTXU) 网络安全何…

【AI绘图】初见·小白入门stable diffusion的初体验

首先,感谢赛博菩萨秋葉aaaki的整合包 上手 stable diffusion还是挺好上手的(如果使用整合包的话),看看界面功能介绍简单写几个prompt就能生成图片了。 尝试 我在网上找了一张赛博朋克边缘行者Lucy的cos图,可能会侵…

黄金交易策略(Nerve Nnife.mql4):趋势平仓按钮的作用

当觉得行情不太对路,可以点击右下角按钮,实现趋势单的移动止盈(止损)。点了这个按钮回撤多少平仓是可以在参数里设定的。 完整EA:Nerve Knife.ex4黄金交易策略_黄金趋势ea-CSDN博客