【现代深度学习技术】循环神经网络06:循环神经网络的简洁实现

news2025/4/26 1:46:22

在这里插入图片描述

【作者主页】Francek Chen
【专栏介绍】 ⌈ ⌈ PyTorch深度学习 ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重要的技术特征是具有自动提取特征的能力。神经网络算法、算力和数据是开展深度学习的三要素。深度学习在计算机视觉、自然语言处理、多模态数据分析、科学探索等领域都取得了很多成果。本专栏介绍基于PyTorch的深度学习算法实现。
【GitCode】专栏资源保存在我的GitCode仓库:https://gitcode.com/Morse_Chen/PyTorch_deep_learning。

文章目录

    • 一、定义模型
    • 二、训练与预测
    • 小结


  虽然循环神经网络的从零开始实现对了解循环神经网络的实现方式具有指导意义,但并不方便。本节将展示如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。我们仍然从读取时光机器数据集开始。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

一、定义模型

  高级API提供了循环神经网络的实现。我们构造一个具有256个隐藏单元的单隐藏层的循环神经网络层rnn_layer。事实上,我们还没有讨论多层循环神经网络的意义(这将在深度循环神经网络中介绍)。现在仅需要将多层理解为一层循环神经网络的输出被用作下一层循环神经网络的输入就足够了。

num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

  我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)。

state = torch.zeros((1, batch_size, num_hiddens))
state.shape

在这里插入图片描述

  通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算:它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape

在这里插入图片描述

  与循环神经网络的从零开始实现类似,我们为一个完整的循环神经网络模型定义了一个RNNModel类。注意,rnn_layer只包含隐藏的循环层,我们还需要创建一个单独的输出层。

#@save
class RNNModel(nn.Module):
    """循环神经网络模型"""
    def __init__(self, rnn_layer, vocab_size, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        self.rnn = rnn_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.rnn.hidden_size
        # 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1
        if not self.rnn.bidirectional:
            self.num_directions = 1
            self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)

    def forward(self, inputs, state):
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        X = X.to(torch.float32)
        Y, state = self.rnn(X, state)
        # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)
        # 它的输出形状是(时间步数*批量大小,词表大小)。
        output = self.linear(Y.reshape((-1, Y.shape[-1])))
        return output, state

    def begin_state(self, device, batch_size=1):
        if not isinstance(self.rnn, nn.LSTM):
            # nn.GRU以张量作为隐状态
            return  torch.zeros((self.num_directions * self.rnn.num_layers,
                                 batch_size, self.num_hiddens), device=device)
        else:
            # nn.LSTM以元组作为隐状态
            return (torch.zeros((
                self.num_directions * self.rnn.num_layers,
                batch_size, self.num_hiddens), device=device),
                    torch.zeros((
                        self.num_directions * self.rnn.num_layers,
                        batch_size, self.num_hiddens), device=device))

二、训练与预测

  在训练模型之前,让我们基于一个具有随机权重的模型进行预测。

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)

在这里插入图片描述

  很明显,这种模型根本不能输出好的结果。接下来,我们使用循环神经网络的从零开始实现中定义的超参数调用train_ch8,并且使用高级API训练模型。

num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

在这里插入图片描述
在这里插入图片描述

  与上一节相比,由于深度学习框架的高级API对代码进行了更多的优化,该模型在较短的时间内达到了较低的困惑度。

小结

  • 深度学习框架的高级API提供了循环神经网络层的实现。
  • 高级API的循环神经网络层返回一个输出和一个更新后的隐状态,我们还需要计算整个模型的输出层。
  • 相比从零开始实现的循环神经网络,使用高级API实现可以加速训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2342847.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【办公类-89-02】20250424会议记录模版WORD自动添加空格补全下划线

背景需求 4月23日听了一个MJB的征文培训,需要写会议记录 把资料黏贴到模版后,发现每行需要有画满下划线 原来做这套资料,就是手动按空格到一行末,有空格才会出现下划线,也就是要按很多的空格(凑满一行&…

解释器模式:自定义语言解析与执行的设计模式

解释器模式:自定义语言解析与执行的设计模式 一、模式核心:定义语言文法并实现解释器处理句子 在软件开发中,当需要处理特定领域的语言(如数学表达式、正则表达式、自定义配置语言)时,可以通过解释器模式…

AI催生DLP新战场 | 天空卫士连续6年入选Gartner 全球数据防泄漏(DLP)市场指南

“管理数据外泄风险仍然是企业的重大挑战之一,客户处出于各种因素寻求DLP。最近,一些组织对使用DLP控制机器对敏感信息的访问表现出很大兴趣。 随着生成式人工智能(GenAI)的运用和数据的不断扩散,数据外泄的问题变得更…

Adobe After Effects的插件--------Optical Flares之Lens Objects参数

Lens Objects,即【镜头对象】。 通用设置 全局参数发光多光圈光圈条纹微光反射钉球闪光圆环箍焦散镜头球缩放✔✔✔✔✔✔✔✔✔✔✔✔✔缩放偏移✔长宽比✔✔✔✔✔✔✔✔✔✔✔✔✔混合模式✔颜色✔全局种子✔亮度✔✔✔✔✔✔✔✔✔✔✔✔拉伸✔✔✔✔✔✔✔✔✔✔✔✔距离…

【问题】解决docker的方式安装n8n,找不到docker.n8n.io/n8nio/n8n:latest镜像的问题

问题概览 用docker方式安装n8n,遇到错误,安装不了的问题: Unable to find image docker.n8n.io/n8nio/n8n:latest locally docker: Error response from daemon: Get "https://registry-1.docker.io/v2/": net/http: request can…

系统与网络安全------弹性交换网络(1)

资料整理于网络资料、书本资料、AI,仅供个人学习参考。 Trunk原理与配置 Trunk原理概述 Trunk(虚拟局域网中继技术)是指能让连接在不同交换机上的相同VLAN中的主机互通。 VLAN内通信 实现跨交换的同VLAN通信,通过Trunk链路&am…

10天学会嵌入式技术之51单片机-day-3

第九章 独立按键 按键的作用相当于一个开关,按下时接通(或断开),松开后断开(或接通)。实物图、原理图、封装 9.2 需求描述 通过 SW1、SW2、SW3、SW4 四个独立按键分别控制 LED1、LED2、LED3、LED4 的亮…

深入解析微软MarkitDown:原理、应用与二次开发指南

一、项目背景与技术定位 微软开源的MarkitDown并非简单的又一个Markdown解析器,而是针对现代文档处理需求设计的工具链核心组件。该项目诞生于微软内部大规模文档系统的开发实践,旨在解决以下技术痛点: 大规模文档处理性能:能够高…

【PVCodeNet】《Palm Vein Recognition Network Combining Transformer and CNN》

[1]吴凯,沈文忠,贾丁丁,等.融合Transformer和CNN的手掌静脉识别网络[J].计算机工程与应用,2023,59(24):98-109. 文章目录 1、Background and Motivation2、Related Work3、Advantages / Contributions4、Method5、Experiments5.1、Datasets and Metrics5.2、Hyper-parameters5.…

x-cmd install | brows - 终端里的 GitHub Releases 浏览器,告别繁琐下载!

目录 核心功能与优势安装适用场景 还在为寻找 GitHub 项目的特定 Release 版本而苦恼吗?还在网页上翻来覆去地查找下载链接吗?现在,有了 brows,一切都将变得简单高效! brows 是一款专为终端设计的 GitHub Releases 浏览…

多模态知识图谱:重构大模型RAG效能新边界

当前企业级RAG(Retrieval-Augmented Generation)系统在非结构化数据处理中面临四大核心问题: 数据孤岛效应:异构数据源(文档/表格/图像/视频)独立存储,缺乏跨模态语义关联,导致知识检…

实验八 版本控制

实验八 版本控制 一、实验目的 掌握Git基本命令的使用。 二、实验内容 1.理解版本控制工具的意义。 2.安装Windows和Linux下的git工具。 3.利用git bash结合常用Linux命令管理文件和目录。 4.利用git创建本地仓库并进行简单的版本控制实验。 三、主要实验步骤 1.下载并安…

JavaWeb:Web介绍

Web开篇 什么是web? Web网站工作流程 网站开发模式 Web前端开发 初识web Web标准 HtmlCss 什么是Html? 什么是CSS?

教育行业网络安全:守护学校终端安全,筑牢教育行业网络安全防线!

教育行业面临的终端安全问题日益突出,主要源于教育信息化进程的加速、终端设备多样化以及网络环境的开放性。 以下是教育行业终端安全面临的主要挑战: 1、设备类型复杂化 问题:教育机构使用的终端设备包括PC、服务器等,操作系统…

Spring Boot知识点详解

打包部署 <!‐‐ 这个插件&#xff0c;可以将应用打包成一个可执行的jar包&#xff1b;‐‐> <build><plugins> <plugin> <groupId>org.springframework.boot</groupId><artifactId>spring‐boot‐maven‐plugin</artifactId&g…

DNS主从同步及解析

DNS 域名解析原理 域名系统的层次结构 &#xff1a;DNS 采用分层树状结构&#xff0c;顶级域名&#xff08;如.com、.org、.net 等&#xff09;位于顶层&#xff0c;下面是二级域名、三级域名等。例如&#xff0c;在域名 “www.example.com” 中&#xff0c;“com” 是顶级域名…

在Windows11上用wsl配置docker register 镜像地址

一、下载软件 1、下载wsl:安装 WSL | Microsoft Learn,先按照旧版 WSL 的手动安装步骤 | Microsoft Learn的步骤走 注:如果wsl2怎么都安装不下来,可能是Hyper-V没有打开,打开控制面板->程序和功能->启用或关闭Windows功能,勾选Hyper-V 如果Windows功能里面没有Hyp…

【Linux网络】构建UDP服务器与字典翻译系统

&#x1f4e2;博客主页&#xff1a;https://blog.csdn.net/2301_779549673 &#x1f4e2;博客仓库&#xff1a;https://gitee.com/JohnKingW/linux_test/tree/master/lesson &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; 如有错误敬请指正&#xff01; &…

【PGCCC】Postgres 故障排除:修复重复的主键行

如何从表中删除不需要的重复行。这些重复行之所以“不需要”&#xff0c;是因为同一个值在指定为主键的列中出现多次。自从 glibc 好心地改变了排序方式后&#xff0c;我们发现这个问题有所增加。当用户升级操作系统并修改底层 glibc 库时&#xff0c;这可能会导致无效索引。 唯…

DeepSeek+Cursor+Devbox+Sealos项目实战

黑马程序员DeepSeekCursorDevboxSealos带你零代码搞定实战项目开发部署视频教程&#xff0c;基于AI完成项目的设计、开发、测试、联调、部署全流程 原视频地址视频选的项目非常基础&#xff0c;基本就是过了个web开发流程&#xff0c;但我在实际跟着操作时&#xff0c;ai依然会…