时序预测demo 代码快速实现 MLP效果比LSTM 好,简单模拟数据

news2024/9/24 1:26:45

【PyTorch修炼】用pytorch写一个经常用来测试时序模型的简单常规套路(LSTM多步迭代预测)

层数的理解:
LSTM(长短期记忆)的层数指的是在神经网络中堆叠的LSTM单元的数量。层数决定了网络能够学习的复杂性和深度。每一层LSTM都能够捕捉和记忆不同时间尺度的依赖关系,因此增加层数可以使网络更好地理解和处理复杂的序列数据。
在这里插入图片描述

LSTM方法:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

x = torch.linspace(0, 999, 1000)
y = torch.sin(x*2*3.1415926/70)

plt.xlim(-5, 1005)
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.title("sin")
plt.plot(y.numpy(), color='#800080')
plt.show()

x = torch.linspace(0, 999, 1000)
y = torch.sin(x * 2 * 3.1415926 / 100) + 0.3 * torch.sin(x * 2 * 3.1415926 / 25) + 0.8 * np.random.normal(0, 1.5)

plt.plot(y.numpy(), color='#800080')
plt.title("Sine-Like Time Series")
plt.xlabel('Time')
plt.ylabel('Value')
plt.show()

train_y= y[:-70]
test_y = y[-70:]

def create_data_seq(seq, time_window):
    out = []
    l = len(seq)
    for i in range(l-time_window):
        x_tw = seq[i:i+time_window]
        y_label = seq[i+time_window:i+time_window+1]
        out.append((x_tw, y_label))
    return out
time_window = 60
train_data = create_data_seq(train_y, time_window)


class MyLstm(nn.Module):
    def __init__(self, input_size=1, hidden_size=128, out_size=1):
        super(MyLstm, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=self.hidden_size, num_layers=1, bidirectional=False)
        self.linear = nn.Linear(in_features=self.hidden_size, out_features=out_size, bias=True)
        self.hidden_state = (torch.zeros(1, 1, self.hidden_size), torch.zeros(1, 1, self.hidden_size))

    def forward(self, x):
        out, self.hidden_state = self.lstm(x.view(len(x), 1, -1), self.hidden_state)
        pred = self.linear(out.view(len(x), -1))
        return pred[-1]


time_window = 60
train_data = create_data_seq(train_y, time_window)

learning_rate = 0.00001
epoch = 13
multi_step = 70

model=MyLstm()
mse_loss = nn.MSELoss()
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate,betas=(0.5,0.999))

for i in range(epoch):
    for x_seq, y_label in train_data:
        x_seq = x_seq 
        y_label = y_label 
        model.hidden_state = (torch.zeros(1, 1, model.hidden_size) ,
                              torch.zeros(1, 1, model.hidden_size) )
        pred = model(x_seq)
        loss = mse_loss(y_label, pred)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {i} Loss: {loss.item()}")
    preds = []
    labels = []
    preds = train_y[-time_window:].tolist()
    for j in range(multi_step):
        test_seq = torch.FloatTensor(preds[-time_window:]) 
        with torch.no_grad():
            model.hidden_state = (torch.zeros(1, 1, model.hidden_size) ,
                                  torch.zeros(1, 1, model.hidden_size) )
            preds.append(model(test_seq).item())
    loss = mse_loss(torch.tensor(preds[-multi_step:]), torch.tensor(test_y))
    print(f"Performance on test range: {loss}")

    plt.figure(figsize=(12, 4))
    plt.xlim(700, 999)
    plt.grid(True)
    plt.plot(y.numpy(), color='#8000ff')
    plt.plot(range(999 - multi_step, 999), preds[-multi_step:], color='#ff8000')
    plt.show()


class SimpleMLP(nn.Module):
    def __init__(self, input_size=60, hidden_size=128, output_size=1):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


mlp_model = SimpleMLP()
mse_loss = nn.MSELoss()
optimizer = torch.optim.Adam(mlp_model.parameters(), lr=0.0001)
for i in range(epoch):
    for x_seq, y_label in train_data:
        x_seq = x_seq
        y_label = y_label
        pred = mlp_model(x_seq)
        loss = mse_loss(y_label, pred)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {i} Loss: {loss.item()}")
    preds = []
    labels = []
    preds = train_y[-time_window:].tolist()
    for j in range(multi_step):
        test_seq = torch.FloatTensor(preds[-time_window:])
        with torch.no_grad():
            preds.append(mlp_model(test_seq).item())
    loss = mse_loss(torch.tensor(preds[-multi_step:]), torch.tensor(test_y))
    print(f"Performance on test range: {loss}")

    plt.figure(figsize=(12, 4))
    plt.xlim(700, 999)
    plt.grid(True)
    plt.plot(y.numpy(), color='#8000ff')
    plt.plot(range(999 - multi_step, 999), preds[-multi_step:], color='#ff8000')
    plt.show()

生成的一个带些随机数的正弦波:y = torch.sin(x * 2 * 3.1415926 / 100) + 0.3 * torch.sin(x * 2 * 3.1415926 / 25) + 0.8 * np.random.normal(0, 1.5)

结果发现:MLP效果比LSTM好?!
MLP:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
偶然有不是很准,但大部分非常准

LSTM:
就很奇怪?

在这里插入图片描述
在这里插入图片描述

但是如果是纯正弦波 y = torch.sin(x23.1415926/70) ,规律太明显了,好像效果都还行:
MLP:
简单聪明的MLP第一轮就学会了
在这里插入图片描述
LSTM:
开始几轮还有些懵
在这里插入图片描述
后边就悟了
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1463746.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PLC_博图系列☞基本指令“插入输入”

PLC_博图系列☞基本指令“插入输入” 文章目录 PLC_博图系列☞基本指令“插入输入”背景介绍插入输入说明参数示例 关键字: PLC、 西门子、 博图、 Siemens 、 插入输入 背景介绍 这是一篇关于PLC编程的文章,特别是关于西门子的博图软件。我并不是专…

基于ORB-SLAM2与YOLOv8剔除动态特征点(三种方法)

基于ORB-SLAM2与YOLOv8剔除动态特征点(三种方法) 写上篇文章时测试过程比较乱,写的时候有些地方有点失误,所以重新写了这篇 本文内容均在RGB-D环境下进行程序测试 本文涉及到的动态特征点剔除速度均是以https://cvg.cit.tum.de/data/datasets/rgbd-dat…

Java学习-21 网络编程

什么是网络编程? 可以让设备中的程序与网络上其他设备中的程序进行数据交互(实现网络通信的) 基本的通信架构 基本的通信架构有2种形式: CS架构(Client客户端/Server服务端) BS架构(Browser浏览器/Server服务端)。 网络通信三要素 IP …

粉色ui微信小程序源码/背景图/头像/壁纸小程序源码带流量主

云开发版粉色UI微信小程序源码,背景图、头像、壁纸小程序源码,带流量主功能。 云开发小程序源码无需服务器和域名即可搭建小程序另外还带有流量主功能噢!微信平台注册小程序就可以了。 这套粉色UI非常的好看,里面保护有背景图、…

【国际化】用JQuery-i18next的国际化demo,引入json

参考: 使用 i18next 的 jQuery 国际化 (i18n) 渐进式指南 (locize.com) i18next-http-backend/example/jquery/index.html at master i18next/i18next-http-backend (github.com) 文档 可能需要解决一下跨域问题,因为浏览器读取本…

基于Java+小程序点餐系统设计与实现(源码+部署文档)

博主介绍: ✌至今服务客户已经1000、专注于Java技术领域、项目定制、技术答疑、开发工具、毕业项目实战 ✌ 🍅 文末获取源码联系 🍅 👇🏻 精彩专栏 推荐订阅 👇🏻 不然下次找不到 Java项目精品实…

ABCDE联合创始人BMAN确认出席Hack .Summit() 2024香港Web3盛会

ABCDE联合创始人和普通合伙人BMAN确认出席Hack .Summit() 2024! ABCDE联合创始人和普通合伙人BMAN确认出席由 Hack VC 主办,并由 AltLayer 和 Berachain 联合主办,与 SNZ 和数码港合作,由 Techub News 承办的Hack.Summit() 2024区…

保护你的Web应用——CSRF攻击与防御

在Web应用开发过程中,保护用户的隐私和安全至关重要。而CSRF(Cross-Site Request Forgery,跨站请求伪造)攻击是一种常见的安全威胁,通过利用受信任用户的身份进行恶意操作,威胁到用户的账户和数据安全。本文…

django配置视图并与模版进行数据交互

目录 安装django 创建一个django项目 项目结构 创建视图层views.py 写入视图函数 创建对应视图的路由 创建模版层 配置项目中的模版路径 创建模版html文件 启动项目 浏览器访问结果 安装django pip install django 创建一个django项目 这里最好用命令行完成&#xf…

为什么需要MDL锁

点击上方蓝字关注我 在数据库管理中,元数据(metadata)的保护至关重要,而MySQL中的"元数据锁"(MDL锁)就是它的守护者。 1. 什么是MDL锁MDL锁,全名Metadata Lock,是MySQL中…

用windbg调试uefi在hyper-v

添加环境变量 CLANG_BINC:\Program Files\NASM\ NASM_PREFIXC:\Program Files\NASM\ 添加path C:\Program Files (x86)\Windows Kits\10\Tools\x64\ACPIVerify 修改edk2-master\Conf\target.txt TARGET_ARCH X64 编译这两个包 #ACTIVE_PLATFORM EmulatorPkg/…

去新加坡旅游,你必须要收藏了解的当地电商欺诈风险!

目录 多元化发展的新加坡电商 平台和消费者面临的欺诈风险 电商平台应如何防控? 2月9日,除夕,中国与新加坡免签正式生效。免签政策简化了持普通护照中国游客入境新加坡的程序,使通关更为便捷。根据协定,双方持普通护照…

【复现】某尔顿 安全审计系统任意文件读取漏洞_56

目录 一.概述 二 .漏洞影响 三.漏洞复现 1. 漏洞一: 四.修复建议: 五. 搜索语法: 六.免责声明 一.概述 某尔顿网络安全审计产品支持1-3线路的internet接入、1-3对网桥;含强大的上网行为管理、审计、监控模块;用…

计算机网络-局域网和城域网(一)

1.什么是局域网? 单一机构所拥有的专用计算机网络,中等规模地理范围,实现多种设备互联、信息交换和资源共享。 2.逻辑链路控制LLC: 目的是屏蔽不同的介质访问控制方法,以向高层(网络层)提供统…

代理模式笔记

代理模式 代理模式代理模式的应用场景先理解什么是代理,再理解动静态举例举例所用代码 动静态的区别静态代理动态代理 动态代理的优点代理模式与装饰者模式的区别 代理模式 代理模式在设计模式中是7种结构型模式中的一种,而代理模式有分动态代理&#x…

使用单一ASM-HEMT模型实现从X波段到Ka波段精确的GaN HEMT非线性仿真

来源:Accurate Nonlinear GaN HEMT Simulations from X- to Ka-Band using a Single ASM-HEMT Model 摘要:本文首次研究了ASM-HEMT模型在宽频带范围内的大信号准确性。在10、20和30 GHz的频率下,通过测量和模拟功率扫描进行了比较。在相同的频…

【C++初阶】系统实现日期类

目录 一.运算符重载实现各个接口 1.小于 (d1)<> 2.等于 (d1d2) 3.小于等于&#xff08;d1<d2&#xff09; 4.大于&#xff08;d1>d2&#xff09; 5.大于等于&#xff08;d1>d2&#xff09; 6.不等于&#xff08;d1!d2&#xff09; 7.日期天数 (1) 算…

Nginx网络服务三-----(三方模块和内置变量)

1.验证模块 需要输入用户名和密码 我们要用htpasswd这个命令&#xff0c;先安装一下httpd 生成文件和用户 修改文件 访问页面 为什么找不到页面&#xff1f; 对应的路径下&#xff0c;没有这个文件 去创建文件 去虚拟机浏览器查看 有的页面不想被别人看到&#xff0c;可以做…

MongoDB的介绍和使用

目录 一、MongoDB介绍 二、MongoDB相关概念 三、MongoDB的下载和安装 四、SpringBoot 整合 MongoDB 一、MongoDB介绍 MongoDB是一种NoSQL数据库管理系统&#xff0c;采用面向文档的数据库模型。它以C语言编写&#xff0c;旨在满足大规模数据存储和高性能读写操作的需求。Mo…

BUGKU-WEB 文件包含

题目描述 题目截图如下&#xff1a; 进入场景看看&#xff1a; 解题思路 你说啥我就干啥&#xff1a;点击一下试试你会想到PHP伪协议这方面去嘛&#xff0c;你有这方面的知识储备吗&#xff1f; 相关工具 解题步骤 查看源码 看到了一点提示信息&#xff1a; ./index.…