风速预测(四)基于Pytorch的EMD-Transformer模型

news2024/11/17 11:41:00

目录

前言

1 风速数据EMD分解与可视化

1.1 导入数据

1.2 EMD分解

2 数据集制作与预处理

2.1 先划分数据集,按照8:2划分训练集和测试集

2.2 设置滑动窗口大小为7,制作数据集

3 基于Pytorch的EMD-Transformer模型预测

3.1 数据加载,训练数据、测试数据分组,数据分batch

3.2 定义EMD-Transformer预测模型

3.3 定义模型参数

3.4 模型训练

3.5 结果可视化


往期精彩内容:

风速预测(一)数据集介绍和预处理

风速预测(二)基于Pytorch的EMD-LSTM模型

风速预测(三)EMD-LSTM-Attention模型

前言

本文基于前期介绍的风速数据(文末附数据集),先经过经验模态EMD分解,然后通过数据预处理,制作和加载数据集与标签,最后通过Pytorch实现EMD-Transformer模型对风速数据的预测。风速数据集的详细介绍可以参考下文:

风速预测(一)数据集介绍和预处理-CSDN博客

1 风速数据EMD分解与可视化

1.1 导入数据

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rc("font", family='Microsoft YaHei')

# 读取已处理的 CSV 文件
df = pd.read_csv('wind_speed.csv' )
# 取风速数据
winddata = df['Wind Speed (km/h)'].tolist()
winddata = np.array(winddata) # 转换为numpy
# 可视化
plt.figure(figsize=(15,5), dpi=100)
plt.grid(True)
plt.plot(winddata, color='green')
plt.show()

1.2 EMD分解

from PyEMD import EMD

# 创建 EMD 对象
emd = EMD()
# 对信号进行经验模态分解
IMFs = emd(winddata)

# 可视化
plt.figure(figsize=(20,15))
plt.subplot(len(IMFs)+1, 1, 1)
plt.plot(winddata, 'r')
plt.title("原始信号")

for num, imf in enumerate(IMFs):
    plt.subplot(len(IMFs)+1, 1, num+2)
    plt.plot(imf)
    plt.title("IMF "+str(num+1), fontsize=10)
# 增加第一排图和第二排图之间的垂直间距
plt.subplots_adjust(hspace=0.8, wspace=0.2)
plt.show()

2 数据集制作与预处理

2.1 先划分数据集,按照8:2划分训练集和测试集

2.2 设置滑动窗口大小为7,制作数据集

3 基于Pytorch的EMD-Transformer模型预测

3.1 数据加载,训练数据、测试数据分组,数据分batch

# 加载数据
import torch
from joblib import dump, load
import torch.utils.data as Data
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
# 参数与配置
torch.manual_seed(100)  # 设置随机种子,以使实验结果具有可重复性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载数据集
def dataloader(batch_size, workers=2):
    # 训练集
    train_set = load('train_set')
    train_label = load('train_label')
    # 测试集
    test_set = load('test_set')
    test_label = load('test_label')

    # 加载数据
    train_loader = Data.DataLoader(dataset=Data.TensorDataset(train_set, train_label),
                                   batch_size=batch_size, num_workers=workers, drop_last=True)
    test_loader = Data.DataLoader(dataset=Data.TensorDataset(test_set, test_label),
                                  batch_size=batch_size, num_workers=workers, drop_last=True)
    return train_loader, test_loader

batch_size = 64
# 加载数据
train_loader, test_loader = dataloader(batch_size)

3.2 定义EMD-Transformer预测模型

注意:输入风速数据形状为 [64, 10, 7], batch_size=64,  维度10维代表10个分量,7代表序列长度(滑动窗口取值)。

3.3 定义模型参数

# 定义模型参数
batch_size = 64
input_len = 7     # 输入序列长度为7 (窗口值)
input_dim = 10    # 输入维度为10个分量
hidden_dim = 100  # Transformer隐层维度
num_layers = 4   # 编码器层数
num_heads = 2   # 多头注意力头数
output_size = 1 # 单步输出

model = EMDTransformerModel(batch_size, input_len, input_dim, hidden_dim, num_layers, num_heads, output_size=1)  

# 定义损失函数和优化函数 
model = model.to(device)
loss_function = nn.MSELoss()  # loss
learn_rate = 0.0003
optimizer = torch.optim.Adam(model.parameters(), learn_rate)  # 优化器

3.4 模型训练

训练结果

采用两个评价指标:MSE 与 MAE 对模型训练进行评价,100个epoch,MSE 为0.01627,MAE  为 0.0005549,EMD-Transformer预测效果良好,适当调整模型参数,还可以进一步提高模型预测表现。EMD-Transformer参数量不到LSTM模型的十分之一,效果相近,可见EMD-Transformer性能的优越性。

注意调整参数:

  • 可以适当增加Transformer堆叠编码器层数和隐藏层的维度,微调学习率;

  • 调整多头注意力头数,增加更多的 epoch (注意防止过拟合)

  • 可以改变滑动窗口长度(设置合适的窗口长度)

3.5 结果可视化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1314222.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

leetcode做题笔记2415. 反转二叉树的奇数层

给你一棵 完美 二叉树的根节点 root ,请你反转这棵树中每个 奇数 层的节点值。 例如,假设第 3 层的节点值是 [2,1,3,4,7,11,29,18] ,那么反转后它应该变成 [18,29,11,7,4,3,1,2] 。 反转后,返回树的根节点。 完美 二叉树需满足…

vue2 tailwindcss jit模式下热更新失效

按照网上教程安装的tailwindcss,但是修改类名后热更新的时候样式没有生效,参考了大佬的文章,解决了该问题。 安装cross-env 修改前 "dev": " vue-cli-service serve", 修改后 "dev": "cross-env TAILWIN…

【通用】Linux,VSCode,IDEA,Eclipse等资源相对位置

正文 不论是 IDEA、Linux、VSCode、cmd等等吧,都遵循这个规则: 如果以斜杠开头,表示从根开始找: IDEA的根是classpath(classpath就是项目被编译后,位于 target下的 classes文件夹,或者位于ta…

Android 12.0 Launcher3定制化之动态时钟图标功能实现

1.概述 在12.0的系统产品rom定制化开发中,在Launcher3中的定制化的一些功能中,对于一些产品要求需要实现动态时钟图标功能,这就需要先绘制时分秒时针表盘,然后 每秒刷新一次时钟图标,时钟需要做到实时更新,做到动态时钟的效果,接下来就来分析这个功能的实现 如图: 2.动…

docker镜像与容器的迁移

docker容器迁移有两组命令,分别是 save & load :操作的是images, 所以要先把容器commit成镜像export & import:直接操作容器 我们先主要看看他们的区别: 一 把容器打包为镜像再迁移到其他服务器 如把mysq…

用于自动驾驶的基于深度学习的图像 3D 物体检测:综述

论文地址:https://ieeexplore.ieee.org/abstract/document/10017184/ 背景 准确、鲁棒的感知系统是理解自动驾驶和机器人驾驶环境的关键。自动驾驶需要目标的 3D 信息,包括目标的位置和姿态,以清楚地了解驾驶环境。 摄像头传感器因其颜色和…

力扣每日一题:2415. 反转二叉树的奇数层(2023-12-15)

力扣每日一题 题目:2415. 反转二叉树的奇数层 日期:2023-12-15 用时:6 m 51 s 时间:0 ms 内存:46.97 MB 代码: /*** Definition for a binary tree node.* public class TreeNode {* int val;* T…

Cockpit upload文件上传漏洞(CVE-2023-1313)

0x01 产品简介 Cockpit 是一个自托管、灵活且用户友好的无头内容平台,用于创建自定义数字体验。 0x02 漏洞概述 Cockpit assetsmanager/upload接口处存在文件上传漏洞,攻击者可通过该漏洞在服务器端任意上传代码,写入后门,获取服务器权限,进而控制整个web服务器。 0x0…

如何使用ArcGIS Pro拼接影像

为了方便数据的存储和传输,我们在网上获取到的影像一般都是分块的,正式使用之前需要对这些影像进行拼接,这里为大家介绍一下ArcGIS Pro中拼接影像的方法,希望能对你有所帮助。 数据来源 本教程所使用的数据是从水经微图中下载的…

【1.6计算机组成与体系结构】存储系统

目录 1.层次化存储结构2.Cache2.1 Cache的介绍2.2 局部性原理2.3 Cache应用 1.层次化存储结构 由 ⬆ CPU:寄存器。 快 ⬆ Cache:按内容存取(相联存储器)。 到 ⬆内存(主存):DRAM。 慢 ⬆ 外存(辅存&#…

蓝牙物联网智慧工厂解决方案

蓝牙物联网智慧工厂解决方案是一种针对工厂管理的智能化解决方案,通过蓝牙、物联网、大数据、人工智能等技术,实现工厂人员的定位、物资的定位管理、车间的智慧巡检、智慧安防以及数据的可视化等功能。 蓝牙物联网智慧工厂解决方案构成: 人员…

CSS盒子的浮动与网页布局(重点,有电影页面案例)

浮动适用于那种盒子的并列布局 CSS 提供了三种传统布局方式(简单说,就是盒子如何进行排列顺序):  普通流(标准流)  浮动  定位 标准流(普通流/文档流) 所谓的标准流: 就是标签按照规定好默认方式排列. 1. 块级…

算法基础概念之数据结构

邻接表 每个点作为头节点接一条链表 链表中元素均为该头节点指向的点 优先队列 参数: ①储存元素类型 ②底层使用的存储结构(一般为vector) ③比较方式(默认小于)

华为数通——网络参考模型

OSI参考模型 七层 应用层:最靠近用户的一层,为应用程序提供网络服务。 六层 表示层:数据格式转换编码格式UTF-8。 五层 会话层:双方之间建立、管理和终止会话。 四层 传输层:建立、维护和取消端到端的数据传输过…

锁--07_1----插入意向锁-Insert加锁过程

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 插入意向锁MySQL执行插入Insert时的加锁过程MySQL官方文档MySQL insert加锁流程1.加插入意向锁2.判断插入记录是否有唯一键3. 插入记录并对记录加X锁插入意向锁----…

使用opencv的Sobel算子实现图像边缘检测

1 边缘检测介绍 图像边缘检测技术是图像处理和计算机视觉等领域最基本的问题,也是经典的技术难题之一。如何快速、精确地提取图像边缘信息,一直是国内外的研究热点,同时边缘的检测也是图像处理中的一个难题。早期的经典算法包括边缘算子方法…

瑞芯微rv1126边缘计算盒子高性价比2.0TOPS INT8/INT16

边缘计算盒子 瑞芯微rv1126 | 2.0TOPS INT8/INT16 ● 集成了NPU,算力高达2.0TOPsINT8/INT16。 ● 支持H.264/H.265/MJPEG视频编解码;支持多级别视频质量配置及编码复杂度设置。 ● 编解码性能 3840 x 216030 fps 1080p30 fps encoding 3840 x 21603…

力扣题目学习笔记(OC + Swift) 12. 整数转罗马数字

12. 整数转罗马数字 罗马数字包含以下七种字符: I, V, X, L,C,D 和 M。 字符 数值 I 1 V 5 X 10 L 50 C 100 D 500 M 1000 例如, 罗马数字 2 写做 II ,即为两个并列的 1。12 写做 XI…

eclipse连接mysql数据库(下载eclipse,下载安装mysql,下载mysql驱动)

前言: 使用版本:eclipse2017,mysql5.7.0,MySQL的jar建议使用最新的,可以避免警告! 1:下载安装:eclipse,mysql在我之前博客中有 http://t.csdnimg.cn/UW5fshttp://t.csdn…

使用 Matplotlib 和 mplcursors 创建交互式数据可视化,鼠标悬停动态显示数据

在本博客中,我们将探讨如何使用 Matplotlib(Python 中流行的绘图库)和 mplcursors(一个为 Matplotlib 图表添加交互式数据光标的库)创建交互式数据可视化。 效果图: 环境设置 首先,请确保已…