深度学习——RNN解决回归问题

news2025/1/3 2:38:31

详细代码与注释

import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

# 有利于复现代码
# torch.manual_seed(1)    # reproducible

# Hyper Parameters
TIME_STEP = 10  # rnn time step
# 输入sin函数的y值,所以输入尺寸为1
INPUT_SIZE = 1  # rnn input size
LR = 0.02  # learning rate

# show data
steps = np.linspace(0, np.pi * 2, 100, dtype=np.float32)  # float32 for converting torch FloatTensor
x_np = np.sin(steps)
y_np = np.cos(steps)
plt.plot(steps, y_np, 'r-', label='target (cos)')
plt.plot(steps, x_np, 'b-', label='input (sin)')
plt.legend(loc='best')
plt.show()


# 搭建RNN
class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()

        # RNN层
        self.rnn = nn.RNN(
            input_size=INPUT_SIZE,
            hidden_size=32,  # rnn hidden unit
            num_layers=1,  # number of rnn layer
            batch_first=True,  # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
        )
        # 输出层
        self.out = nn.Linear(32, 1)

    def forward(self, x, h_state):
        # x (batch, time_step, input_size)
        # h_state (n_layers, batch, hidden_size)
        # r_out (batch, time_step, hidden_size)
        r_out, h_state = self.rnn(x, h_state)

        outs = []  # save all predictions
        for time_step in range(r_out.size(1)):  # calculate output for each time step
            outs.append(self.out(r_out[:, time_step, :]))
        return torch.stack(outs, dim=1), h_state

        # instead, for simplicity, you can replace above codes by follows
        # r_out = r_out.view(-1, 32)
        # outs = self.out(r_out)
        # outs = outs.view(-1, TIME_STEP, 1)
        # return outs, h_state

        # or even simpler, since nn.Linear can accept inputs of any dimension
        # and returns outputs with same dimension except for the last
        # outs = self.out(r_out)
        # return outs


rnn = RNN()
print(rnn)

optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.MSELoss()

h_state = None  # for initial hidden state

# 绘图
plt.figure(1, figsize=(12, 5))
# 动态绘制
plt.ion()  # continuously plot

for step in range(100):
    start, end = step * np.pi, (step + 1) * np.pi  # time range
    # use sin predicts cos
    steps = np.linspace(start, end, TIME_STEP, dtype=np.float32,
                        endpoint=False)  # float32 for converting torch FloatTensor
    x_np = np.sin(steps)
    y_np = np.cos(steps)

    x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis])  # shape (batch, time_step, input_size)
    y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])

    prediction, h_state = rnn(x, h_state)  # rnn output
    # !! next step is important !!
    h_state = h_state.data  # repack the hidden state, break the connection from last iteration

    loss = loss_func(prediction, y)  # calculate loss
    optimizer.zero_grad()  # clear gradients for this training step
    loss.backward()  # backpropagation, compute gradients
    optimizer.step()  # apply gradients

    # plotting
    plt.plot(steps, y_np.flatten(), 'r-')
    plt.plot(steps, prediction.data.numpy().flatten(), 'b-')
    plt.draw()
    plt.pause(0.05)

plt.ioff()
plt.show()

运行效果

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/760826.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

posix ipc之消息队列

note 1.mq_open函数的参数pathname应以/开始&#xff0c;且最多一个/ 2.mq_receive的参数msg_len应大于等于attr.msgsize 3.消息队列写方写时不要求读方就绪&#xff0c;读方读时不要求写方就绪(和管道不同) code #include <fcntl.h> #include <sys/stat.h> #…

汽车销售数据可视化分析实战

1、任务 市场需求&#xff1a;各年度汽车总销量及环比&#xff0c;各车类、级别车辆销量及环比 消费能力/价位认知&#xff1a;车辆销售规模及环比、不同价位车销量及环比 企业/品牌竞争&#xff1a;各车系、厂商、品牌车销量及环比&#xff0c;市占率及变化趋势 热销车型&…

x86架构ubuntu22下运行SFC模拟器zsnet

0. 环境 ubuntu22 1. apt安装 sudo apt install zsnes 2. 运行 zsnet 参考&#xff1a;在Ubuntu上用zsnes玩SFC游戏&#xff0c;https://blog.csdn.net/gqwang2005/article/details/3877121

MyBatis学习笔记之首次开发及文件配置

文章目录 MyBatis概述框架特点 有关resources目录开发步骤从XML中构建SqlSessionFactoryMyBatis中有两个主要的配置文件编写MyBatis程序关于第一个程序的小细节MyBatis的事务管理机制JDBCMANAGED 编写一个较为完整的mybatisjunit测试mybatis集成日志组件 MyBatis概述 框架 在…

win11 系统暂无可用音频设备导致播放失败/音频服务未响应

win11 系统暂无可用音频设备导致播放失败/音频服务未响应 win11再一次更新后音频突然用不了了&#xff0c;驱动和输出设备都显示正常&#xff0c;但每次播放就会出现下面的问题&#xff0c;重启和更新驱动也没用。最后百度了好久终于解决了。 最后发现可能是新的驱动和电脑不兼…

【C++11】function包装器的简单使用

function 1 function包装器使用场景2 包装器3 包装成员函数4 一道例题5 包装器的意义 1 function包装器使用场景 现在有代码如下&#xff1a; 要求声明出这两个函数的类型 int f(int a,int b) {return a b; } struct Functor {int operator(int a,int b){return a b;} }可以…

(atan2)+(最小回文字符统计)

C - Nearest vectors long double eps1e-18 atan2:x正轴旋转的弧度角&#xff0c;y>0为正&#xff0c;y<0为负const long double PIacos(-1.0); const long double eps1e-18; struct node {ll id;long double x,y,z; }t[NN]; bool cmp1(node l,node r) {return l.zep…

布隆过滤器在海量数据去重验证中应用

布隆过滤器在海量数据去重验证中应用 文章目录 布隆过滤器在海量数据去重验证中应用引子面试结束级方案——从数据库中取新手级方案——利用redis的set数据结构专业级方案——利用布隆过滤器 布隆过滤器基本概念优点缺点布隆过滤器的数据结构布隆过滤器的工作流程布隆过滤器的优…

位1的个数:一种高效且优雅的解法

本篇博客会讲解力扣“191. 位1的个数”的解题思路&#xff0c;这是题目链接。 要想求解二进制位中1的个数&#xff0c;首先要了解一个神奇的表达式&#xff1a;n & n - 1。这个表达式的含义是&#xff1a;把n的二进制中最低位的1变成0&#xff0c;也就是消除n的二进制中最低…

day5-环形链表II

142.环形链表II 力扣题目链接 给定一个链表的头节点 head &#xff0c;返回链表开始入环的第一个节点。 如果链表无环&#xff0c;则返回 null。 如果链表中有某个节点&#xff0c;可以通过连续跟踪 next 指针再次到达&#xff0c;则链表中存在环。 为了表示给定链表中的环&am…

Prometheus实现钉钉报警

1、Prometheus实现钉钉报警 1.1 Prometheus环境 # my global config global:scrape_interval: 15s # Set the scrape interval to every 15 seconds. Default is every 1 minute.evaluation_interval: 15s # Evaluate rules every 15 seconds. The default is every 1 minute…

位运算专题

异或:相同为0&#xff0c;相异为1&#xff0c;无进位相加 约定:给定的所有数从左到右依次是从低位到高位&#xff0c;下标从0开始 一)给定一个数&#xff0c;判断它的二进制位表示中的第X位是0还是1 1)(N>>X)&1&#xff0c;先将x位右移动到最低位&#xff0c;然后再将…

warp框架教程5-Filter系统中各个模块

any 模块 any 模块只有一个方法&#xff0c;就是 any 方法&#xff0c;它可以匹配任何路由的过滤器。我们可以使用 any 方法将一些可克隆的资源转换成一个过滤器&#xff0c;从而允许轻松地将它与其他 Filter 结合在一起。当然也可以使用 any 方法创建适用于多个 Filter 的末尾…

Android Hook技术实战详解

前些天发现了一个蛮有意思的人工智能学习网站,8个字形容一下"通俗易懂&#xff0c;风趣幽默"&#xff0c;感觉非常有意思,忍不住分享一下给大家。 &#x1f449;点击跳转到教程 前言&#xff1a; 什么是Android Hook技术&#xff1f; Android Hook技术是指在Android…

Sentinel流量规则模块(新增)

系统并发能力有限&#xff0c;比如系统A的QPS支持1个请求&#xff0c;如果太多请求过来&#xff0c;那么系统A就应该进行流量控制了&#xff0c;比如其他请求直接拒绝 新增流控规则介绍:新增流控规则窗口 1.资源名&#xff1a;默认请求路径。 2.针对来源&#xff1a;Se…

WebRTC基础

有用的网址&#xff1a; https://webrtc.org/ WebRTC API - Web API 接口参考 | MDN Browser APIs and Protocols: WebRTC - High Performance Browser Networking(OReilly) 浏览器中查看webrtc运行的实时信息&#xff1a; Chrome浏览器&#xff1a;chrome://webrtc-inter…

Linux(Ubuntu2004)实现kitti数据集的点云投影记录

1、新建终端打开roscore roscore 2、在“下载”路径里面播放example_new的bag rosbag play example_new.bag -l 3、在catkin_ws里面运行下列命令 rosrun ros_detection_tracking projector 如果不能运行&#xff0c;可以采用tab的方式查看是否错误 4、新建一个终端运行r…

Kafka 入门到起飞 - 生产者发送消息流程解析

生产者通过send&#xff08;&#xff09;方法发送消息消息会经过拦截器->序列化器->分区器 进行加工然后将消息存在缓冲区当缓冲区中消息达到条件会按批次发送到broker对应分区上broker将接收到的消息进行刷盘持久化消息处理broker会返回给producer响应落盘成功返回元数据…

孟德尔随机化推断暴露因素与健康结局的因果关系

学习视频&#xff1a; 应用孟德尔随机化方法推断暴露因素与健康结局的因果关系 王友信教授 梅斯医学_哔哩哔哩_bilibili http://chinaepi.icdc.cn/zhlxbx/ch/reader/create_pdf.aspx?file_no20170427&flag1&journal_idzhlxbx&year_id2017 1. 孟德尔随机化方法 传…

详细介绍matlab使用支持向量机(SVM)预测股票市场趋势的实例

简介: 股票市场的趋势预测一直是投资者和交易员关注的重要问题之一。支持向量机(SVM)作为一种强大的机器学习算法,被广泛应用于股票市场趋势预测。本实例将介绍如何使用SVM来预测股票市场的涨跌趋势,并提供一个MATLAB代码示例。 数据准备: 首先,我们需要获取股票市场的历…