pytorch案例代码-2

news2024/11/18 1:30:26

循环神经网络——基础知识
适合前后有联系的连续数据预测,比如天气预测、股市预测、自然语言等,而这些用DNN、CNN来做计算量就太大或者没法做,h0是先验,也可以前面接上CNN+FC后面连上RNN,就可以完成图像到文本的转换,没有先验时h0也可以设为和h1同维度的全0。
注意RNN Cell是共享的,所有的RNN Cell都是同一个Linear(线性层),将xi的n维映射到hi的m维度,循环使用这个Linear,所以叫循环神经网络。hi即作为时刻i的xi的输出结果,也作为xi+1时刻的输入。
请添加图片描述
Cell运算如下所示,输入xt的维度是input_size × 1,因为需要和ht-1相加所以把它转为hidden_size维度,所以Wih便是hidden_size × input_size大小的,然后Whh是hidden_size × hidden_size的,各自加上偏置,所得俩个向量直接(信息融合)相加传入tanh激活(在RNN中这个常用一些),得到ht。
实际中这俩个运算是一起的,因为W1 × h + W2 × x可转化为[W1 W2] × [h x]^T,实际中会把x和h拼起来得到(hidden_size+input_size)× 1的向量和一个hidden_size × (hidden_size+input_size)的权重矩阵运算。所以本质上还是一个线性层。

请添加图片描述
pytorch中可以自己定义RNN Cell。也可以用自带的,如下所示,注意input和hidden各自分别有俩个参数,第一个是batch,代表样本数,第二个才是维度。
请添加图片描述
如下所示,batch_size表示样本数,seqLen表示每个样本有多少个序列,比如这种情况每个样本包含x1 x2 x3三个四维的序列。
请添加图片描述




循环神经网络——使用循环神经网络
有了以上基础知识就可以用RNN了,RNN不用我们自己写循环,自动计算。注意RNN最关键和最基础的就是搞懂各个参数的size。
使用RNN,多一个num_layers参数,表示层数,如下面第二张图表示三层RNN(层数多了计算比较耗时),虽然看起来很复杂,其实也就只有三个线性层权重。
对于RNN的输入参数来说,inputs需要包含整个输入序列(x1-xn),hidden就是h0(如果有三层,就是三个左边的h0)。
输出的俩个张量参数中,第一个out表示h1-hn,第二个参数hidden表示hn(如果有三层,就是三个左边的hn)。
请添加图片描述
请添加图片描述

写代码时就要关注参数构建:请添加图片描述
TIPS:为了方便构造数据集,RNN还提供了个参数batch_first,设置成True时,batch_size和seq_len参数位置互换。其他的不变
请添加图片描述



循环神经网络——例子使用RNNCell
一个简单的学习输入hello序列输出ohlo序列的例子。但是这些字符网络无法计算,所以一般在nlp中要将其向量化,一般方法是构造字符词典,如果是词级别的就是词的词典。如图2所示,这里给每个字符分配一个索引(可以随机分配,保持唯一就行),再由索引下标,生成独热向量,即只有索引下标为1其他为0,而向量的维度就是词典元素数量,这里是4,所以input_size是4。
因为我们想要输出告诉我们输出的是h e l o四个字符中的哪一个类别,所以输出的size即hidden_size也是4。
请添加图片描述
请添加图片描述
因为只有一个样本拿来训练,所以batch_size为1。然后首先创建字典即字符列表,方便根据索引拿字符,x_data对应h e l l o,y_data对应o h l o l。
然后用一个简单的查询语句构造了输入数据的独热向量,比如输入数据第二个字符索引对应0,就把第0行数据拿出来作为对应向量,所以输入数据独热向量的size是(seq_len, input_size),因为要求inputs中间还有个batch_size,所以view一下,最终保持成(seq_len,batch_size, input_size)的大小,同样把lables也reshape成(seq_len,1)的维度。
这里注意区分inputs与forward中的input,这里input的size是(batch_size, input_size),训练时一个个的input从inputs中取出,即一个个xt,这里forward做的是计算一个个ht = cell(xt, ht-1).
图三的第三个函数做的是生成初试的默认h0,这里生成全0。大小与ht、xt一样都是input的维度。可以看到只有在这里用到了传入的batch_size参数,实际上不传入这个参数,由input的shape也可以得到这个参数。

请添加图片描述
请添加图片描述
请添加图片描述

构造完模型就可以训练了,用这一个样本训练15轮,每轮训练记得要用init_hidden初始化h0,然后写循环遍历inputs拿出一个个input,也就是拿出一个个xt训练。关键还是弄明白一个个size,这里input的size是(batch_size, input_size),inputs是(seq_len,batch_size, input_size);lables是(seq_len,1),lable是(1).
注意这里loss要一个个累加,将整个序列的损失加起来构造计算图。
for循环最后俩行代码是为了输出预测,因为hidden是四维向量,分别代表每个字符的预测概率,因为其是(batch_size, input_size)维度,所以按dim1,即行来查找最大值下标来输出ht对应字符。
最后可以看到loss不断减少,最后结果也正确了。
请添加图片描述
在这里插入图片描述
请添加图片描述



循环神经网络——例子使用RNN
还是上面的例子,如果使用RNN就不用自己写循环计算。
这里隐层h0放在forward中构造,也可以放在外面构造传入forward。这里forward的input是整个inputs,所以是(seq_len,batch_size, input_size),这里hidden的维度是(num_layers, batch_size,hidden_size),因为最后的输出out是(seq_len,batch_size, hidden_size),所以将outview一下,好处是在计算交叉熵的时候需要将其变成一个二维矩阵,lables也是要从(seq_len,batch_size,1)变为(seq_len×batch_size,1)。
请添加图片描述
请添加图片描述
训练步骤就不用写循环,一次传入整个inputs,最下面的代码也是为了将预测输出。
这里有个小问题是,损失函数中传入的是outputs和lables,分别是view过后的((seq_len×batch_size, hidden_size))与(seq_len×batch_size,1)(其实经过pycharm测试,lables是torch.Size([5]),outputs是torch.Size([5, 4])),hidden_size和1怎么能对得上然后求损失函数呢?
其实torch.nn.CrossEntropyLoss(input, target)中的标签target使用的不是one-hot形式,而是类别的序号。形如 target = [1, 3, 2] 表示3个样本分别属于第1类、第3类、第2类。input:预测值,(batch,dim),这里dim就是要分类的总类别数,target:真实值,(batch),这里为啥是1维的?因为真实值并不是用one-hot形式表示,而是直接传类别id。
请添加图片描述
请添加图片描述



循环神经网络——独热向量—>EMBEDDING
缺陷:当词典存词,独热向量可能有上万的维度,且特别稀疏,且属于一种硬编码(不是学习出来的)。
针对以上有了嵌入层EMBEDDING
请添加图片描述
请添加图片描述
可以降维也可以升维,比如独热向量是四维,我们升到五维,创造下图所示矩阵,比如输入的是第二个字符,就把第二行拿出来。
请添加图片描述
所以网络变成下面这样,首先嵌入层会把输入的独热向量转成稠密的形式,嵌入层要求输入是长整型的数据,最后还有一个线性层,因为要保证最后输出的hidden要和分类的类别数一致
请添加图片描述
嵌入层主要关注俩个参数,第一个是输入数据的维度num_emdeddings即独热向量的维度,第二个是输出维度embedding_dim,这俩就构成了上面矩阵的高度和宽度,嵌入层的input是长整型张量,output的size是(*, embedding_dim),表示input的shape,表示可以是任意维度。请添加图片描述
请添加图片描述
然后是线性层:
请添加图片描述
然后是前面提到的交叉熵,target要比input小一维度,然后后面的数据可以是k维的。
请添加图片描述



循环神经网络——使用嵌入层网络
在上面基础之上,网络结构就是加了一个嵌入层和线性层,注意嵌入层的输入数据是长整型的,且size是(batch_size,seqlen),所以网络的输入数据的size得改成这样的,然后嵌入层的输出是(batch_size,seqlen,embedding_size),等于说它内部已经帮你处理好从标量到独热向量转换的过程了,你只需要提供input_size、embedding_size和索引序列x_data。
其次注意RNN构造时,batch_first设置为True了,将batch_size放到seq_len前面。
然后全连接层完成hidden_size到类别数的转换, 最后输出数据view成二维矩阵(batch_size×seqlen,num_class)以传入loss计算。
ouputs是(batch_size×seqlen,num_class),target或者说lables是(batch_size×seqlen)的。
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述



循环神经网络——实现一个循环神经网络的分类器
解决对名字进行分类的问题,用几千个来自18个语言的人的名字进行训练,模型可以根据拼写来预测人名属于哪种语言。请添加图片描述
因为我们只关心最后名字的分类所以模型可以变得更简单,由下面图1转为图2,我们只关心最终的隐层状态,然后在通过一个全连接分成18个类别。
请添加图片描述
请添加图片描述
并且我们将定义的模型改成使用gru,这里x虽然看起来只有一个数据,但是实际上一个Maclean的名字,是由多个单词组成,即x1、x2、x3…组成一个单词,我们输入的是x1、x2、x3…这样的序列,因为名字长短不一,所以输入也不一,这是我们需要考虑的。
请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/30007.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AE VAE 代码和结果记录

Auto Encoder 在MNIST 上记录 直接上代码 import os os.chdir(os.path.dirname(__file__)) import torch import torch.nn as nn import torch.nn.functional as F import torchvision from torchvision import transforms from torchvision.utils import save_image from to…

数据结构-学习-01-线性表之顺序表-初始化、销毁、清理、获取长度、判断为空、获取元素等实现

一、测试环境 名称值cpu12th Gen Intel Core™ i7-12700H操作系统CentOS Linux release 7.9.2009 (Core)内存3G逻辑核数2gcc 版本4.8.5 20150623 二、个人理解 数据结构分为逻辑结构和物理结构(也称为存储结构)。 1、逻辑结构 逻辑结构又可以分为以下…

JS 事件

事件 事件是 JS 和 HTML 交互的桥梁。采用“观察者模式”,使用仅在事件发生时执行的监听器(也叫处理程序)订阅事件 事件流 事件流描述的是页面接收事件的顺序。分为 3 各阶段: 事件捕获:最先触发,可以做…

致敬经典 睛彩再现——AVS产业联盟和中国移动咪咕公司携手推动AVS3视频、音频标准

2022年11月14日,中国移动咪咕公司首发AVS3移动端规模化商用版本咪咕视频6.0.7.00,该版本下设的“致敬经典 睛彩再现”专区、以及“菁彩视听”双Vivid直播视角(Audio Vivid & HDR Vivid),通过国家自主的AVS3、Audio…

回顾复习【矩阵分析】初等因子 和 矩阵的相似 || 由不变因子求初等因子 || 由初等因子和秩求Smith标准形(不变因子)

目录 1. 由不变因子,引出 初等因子的概念2. 【必看】例子:已知 不变因子,求初等因子。3.【必看】 例子:已知 秩和初等因子,求史密斯标准形(不变因子)4. 分块矩阵 初等因子的 求法5. 数字矩阵的相似 与 入-矩阵的等价1. 由不变因子,引出 初等因子的概念 例如,下面两个矩阵…

Kotlin 开发Android app(十):Android控件绑定ViewBinding

上一节中,我们知道了Android的布局,这种把界面和逻辑控制分开,是编程里很好的分离方式,也大大的解耦了界面和逻辑控制,使得编程的逻辑不在和界面挂钩。 有了界面的布局,我们需要把界面和代码部分进行绑定&…

OpenPose训练教程

找遍全网都没有非常完整的OpenPose训练教程 决定自己摸索并且记录下来 openpose作者发布了一份训练代码,下面根据这个来操作 GitHUB地址: openpsoe_train 环境:ubuntu 执行matklab脚本的时候懒得下载新的matlab 就在windows下运行的 感觉没…

品质为先,服务不停,广州流辰信息公司恪守初心,匠心为民!

随着互联网技术的蓬勃发展,越来越多的企业也感受到了日益激烈的竞争,也意识到墨守成规的发展模式必当会让企业停滞不前,只有一步一个脚印,始终跟随市场的脚步创新升级,才有可能在汹涌的市场洪流中站稳脚跟。广州流辰信…

精简 Windows10

下载链接文后评论里找: 旧机福音 极限精简Win10系统Tiny10https://baijiahao.baidu.com/s?id1743901721464184983不想成天折腾操作系统,一直以来都认为跟着微软每月升级就好了。但是现实啪啪的打脸:升级到Windows11 22H2 后, 连…

常见算法设计与分析的简单C++代码实现(排列、二分法搜索、Dijkstra算法、元素换位、单调子序列、硬币问题、运动员最佳匹配问题)

常见算法设计与分析的简单C代码实现(排列、二分法搜索、Dijkstra算法、元素换位、单调子序列、硬币问题、运动员最佳匹配问题)1 一些简单排列问题2 二分法查找3 前后元素换位4 找最长单调递增子序列(O(n2)复杂度)5最小硬币问题一、…

c3p0,DBCP,Druid(德鲁伊)数据库连接池

c3p0,DBCP,Druid(德鲁伊)数据库连接池 每博一文案 佛说:前世 500 次的回眸,才换来今生的一次擦肩而过。 人与人之间的缘分,真的无需强求,并不是所有的感情都能天长地久,…

C#压缩图片

SqlSer数据库设置保存图片字段类型为Image类型 对应保存 方法参数为图片路径&#xff0c;压缩后路径&#xff0c;压缩最大宽度&#xff0c;压缩最大高度 引用类型using System.Data; using System.Drawing; using System.IO; \完整类 /// <summary> /// 按比例缩放&…

七牛qshell 批量上传 mac 本地目录

七牛qshell 批量上传 mac 本地目录下载路径及使用方法(官方)下载到自己指定的文件夹添加环境变量,使qshell在任意地方可以执行添加密钥 生成账户文件下载路径及使用方法(官方) https://developer.qiniu.com/kodo/1302/qshell记录自己部署遇到的问题及操作步骤 下载到自己指定…

音视频开发核心知识点及源码解析,还不赶紧收藏起来

随着基础设施的完善&#xff08;光纤入户、wifi覆盖、5G普及&#xff09;的影响&#xff0c;将短视频、直播、视频会议、在线教育、在线医疗瞬间推到了顶峰&#xff0c;人们对音视频的需求和要求也越来越强烈 音视频开发还具有许多方向&#xff0c;比如&#xff1a; 如果对音视…

C语言:while后加分号与for后加分号的区别

while 后面不能加分号&#xff0c;否则虽然编译可以通过&#xff0c;但是执行程序时会发生死循环#include <stdio.h> int main() { int i1,total0; while(i<100)//不能在 while 后面加分号 { totali; i;//循环…

个人付费专栏上线预热

个人付费专栏上线预热 专栏地址&#xff1a;请点击访问 文章目录一、订阅这个专栏有什么好处&#xff1f;二、实战项目预告1. 活动类站点 &#xff08;已完成前端后端&#xff09;2. 电商项目 &#xff08;筹备中&#xff0c;一比一还原设计图&#xff09;3. 论坛问答系统 &…

每日三题-爬楼梯、买卖股票的最佳时机、正则表达式匹配

&#x1f468;‍&#x1f4bb;个人主页&#xff1a; 才疏学浅的木子 &#x1f647;‍♂️ 本人也在学习阶段如若发现问题&#xff0c;请告知非常感谢 &#x1f647;‍♂️ &#x1f4d2; 本文来自专栏&#xff1a; 算法 &#x1f308; 算法类型&#xff1a;Hot100题 &#x1f3…

IP 摄像机移动应用 SDK 开发入门教程(安卓版)

涂鸦智能安卓版摄像机&#xff08;IP Camera&#xff0c;简称 IPC&#xff09;SDK 是基于智能生活 App SDK 开发而成。 通过移动应用控制物理网设备是常见的使用场景&#xff0c;但由于设备的品类丰富&#xff0c;增大了应用开发难度。因此 智能生活 App SDK 提供了常见的垂直…

支付宝支付内网穿透

支付宝支付&内网穿透一 沙箱环境二 python第三方模块python-alipay-sdk三 python-alipay-sdk二次封装四 支付接口五 内网穿透5.1 cpolar软件5.2 测试支付宝post回调一 沙箱环境 注册认证沙箱环境&#xff1a;https://openhome.alipay.com/platform/appDaily.htm?tabinfo …

【FileZila】实现windows与Linux系统文件互传

1、下载安装FileZila客户端 根据自己的PC系统版本&#xff0c;下载对应的FileZila客户端https://www.filezilla.cn/download/client 2、Linux服务端&#xff0c;安装配置vsftpd 2.1 安装ftp服务 sudo apt-get install vsftpd2.2 配置ftp服务 &#xff08;1&#xff09;打开ft…