跟着刘二大人学pytorch(第---13---节课之RNN高级篇)

news2025/1/1 23:57:58

文章目录

  • 0 前言
    • 0.1 课程视频链接:
    • 0.2 课件下载地址:
  • 1 本节课任务描述
    • 模型的处理过程
      • 训练循环
        • 初始化分类器
        • 是否使用GPU
        • 构造损失函数和优化器
        • 每个epoch所要花费的时间
        • 遍历每个epoch时进行训练和测试
        • 记录每次测试的准确率加入到列表中
  • 具体实现(经典的4步)
    • 1、准备数据
      • 1.1 输入数据的处理
      • 1.2 输出数据的处理
        • 1.2.1 编写数据集类
        • 1.2.2 实例化数据集类并加载DataLoader
    • 2、设计模型
      • Bi-direction RNN/LSTM/GRU的实现
        • 将name转换为tensor
        • 实现-测试
          • 实现-训练模型的结果展示
          • 练习13-1 电影评论的情感分析
  • 本系列的总结

0 前言

0.1 课程视频链接:

《PyTorch深度学习实践》完结合集
大佬的笔记:大佬的笔记
pytorch=0.4

0.2 课件下载地址:

链接:https://pan.baidu.com/s/1_J1f5VSyYl-Jj2qIuc1pXw
提取码:wyhu

1 本节课任务描述

本节将构建一个RNN来做分类
在这里插入图片描述
具体是名字的分类
在这里插入图片描述
复习RNN
嵌入层负责将词的one-hot表示(高维稀疏)映射到一个低维稠密的向量中
隐层的h不一定和输出一致,因此需要添加一个线性层,把输出映射成和要求一致的
在这里插入图片描述
有些任务是不需要对所有的隐层状态h都做线性映射的,比如本节的目标是让RNN最后的输出是一个分类,对o1,o2,o3,o4是没有要求的,
实际上可以让网络简化为下面(分类结果只与最后一个隐藏状态相关):
在这里插入图片描述
具体如下:
在这里插入图片描述
每个名字中的字母就是序列中的x1,x2,。。。
序列有的长有的短,之后还要解决序列长度不一的情况

模型的处理过程

在这里插入图片描述

训练循环

初始化分类器

在这里插入图片描述
RNNClassifier是自己定义的分类模型
N_CHARS:字符的数量
HIDDEN_SIZE:GRU输出的维度
N_COUNTRY:国籍的类别数量
N_LAYER:GRU的层数

是否使用GPU

在这里插入图片描述

构造损失函数和优化器

在这里插入图片描述分类问题使用交叉熵损失函数

每个epoch所要花费的时间

elapsed:消逝
在这里插入图片描述
这里定义的time_since函数将当前时间减去since时间,再转换成分钟

遍历每个epoch时进行训练和测试

在这里插入图片描述

记录每次测试的准确率加入到列表中

在这里插入图片描述
记录到列表中后便于之后画图查看准确率变化情况

具体实现(经典的4步)

1、准备数据

1.1 输入数据的处理

在这里插入图片描述
77是M的ASCII值,M的one-hot表示则为索引为77的那一位为1,其余值为0的向量
长短不一:使用padding进行扩充
RNN的输入需要保证是张量,不进行padding的话不是矩阵
填充的长度为一个batch中最长的那个单词的字符数
在这里插入图片描述

1.2 输出数据的处理

1.2.1 编写数据集类

将所有的国家名字做成一个词典,给每个国家一个索引值,字典的格式为键为国家名,值为索引值
在这里插入图片描述
is_train_set=True:判断文件是否是训练集
使用了gzip和csv这两个包进行读取数据集,如果数据集的文件格式是pickle或者是HD5格式,此时需要用相应的包进行读取,需随机应变
在这里插入图片描述
rows是一个列表,这个列表中每一个元素都是元组(name,name所归属的国家)
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
此时name是一个字符串,国家是一个索引值
len:返回数据集的长度
在这里插入图片描述
在这里插入图片描述
函数getCountryDict()可以将国家与国家对应的索引值做出来,其中国家名为键,索引为值
其他的工具函数:
idx2country():给一个索引需要将国家这个字符串拿出来
在这里插入图片描述
getCountriesNum():返回国家的数量
在这里插入图片描述

1.2.2 实例化数据集类并加载DataLoader

在这里插入图片描述
在这里插入图片描述
参数配置:
在这里插入图片描述

2、设计模型

隐藏层的维度设置
在这里插入图片描述
Embedding层的维度设置
在这里插入图片描述
GRU的输入维度是hidden_size,输出维度也是hidden_size
在这里插入图片描述
bidirectional是设置这个GRU网络是否是双向的,如果为True则为双向,如果为False则为单向
在这里插入图片描述
设置n_directions值,如果bidirectional值为True则n_directions=2,否则n_directions=1
在这里插入图片描述

Bi-direction RNN/LSTM/GRU的实现

x_N-1只包含了前面的词的影响,也应该考虑后面的词的影响

下面的图中h0f中的f意思是序列的前向处理,h0b中的b就是序列的反向处理,将正向和反向的信息进行拼接

双向的GRU的输出信息包含两种内容:output和hidden
out包含h0,h1,…,hN
在这里插入图片描述
hidden只包含hNb和hNf
在这里插入图片描述
这里的hidden_size*self.n_directions是指隐藏层h维度是拼接起来的,所以需要乘以一个层数
在这里插入图片描述
_init_hidden是隐层的初始化,初始设置为一个全是0的张量
n_directions是指方向的个数,如果值为2,则为双向,如果值为1则为单项
n_linear是指网络的层数
.t():是张量的转置
在这里插入图片描述
转置之后的结果:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

对序列长度不一的数据进行padding0的操作,0在实际计算的时候其实不需要参与计算,不计算的话可以提高神经网络的运算速度
在pytorch中的一个函数是pack_padded_sequence(),这个函数就可以实现上面的功能,它需要Embedding和序列长度seq_lengths,这个函数的工作原理如下:

在这里插入图片描述
左侧的每个字都会变成一个向量,如上红色画出的部分,像padding的0都是映射为同一个向量,如下图右侧深色部分,如果显示不是0.11开头,那是up主写错了,所以他在图中标识出来了
在这里插入图片描述
Embeddin函数会将左边的图片变成右边的方块(seqLen,batchsize,hiddensize),如下
在这里插入图片描述
pack_padded_sequence()会返回一个打包的对象,
在这里插入图片描述
这个对象的文档解释:
在这里插入图片描述
看不懂的话看下面:
在这里插入图片描述

在处理长短不一的情况下,需要将0对应的词向量拿出去不参与计算,把非0的词向量给依次的堆叠在一起,等于说是一个这样的筛选工作,如果没有按照非0词的个数进行排列的话,pack_padded_sequence是不能work的

但是pack_padded_sequence这个函数要求按照非0个数进行排列,人话讲就是按照句子中词的多少来排序这个词向量,见下图:
首先需要按照序列长度进行排序
在这里插入图片描述
排序映射为向量:
在这里插入图片描述
之后按照下面的排序顺序进行排列,中间省去了两个满的,下面的10,10,10,9,6,表示左边那个图一横行有多少个非padding0的词向量,应该是up主写错了,应该写成9,9,9,8,5,数错了应该是
在这里插入图片描述
pack_padded_sequence将上图中的dataseq_length两个数据进行打包,其中seq_length会变成batch_sizes,即data和batch_sizes将作为gru的输入,gru会根据batch_sizes中存储的数字来每次从data中取出来多少的向量

通过上面的操作得到了下面gru的输入:gru_input
在这里插入图片描述
之后将gru_input作为输入得到最终的hidden,如果是双向GRU的话hidden如下:
hidden[-1]:GRU的最后一个隐藏层状态
hidden[-2]:GRU的倒数第二个隐藏层状态,
之后再将hidden放进全连接层,转换为想要的那种维度
在这里插入图片描述
不懂可以再看下单层的GRU结构:
在这里插入图片描述
双向的GRU的hidden如下操作:
在这里插入图片描述
如果使用的是双向的GRU,则需要将前向和后向的hidden进行拼接
在这里插入图片描述
最后使用一个全连接网络
在这里插入图片描述

将name转换为tensor

这个过程的最终结果需要得出3个值,一个batch的大小B,一个每个name的向量长度S,一个是每个name的向量表示中非0元素的个数,
在这里插入图片描述
过程:
字符串—>字符—>ASCII码—>填充—>转置—>排序
字符串—>字符—>ASCII码:
在这里插入图片描述
ASCII码—>填充:
在这里插入图片描述
填充—>转置:
在这里插入图片描述
转置—>排序:
在这里插入图片描述
name2list将名字name中的每个字符转成ASCII值
在这里插入图片描述
在这里插入图片描述
序列长度装成longtensor:
在这里插入图片描述
padding0的操作:先创造一个全是0的张量,然后将非0的张量给贴到0张量上
在这里插入图片描述
排序
在这里插入图片描述
将这三个都转为张量
在这里插入图片描述
一个epoch上进行训练训练
在这里插入图片描述
经典5步骤:
在这里插入图片描述
打印输出:
在这里插入图片描述

实现-测试

在这里插入图片描述
测试时不需要求梯度,因此使用with torch.no_grad():
在这里插入图片描述
计算输出:
在这里插入图片描述

output.max()这里是计算预测对了多少
在这里插入图片描述

实现-训练模型的结果展示

大概20个epoch之后模型在测试集上的效果表现最好,此时可以将模型保存下来
在这里插入图片描述

练习13-1 电影评论的情感分析

在这里插入图片描述
数据集下载地址:kaggle
在这里插入图片描述
数据示例展示
最后一列的数字表示这句话的情感分类
在这里插入图片描述

本系列的总结

1、通读pytorch文档
2、多读新的文献
3、多动手写代码

文章看不懂或者视频看不懂的可以参考以下优秀的博文:

1、【Pytorch深度学习实践】B站up刘二大人课程笔记——目录与索引(已完结)(该博主的博文建议多看看)
2、玩一玩 Typora
3、刘二大人《PyTorch深度学习实践》循环神经网络RNN高级篇

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1829447.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

哇塞,超好吃的麻辣片,一口就爱上

最近,我发现了一款让人欲罢不能的美食——食家巷麻辣片!😍 一打开包装,那浓郁的麻辣香气就扑鼻而来,瞬间刺激着我的嗅觉神经。😃食家巷麻辣片的外观色泽鲜艳,红通通的一片,看着就特…

Verilog综合出来的图

Verilog写代码时需要清楚自己综合出来的是组合逻辑、锁存器还是寄存器。 甚至,有时写的代码有误,vivado不能识别出来,这时打开综合后的schematic简单查看一下是否综合出想要的结果。 比如:误将一个always模块重复一遍,…

Java环境安装

下载JDK https://www.oracle.com/cn/java/technologies/downloads/#jdk22-windows 点开那个下载都可以但是要记住下载的路径因为下一步要添加环境变量 选择编辑系统环境变量 点击环境变量 点击新建 新建环境变量JAVA_HOME 并输入JDK在计算机保存的路径 打开cmd 输入java -…

深度解析Spring事务管理:从源码到实际应用

引言 Spring框架的事务管理是Java企业级应用开发中不可或缺的一部分。它提供了一种声明式和编程式的事务管理方式,极大地简化了事务的处理。本文将深入探讨Spring事务的底层实现原理,通过源码分析,揭示其内部工作机制。 EnableTransactionMan…

举例说明 如何判断Spark作业的瓶颈

首先看哪个Job执行时间长: 例如下图中明显Job 2时间执行最长,这个对rdd作业是直观有效的。 对于sql作业可能不准确,sql需要关注stage的详情耗时。 然后看执行时间长的Job中哪个stage执行时间长: 明显stage 7和stage 13执行时间长&…

Excel中多条件判断公式怎么写?

在Excel里,这种情况下的公式怎么写呢? 本题有两个判断条件,按照题设,用IF函数就可以了,这样查看公式时逻辑比较直观: IF(A2>80%, 4, IF(A2>30%, 8*(A2-30%),0)) 用IF函数写公式,特别是当…

单列集合顶层接口Collection及五类遍历方式(迭代器)

collection add方法细节: remove方法细节: contains方法细节: 如果集合中存储的是自定义对象, student之类的, 也想通过contains进行判断, 就必须在javaBean中重写equals方法 contains在arrayList中源代码:在底层调用了equals方…

爱了爱了,11款超良心App推荐!

AI视频生成:小说文案智能分镜智能识别角色和场景批量Ai绘图自动配音添加音乐一键合成视频https://aitools.jurilu.com/今天,我们向你推荐十款与众不同但又不错的win10软件,它们都有各自的功能和优点,相信你一定会喜欢。 1.图片处…

大数据开发流程解析

大数据开发是一个复杂且系统的过程,涉及需求分析、数据探查、指标管理、模型设计、ETL开发、数据验证、任务调度以及上线管理等多个阶段。本文将详细介绍每个阶段的内容,并提供相关示例和代码示例,帮助理解和实施大数据开发流程。 本文中的示…

通义千问调用笔记

如何使用通义千问API_模型服务灵积(DashScope)-阿里云帮助中心 package com.ruoyi.webapp.utils;import com.alibaba.dashscope.aigc.generation.Generation; import com.alibaba.dashscope.aigc.generation.GenerationOutput; import com.alibaba.dashscope.aigc.generation.G…

期末算法复习

0-1背包问题(动态规划) 例题 算法思想: 动态规划的核心思想是将原问题拆分成若干个子问题,并利用已解决的子问题的解来求解更大规模的问题。 主要是状态转移方程和状态 算法描述: 初始化一个二维数组dp&#xff0…

深度学习 --- stanford cs231学习笔记三(卷积神经网络CNN)

卷积神经网络CNN 1,有效的利用了图像的空间信息/局部感受野 全连接神经网络中的神经是由铺平后的所有像素计算决定。 由于计算时是把图像的所有像素拉成了一条线,因此在拉伸的同时也损失了图像像素之间固有的空间信息。 卷积层中的神经只由5x5x3(假设fil…

JavaFX文本

另一个基本的JavaFX节点是Text节点,它允许我们在场景图上显示文本。要创建Text节点,请使用javafx.scene.text.Text类。 所有JavaFX场景节点都从javafx.scene.Node中扩展,并且它们继承了许多功能,例如缩放,翻译或旋转的…

稀疏矩阵是什么 如何求

稀疏矩阵是一种特殊类型的矩阵,其中大多数元素都是零。由于稀疏矩阵中非零元素的数量远少于零元素,因此可以使用特定的数据结构和算法来高效地存储和处理它们,从而节省存储空间和计算时间。 RowPtr 数组中的每个元素表示对应行的第一个非零元…

计算机缺失msvcr110.dll如何解决,这6种解决方法可有效解决

电脑已经成为我们生活和工作中不可或缺的工具,然而在使用电脑的过程中,我们常常会遇到一些问题,其中之一就是电脑找不到msvcr110.dll文件。这个问题可能会给我们带来一些困扰,但是只要我们了解其原因并采取相应的解决方法&#xf…

C 语言连接MySQL 数据库

前提条件 本机安装MySQL 8 数据库 整体步骤 第一步:开启Windows 子系统安装Ubuntu 22.04.4,安装MySQL 数据库第三方库执行 如下命令: sudo aptitude install libmysqlclient-dev wz2012LAPTOP-8R0KHL88:/mnt/e/vsCode/cpro$ sudo aptit…

使用Java Spring Boot生成二维码与条形码

个人名片 🎓作者简介:java领域优质创作者 🌐个人主页:码农阿豪 📞工作室:新空间代码工作室(提供各种软件服务) 💌个人邮箱:[2435024119qq.com] &#x1f4f1…

导出excle表

文章目录 导出excle表需求场景引入依赖具体代码 导出excle表 需求场景 假设我们有一个需求,现在数据库中有一些用户信息,我们想要把这些信息导出到excle表格中,然后存储到本地磁盘中。要求:excle表格的第一行需要有黄色背景&…

系统报错vcruntime140_1.dll文件缺失怎么回事?多种解决方法让你对比

一、vcruntime140_1.dll常见问题与错误信息 错误信息类型 启动错误:应用程序在启动时提示缺少 vcruntime140_1.dll 文件。 运行时错误:应用程序在运行过程中突然崩溃,提示 vcruntime140_1.dll 错误。 兼容性错误:新旧版本的 V…

7z及7zip-cpp最高压缩比的免费开源压缩软件

7z介绍 7z是一种主流高效的压缩格式,它拥有极高的压缩比。在计算机科学中,7z是一种可以使用多种压缩算法进行数据压缩的档案格式。该格式最初由7-Zip实现并采用,但这种档案格式是公有的,并且7-Zip软件本身亦在GNU宽通用公共许可证…