PyTorch -- RNN 快速实践

news2025/1/12 1:03:49
  • RNN Layer torch.nn.RNN(input_size,hidden_size,num_layers,batch_first)

    • input_size: 输入的编码维度
    • hidden_size: 隐含层的维数
    • num_layers: 隐含层的层数
    • batch_first: ·True 指定输入的参数顺序为:
      • x:[batch, seq_len, input_size]
      • h0:[batch, num_layers, hidden_size]
  • RNN 的输入

    • x:[seq_len, batch, input_size]
      • seq_len: 输入的序列长度
      • batch: batch size 批大小
    • h0:[num_layers, batch, hidden_size]
  • RNN 的输出

    • y: [seq_len, batch, hidden_size]

在这里插入图片描述


  • 实战之预测 正弦曲线:以下会以此为例,演示 RNN 预测任务的部署
    在这里插入图片描述
    • 步骤一:确定 RNN Layer 相关参数值并基于此创建 Net

      import numpy as np
      from matplotlib import pyplot as plt
      
      import torch
      import torch.nn as nn
      import torch.optim as optim
      
      
      seq_len     = 50
      batch       = 1
      num_time_steps = seq_len
      
      input_size  = 1
      output_size = input_size
      hidden_size = 10  	
      num_layers = 1  	
      batch_first = True 
      
      class Net(nn.Module):  ## model 定义
      	def __init__(self):
      		super(Net, self).__init__()
      		self.rnn = nn.RNN(
      		input_size=input_size,
      		hidden_size=hidden_size,
      		num_layers=num_layers,
      		batch_first=batch_first)
      		# for p in self.rnn.parameters():
      		# 	nn.init.normal_(p, mean=0.0, std=0.001)
      		self.linear = nn.Linear(hidden_size, output_size)
      
      	def forward(self, x, hidden_prev):
      		out, hidden_prev = self.rnn(x, hidden_prev)
      		# out: [batch, seq_len, hidden_size]
      		out = out.view(-1, hidden_size)  # [batch*seq_len, hidden_size]
      		out = self.linear(out) 			 # [batch*seq_len, output_size]
      		out = out.unsqueeze(dim=0)    # [1, batch*seq_len, output_size]
      		return out, hidden_prev
      
    • 步骤二:确定 训练流程

      lr=0.01
      
      def tarin_RNN():
          model = Net()
          print('model:\n',model)
          criterion = nn.MSELoss()
          optimizer = optim.Adam(model.parameters(), lr)
          hidden_prev = torch.zeros(num_layers, batch, hidden_size)  #初始化h
      
          l = []
          for iter in range(100):  # 训练100次
              start = np.random.randint(10, size=1)[0]  ## 序列起点
              time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列
              data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据
      
              x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)
              y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点
      
              output, hidden_prev = model(x, hidden_prev)
              hidden_prev = hidden_prev.detach()  ## 最后一层隐藏层的状态要 detach
      
              loss = criterion(output, y)
              model.zero_grad()
              loss.backward()
              optimizer.step()
      
              if iter % 100 == 0:
                  print("Iteration: {} loss {}".format(iter, loss.item()))
                  l.append(loss.item())
          #############################绘制损失函数#################################
          plt.plot(l,'r')
          plt.xlabel('训练次数')
          plt.ylabel('loss')
          plt.title('RNN LOSS')
          plt.savefig('RNN_LOSS.png')
          return hidden_prev,model
      
       hidden_prev,model = tarin_RNN()
      
    • 步骤三:测试训练结果

      start = np.random.randint(3, size=1)[0]  ## 序列起点
      time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列
      data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据
      x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)
      y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点    
      
      predictions = []  ## 预测结果
      input = x[:,0,:]
      for _ in range(x.shape[1]):
          input = input.view(1, 1, 1)
          pred, hidden_prev = model(input, hidden_prev)
          input = pred  ## 循环获得每个input点输入网络
          predictions.append(pred.detach().numpy()[0])
      x= x.data.numpy()
      y = y.data.numpy( )
      plt.scatter(time_steps[:-1], x.squeeze(), s=90)
      plt.plot(time_steps[:-1], x.squeeze())
      plt.scatter(time_steps[1:],predictions)  ## 黄色为预测
      plt.show()
      

      在这里插入图片描述


【高阶】上述例子比较简单,便于入门以推理到自己的目标任务,实际 RNN 训练可能更有难度,可以添加

  • 对于梯度爆炸的解决:
    for p in model.parameters()"
    	p.grad.nomr()
    	torch.nn.utils.clip_grad_norm_(p, 10)  ## 其中的 norm 后面的_ 表示 in place
    
  • 对于梯度消失的解决:-> LSTM

  • 另一个很好的实例关于飞行轨迹预测- - RNN-博客链接,可供学习参考
  • B站视频参考资料

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1839629.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

探究C语言函数栈帧的创建和销毁

引言 在C语言程序中,每当一个函数被调用时,系统都会在栈上为该函数分配一块内存空间,这块内存空间就被称为栈帧。 栈帧中包含了函数执行所需的所有信息,如局部变量、参数、返回地址等。栈帧的创建和销毁是函数调用的核心部分&am…

【华为HCIA数通网络工程师真题-数据通信与网络基础】

文章目录 选择题判断题 选择题 1、在 VRP 平台上,可以通过下面哪种方式访向上条历史命令? 上光标 (ctrlU 为自定义快捷键,ctrlP 为显示历史缓存区的前一条命令,左光标为移动光标) 2、主机 A (1…

TensorRT-常见问题

1、ModelImporter.cpp:779: ERROR: builtin_op_importers.cpp:3608 In function importResize:[8] Assertion failed: scales.is_weights() && "Resize scales must be an initializer!"解决方法:将TensorRT版本升到可以匹配cuda版本的最高版本&a…

多态性(Java)

本篇学习面向对象语言的第三个特性——多态。 目录 1、多态的概念 2、继承多态实现条件 3、重写 4、重新与重载的区别: 5、向上转移和向下转型 5、1向上转型: 5、2 向下转型 1、多态的概念 多态的概念:通俗来说,就是多种形态…

Servlet实践操作

Servlet运行原理 Tomcat 的代码内置了 main 方法,当我们启动 Tomcat 的时候,就是从 Tomcat 的 main 方法开始执行的 被 WebServlet 注解修饰的类会在 Tomcat 启动的时候就被获取并集中管理 Tomcat 通过反射这样的语法机制来创建被 WebServlet 注解修饰…

Day 27:2596. 检查骑士巡视方案

Leetcode 2596. 检查骑士巡视方案 骑士在一张 n x n 的棋盘上巡视。在 **有效 **的巡视方案中,骑士会从棋盘的 左上角 出发,并且访问棋盘上的每个格子 恰好一次 。 给你一个 n x n 的整数矩阵 grid ,由范围 [0, n * n - 1] 内的不同整数组成&…

超神级!Markdown最详细教程,程序员的福音

超神级!Markdown最详细教程,程序员的福音Markdown最详细教程,关于Markdown的语法和使用就先讲到这里,如果喜欢,请关注“IT技术馆”。馆长会更新​最实用的技术!https://mp.weixin.qq.com/s/fNzhLFyYRd3skG-…

linux环境编程基础学习

Shell编程: 相对的chmod -x xx.sh可以移除权限 想获取变量的值要掏点dollar($) 多位的话要加个花括号 运算:expr 运算时左右两边必须要加空格 *号多个含义必须加转义符 双引号可以加反单,但是发过来就不行 …

containerd手动配置容器网络

containerd手动配置容器网络 机器详情nerdctl启动一个不带网络的容器获取容器ID、PID与network namespace路径准备bridge插件的执行配置文件通过下面的命令调用bridge插件准备tuning插件文件执行下面的命令调用tuning插件准备portmap插件文件执行下面的命令调用portmap插件删除…

算法竞赛数论杂题

menji 和 gcd 题目: 一开始以为是只有l不确定,r是确定的,这样的话我们可以枚举r的所有约数,然后对其每个约数x进行判断,判断是否满足题意,具体做法是先让l % x如果 0则该约数可行,如果不可行…

文件扫描工具都有哪些?职场大佬都在用的文本提取工具大盘点~

回想起刚毕业初入职场那阵子,领导让帮忙把纸质文件扫描提取为文本时,还只会傻乎乎地一点点操作,属实是费劲得很! 好在后面受朋友安利,找到了4个能够快速实现文件扫描文字提取的方法,这才让我的办公效率蹭蹭…

[SCAU 课程设计参考] 活动管理程序

(仅供参考!!!!!!) 废话不多说,直接上代码!(但是量有点多,放前面影响观感,所以还是先不放了,文章末尾有链接) 题目的要求: 提要:我的设计只是一个参考,当时还是大一的时候写的,代码比较青涩&a…

[学习笔记]-MyBatis-Plus简介

简介 Mybatis-Plus(简称 MP)是一个 MyBatis (opens new window)的增强工具,在 MyBatis 的基础上只做增强不做改变,为简化开发、提高效率而生。 简言之就是对单表的增删改查有了很好的封装。基本不用再单独写sql语句了。目前此类…

微博舆情分析系统可以继续完善的基于python 前端vue

微博舆情分析系统可以继续完善的,前后端分离,前端基于vue 后端基于python的flask可以说是非常的简洁,支持实时更新数据。界面如图 主要工作点体现在后端实时更新数据跟数据的处理方面上,后续有空会用hadoop来处理海量数据真…

数据库 | 试卷四

1.数据库系统的特点是 数据共享、减少数据冗余、数据独立、避免了数据不一致和加强了数据保护 2.关系模型的数据结构是二维表结构 3.聚簇索引 cluster index 4. 这里B,C都是主属性,所以B->C不是非主属性对码的部分函数依赖 候选键(AC&a…

Jlink下载固件到RAM区

Jlink下载固件到RAM区 准备批处理搜索exe批处理调用jlink批处理准备jlink脚本 调用执行 环境:J-Flash V7.96g 平台:arm cortex-m3 准备批处理 搜索exe批处理 find_file.bat echo off:: 自动识别脚本名和路径 set "SCRIPT_DIR%~dp0" set &qu…

开发者黑板报#65

第65期 AI 谷歌Gemini 终于,GPT-4独霸时代终结了! 过去一个月里,四款大模型横空出世,在各项关键基准测试中与GPT-4相匹敌,甚至更胜一筹。 谷歌Gemini 1.5突破100万个tokens,是GPT-4的近8倍&#xff0c…

【Docker】——安装镜像和创建容器,详解镜像和Dockerfile

前言 在此记录一下docker的镜像和容器的相关注意事项 前提条件:已安装Docker、显卡驱动等基础配置 1. 安装镜像 网上有太多的教程,但是都没说如何下载官方的镜像,在这里记录一下,使用docker安装官方的镜像 Docker Hub的官方链…

易舟云财务软件:开启云记账新时代

在数字化浪潮的推动下,财务管理正经历着深刻的变革。易舟云财务软件,作为一款引领时代的云记账平台,以其卓越的功能和便捷的操作,为企业带来了全新的财务管理体验。 云记账,财务管理的未来趋势 云记账,即基…

【紫光同创盘古PGX-Nano教程】——(盘古PGX-Nano开发板/PG2L50H_MBG324第十一章)模拟波形实验例程说明

本原创教程由深圳市小眼睛科技有限公司创作,版权归本公司所有,如需转载,需授权并注明出处(www.meyesemi.com) 适用于板卡型号: 紫光同创PG2L50H_MBG324开发平台(盘古PGX-Nano) 一:…