人工智能|机器学习——循环神经网络的简洁实现

news2024/11/19 0:36:26

循环神经网络的简洁实现

如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。 我们仍然从读取时光机器数据集开始。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
 
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

定义模型

高级API提供了循环神经网络的实现。 我们构造一个具有256个隐藏单元的单隐藏层的循环神经网络层rnn_layer。 事实上,我们还没有讨论多层循环神经网络的意义。 现在仅需要将多层理解为一层循环神经网络的输出被用作下一层循环神经网络的输入就足够了。

num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)。

state = torch.zeros((1, batch_size, num_hiddens))
state.shape

torch.Size([1, 32, 256])

通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。 需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape

 (torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))

我们为一个完整的循环神经网络模型定义了一个RNNModel类。 注意,rnn_layer只包含隐藏的循环层,我们还需要创建一个单独的输出层。

#@save
class RNNModel(nn.Module):
    """循环神经网络模型"""
    def __init__(self, rnn_layer, vocab_size, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        self.rnn = rnn_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.rnn.hidden_size
        # 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1
        if not self.rnn.bidirectional:
            self.num_directions = 1
            self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)
 
    def forward(self, inputs, state):
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        X = X.to(torch.float32)
        Y, state = self.rnn(X, state)
        # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)
        # 它的输出形状是(时间步数*批量大小,词表大小)。
        output = self.linear(Y.reshape((-1, Y.shape[-1])))
        return output, state
 
    def begin_state(self, device, batch_size=1):
        if not isinstance(self.rnn, nn.LSTM):
            # nn.GRU以张量作为隐状态
            return  torch.zeros((self.num_directions * self.rnn.num_layers,
                                 batch_size, self.num_hiddens),
                                device=device)
        else:
            # nn.LSTM以元组作为隐状态
            return (torch.zeros((
                self.num_directions * self.rnn.num_layers,
                batch_size, self.num_hiddens), device=device),
                    torch.zeros((
                        self.num_directions * self.rnn.num_layers,
                        batch_size, self.num_hiddens), device=device))

 训练与预测

在训练模型之前,让我们基于一个具有随机权重的模型进行预测。

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)

 很明显,这种模型根本不能输出好的结果。 接下来,我们使用定义的超参数调用train_ch8,并且使用高级API训练模型。 

num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

perplexity 1.3, 404413.8 tokens/sec on cuda:0 time travellerit would be remarkably convenient for the historia travellery of il the hise fupt might and st was it loflers

由于深度学习框架的高级API对代码进行了更多的优化, 该模型在较短的时间内达到了较低的困惑度。  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1252675.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

4-20mA高精度采集方案

下载链接!https://mp.weixin.qq.com/s?__bizMzU2OTc4ODA4OA&mid2247557466&idx1&snb5a323285c2629a41d2a896764db27eb&chksmfcfaf28dcb8d7b9bb6211030d9bda53db63ab51f765b4165d9fa630e54301f0406efdabff0fb&token976581939&langzh_CN#rd …

明道云伙伴成果与展望

摘要:这篇文章介绍了明道云在过去一年的成果以及未来的计划。明道云将把更多资源和精力投入到伙伴身上,提供更全面的支持,包括产品特性、展业支持和 GTM (Go-To-Market)支持三个方面。在产品特性方面,明道云…

【数据结构实验】排序(一)冒泡排序改进算法 Bubble及其性能分析

文章目录 1. 引言2. 冒泡排序算法原理2.1 传统冒泡排序2.2 改进的冒泡排序 3. 实验内容3.1 实验题目(一)输入要求(二)输出要求 3.2 算法实现 4. 实验结果5. 实验结论 1. 引言 排序算法是计算机科学中一个重要而基础的研究领域&…

03:2440--UART

目录 一:UART 1:概念 2:工作模式 3:逻辑电平 4:串口结构图 5:时间的计算 二:寄存器 1:简单的UART传输数据 A:GPHCON--配置引脚 B:GPHUP----使能内部上拉​编辑 C: UCON0---设置频率115200 D: ULCON0----数据格式8n1 E:发送数据 A:UTRSTAT0 B:UTXHO--发送数据输…

makefile 学习(5)完整的makefile模板

参考自: (1)深度学习部署笔记(二): g, makefile语法,makefile自己的CUDA编程模板(2)https://zhuanlan.zhihu.com/p/396448133(3) 一个挺好的工程模板,(https://github.com/shouxieai/cpp-proj-template) 1. c 编译流…

linux嵌入式时区问题

目录 操作说明实验参考 最近有个针对时区的需求,研究了下。 查询网上的一些设置,发现基本都是系统中自带的一些文件,然后开机时解析,或者是有个修改的命令。 操作 但针对嵌入式常用到的 busybox 制作的最小系统,并没…

图论|知识图谱——详解自下而上构建知识图谱全过程

导读:知识图谱的构建技术主要有自顶向下和自底向上两种。其中自顶向下构建是指借助百科类网站等结构化数据源,从高质量数据中提取本体和模式信息,加入到知识库里。而自底向上构建,则是借助一定的技术手段,从公开采集的…

activiti流程回退与跳转

学习连接 【工作流Activiti7】3、Activiti7 回退与会签 【工作流Activiti7】4、Activiti7 结束/终止流程 Activiti-跳转到指定节点、回退 ativiti6.0 流程节点自由跳转实现、拒绝/不同意/返回上一节点、流程撤回、跳转、回退等操作(通用实现,亲测可用…

基于python+TensorFlow+Django算法模型的车辆车型识别系统

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介简介技术栈主要模块1. 数据预处理2. 模型构建3. 模型训练4. 模型集成5. 用户界面 系统工作流程未来改进计划 二、功能三、系统四. 总结 一项目简介 # 车辆车…

【linux】基本指令(中篇)

echo指令 将引号内容打印到显示屏上 输出的重定向 追加的重定向 输出的重定向 我们学习c语言的时候当以写的方式创建一个文件,就会覆盖掉该文件之前的内容 当我们以追加的方式打开文件的时候,原文件内容不会被覆盖而是追加 more指令 10.more指令…

cephadm部署ceph quincy版本

环境说明 IP主机名角色 存储设备 192.168.2.100 master100 mon,mgr,osd,mds,rgw 大于5G的空设备192.168.2.101node101mon,mgr,osd,mds,rgw大于5G的空设备192.168.2.102node102mon,mgr,osd,mds,rgw大于5G的空设备 关闭防火墙 关闭并且禁用selinux 配置主机名/etc/hosts …

⑦【Redis GEO 】Redis常用数据类型:GEO [使用手册]

个人简介:Java领域新星创作者;阿里云技术博主、星级博主、专家博主;正在Java学习的路上摸爬滚打,记录学习的过程~ 个人主页:.29.的博客 学习社区:进去逛一逛~ Redis GEO ⑦Redis GEO 基本操作命令1.geoadd …

原生JS实现计算器(内含源码)

前言 本文讲解了JavaScript如何在一小时内实现一个简易计算器,这里最大的亮点就在于,我在JS中只用了一个事件,就实现了计算器的效果和功能,那么好文本正式开始。 布局和样式流程 首先是HTMLCSS结构:这里主要用到的…

基于uQRCode封装的Vue3二维码生成插件

标题:基于uQRCode封装的Vue3二维码生成插件 摘要:本文介绍了一种基于uQRCode封装的Vue3二维码生成插件,可以在Javascript运行环境下生成二维码并返回图片地址。该插件适用于所有Javascript运行环境,并且支持微信小程序。本文将详…

在Python中matplotlib函数的plt.plot()函数的颜色参数设置,以及可以直接运行的程序代码!

文章目录 前言一、使用字符串颜色:二、使用十六进制颜色:三、使用RGB元组:四、使用颜色映射:总结 前言 在matplotlib中,plt.plot()函数可以接受颜色参数,可以设置为字符串颜色(如red&#xff0…

笔记:pycharm当有多个plt.show()时候,只显示第一个plt.show()

import matplotlib.pyplot as plt import numpy as np# 创建数据 x np.linspace(0, 10, 100) y1 np.sin(x) y2 np.cos(x) y3 np.tan(x) y4 np.exp(x)# 创建一个2x2的子图网格 # fig plt.figure() fig,((ax1, ax2), (ax3, ax4)) plt.subplots(nrows2, ncols2, figsize(8,…

Redis:事务操作

目录 Redis事务定义相关命令事务的错误处事务冲突的问题Redis事务三特性 Redis事务定义 redis事务是一个单独的隔离操作,事务中的所有命令都会序列化、按顺序地执行,事务在执行的过程中,不会被其他客户端发送来的命令请求所打断。 redis事务…

车载电子电器架构 ——电子电气架构设计方案概述

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 注:本文1万多字,认证码字,认真看!!! 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和事,多看一眼都是你的不对。非必要不费力证…

作为Java初学者,如何快速学好Java?

作为Java初学者,如何快速学好Java? 开始的一些话 对于初学者来说,编程的学习曲线可能相对陡峭。这是正常现象,不要感到沮丧。逐步学习,循序渐进。 编程是一门实践性的技能,多写代码是提高的唯一途径。尽量…

【一起来学kubernetes】7、k8s中的ingress详解

引言配置示例负载均衡的实现负载均衡策略实现模式实现方案Nginx类型Ingress实现Treafik类型Ingress实现HAProxy类型ingress实现Istio类型ingress实现APISIX类型ingress实现 更多 引言 Ingress是Kubernetes集群中的一种资源类型,用于实现用域名的方式访问Kubernetes…