基于LSTM及其变体的回归预测

news2024/11/23 16:49:17

1 所用模型

       代码中用到了以下模型:

      1. LSTM(Long Short-Term Memory):长短时记忆网络,是一种特殊的RNN(循环神经网络),能够解决传统RNN在处理长序列时出现的梯度消失或爆炸的问题。LSTM有门控机制,可以选择性地记住或忘记信息。

       2. FC-LSTM:全连接的LSTM,与传统的LSTM相比,其细胞单元之间采用全连接的方式。

       3. Coupled LSTM:耦合LSTM,是一种特殊的LSTM结构,其中每个LSTM单元被分解为两个交互的子单元。

       4. GRU(Gated Recurrent Unit):门控循环单元,与LSTM类似,但结构更简单,参数更少,通常训练更快,但可能不如LSTM准确。

       5. ConvLSTM:卷积LSTM,将卷积神经网络(CNN)与LSTM结合,可以捕捉时空特征,常用于处理图像和视频数据。

       6. Deep LSTM:深层LSTM,包含多个LSTM层的堆叠,可以捕捉更复杂的模式。

       7. DB-LSTM(Bidirectional LSTM):双向LSTM,有两个方向的LSTM层,一个按时间顺序,一个逆序,可以同时获取过去和未来的信息。

       8. SRU(SimpleRNN):简单循环神经网络,是最基本的RNN形式。

       9. TPA-LSTM:时间感知LSTM,通过改变LSTM的内部计算方式,使其更加关注时间序列的特性。

       10. ConvGRU:卷积GRU,与ConvLSTM类似,但使用GRU代替LSTM。

       这些模型都是用于处理序列数据的深度学习模型,特别适用于时间序列预测、自然语言处理等领域。

2 运行结果

       左边是Epoch=50次的效果,右边是Epoch=15次的效果:

a1e88c48c6f645eea96360f59b239c00.jpg

 图2-1 训练损失

3623cb88b9294ce796d7dbacd244f481.jpg

 图2-2 测试损失

d9ab03d1196542bf9235bafc58288e07.jpg

 图2-3 预测结果

3 代码

     

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dense, LSTM, GRU, SimpleRNN, Bidirectional, TimeDistributed, Conv1D, Attention
from keras.layers import Flatten, Dropout, BatchNormalization
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
from tensorflow.keras.layers import Conv1D
# 读取数据
data = pd.read_excel('A.xlsx')
data=data.dropna()
data = data['A'].values.reshape(-1, 1)
# 数据预处理
scaler = MinMaxScaler()
data = scaler.fit_transform(data)

# 划分训练集和测试集
train_size = int(len(data) * 0.8)
train, test = data[:train_size], data[train_size:]

# 转换数据格式以适应LSTM输入
def create_dataset(dataset, look_back=1):
    X, Y = [], []
    for i in range(len(dataset) - look_back - 1):
        X.append(dataset[i:(i + look_back), 0])
        Y.append(dataset[i + look_back, 0])
    return np.array(X), np.array(Y)
 
look_back = 1
X_train, y_train = create_dataset(train, look_back)
X_test, y_test = create_dataset(test, look_back)
 
# 重塑输入数据的维度以适应LSTM模型
X_train = np.reshape(X_train, (X_train.shape[0], 1, X_train.shape[1]))
X_test = np.reshape(X_test, (X_test.shape[0], 1, X_test.shape[1]))
# 定义模型函数
def create_model(name):
    model = Sequential()
    if name == 'LSTM':
        model.add(LSTM(50, activation='relu', input_shape=(1, 1)))
    elif name == 'FC-LSTM':
        model.add(LSTM(50, activation='relu', input_shape=(1, 1), recurrent_activation='sigmoid'))
    elif name == 'Coupled LSTM':
        model.add(LSTM(50, activation='relu', input_shape=(1, 1), implementation=2))
    elif name == 'GRU':
        model.add(GRU(50, activation='relu', input_shape=(1, 1)))
    elif name == 'ConvLSTM':
        model.add(Conv1D(filters=64, kernel_size=1, activation='relu', input_shape=(1, 1)))
        model.add(LSTM(50, activation='relu'))
    elif name == 'Deep LSTM':
        model.add(LSTM(50, return_sequences=True, activation='relu', input_shape=(1, 1)))
        model.add(LSTM(50, activation='relu'))
    elif name == 'DB-LSTM':
        model.add(Bidirectional(LSTM(50, activation='relu'), input_shape=(1, 1)))
    elif name == 'SRU':
        model.add(SimpleRNN(50, activation='relu', input_shape=(1, 1)))
    elif name == 'TPA-LSTM':
        model.add(LSTM(50, activation='relu', input_shape=(1, 1), unroll=True))
    elif name == 'ConvGRU':
        model.add(Conv1D(filters=64, kernel_size=1, activation='relu', input_shape=(1, 1)))
        model.add(GRU(50, activation='relu'))
    model.add(Dense(1))
    model.compile(optimizer=Adam(), loss='mse')
    return model

# 训练模型并绘制损失图
names = ['LSTM', 'FC-LSTM', 'Coupled LSTM', 'GRU', 'ConvLSTM', 'Deep LSTM', 'DB-LSTM','SRU', 'TPA-LSTM', 'ConvGRU']
train_losses = []
test_losses = []
predictions = []

for name in names:
    model = create_model(name)
    history = model.fit(train, train, epochs=15, batch_size=32, validation_data=(test, test), verbose=0)
    train_losses.append(history.history['loss'])
    test_losses.append(history.history['val_loss'])
    pred = model.predict(test)
    predictions.append(pred)
    
    
import matplotlib.pyplot as plt

# 设置不同的marker
markers = ['o', '.', '_', '^', '*', '>', '+', '1', 'p', '_', '8']
linestyles = ['-', '--', '--', ':', '-', '-.', '-.', ':', '-', '--']
# 绘制训练损失图
plt.figure(figsize=(16, 20))
for i, loss in enumerate(train_losses):
    plt.plot(loss, color='black',label=names[i], marker=markers[i], linestyle=linestyles[i])
plt.title('Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(fontsize=8,loc='best')
plt.show()
# 绘制测试损失图
for i, loss in enumerate(test_losses):
    plt.plot(loss, color='black',label=names[i], marker=markers[i], linestyle=linestyles[i])
plt.title('Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(fontsize=8,loc='best')
plt.show()
# 绘制预测结果折线图
for i, pred in enumerate(predictions):
    plt.plot(pred, color='black',label=names[i], marker=markers[i], linestyle=linestyles[i])
# 绘制真实值折线图
plt.plot(y_test, color='black', label='True Value')
plt.title('Predictions and True Values')
plt.xlabel('x')
plt.ylabel('value')
plt.legend(fontsize=8, loc='best')
# 显示图像
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1930996.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GB35114国密算法-GMSSL

C有个三方库-GMSSL是可以进行GB35114所需要的SM2、SM3、SM4等加解密算法的,但是使用国密算法是需要申请报备的 GmSSL是由北京大学自主开发的国产商用密码开源库,实现了对国密算法、标准和安全通信协议的全面功能覆盖,支持包括移动端在内的主流…

SpringBoot整合Swagger报错:Failed to start bean ‘documentationPluginsBootstrapper

文章目录 1 问题背景2 问题原因3 修改SpringBoot配置文件 application.properties参考 1 问题背景 Swagger是SpringBoot中常用的API文档工具,在刚接触使用的时候,按照通用的代码进行配置,发现报错了 [main] ERROR org.springframework.boot…

【ARM AMBA AXI 入门 5.1 - QoS是什么?QoS是怎么工作的? 】

请阅读【ARM AMBA AXI 总线 文章专栏导读】 转自:揭秘数通知识:QoS是什么?QoS是怎么工作的?(一) 文章目录 QoS 概述综合服务和差分服务 QoS 工具报文分类报文标记流量监管和整形工具拥塞管理工具拥塞避免工…

JuiceFS缓存特性

缓存 对于一个由对象存储和数据库组合驱动的文件系统,缓存是本地客户端与远端服务之间高效交互的重要纽带。读写的数据可以提前或者异步载入缓存,再由客户端在后台与远端服务交互执行异步上传或预取数据。相比直接与远端服务交互,采用缓存技…

Linux 线程初步解析

1.线程概念 在一个程序里的一个执行路线就叫做线程(thread)。更准确的定义是:线程是“一个进程内部的控制序列。在linux中,由于线程和进程都具有id,都需要调度等等相似性,因此都可以用PCB来描述和控制,线程含有PCB&am…

[RuoYi-Vue] - 5. 部分代码分析

文章目录 🍴1. 后端部分代码分析1.1 BaseController1.2 TableDataInfo1.3 AjaxResult1.4 BaseEntity 🍸2. 权限注解🍺3. 前后端交互3.1 跨域问题 🍴1. 后端部分代码分析 1.1 BaseController 1.2 TableDataInfo 分页查询统一返回对…

如何用个人电脑搭建一台本地服务器,并部署云原生开发工具TitanIDE到服务器详细教程

服务器是一种高性能计算机,作为网络的节点,它存储、处理网络上80%的数据、信息,因此也被称为网络的灵魂。与普通计算机相比,服务器具有高速CPU运算能力、长时间可靠运行、强大I/O外部数据吞吐能力以及更好的扩展性。 服务器的主要…

Day07-ES集群加密,kibana的RBAC实战,zookeeper集群搭建,zookeeper基本管理及kafka单点部署实战

Day07-ES集群加密,kibana的RBAC实战,zookeeper集群搭建,zookeeper基本管理及kafka单点部署实战 0、昨日内容回顾:1、基于nginx的反向代理控制访问kibana2、配置ES集群TSL认证:3、配置kibana连接ES集群4、配置filebeat连接ES集群5、配置logsta…

自建Web网站部署——案例分析

作者主页: 知孤云出岫 目录 作者主页:如何自建一个Web网站一、引言二、需求分析三、技术选型四、开发步骤1. 项目初始化初始化前端初始化后端 2. 前端开发目录结构示例代码App.jsHome.js 3. 后端开发目录结构示例代码app.jsproductRoutes.jsProduct.js 4. 前后端连接安装axio…

ospf复习综合小实验

实验要求: 1,R4为ISP,其上只能配置IP地址;R4与其他所有直连设备间均使用公有IP 2,R3-R5/6/7为MGRE环境,R3为中心站点; 3,整个OSPF环境IP基于172.16.0.0/16划分; 4&#…

redis其他类型和配置文件

很多博客只讲了五大基本类型,确实,是最常用的,而且百分之九十的程序员对于Redis只限于了解String这种最常用的。但是我个人认为,既然Redis官方提供了其他的数据类型,肯定是有相应的考量的,在某些特殊的业务…

1.MQ介绍

MQ 消息队列,本质是一个队列,先进先出,只不过队列中存放的内容是message而已。 为啥学习MQ 1.流量消峰 如果一个订单系统最多每秒能处理一万次订单,正常情况下我们下单1秒后就能返回结果。但是在高峰期,如果有两万…

NLP入门——RNN、LSTM模型的搭建、训练与预测

在卷积语言模型建模时,我们选取上下文长度ctx_len进行训练,预测时选取句子的最后ctx_len个分词做预测,这样句子的前0~seql-1-ctx_len个词对于预测没有任何帮助,这对于语言处理来说显然是不利的。 在词袋语言模型建模时&#xff0c…

Milvus 核心设计(5)--- scalar indexwork mechanism

目录 背景 Scalar index 简介 属性过滤 扫描数据段 相似性搜索 返回结果 举例说明 1. 属性过滤 2. 扫描数据段 3. 相似性搜索 实际应用中的考虑 Scalar Index 方式 Auto indexing Inverted indexing 背景 继续Milvus的很细设计,前面主要阐述了Milvu…

【排序算法】1.冒泡排序-C语言实现

冒泡排序(Bubble Sort)是最简单和最通用的排序方法,其基本思想是:在待排序的一组数中,将相邻的两个数进行比较,若前面的数比后面的数大就交换两数,否则不交换;如此下去,直…

C++ 入门基础:开启编程之旅

文章目录 引言一、C的第⼀个程序二、命名空间1、namespace2、namespace的定义 三、C输入 与 输出四、缺省参数五、函数重载六、引用1、引用的概念和定义2、引用的特性3、指针和引用的关系七、inline八、nullptr 引言 C 是一种高效、灵活且功能强大的编程语言,广泛应…

【java】力扣 合并两个有序数组

文章目录 题目链接题目描述代码第一种第二种 题目链接 88.合并两个有序数组 题目描述 代码 第一种 public void merge(int[] nums1, int m, int[] nums2, int n) {for(int i 0;i<n;i){nums1[mi] nums2[i];}Arrays.sort(nums1);}第二种 public void merge(int[] nums1,…

【数据结构】二叉树全攻略,从实现到应用详解

​ &#x1f48e;所属专栏&#xff1a;数据结构与算法学习 &#x1f48e; 欢迎大家互三&#xff1a;2的n次方_ ​ &#x1f341;1. 树形结构的介绍 树是一种非线性的数据结构&#xff0c;它是由n&#xff08;n>0&#xff09;个有限结点组成一个具有层次关系的集合。把它叫做…

JVM垃圾回收-----垃圾分类

一、垃圾分类定义 垃圾分类是JVM垃圾分类中的第一步&#xff0c;这一步将堆中的对象分为存活对象和垃圾对象两类。 在垃圾分类阶段&#xff0c;JVM会从一组根对象开始&#xff0c;通过对象之间的引用关系&#xff0c;遍历所有的对象&#xff0c;并将所有存活的对象进行标记。…

QT使用QPainter绘制多边形维度图

多边形统计维度图是一种用于展示多个维度的数据的图表。它通过将各个维度表示为图表中的多边形的边&#xff0c;根据数据的大小和比例来确定各个维度的长度。 一、简述 本示例实现六边形战力统计维度图&#xff0c;一种将六个维度的战力统计以六边形图形展示的方法。六个维度是…