pytorch-RNN实战-正弦曲线预测

news2024/9/21 12:30:26

目录

  • 1. 正弦数据生成
  • 2. 构建网络
  • 3. 训练
  • 4. 预测
  • 5. 完整代码
  • 6. 结果展示

1. 正弦数据生成

曲线如下图:
在这里插入图片描述
代码如下图:

  • 50个点构成一个正弦曲线
  • 随机生成一个0~3之间的一个值(随机的原因是防止每次都从相同的点开始,50个点的正弦曲线一样,被模型记住),值的范围区间是[start, start+10]
  • 输入x范围[0,48],预测值y范围是[1,49]

在这里插入图片描述

2. 构建网络

下图是构建的网络,注意out维度扩展出一个维度,是为了和y维度一致
在这里插入图片描述

3. 训练

loss计算采用均方差MSE,优化器采用Adam
注意:hidden_prev的自更新
在这里插入图片描述

4. 预测

预测是循环一个点一个点的预测,每次预测的点的结果作为下次点的输入,直到预测出全部点,放到predictions中。
input = x[:,0,:] 去掉了x[1,seq,1]中的seq维度,变成[1,1]
在这里插入图片描述

5. 完整代码

import  numpy as np
import  torch
import  torch.nn as nn
import  torch.optim as optim
from    matplotlib import pyplot as plt


num_time_steps = 50
input_size = 1
hidden_size = 16
output_size = 1
lr=0.01



class Net(nn.Module):

    def __init__(self, ):
        super(Net, self).__init__()

        self.rnn = nn.RNN(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=True,
        )
        for p in self.rnn.parameters():
          nn.init.normal_(p, mean=0.0, std=0.001)

        self.linear = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden_prev):

       out, hidden_prev = self.rnn(x, hidden_prev)
       # [b, seq, h]
       out = out.view(-1, hidden_size)
       out = self.linear(out)
       out = out.unsqueeze(dim=0)
       return out, hidden_prev




model = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr)

hidden_prev = torch.zeros(1, 1, hidden_size)

for iter in range(6000):
    start = np.random.randint(3, size=1)[0]
    time_steps = np.linspace(start, start + 10, num_time_steps)
    data = np.sin(time_steps)
    data = data.reshape(num_time_steps, 1)
    x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
    y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)

    output, hidden_prev = model(x, hidden_prev)
    hidden_prev = hidden_prev.detach()

    loss = criterion(output, y)
    model.zero_grad()
    loss.backward()
    # for p in model.parameters():
    #     print(p.grad.norm())
    # torch.nn.utils.clip_grad_norm_(p, 10)
    optimizer.step()

    if iter % 100 == 0:
        print("Iteration: {} loss {}".format(iter, loss.item()))

start = np.random.randint(3, size=1)[0]
time_steps = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_steps)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)

predictions = []
input = x[:, 0, :]
for _ in range(x.shape[1]):
  input = input.view(1, 1, 1)
  (pred, hidden_prev) = model(input, hidden_prev)
  input = pred
  predictions.append(pred.detach().numpy().ravel()[0])

x = x.data.numpy().ravel()
y = y.data.numpy()
plt.scatter(time_steps[:-1], x.ravel(), s=90)
plt.plot(time_steps[:-1], x.ravel())

plt.scatter(time_steps[1:], predictions)
plt.show()

6. 结果展示

图中黄色点是预测点,蓝色为实际点,前面的曲线是start不随机预测的效果,说明曲线已经被模型记住了;后面的曲线是start随机预测的效果,基本趋势和真实点是一致的。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1916331.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

云手机批量操作使用场景,从Amazon、TK等软件分析

云手机目前所具备的群控,批量操作,自动化等功能,对于电商,软测,办公,直播,营销等行业有很好的减负作用。 针对于具体的海外APP,云手机具体可以做哪些事情来帮助我们减轻压力&#x…

Docker拉取失败,利用github将镜像推送到阿里云

GITHUB配置 fork https://github.com/tech-shrimp/docker_image_pusher 该项目到自己的账户下。 设置环境变量,其路径如下图 在该项目中 .github/workflows/docker.yaml 找到 env 标签 ALIYUN_REGISTRY: "${{ secrets.ALIYUN_REGISTRY }}"ALIYUN_NAME_S…

AC修炼计划(AtCoder Regular Contest 180) A~C

A - ABA and BAB A - ABA and BAB (atcoder.jp) 这道题我一开始想复杂了,一直在想怎么dp,没注意到其实是个很简单的规律题。 我们可以发现我们住需要统计一下类似ABABA这样不同字母相互交替的所有子段的长度,而每个字段的的情况有&#xff…

600Kg大载重起飞重量多旋翼无人机技术详解

600Kg大载重起飞重量的多旋翼无人机是一种高性能的无人驾驶旋翼飞行器,具有出色的载重能力和稳定的飞行特性。该无人机采用先进的飞行控制系统和高效的动力系统,能够满足各种复杂任务的需求,广泛应用于物资运输、应急救援、森林防火等领域。 …

西门子大手笔又买一家公司,2024年“两买”和“两卖”的背后……

导语 大家好,我是社长,老K。专注分享智能制造和智能仓储物流等内容。 新书《智能物流系统构成与技术实践》 更多的海量【智能制造】相关资料,请到智能制造online知识星球自行下载。 今年,这家全球工业巨头不仅精准出击&#xff0c…

MACOS查看硬盘读写量

一、安装Homebrew 按照提示进行安装 /bin/zsh -c "$(curl -fsSL https://gitee.com/cunkai/HomebrewCN/raw/master/Homebrew.sh)"二、安装smartmontools brew install smartmontools三、查看硬盘读写量等信息 sudo smartctl -a /dev/disk0

Python8:线程和进程

1.并发和并行 并发:在逻辑上具备同时处理多个任务的能力(其实每时刻只有一个任务) 并行:物理上在同一时刻执行多个并发任务 2.线程与进程 一个进程管多个线程,一个进程至少有一个线程 python多线程是假的&#xf…

UML-各种图

什么是类图 定义系统中的类,描述类的内部结构(属性、方法等),表示类之间的关系(泛化、实现、依赖、关联、聚合、组合)。 UML表示类图 上图中左侧图形是一个常见的类图, 类名:在顶…

Qt:15.布局管理器(QVBoxLayout-垂直布局、QHBoxLayout-水平布局、QGridLayout-网格布局、拉伸系数,控制控件显示的大小)

目录 一、QVBoxLayout-垂直布局: 1.1QVBoxLayout介绍: 1.2 属性介绍: 1.3细节理解: 二、QHBoxLayout-水平布局: 三、QGridLayout-网格布局: 3.1QGridLayout介绍: 3.2常用方法&#xff1a…

iMazing 3.0.3.1Mac中文破解版下载安装激活

今天,小编要分享的是Mac下一款可以帮助用户管理IOS设备的软件——iMazing,之前,小编也分享的过类似的软件,iMazing却有独特之处。小子这次带来的是3.0.3.1版本。 iMazing 3是一款iOS设备管理软件,该软件支持对基于iOS…

【STM32学习】cubemx配置,串口的使用,串口发送接收函数使用,以及串口重定义、使用printf发送

1、串口的基本配置 选择USART1,选择异步通信,设置波特率 选择后,会在右边点亮串口 串口引脚是用来与其他设备通信的,如在程序中打印发送信息,电脑上打开串口助手,就会收到信息。 串口的发送接收&#xff0…

机器学习筑基篇,容器调用显卡计算资源,Ubuntu 24.04 快速安装 NVIDIA Container Toolkit!...

[ 知识是人生的灯塔,只有不断学习,才能照亮前行的道路 ] Ubuntu 24.04 安装 NVIDIA Container Toolkit 什么是 NVIDIA Container Toolkit? 描述:NVIDIA Container Toolkit(容器工具包)使用户能够构建和运行 GPU 加速的容器,该工具包括一个容器运行时库和实用程序,用于自动…

新能源汽车充电站远程监控系统S275钡铼技术无线RTU

新能源汽车充电站的远程监控系统在现代城市基础设施中扮演着至关重要的角色,而钡铼技术的S275无线RTU作为一款先进的物联网数据监测采集控制短信报警终端,为充电站的安全运行和高效管理提供了强大的技术支持。 技术特点和功能 钡铼S275采用了基于UCOSI…

【PTA天梯赛】L1-006 连续因子(20分)

作者:指针不指南吗 专栏:算法刷题 🐾或许会很慢,但是不可以停下来🐾 文章目录 题目题解题意步骤 总结 题目 题目链接 题解 题意 求解n的最长连续因子 和因子再相乘的积无关,真给绕进去了 步骤 双重循…

D-走一个大整数迷宫(牛客月赛97)

题意:给两个n行m列的矩阵a和b,计数器,只有当计数器的值模(p-1)时出口才打开,要从左上走到右下,求最快多久走出迷宫。 分析:无论2的bij次方有多大p的2的bij次方的次方取模&#xff0…

前端vue 实现取色板 的选择

大概就是这样的 一般的web端框架 都有自带的 的 比如 ant-design t-design 等 前端框架 都是带有这个的 如果遇到没有的我们可以自己尝试开发一下 简单 的 肯定比不上人家的 但是能用 能看 说的过去 我直接上代码了 其实这个取色板 就是一个input type 是color 的input …

云视频监控中的高效视频转码策略:视频汇聚EasyCVR平台H.265自动转码H.264能力解析

随着科技的快速发展,视频监控技术已经广泛应用于各个领域,如公共安全、商业管理、教育医疗等。与此同时,视频转码技术作为视频处理的关键环节,也在不断提高视频的质量和传输效率。 一、视频监控技术的演进 视频监控技术的发展历…

【基于R语言群体遗传学】-16-中性检验Tajima‘s D及连锁不平衡 linkage disequilibrium (LD)

Tajimas D Test 已经开发了几种中性检验,用于识别模型假设的潜在偏差。在这里,我们将说明一种有影响力的中性检验,即Tajimas D(Tajima 1989)。Tajimas D通过比较数据集中的两个𝜃 4N𝜇估计值来…

nssm的下载和使用

nssm(Non-Sucking Service Manager)是一个用于在Windows系统上管理服务的工具。它允许你将.exe文件和.bat文件转换为Windows服务,并提供了一些功能来管理这些服务。 下载和安装 首先,你需要从nssm官方网站(https://n…

顺序结构 ( 四 ) —— 标准数据类型 【互三互三】

序 C语言提供了丰富的数据类型,本节介绍几种基本的数据类型:整型、实型、字符型。它们都是系统定义的简单数据类型,称为标准数据类型。 整型(integer) 在C语言中,整型类型标识符为int。根据整型变量的取值范…