pytorch实现RNN网络

news2024/9/22 2:51:58

目录

1.导包

2. 加载本地文本数据

 3.构建循环神经网络层

4.初始化隐藏状态state

5.创建随机的数据,检测一下代码是否能正常运行

6. 构建一个完整的循环神经网络¶ 

7.模型训练 

8.个人知识点理解


 

1.导包

import torch
from torch import nn
from torch.nn import functional as F
import dltools

2. 加载本地文本数据

#声明变量:批次大小(一批所取的数据量)、子序列的长度
batch_size, num_steps =32, 35
#获取训练数据的迭代器, 词汇表
train_iter, vocab = dltools.load_data_time_machine(batch_size=batch_size, num_steps=num_steps)

 3.构建循环神经网络层

#声明变量:隐藏层的神经元数量(每个神经元都会有一个输出)
num_hiddens = 256
#构建一个具有256个隐藏单元的单隐藏层的循环神经网络
#num_layers=1默认值:一层神经网络
rnn_layer = nn.RNN(input_size=len(vocab), hidden_size=num_hiddens, num_layers=1)

4.初始化隐藏状态state

# 括号中的1:因为num_layers=1默认值:一层神经网络
state = torch.zeros((1, batch_size, num_hiddens))
state.shape
torch.Size([1, 32, 256])

5.创建随机的数据,检测一下代码是否能正常运行

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
#传入X和初始化时的state,获取Y和state_new
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape


#有输出表示代码正常运行!!!

 (torch.Size([35, 32, 256]), torch.Size([1, 32, 256])) 

6. 构建一个完整的循环神经网络¶ 

.long() 方法‌:这是PyTorch张量的一个方法,用于将张量的数据类型转换为torch.long。torch.long是一种整数数据类型,通常用于索引或存储不需要浮点数精度的整数数据。 

class RNNModel(nn.Module):   #继承nn.Module
    #初始化(需要用到的)参数,  **kwargs表示继承的其他参数(不一一写明的意思)
    #vocab_size = len(vocab)
    def __init__(self, rnn_layer, vocab_size, **kwargs):
        #继承父类的属性和方法
        super().__init__(**kwargs)
        self.rnn_layer = rnn_layer
        #词汇表的长度
        self.vocab_size =vocab_size
        self.num_hiddens = self.rnn_layer.hidden_size
        
        #判断是否为双向循环
        if not self.rnn_layer.bidirectional:
            self.num_directions = 1
            #nn.Linear用于定义线性层的类,一般用于全连接层
            self.linear = nn.Linear(in_features=self.num_hiddens, out_features=self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens*2, self.vocab_size)
    
    #定义了数据在模型中的前向传播过程。(串联每一件事件的逻辑顺序)
    def forward(self, inputs, state):
        #one_hot编码,处理输入的X数据,此时的X.shape=(batch_size, num_steps)
        #。T转置之后,X.shape=(num_steps,batch_size)
        #one_hot编码之后, X.shape=(num_steps,batch_size, len(vocab)
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        #将数据转化为tensor
        X = X.to(torch.float32)
        Y, state = self.rnn_layer(X, state)
        #此时,Y.shape = torch.Size(num_steps, batch_size, num_hiddens)
        
        #输出层:Y.shape必须是一个二维的, -1表示合并Y.shape中的num_steps与batch_size,
        outputs = self.linear(Y.reshape(-1, Y.shape[-1]))
        return outputs, state
                              
   # 初始化隐藏状态
    def begin_state(self, device, batch_size=1):
        return torch.zeros((self.num_directions * self.rnn_layer.num_layers, batch_size, self.num_hiddens), device=device)
#在训练之前,基于随机初始化的权重进行预测,测试模型
device = dltools.try_gpu()
rnn_net = RNNModel(rnn_layer, vocab_size=len(vocab))
rnn_net = rnn_net.to(device)
dltools.predict_ch8(prefix='time traveller',
                    num_preds=10, 
                    net=rnn_net, 
                    vocab=vocab, 
                    device=device)
'time travellergghhhhhhhh'

7.模型训练 

#声明变量
#模型训练时,可以先让学习率的值稍大一些,让梯度下降的快一些,然后
#梯度下降到一定程度再改成较小的值
num_epochs, lr = 500, 0.1
dltools.train_ch8(net=rnn_net, 
                  train_iter=train_iter, 
                  vocab=vocab, 
                  lr=lr, 
                  num_epochs=num_epochs, 
                  device=device)

 

8.个人知识点理解

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2153945.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何解决DataGrip的 Public Key Retrieval is not allowed错误

对于 DataGrip 出现 [08001] Public Key Retrieval is not allowed 错误,原因通常是 MySQL 的安全机制不允许客户端检索公钥。你可以通过以下步骤来解决这个问题: 解决步骤: 修改 DataGrip 中的连接设置: 打开 DataGrip。在左侧导…

CLion/Git版本控制

文章目录 文章介绍准备工具操作首次提交修改代码提交第二版 文章介绍 记录用clion和git做代码的版本控制 准备工具 CLion2024.2.0.1 git 操作 首次提交 该文件夹的打开方式选择clion 全部提交 成功提交后查看分支 修改代码提交第二版

hutool 解压缩读取源文件和压缩文件大小失败导致报错

前言 最近处理老项目中的问题,升级安全jar,发现hutool的jar在解压缩的时候报错了,实际上是很简单的防御zip炸弹攻击的手段,但是却因为hutool的工具包取文件大小有bug,造成了解压缩不能用,报错:…

山东潍坊戴尔存储服务器维修 md3800f raid恢复

山东戴尔存储故障维修 存储型号:DELL PowerVault md3800f 故障问题:存储除尘后通电开机,发现有物理硬盘没有插到位,用户带电拔插了多块物理盘,导致关连的磁盘阵列掉线,卷失败; 处理方式&#xf…

Python基于Django、大数据的北极星招聘数据可视化系统

博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇…

[JavaEE] TCP协议

目录 一、TCP协议段格式 二、TCP确保传输可靠的机制 2.1 确认应答 2.2 超时重传 2.3 连接管理 2.3.1 三次握手 2.3.2 四次挥手 2.4 滑动窗口 2.4.1 基础知识 2.4.2 两种丢包情况 2.4.2.1 数据报已经抵达,ACK丢包 2.4.2.2 数据包丢包 2.5 流量控制…

国标GB28181视频融合监控汇聚平台的方案实现及场景应用

Liveweb国标视频融合云平台基于端-边-云一体化架构,部署轻量简单、功能灵活多样,平台可支持多协议(GB28181/RTSP/Onvif/海康SDK/Ehome/大华SDK/RTMP推流等)、多类型设备接入(IPC/NVR/监控平台),在视频能力上&#xff0…

图解 | 消息认证码(MAC)到底解决了什么问题?还有什么问题是它解决不了的?

消息认证码(Message Authentication Code,MAC)是一种用于验证数据完整性和来源可信性(对消息进行认证)的技术。它通常由一个密钥和被保护的消息通过特定算法计算得出,接收方可以使用相同的密钥(…

C++类之set与get理解

在类中,我们尝尝将一些变量设置为private或者protect里面,而我们经常会遇到在主函数(main.cpp)使用到这些private变量,而往往我们会下意识地在主函数直接调用在private里面的变量,但现实比较残酷&#xff0…

20240921解决使用PotPlayer在WIN10电脑播放4K分辨率10bit的视频出现偏色的问题

20240921解决使用PotPlayer在WIN10电脑播放4K分辨率10bit的视频出现偏色的问题 2024/9/21 10:40 缘起:常见的问题,你下载视频的时候,4K分辨率的视频播放的时候出现偏色异常,但是1080p分辨率的正常呀! 偏色的识别&…

re题(32)BUUCTF-[MRCTF2020]hello_world_go

BUUCTF在线评测 (buuoj.cn) 查壳,无壳,64位elf文件 ida打开是go语言写的,shiftF12看字符串 ctrlF搜索字符串,得到flag 本题是go语言写的,可以用linux打开go语言文件,本题直接把flag放到了字符串表&#xf…

数据结构---二叉搜索树(二叉排序树)

什么是二叉排序树 二叉搜索树又是二叉排序树,当我们的是一颗空树或者具有以下性质时: 左子树不为空,左子树上的值都小于我们的根节点上的值。右子树不为空时,右子树上的值都大于我们的根节点上的值左右子树都是二叉搜索树&#…

我的AI工具箱Tauri版-VideoDuplication视频素材去重

本教程基于自研的AI工具箱Tauri版进行VideoDuplication视频素材去重。 该项目是基于自研的AI工具箱Tauri版的视频素材去重工具,用于高效地处理和去除重复视频内容。用户可以通过搜索关键词"去重"或通过路径导航到"Python音频技术/视频tools"模…

封装的例题

答案A 解析: 选项B说法也正确,但是不如A更有效 选项C 不管采用什么方法,文档是必须要写的 选项D 说法太绝对了,如果封装的内容不适合,开发者可能做软件开发反而难度系数加大

芯片开发(1)---BQ76905---底层参数配置

主要开发思路:AFE主要是采集、保护功能、均衡,所以要逐一去配置芯片的寄存器 采集、均衡功能主要是配置引脚 保护功能主要是参数寄存器配置,至于如何使用命令修改寄存器参数该系列芯片提供了子命令和直接命令两种方式 BQ76905的管脚配置 I、参数配置 …

ubuntu 执行定时任务crontab -e 无法输入的问题

界面显示 GNU nano 4.8 /tmp/crontab.l0A1HJ/crontab # Edit this file to introduce tasks to be run by cron. # # Each task to run has to be defined t…

全国职业院校技能大赛(大数据赛项)-平台搭建hive笔记

在大数据时代,数据量呈爆炸性增长,传统的数据处理工具已难以满足需求。Hive作为一个开源的数据仓库工具,能够处理大规模数据集,提供了强大的数据查询和分析能力,是大数据学习中的关键工具。在全国职业院校技能大赛&…

【图像检索】基于Gabor特征的图像检索,matlab实现

博主简介:matlab图像代码项目合作(扣扣:3249726188) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 本次案例是基于Gabor特征的图像检索,用matlab实现。 一、案例背景和算法介绍 这次博…

GPT-4o在matlab编程中性能较好,与智谱清言相比

边标签由矩阵给出 s [1 2 3 3 3 3 4 5 6 7 8 9 9 9 10]; t [7 6 1 5 6 8 2 4 4 3 7 1 6 8 2]; G graph(s,t); plot(G) ------------------- GPT-4o给出的代码可用, clc;clear; % 定义边的起点和终点 s [1 2 3 3 3 3 4 5 6 7 8 9 9 9 10]; t [7 6 1 5 6 8 2 …

您可能一直在寻找的 10 个非常有用的前端库

文章目录 前言正文1.radash2.dayjs3.driver4.formkit/drag-and-drop5.logicflow6.ProgressBar7.tesseract8.zxcvbn9.sunshine-track10.lottie 前言 前端开发中,总有一些重复性的工作让我们疲于奔命。为了提高开发效率,我们精心挑选了10个功能强大、易于…