Pytorch搭建循环神经网络RNN(简单实战)

news2024/9/23 21:25:37

Pytorch搭建循环神经网络RNN(简单实战)

去年写了篇《循环神经网络》,里面主要介绍了循环神经网络的结构与Tensorflow实现。而本篇博客主要介绍基于Pytorch搭建RNN。

通过Sin预测Cos

import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt

首先,我们定义一些超参数

TIME_STEP = 10  # rnn 时序步长数
INPUT_SIZE = 1  # rnn 的输入维度
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
H_SIZE = 64  # of rnn 隐藏单元个数
EPOCHS = 100  # 总共训练次数
h_state = None  # 隐藏层状态

使用Numpy生成Sin和Cos函数

steps = np.linspace(0, np.pi*2, 256, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)

可视化数据

plt.figure(1)
plt.suptitle('Sin and Cos', fontsize='18')
plt.plot(steps, y_np, 'r-', label='target (cos)')
plt.plot(steps, x_np, 'b-', label='input (sin)')
plt.legend(loc='best')
plt.show()

定义网络结构

class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        self.rnn = nn.RNN(
            input_size=INPUT_SIZE,
            hidden_size=H_SIZE,
            num_layers=1,
            batch_first=True,
        )
        self.out = nn.Linear(H_SIZE, 1)

    def forward(self, x, h_state):
        r_out, h_state = self.rnn(x, h_state)
        outs = []  # 保存所有的预测值
        for time_step in range(r_out.size(1)):  # 计算每一步长的预测值
            outs.append(self.out(r_out[:, time_step, :]))
        return torch.stack(outs, dim=1), h_state
rnn = RNN().to(DEVICE)
optimizer = torch.optim.Adam(rnn.parameters())  # Adam优化,几乎不用调参
criterion = nn.MSELoss()  # 因为最终的结果是一个数值,所以损失函数用均方误差

rnn.train()
plt.figure(2)
for step in range(EPOCHS):
    start, end = step * np.pi, (step+1)*np.pi  # 一个时间周期
    steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
    x_np = np.sin(steps)
    y_np = np.cos(steps)
    x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis])  # shape (batch, time_step, input_size)
    y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
    x = x.to(DEVICE)
    prediction, h_state = rnn(x, h_state) # rnn output
    # 这一步非常重要
    h_state = h_state.data  # 重置隐藏层的状态, 切断和前一次迭代的链接
    loss = criterion(prediction.cpu(), y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (step+1) % 20 == 0:  # 每训练20个批次可视化一下效果,并打印一下loss
        print("EPOCHS: {},Loss:{:4f}".format(step, loss))
        plt.plot(steps, y_np.flatten(), 'r-')
        plt.plot(steps, prediction.cpu().data.numpy().flatten(), 'b-')
        plt.draw()
        plt.pause(0.01)

运行结果如下:

EPOCHS: 19,Loss:0.052745

EPOCHS: 39,Loss:0.016266

EPOCHS: 59,Loss:0.005471

EPOCHS: 79,Loss:0.001329

EPOCHS: 99,Loss:0.002216

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1019880.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

医学影像相关开源数据集资源汇总

CT 医学图像 下载链接:http://suo.nz/2tQehH 该数据集旨在允许测试不同的方法来检查与使用对比度和患者年龄相关的 CT 图像数据的趋势。基本思想是识别与这些特征密切相关的图像纹理、统计模式和特征,并可能构建简单的工具,在这些图像被错误…

MyEclipse项目导入与导出

一、项目导出 1、右键选择项目名称,弹出菜单中选择“export”,如下图所示 2、选择“恶心“export”,弹出菜单如下;在“General“选项中,选择“File System”选项 3、点击“next”,进入保存位置选择界面&am…

异步FIFO设计的仿真与综合技术(5)

概述 本文主体翻译自C. E. Cummings and S. Design, “Simulation and Synthesis Techniques for Asynchronous FIFO Design 一文,添加了笔者的个人理解与注释,文中蓝色部分为笔者注或意译。前文链接: 异步FIFO设计的仿真与综合技术&#xf…

小目标检测高效解决方案汇总,附19篇原论文&开源代码

目标检测发展至今,涌现出了许多非常实用的方法,但在小目标检测领域, 由于小目标经常存在图片模糊、信息少、分辨率低等问题,性能水平仍然难以提升。 不过在近几年间,已经有许多有效的解决方法被提出,我今天…

前端录入音频并上传

目录 纯 js 实现(有问题)使用插件 recorder-core (没问题) 纯 js 实现(有问题) 上传音频文件时 blob 数据中 size 一直是0,导致上传之后音频不可播放(本地录制后本地是可以播放的&am…

什么是CORS(跨源资源共享)?如何解决前端中的CORS问题?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ CORS(跨源资源共享)⭐ 解决前端中的CORS问题的方法⭐ 写在最后 ⭐ 专栏简介 前端入门之旅:探索Web开发的奇妙世界 欢迎来到前端入门之旅!感兴趣的可以订阅本专栏哦!这个专栏是为…

【前端知识】Three 学习日志(三)—— 光源对物体表面的影响

Three 学习日志(三)—— 光源对物体表面的影响 一、设置材质为受光照影响 //MeshLambertMaterial受光照影响 const material new THREE.MeshLambertMaterial();此时,场景中一片漆黑,无法看到原来的物体,需要设置光源…

24v转5v稳压芯片-5A大电流输出ic

这款24V转5V5A汽车充电芯片具有以下特性和参数: - 宽输入电压范围:4.5V至36V - 最大输出电流:5.0A - 高达92%的转换效率 - 恒流/恒压模式控制 - 最大占空比100% - 可调输出电压 - 2%的输出电压精度 - 集成40mΩ高侧开关 - 集成18mΩ低侧开关 …

网络安全深入学习第六课——热门框架漏洞(RCE— Weblogic反序列化漏洞)

文章目录 一、Weblogic介绍二、Weblogic反序列化漏洞历史三、Weblogic框架特征1、404界面2、登录界面 四、weblogic常用弱口令账号密码五、Weblogic漏洞介绍六、Weblogic漏洞手工复现1、获取账号密码,这是一个任意文件读取的漏洞1)读取SerializedSystemI…

K8s(Kubernetes)学习(六)——Ingress

第六章 Ingress 什么是 IngressIngress 和 Service 区别Ingress 控制器 Traefik 使用Ingress Route的定义 1 简介 https://kubernetes.io/zh-cn/docs/concepts/services-networking/ingress/ Ingress 是一种 Kubernetes 资源类型,它允许在 Kubernetes 集群中暴露…

浏览器报错内容:Provisional headers are shown

浏览器报错内容:Provisional headers are shown 如下图: 解决方法:nginx 443 启用HTTP/2模式,如下图: server {listen 443 ssl http2;server_name callcenterda.umworks.com;client_max_body_size 200M;ssl_session_…

Idea注释相关配置模板

设置-编辑器-实时模板。 这里可以自己建立一个文件夹,建立自己的模板 1、普通多行注释 2、方法注释 我的方法注释模板文本: ** *$param$$return$ **/ 点击编辑变量: 两个默认值分别为: groovyScript("if(\"${…

倾情奉献,纯css(无图,无JS)原创中秋贺卡!!!

🪴 背景故事 中秋节马上就要到了,在这里我提前祝大家生活美满万年长,阖家幸福永平安!🥳 好了进入正题,最近掘金出了一个“中秋创意投稿”活动,我向来对这种可以写一些具有创意性的代码的活动很…

问题记录:两台Ubuntu之间传输文件(SCP)

1.查看IP地址 首先查看目标设备的 IP 地址:要把文件传到哪台机器上,就看哪台机器的 IP 地址,有两种方法 1.在终端输入 ifconfig: 2.设置里面看 2. 在自己的PC端 ping 一下目标机器的 IP 地址,看是否可以连接 ping 172.17.160…

使用ROS与Movelt实现myCobot 280运动轨迹规划和控制

ROS的技术案例 Introduction 今天这篇文章将记录我使用myCobot 280 M5stack 在ROS当中是如何使用的。为什么使用ROS呢,因为提及到机器人都离不开ROS这个操作系统,今天是我们第一次使用ROS这个系统。 今天我将从ROS的介绍,环境的配置以及mycob…

DPU加速AI应用“遍地开花”,中科驭数亮相2023全球AI芯片峰会

9月15日,2023全球AI芯片峰会(GACS 2023)在深圳举行,聚集了全球AI芯片产业的领军者和中坚力量,共探AI芯片的求新、求变、求索之径。中科驭数高级副总裁张宇应邀在智算中心算力与网络高峰论坛发表题为《基于DPU的高效AI大…

《2023年中国数字孪生行业报告》丨附下载_三叠云

✦ ✦✦ ✦✦ ✦✦ ✦ 随着近年来人工智能、物联网、虚拟现实等技术的持续发展以及元宇宙概念的兴起,数字孪生概念进一步完善,适用范围不断拓宽。然而非业界人士对数字孪生概念仍缺乏统一认知。对此,本报告介绍数字孪生概念、行业情况、市场…

java开发之个微机器人的开发

简要描述: 根据消息回调收到的xml转发文件消息,适用于同内容大批量发送 请求URL: http://域名地址/sendRecvFile 请求方式: POST 请求头Headers: Content-Type:application/jsonAuthorization&#…

uniapp 使用subNVue原生子窗体显示弹框或悬浮框

效果展示 在uniapp中,我们可以使用subNVue原生子窗体来解决web-view等原生页面中弹框无法显示的问题。 subNVue原生子窗体是uniapp提供的一种原生组件,可以在uniapp中嵌入原生页面,并且可以与uniapp页面进行通信。我们可以在原生页面中使用…

web浏览器公网远程访问jupyter notebook【内网穿透】

文章目录 前言1. Python环境安装2. Jupyter 安装3. 启动Jupyter Notebook4. 远程访问4.1 安装配置cpolar内网穿透4.2 创建隧道映射本地端口 5. 固定公网地址 前言 Jupyter Notebook,它是一个交互式的数据科学和计算环境,支持多种编程语言,如…