8.6 循环神经网络的简洁实现

news2024/12/23 6:23:56

在这里插入图片描述
每个步长共用参数

加载数据

虽然 8.5节 对了解循环神经网络的实现方式具有指导意义,但并不方便。 本节将展示如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。 我们仍然从读取时光机器数据集开始。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

定义模型

高级API提供了循环神经网络的实现。 我们构造一个具有256个隐藏单元的单隐藏层的循环神经网络层rnn_layer。 事实上,我们还没有讨论多层循环神经网络的意义(这将在 9.3节中介绍)。 现在仅需要将多层理解为一层循环神经网络的输出被用作下一层循环神经网络的输入就足够了。

  1. 设置一个单层 256个隐藏单元的神经网络 rnn
num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

单隐藏层, 256个隐藏单元

隐藏层数(Number of Hidden Layers):
定义:隐藏层数指的是神经网络中介于输入层和输出层之间的层级数量。
作用:增加隐藏层数可以帮助神经网络学习更复杂的模式和特征,提高模型的表示能力。
影响:过多的隐藏层可能导致过拟合,而过少的隐藏层可能导致欠拟合。
选择:隐藏层数的选择通常需要根据具体问题的复杂度和数据集的特征来决定,可以通过交叉验证等方法来选择合适的层数。
隐藏单元数(Number of Hidden Units):
定义:隐藏单元数指的是每个隐藏层中包含的神经元数量。
作用:增加隐藏单元数可以增加神经网络的拟合能力,使其能够更好地学习数据的复杂关系。
影响:隐藏单元数的增加会增加模型的复杂度,可能导致过拟合;减少隐藏单元数可能导致模型欠拟合。
选择:隐藏单元数的选择通常需要在训练过程中进行调优,可以通过交叉验证或其他调参方法来确定合适的数量。
在实际应用中,选择合适的隐藏层数和隐藏单元数是一项重要的任务,需要根据具体问题的复杂度、数据集的特征以及计算资源等因素进行权衡和调整,以获得最佳的模型性能

  1. 我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)
state = torch.zeros((1, batch_size, num_hiddens))
state.shape
torch.Size([1, 32, 256])
  1. 通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。 需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。
X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape
(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))
  1. 与 8.5节类似, 我们为一个完整的循环神经网络模型定义了一个RNNModel类。 注意,rnn_layer只包含隐藏的循环层,我们还需要创建一个单独的输出层。
#@save
class RNNModel(nn.Module):
    """循环神经网络模型"""
    def __init__(self, rnn_layer, vocab_size, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        self.rnn = rnn_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.rnn.hidden_size
        # 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1
        if not self.rnn.bidirectional:
            self.num_directions = 1
            self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)

    def forward(self, inputs, state):
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        X = X.to(torch.float32)
        Y, state = self.rnn(X, state)
        # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)
        # 它的输出形状是(时间步数*批量大小,词表大小)。
        output = self.linear(Y.reshape((-1, Y.shape[-1]))) # 输出层
        return output, state

    def begin_state(self, device, batch_size=1):
        if not isinstance(self.rnn, nn.LSTM):
            # nn.GRU以张量作为隐状态
            return  torch.zeros((self.num_directions * self.rnn.num_layers,
                                 batch_size, self.num_hiddens),
                                device=device)
        else:
            # nn.LSTM以元组作为隐状态
            return (torch.zeros((
                self.num_directions * self.rnn.num_layers,
                batch_size, self.num_hiddens), device=device),
                    torch.zeros((
                        self.num_directions * self.rnn.num_layers,
                        batch_size, self.num_hiddens), device=device))

训练与预测

在训练模型之前,让我们基于一个具有随机权重的模型进行预测。

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)
'time travellerbbabbkabyg'

很明显,这种模型根本不能输出好的结果。 接下来,我们使用 8.5节中 定义的超参数调用train_ch8,并且使用高级API训练模型。

num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)
perplexity 1.3, 404413.8 tokens/sec on cuda:0
time travellerit would be remarkably convenient for the historia
travellery of il the hise fupt might and st was it loflers

在这里插入图片描述

#@save
def train_ch8(net, train_iter, vocab, lr, num_epochs, device,
              use_random_iter=False):
    """训练模型(定义见第8章)"""
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',
                            legend=['train'], xlim=[10, num_epochs])
    # 初始化
    if isinstance(net, nn.Module):
        updater = torch.optim.SGD(net.parameters(), lr)
    else:
        updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)
    predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)
    # 训练和预测
    for epoch in range(num_epochs):
        ppl, speed = train_epoch_ch8(
            net, train_iter, loss, updater, device, use_random_iter)
        if (epoch + 1) % 10 == 0:
            print(predict('time traveller'))
            animator.add(epoch + 1, [ppl])
    print(f'困惑度 {ppl:.1f}, {speed:.1f} 词元/秒 {str(device)}')
    print(predict('time traveller'))
    print(predict('traveller'))
#@save
def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):
    """训练网络一个迭代周期(定义见第8章)"""
    state, timer = None, d2l.Timer()
    metric = d2l.Accumulator(2)  # 训练损失之和,词元数量
    for X, Y in train_iter:
        if state is None or use_random_iter:
            # 在第一次迭代或使用随机抽样时初始化state
            state = net.begin_state(batch_size=X.shape[0], device=device)
        else:
            if isinstance(net, nn.Module) and not isinstance(state, tuple):
                # state对于nn.GRU是个张量
                state.detach_()
            else:
                # state对于nn.LSTM或对于我们从零开始实现的模型是个张量
                for s in state:
                    s.detach_()
        y = Y.T.reshape(-1)
        X, y = X.to(device), y.to(device)
        y_hat, state = net(X, state)
        l = loss(y_hat, y.long()).mean()
        if isinstance(updater, torch.optim.Optimizer):
            updater.zero_grad()
            l.backward()
            grad_clipping(net, 1)
            updater.step()
        else:
            l.backward()
            grad_clipping(net, 1)
            # 因为已经调用了mean函数
            updater(batch_size=1)
        metric.add(l * y.numel(), y.numel())
    return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()

小结

  • 深度学习框架的高级API提供了循环神经网络层的实现。

  • 高级API的循环神经网络层返回一个输出和一个更新后的隐状态,我们还需要计算整个模型的输出层。

  • 相比从零开始实现的循环神经网络,使用高级API实现可以加速训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1553650.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

面向对象:继承

文章目录 一、什么叫继承?二、单继承三、多继承3.1多继承的各种情况3.1.1一般情况3.1.1特殊情况(菱形继承) 四、菱形继承引发的问题4.1 问题1:数据冗余4.2 问题2:二义性(无法确定到底是访问哪个) 五、虚拟继承解决菱形…

PXE批量装centos7系统

1.环境准备: yum -y install tftp-server xinetd #安装并启用 TFTP 服务 #修改TFTP服务的配置文件 vim /etc/xinetd.d/tftp protocol udp # TFTP默认使用UDP协议 wait no # no表示客户机可以多台一起连接&…

JavaScript高级 —— 学习(一)

目录 一、作用域 (一)局部作用域 1.函数作用域 2.块作用域 (二)全局作用域 二、垃圾回收机制 GC (一)生命周期 1.内存分配 2.内存使用 3.内存回收 4.特殊情况——内存泄漏: 注意&…

【MySQL】数据库--表操作

目录 一、创建表 二、查看表 三、修改表 1. 添加字段--add 2.修改表名--rename to 3.修改列名--change 4.修改字段的数据类型--modify 5.删除字段(列)--drop 四、删除表 一、创建表 create [temporary]table[if not exists]table_name [([colu…

答题小程序功能细节揭秘:如何提升用户体验和满足用户需求?

答题小程序功能细节体现 随着移动互联网的快速发展,答题小程序成为了用户获取知识、娱乐休闲的重要平台。一款优秀的答题小程序不仅应该具备简洁易用的界面设计,更应该在功能细节上做到极致,以提升用户体验和满足用户需求。本文将从题库随机…

Java毕业设计-基于springboot开发的游戏分享网站平台-毕业论文+答辩PPT(附源代码+演示视频)

文章目录 前言一、毕设成果演示(源代码在文末)二、毕设摘要展示1、开发说明2、需求分析3、系统功能结构 三、系统实现展示1、系统功能模块2、后台登录2.1管理员功能模块2.2用户功能模块 四、毕设内容和源代码获取总结 Java毕业设计-基于springboot开发的…

类与对象中C++

加油!!! 文章目录 前言 一、类的6个默认成员函数 ​编辑 二、构造函数 1.概念 三、析构函数 1.概念 2.特性 四、拷贝构造函数 1.概念 2.特征 拷贝构造函数典型调用场景 五、赋值运算符重载 1.运算符重载 2.赋值运算符重载 赋值运算符重载格式…

Day24:私信列表、私信详情、发送私信

测试用户:用户名aaa 密码aaa 查询当前用户的会话列表;每个会话只显示一条最新的私信;支持分页显示。 首先看下表结构: conversation_id: 用from_id和to_id拼接,小的放前面去(因为两个人的对话应该在一个会…

如何使用剪映专业版剪辑视频

1.操作界面功能介绍 2.时间线的使用 拖动前端后端缩减时长,有多个素材可以拖动调节前后顺序拼接。 分割视频 删除

新零售SaaS架构:客户管理系统的应用架构设计

客户管理系统的应用架构设计 应用层定义了软件系统的应用功能,负责接收用户的请求,协调领域层能力来执行任务,并将结果返回给用户,功能模块包括: 客户管理:核心功能模块,负责收集和更新客户信息…

FPGA时钟资源详解(3)——全局时钟资源

FPGA时钟系列文章总览:FPGA原理与结构(14)——时钟资源https://ztzhang.blog.csdn.net/article/details/132307564 一、概述 全局时钟是 FPGA 中的一种专用互连网络,旨在将时钟信号分配到 FPGA 内各种资源的时钟输入处。这种设计…

探索数据库--------------mysql主从复制和读写分离

目录 前言 为什么要主从复制? 主从复制谁复制谁? 数据放在什么地方? 一、mysql支持的复制类型 1.1STATEMENT:基于语句的复制 1.2ROW:基于行的复制 1.3MIXED:混合类型的复制 二、主从复制的工作过程 三个重…

IMU评估产后腹直肌分离康复训练

腹直肌分离(Diastasis Recti Abdominis,DRA)是由腹部纤维联结组织——腹白线的过度伸张所致,这项研究的目标是通过惯性测量单元(IMU)感应器信号来分析产后康复。传统的康复方式通常包括针对性的物理疗法&am…

《QT实用小工具·一》电池电量组件

1、概述 项目源码放在文章末尾 本项目实现了一个电池电量控件,包含如下功能: 可设置电池电量,动态切换电池电量变化。可设置电池电量警戒值。可设置电池电量正常颜色和报警颜色。可设置边框渐变颜色。可设置电量变化时每次移动的步长。可设置…

CleanMyMac X 4.15.1 for Mac 最新中文破解版 系统优化垃圾清理工具

CleanMyMac X for Mac 是一款功能更加强大的系统优化清理工具,相比于 CleanMyMac 4.15.1来说,功能增加了不少,此版本为4.15.1官方最新中英文正式破解版本,永久使用,解决了打开软件崩溃问题,最新版4.15.1版本…

neo4j相同查询语句一次查询特慢再次查询比较快。

现象&#xff1a; neo4j相同查询语句一次查询特慢再次查询比较快。 分析&#xff1a; 查询语句 //查询同名方法match(path:Method) where id(path) in [244333030] and NOT path:Constructor//是rpc的方法match(rpc_method:Method)<-[:DECLARES]-(rpc_method_cls:Class) -…

2024年智能版控费系统方案卓健易控

2024年智能版控费系统方案卓健易控 详细可咨询&#xff1a;19138173009 设备智能卓健易控ZJ-V8.0控费方案在科学和技术不断发展的背景下&#xff0c;逐渐实现了更新和迭代。现如今&#xff0c;感应技术、生物识别技术、智能图像识别技术、过程记录技术、监管控制技术等方面的…

Pytorch的hook函数

hook函数是勾子函数&#xff0c;用于在不改变原始模型结构的情况下&#xff0c;注入一些新的代码用于调试和检验模型&#xff0c;常见的用法有保留非叶子结点的梯度数据&#xff08;Pytorch的非叶子节点的梯度数据在计算完毕之后就会被删除&#xff0c;访问的时候会显示为None&…

解析为什么使用celery的task装饰就有delay属性

分析task装饰器原理 from celery.task import periodic_task task源码如下 def task(*args, **kwargs):"""Deprecated decorator, please use :func:celery.task."""return current_app.task(*args, **dict({base: Task}, **kwargs))这里回调…

JavaSE day14笔记

第十四天课堂笔记 课上: 适当做笔记课下 : 总结 , 读代码 , 反复敲代码 , 做练习 数组★★★ 数组 : 存储多个 同一类型 的容器格式 :数组类型 : 引用数据类型, new运算符在堆中 分配一块连续的存储空间 , 系统会给数组元素默认初始化 , 将该数组的引用赋值给数组名 引用数据…