pytorch(十)循环神经网络

news2025/1/18 16:58:15

文章目录

    • 卷积神经网络与循环神经网络的区别
    • RNN cell结构
    • 构造RNN
    • 例子 seq2seq

卷积神经网络与循环神经网络的区别

卷积神经网络:在卷积神经网络中,全连接层的参数占比是最多的。

卷积神经网络主要用语处理图像、语音等空间数据,它的特点是局部连接和权值共享,减少了参数的数量。CNN通常包括卷积层、池化层和全连接层,卷积层用于提取输入数据的特诊,池化层用语降维和压缩特征,全连接层将特征映射到更加高维度的空间中。

卷积神经网络在处理空间数据的时候,其输入通常是一个二维或者三维的张量,每一个元素对应一个输出。CNN通过卷积操作获取特征,这些特征往往是局部相关的,即每个特征都与输入数据的一个局部区域相关联(相邻的特征往往具有相似性)。

循环神经网络

循环神经网络主要用于处理序列数据,比如自然语言处理、语音识别和时间序列预测(具有时间依赖的数据),它的特点是具有时间轴,能够保留和处理序列中的历史信息。RNN包括循环层,能够接收前一层的信息并且与当前的输入结合,生成当前的输出序列。

循环神经网络处理的数据往往是序列数据,即输入的数据是一个序列,每一个时间步都有一个对应的输出,RNN通过循环层保留和传递历史信息,从而建立时间上的依赖关系。

循环神经网络也使用到了权值共享的想法,RNN在不同的时间位置共享参数(CNN在不同的空间位置共享参数)

RNN cell结构

RNN cell的结构在本质上类似于传统的神经网络模型(输入层、隐藏层、输出层),但是RNN于传统的NN最大的区别在于:隐藏层的输入不仅包括输入层的输出,还包括上一时刻隐藏层的输出,其结构如下:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

之所以叫做循环神经网络,是因为以上的cell都是同一个cell,参数也只有cell里的一份而已,在训练中,cell层使用的是循环结构,使用不同的输入数据进行训练。

构造RNN

第一种方法 RNNCell

cell=torch.nn.RNNCell(input_size=input_size,hidden_size=hidden_size)
hidden=cell(input,hidden)
# input of shape(batch,input_size)
# hidden of shape(batch,hidden_size)
# output of shape(batch,hidden_size)
# dataset.shape(seqLen,batchSize,inputSize)
import torch

batch_size=1 # N=1
seq_len=3 # x=[x1,x2,x3]
input_size=4 # [4*3]
hidden_size=2 # [2*3]
# 两种构造方式——1
cell=torch.nn.RNNCell(input_size=input_size,hidden_size=hidden_size)

#(seq,batch=n,features) 序列的长度 * 批量 * 维度
dataset=torch.randn(seq_len,batch_size,input_size)
# 批量 * 维度
hidden=torch.zeros(batch_size,hidden_size)

print('dataset:',dataset)
print('hidden:',hidden)

# 构造循环
for idx,input in enumerate(dataset):
    print('=' *20,idx,'='*20)
    print(input)
    print('Input size:',input.shape)
    
    hidden=cell(input,hidden)
    
    print(hidden)
    print('Outputs size:',hidden.shape)
    print(hidden)

第二种方法

cell=torch.nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers)
out,hidden=cell(input,hidden)
# input:表示整个输入序列;input of shape(seqSize,batch,input_size),这里的seqSize实际上就是上面的循环过程
# hidden:表示h0;hidden of shape(numLayers,batch,hidden_size)
# output of shape(seqLen,batch,hidden_size)
import torch

batch_size=1 # N=1
seq_len=3 # x=[x1,x2,x3]
input_size=4 # [4*3]
hidden_size=2 # [2*3]
num_layers=1
# 两种构造方式——2 num_layers表示层数
cell=torch.nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers)
# cell=torch.nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers)
# inputs=torch.randn(batch_size,seq_len,input_size)

inputs=torch.randn(seq_len,batch_size,input_size)
hidden=torch.zeros(num_layers,batch_size,hidden_size)

out,hidden=cell(inputs,hidden)
print('Output size:',out.shape)
print('Output:',out)
print('Hidden size:',hidden.shape)
print('Hidden:',hidden)

例子 seq2seq

需要把 hello 序列经过训练变成 ohlol 序列,x=[x1,x2,x3,x4,x5]

这里是需要训练一个模型,使得输入“hello”,而输出为“ohlol”,实际上对于输入的数据,判别输出的数据属于哪一个类别,这里只有4种字母,所以输出的维度为4。

在这里插入图片描述

将序列数据转变成向量,一般转换为独热向量

在这里插入图片描述

import torch

batch_size=1
input_size=4
hidden_size=4

idx2char=['e','h','l','o']
# 字典转变
x_data=[1,0,2,2,3]
y_data=[3,1,2,3,2]

one_hot_looup=[[1,0,0,0],
              [0,1,0,0],
              [0,0,1,0],
              [0,0,0,1]]

x_one_hot=[one_hot_looup[x] for x in x_data]

# 两个数据中的-1表示序列的长度seqSize
inputs=torch.Tensor(x_one_hot).view(-1,batch_size,input_size)
print('inputs:',inputs)
labels=torch.LongTensor(y_data).view(-1,1)

class Model(torch.nn.Module):
    def __init__(self,input_size,hidden_size,batch_size):
        super(Model,self).__init__()
        self.batch_size=batch_size
        self.input_size=input_size
        self.hidden_size=hidden_size
        
        self.rnncell=torch.nn.RNNCell(input_size=self.input_size,hidden_size=self.hidden_size)
        
    def forward(self,input,hidden):
        hidden=self.rnncell(input,hidden)
        return hidden
    
    # h0
    def init_hidden(self):
        return torch.zeros(self.batch_size,self.hidden_size)
    
net=Model(input_size,hidden_size,batch_size)
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(net.parameters(),lr=0.1)

for epoch in range(15):
    loss=0
    optimizer.zero_grad()
    hidden=net.init_hidden()
    
    print('Predicted string:',end='')
    # 这个实际上就是RNN
    for input,label in zip(inputs,labels):
        # print('input:',input,'label:',label)
        hidden=net(input,hidden)
        loss+=criterion(hidden,label)
        # 输出可能性最大的那个值,也就是预测的值
        _,idx=hidden.max(dim=1)
        print(idx2char[idx.item()],end='')
        
    loss.backward()
    optimizer.step()
    print(',Epoch [%d/15] loss=%.4f'%(epoch+1,loss.item()))

在这里插入图片描述

独热向量的缺点

  • 独热向量的维度过高,比如26个字母有26维度。给每一个汉字都映射一个独热向量会有维度的诅咒;
  • 独热向量过于稀疏;
  • 独热向量是硬编码,不是学习出来的,每一个向量对应都是确定的。

解决方案:Embedding,也就是把高维度的稀疏的数据映射成低维的稠密的数据中,也就是数据降维

降维后的网络

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1518749.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【分类讨论】【解析几何】【 数学】【推荐】1330. 翻转子数组得到最大的数组值

作者推荐 视频算法专题 本文涉及知识点 分类讨论 解析几何 LeetCode1330. 翻转子数组得到最大的数组值 给你一个整数数组 nums 。「数组值」定义为所有满足 0 < i < nums.length-1 的 |nums[i]-nums[i1]| 的和。 你可以选择给定数组的任意子数组&#xff0c;并将该子…

3月15日ACwing每日一题

789. 数的范围 - AcWing题库 #include <bits/stdc.h> using namespace std; int n,q; const int N100007; int a[N]; void solve(){//lower_bound是大于等于 upper_bound是大于int num;cin>>num;if(lower_bound(a,an,num)!an&&*lower_bound(a,an,num)num)…

fs模块 之 文件读取

fs 文件读取&#xff1a; 利用文件读取而不是直接打开文本查看的目的是为了实现自动化 读取文件的应用场景:电脑开机/程序运行/播放视频音乐/上传文件... 一、异步读取 &#xff08;1&#xff09;语法&#xff1a;fs.readFile(path,[options],callback); 以之前写的文件写…

matlab去除图片上的噪声

本问题来自CSDN-问答板块,题主提问。 如何利用matlab去除图片上的噪声? 一、运行效果图 左边是原图,右边是去掉噪音后的图片。 二、中文说明 中值滤波是一种常见的图像处理技术,用于去除图像中的噪声。其原理如下: 1. 滤波器移动:中值滤波器是一个小的窗口,在图像上移…

红队笔记7--Web机器为Linuxdocker逃逸

其实&#xff0c;不知道大家有没有想过&#xff0c;我们之前练习的都是web机器是windows的版本&#xff0c;但是其实&#xff0c;在现实生活中&#xff0c;服务器一般都是Linux的版本&#xff0c;根本不可能用到windows的版本 那么如果是Linux的话&#xff0c;我们就有很多的困…

express+mysql+vue,从零搭建一个商城管理系统14--快递查询(对接快递鸟)

提示&#xff1a;学习express&#xff0c;搭建管理系统 文章目录 前言一、安装md5&#xff0c;axios&#xff0c;qs二、新建config/logistics.js三、修改routes/order.js四、添加商品到购物车总结 前言 需求&#xff1a;主要学习express&#xff0c;所以先写service部分 快递鸟…

隐藏深的bug发现不了 ,有点挫备感 ,那是你没有进行bug总结 。

1.bug总结的意义 作为功能测试人员来说&#xff0c;可能有一半的时间都花在了和bug打交道上&#xff0c;比如如何发现bug &#xff0c;提交bug &#xff0c;跟踪bug以及回归bug上 。作为测试人员最重要的成果的bug &#xff0c;我们往往更看重的是它的数量 &#xff0c;却很少…

Android 辅助功能 -抢红包(三)

Android 辅助功能 -抢红包(三) 本篇文章继续讲述辅助功能. 主要通过监听通知栏红包消息,来跳转聊天页面,并自动回复对方"谢谢". 上篇文章我们讲述了监听notification, 跳转聊天界面. 具体可查看: Android 辅助功能 -抢红包(二) 1: 使用monitor抓取id. 打开andro…

RabbitMQ 模拟实现【六】:程序模拟实现

文章目录 模拟实现模拟消费者模拟生产者效果展示 启动结果如下&#xff1a; ![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/71841546ad8043f1bd51e4408df791de.png)![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/f6e3e72ff9a4483c978ec48e24f075c2.p…

运营模型—RFM 模型

运营模型—RFM 模型 RFM 是什么其实我们前面的文章介绍过,这里我们不再赘述,可以参考运营数据分析模型—用户分层分析,今天我们要做的事情是如何落地RFM 模型 我们的数据如下,现在我们就开始进行数据处理 数据预处理 因为数据预处理没有一个固定的套路,都是根据数据的实…

Unity类银河恶魔城学习记录10-1 10-2 P89,90 Character stats - Stat script源代码

Alex教程每一P的教程原代码加上我自己的理解初步理解写的注释&#xff0c;可供学习Alex教程的人参考 此代码仅为较上一P有所改变的代码 【Unity教程】从0编程制作类银河恶魔城游戏_哔哩哔哩_bilibili Stat.cs using System.Collections; using System.Collections.Generic; us…

C类期刊:基于改进粒子群优化算法的电力系统有功最优潮流程序代码!

程序提出了一种基于改进粒子群优化算法的有功最优潮流模型及求解方法&#xff0c;采用了自适应罚函数法处理最优潮流问题的各种约束条件。通过对IEEE-30节点系统的仿真计算&#xff0c;并且与遗传算法进行比较&#xff0c;验证了提出的模型和方法的有效性。程序算例丰富、注释清…

3.排序查找——2.整数奇偶排序

输入 4 7 3 13 11 12 0 47 34 98 输出 47 13 11 7 3 0 4 12 34 98 【提交地址】 题目分析 关键是找到交换位序的逻辑&#xff0c;有如下几种情况&#xff1a; 左值为奇数&#xff0c;右值为偶数 > 不需要交换左值为偶数&#xff0c;右值为奇数 > 需要交换左值和右值同…

【数据结构】 Map和Set万字总结(搜索树+哈希桶+使用方法+实现方法)

文章目录 Map和Set一、搜索树1.二叉搜索树的查找&#xff08;search&#xff09;2.二叉搜索树的插入3.二叉搜索树的删除4.性能分析 二、搜索方法1.概念 三、Map的使用1.概念&#xff1a;2.Map的常用方法&#xff1a;1.V put(K Key ,V Value )2.V get(Object key)3.V getOrDefau…

YOLOv8旋转目标检测实战:训练自己的数据集

课程链接&#xff1a;https://edu.csdn.net/course/detail/39393 旋转目标检测是计算机视觉领域的一个高级任务&#xff0c;它在传统目标检测的基础上进一步发展。传统目标检测技术主要关注于识别和定位图像中的物体&#xff0c;通常以水平边界框(HBB)来标识目标物体的位置。而…

某阿系影城网爬虫JS逆向

本次逆向目标网站如下,使用base64解码获得 aHR0cHM6Ly9oNWxhcmsueXVla2V5dW4uY29tL2ZpbG0vaW5kZXguaHRtbD93YXBpZD1GWVlDX0g1X1BST0RfU19NUFMmc3RhbXA9MTcxMDExNzc5NDM0NiZzcG09YTJvZjYubG9jYXRpb25faW5kZXhfcGFnZS4wLjA= 打开网站,发起请求后,发现请求参数没有加密,请求头…

【Stable Diffusion】入门-03:图生图基本步骤+参数解读

目录 1 图生图原理2 基本步骤2.1 导入图片2.2 书写提示词2.3 参数调整 3 随机种子的含义4 拓展应用 1 图生图原理 当提示词不足以表达你的想法&#xff0c;或者你希望以一个更为简单清晰的方式传递一些要求的时候&#xff0c;可以给AI输入一张图片&#xff0c;此时图片和文字是…

CorelDRAW2024中文版全新功能和软件使用介绍!

亲爱的用户们&#xff0c;我们非常高兴地向您介绍CorelDRAW 2024的全新功能和软件使用介绍&#xff01;作为一款深受设计师们喜爱的图形设计软件&#xff0c;CorelDRAW一直在不断地优化和升级&#xff0c;力求为您提供更加优秀的创作体验。今天&#xff0c;我们就来一起了解一下…

免费开源的 Vue 拖拽组件 VueDraggablePlus (兼容移动端)

VueDraggablePlus 支持 Vue2 / Vue3&#xff0c;是被尤雨溪推荐了的拖拽组件。我自己试用过了&#xff0c;还挺好用的&#xff0c;兼容移动端。 官网&#xff1a;https://alfred-skyblue.github.io/vue-draggable-plus/ 官网文档里面很详细了&#xff0c;我就不再介绍安装和用…

vitepress里使用gitalk(图文教程)

vitepress里使用gitalk Gitalk 是一个基于 GitHub Issue 和 Preact 开发的评论插件 生成client配置 创建OAuth application 填写完毕&#xff0c;点击 Register application 即可 生成client secrets 一开始没有自动生成 Client secrets&#xff0c;需要手动生成&#xff…