《PyTorch深度学习实践10》——循环神经网络-基础篇(Basic-Recurrent Neural Network)

news2024/9/25 13:21:22

目录

    • 一、RNN简介
    • 二、RNN Cell用法
    • 三、RNN用法
    • 三、实例:hello换序
      • 1.RNN Cell
      • 2.RNN
    • 四、Embedding

一、RNN简介

       RNN网络最大的特点就是可以处理序列特征,就是我们的一组动态特征。比如,我们可以通过将前三天每天的特征(是否下雨,是否有太阳等)输入到网络,从而来预测第四天的天气。
       我们可以看RNN的网络结构如下:
在这里插入图片描述

二、RNN Cell用法

在这里插入图片描述

import torch

batch_size = 1 # 批处理大小
seq_len = 3 # 序列长度
input_size = 4 # 输入维度
hidden_size = 2 # 隐藏层维度

cell = torch.nn.RNNCell(input_size=input_size, hidden_size=hidden_size)

# (seq, batch, features)
dataset = torch.randn(seq_len, batch_size, input_size)
print(dataset)
hidden = torch.zeros(batch_size, hidden_size)
print(hidden)

for idx, input in enumerate(dataset):
    print( '=' * 20, idx, '=' * 20)
    print( 'Input size: ', input.shape)
    hidden = cell(input, hidden)
    print( 'outputs size: ', hidden.shape)
    print(hidden)

在这里插入图片描述

三、RNN用法

在这里插入图片描述

在这里插入图片描述

import torch

batch_size = 1 # 批处理大小
seq_len = 3 # 序列长度
input_size = 4 # 输入维度
hidden_size = 2 # 隐藏层维度
num_layers = 4  # 隐藏层数量

cell = torch.nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)

# (seqLen, batchSize, inputSize)
inputs = torch.randn(seq_len, batch_size, input_size)
hidden = torch.zeros(num_layers, batch_size, hidden_size)
out, hidden = cell(inputs, hidden)

print( 'Output size:', out.shape)
print( 'Output:', out)
print( 'Hidden size: ', hidden.shape)
print( 'Hidden: ', hidden)

在这里插入图片描述

batch_first参数:
在这里插入图片描述

三、实例:hello换序

       任务描述:我们需要训练一个模型,输入是“hello”,使输出是“ohlol”。如下图所示:
在这里插入图片描述
       方法描述:首先我们可以将“hello”中每个字母对应一个索引,之后得到输入“hello”和输出“ohlol”的编码分别为10223和31232。对编码中的每一个数字,都可以转换成一个四维张量(通过在对应张量对应索引填充为1,其余填充为0),如下图所示。这样我们的输入序列有5个元素,每个元素的维度为4。
在这里插入图片描述

1.RNN Cell

import torch

input_size = 4 # 输入维度,每个字母对应张量维度
hidden_size = 4
batch_size = 1

# 准备数据集
idx2char = ['e', 'h', 'l', 'o']
x_data = [1, 0, 2, 3, 3] # hello对应编码
y_data = [3, 1, 2, 3, 2] # ohlol对应编码

one_hot_lookup = [[1, 0, 0, 0],
                  [0, 1, 0, 0],
                  [0, 0, 1, 0],
                  [0, 0, 0, 1]]

x_one_hot = [one_hot_lookup[x] for x in x_data]

inputs = torch.Tensor(x_one_hot).view(-1, batch_size, input_size)
labels = torch.LongTensor(y_data).view(-1, 1)
print(inputs.shape, labels.shape)

# 构建模型
class Model(torch.nn.Module):
    def __init__(self, input_size, hidden_size, batch_size):
        super(Model, self).__init__()
        self.batch_size = batch_size
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.rnncell = torch.nn.RNNCell(input_size=self.input_size, hidden_size=self.hidden_size)

    def forward(self, inputs, hidden):
        hidden = self.rnncell(inputs, hidden)
        return hidden

    def init_hidden(self):
        return torch.zeros(self.batch_size, self.hidden_size)

net = Model(input_size, hidden_size, batch_size)

# 损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)

# 训练
for epoch in range(15):
    loss = 0
    optimizer.zero_grad()
    hidden = net.init_hidden()
    print( 'Predicted string: ', end= '')
    for input, label in zip(inputs, labels):
        hidden = net(input, hidden)
        loss += criterion(hidden, label)
        _, idx = hidden.max(dim=1)
        print(idx2char[idx.item()], end= '')
    loss.backward()
    optimizer.step()
    print( ', Epoch [%d/15] loss=%.4f' % (epoch+1, loss.item()))

在这里插入图片描述

2.RNN

import torch

input_size = 4
hidden_size = 4
batch_size = 1
seq_len = 5
num_layers = 1

# 准备数据集
idx2char = ['e', 'h', 'l', 'o']
x_data = [1, 0, 2, 3, 3] # hello对应编码
y_data = [3, 1, 2, 3, 2] # ohlol对应编码

one_hot_lookup = [[1, 0, 0, 0],
                  [0, 1, 0, 0],
                  [0, 0, 1, 0],
                  [0, 0, 0, 1]]
x_one_hot = [one_hot_lookup[x] for x in x_data]

inputs = torch.Tensor(x_one_hot).view(seq_len, batch_size, input_size)
labels = torch.LongTensor(y_data)
print(inputs.shape, labels.shape)

# 构建模型
class Model(torch.nn.Module):
    def __init__(self, input_size, hidden_size, batch_size, num_layers=1):
        super(Model, self).__init__()
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.rnn = torch.nn.RNN(input_size=self.input_size, hidden_size=self.hidden_size, )

    def forward(self, inputs):
        hidden = torch.zeros(self.num_layers, self.batch_size, self.hidden_size)
        out, _ = self.rnn(inputs, hidden)    # 注意维度是(seqLen, batch_size, hidden_size)
        return out.view(-1, self.hidden_size) # 为了容易计算交叉熵这里调整维度为(seqLen * batch_size, hidden_size)

net = Model(input_size, hidden_size, batch_size)

# 损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)

# 训练
for epoch in range(15):
    optimizer.zero_grad()
    outputs = net(inputs)
    # print(outputs.shape, labels.shape)
    # 这里的outputs维度是([seqLen * batch_size, hidden]), labels维度是([seqLen])
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    _, idx = outputs.max(dim=1)
    idx = idx.data.numpy()
    print('Predicted: ', ''.join([idx2char[x] for x in idx]), end='')
    print(', Epoch [%d/15] loss = %.3f' % (epoch + 1, loss.item()))

在这里插入图片描述

四、Embedding

  • One-hot encoding of words and characters:
    (1) The one-hot vectors are high-dimension.
    (2) The one-hot vectors are sparse.
    (3) The one-hot vectors are hardcoded.

  • Do we have a way to associate a vector with a word/character with following specification:
    (1) Lower-dimension
    (2) Dense
    (3) Learned from data

  • A popular and powerful way is called EMBEDDING.

在这里插入图片描述

我们使用embedding将模型变成如下所示:
在这里插入图片描述

import torch

# parameters
num_class = 4
input_size = 4
hidden_size = 8
embedding_size = 10
num_layers = 2
batch_size = 1
seq_len = 5

# 准备数据集
idx2char = ['e', 'h', 'l', 'o']
x_data = [[1, 0, 2, 2, 3]]  # (batch, seq_len)
y_data = [3, 1, 2, 3, 2]    # (batch * seq_len)

inputs = torch.LongTensor(x_data)   # Input should be LongTensor: (batchSize, seqLen)
labels = torch.LongTensor(y_data)   # Target should be LongTensor: (batchSize * seqLen)

# 构建模型
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.emb = torch.nn.Embedding(input_size, embedding_size)
        self.rnn = torch.nn.RNN(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = torch.nn.Linear(hidden_size, num_class)

    def forward(self, x):
        hidden = torch.zeros(num_layers, x.size(0), hidden_size)
        x = self.emb(x)  # (batch, seqLen, embeddingSize)
        x, _ = self.rnn(x, hidden)  # 输出(𝒃𝒂𝒕𝒄𝒉𝑺𝒊𝒛𝒆, 𝒔𝒆𝒒𝑳𝒆𝒏, hidden_size)
        x = self.fc(x)  # 输出(𝒃𝒂𝒕𝒄𝒉𝑺𝒊𝒛𝒆, 𝒔𝒆𝒒𝑳𝒆𝒏, 𝒏𝒖𝒎𝑪𝒍𝒂𝒔𝒔)
        return x.view(-1, num_class)  # reshape to use Cross Entropy: (𝒃𝒂𝒕𝒄𝒉𝑺𝒊𝒛𝒆×𝒔𝒆𝒒𝑳𝒆𝒏, 𝒏𝒖𝒎𝑪𝒍𝒂𝒔𝒔)

net = Model()

# 损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.05)

# 训练模型
for epoch in range(15):
    optimizer.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    _, idx = outputs.max(dim=1)
    idx = idx.data.numpy()
    print('Predicted: ', ''.join([idx2char[x] for x in idx]), end='')
    print(', Epoch [%d/15] loss = %.3f' % (epoch + 1, loss.item()))

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/382433.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

七、SpringBoot_自动装配

自动装配 官方文档 SpringBoot自动配置尝试根据您添加的 jar 依赖项自动配置您的 Spring 应用程序 Spring Boot auto-configuration attempts to automatically configure your Spring application based on the jar dependencies that you have added. SpringBoot定义了一套接…

详谈IIC

前言 在嵌入式底层系统中,常见的通讯方式,串口,IIC,SPI,IIS等,一般IIC,SPI,IIS更多的采取IO模拟,其余CAN,UART均是硬件设计直接支持,而IIC主要用于多数传感器数据的读写&#xff0c…

【c++基础】

C基础入门统一初始化输入输出输入输出符输入字符串const与指针c和c中const的区别const与指针的关系常变量与指针同类型指针赋值的兼容规则引用引用的特点const引用作为形参替换指针其他引用形式引用和指针的区别inline函数缺省参数函数重载判断函数重载的规则名字粉碎C编译时函…

Goby 征文大擂台,超值盲盒等你来!

001 Goby 技术征文正式启动 Goby 致力于做最好的网络安全工具。为了促进师傅们知识共享和技术交流,现发起关于 Goby 的技术文章征集活动! 欢迎所有师傅们参加,分享您的使用经验或挖洞窍门等,帮助其他人更好地了解和利用 Goby。 …

Winform界面实现控件中英文语言切换

一、业务需求 在Winform项目的开发过程中,涉及到一个基础的功能就是需要对界面中的显示语言内容可以进行选择切换,方便不同地区的使用者快速上手使用;效果如下: 二、需求分析 需要实现对Winform项目界面显示语言可选择切换步骤如下: ①修改控件的显示内容; ②获取到界面显…

【计算机网络】数据链路层

概述 封装成帧 差错检验 可靠传输 实现机制 可靠传输的实现机制 停止等待协议 回退N帧协议 选择重传协议 【计算机网络】MAC帧和PPP帧(定义使用范围区别共同点)_GPNU_Log的博客-CSDN博客_ppp帧 PPP帧和以太网帧 | Mixoo 数据链路层的协议有PPP协…

Rman单实例迁移到单实例

关于同平台同版本数据库之间的迁移操作的实验 ---Source DB[rootoracle-db-19cs ~]# cat /etc/redhat-release CentOS Stream release 8 [rootoracle-db-19cs ~]# --- Target DB[rootoracle-db-19ct ~]# cat /etc/redhat-release CentOS Stream release 8 [rootoracle-db-19ct…

如何使用dlinject将一个代码库实时注入到Linux进程中

关于dlinject dlinject是一款针对Linux进程安全的注入测试工具,在该工具的帮助下,广大研究人员可以在不使用ptrace的情况下,轻松向正在运行的Linux进程中注入一个共享代码库(比如说任意代码)。之所以开发该工具&#…

扬帆优配|新概念火了!时空大数据龙头再冲击涨停

时空大数据近期受资金追捧。 今天早盘,中通邦本开盘后再度冲击涨停,一度封死涨停板,之后涨停打开,截至上午收盘仍上涨超5%。此前,中通邦本已接连两日涨停,公司在互动渠道上表示,参股公司北京邦本…

Spring中Bean生命周期及循环依赖

spring中所说的bean对象 与 我们自己new的对象(原始对象)是不同的;bean对象是指spring框架创建管理的我们的对象生命周期即:何时生,何时死1.实例化 Instantiation:spring通过反射机制以及工厂创建出来的原始对象;2.属性…

【Spring】八种常见Bean加载方式

🚩本文已收录至专栏:Spring家族学习 一.引入 (1) 概述 ​ 关于bean的加载方式,spring提供了各种各样的形式。因为spring管理bean整体上来说就是由spring维护对象的生命周期,所以bean的加载可以从大的方面划分成2种形式&#xff…

2023年融资融券研究报告

第一章 行业概况 融资融券是证券交易市场上的两种金融衍生品交易方式,主要用于股票、债券等证券的融资和投资。 融资是指投资者向证券公司借入资金购买证券,以期望股票价格上涨后卖出获得利润。融资需支付一定的利息和费用,利息根据借入的资…

CSS实现checkbox选中动画

前言 👏CSS实现checkbox选中动画,速速来Get吧~ 🥇文末分享源代码。记得点赞关注收藏! 1.实现效果 2.实现步骤 定义css变量,–checked,表示激活选中色值 :root {--checked: orange; }创建父容器&#xf…

python+pytest接口自动化(6)-请求参数格式的确定

我们在做接口测试之前,先需要根据接口文档或抓包接口数据,搞清楚被测接口的详细内容,其中就包含请求参数的编码格式,从而使用对应的参数格式发送请求。例如某个接口规定的请求主体的编码方式为 application/json,那么在…

Go 实现 AOI 区域视野管理

在游戏中,场景里存在大量的物体.如果我们把所有物体的变化都广播给玩家.那客户端很难承受这么大的压力.因此我们肯定会做优化.把不必要的信息过滤掉.如只关心玩家视野所看到的.减轻客户端的压力,给玩家更流畅的体验. 优化的思路一般是: 第一个是尽量降低向客户端同步对象的数量…

【Java】P1 基础知识与碎碎念

Java 基础知识 碎碎念安装 Intellij IDEAJDK 与 JREJava 运行过程Java 系统配置Java 运行过程Java的三大分类前言 本节内容主要围绕Java基础内容,从Java的安装到helloworld,什么是JDK与什么是JRE,系统环境配置,不深入Java代码知识…

传导EMI抑制-Π型滤波器设计

1 传导电磁干扰简介 在开关电源中,开关管周期性的通断会产生周期性的电流突变(di/dt)和电压突变(dv/dt),周期性的电流变化和电压变化则会导致电磁干扰的产生。 图1所示为Buck电路的电流变化,在Buck电路中上管电流和下…

ubuntu 22.04 mangodb

文章写在2023年3月1日 目前最新的mangodb稳定版本是6.04 1.安装server server安装包为mangodb的程序主体。 服务器deb安装包下载地址 https://www.mongodb.com/try/download/community ubuntu22.04的server deb 文件url https://repo.mongodb.org/apt/ubuntu/dists/jammy/mo…

计算机组成原理 浮点数运算清晰明了

注释:阶码和尾数都需要符号位区分正负 例题1:x 2^-11*0.100101, y 2^-10*(-0.011110),求xy 第零步 补码表示 对于x来说-11 补码表示为 11011; 0.100101补码表示为00.100101对于y来说-10补码表示为 10110&#xff…

【el】表单

elementUI中的表单相关问题一、用法1、动态表单调用接口返回表单&#xff0c;后端的接口返回值如下&#xff1a;这些是渲染后的效果页面使用&#xff08;父组件&#xff09;<el-button size"small" class"Cancelbtn" click"sub(true)">发起…