人工智能基础部分13-LSTM网络:预测上证指数走势

news2025/2/25 21:15:34

大家好,我是微学AI,今天给大家介绍一下LSTM网络,主要运用于解决序列问题。

一、LSTM网络简单介绍

LSTM又称为:长短期记忆网络,它是一种特殊的 RNN。LSTM网络主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。对于相比普通的RNN,LSTM能够在更长的序列中有更好的表现。

引入LSTM网络的原因:由于 RNN 网络主要问题是长期依赖,即隐藏状态在时间上传递过程中可能会丢失之前的信息。为了解决这个问题,引入了长短时记忆网络 (LSTM) 和门控循环单元 (GRU)。这两种网络结构在隐藏层中增加了门控机制,能够更好地控制信息的传递。

 其中符号及表示意思如下:

 LSTM中有三个门:
(1)遗忘门f:决定上一个时刻的记忆单元状态需要遗忘多少信息,保留多少信息到当前记忆单元状态。
(2)输入门i:控制当前时刻输入信息候选状态有多少信息需要保存到当前记忆单元状态。
(3)输出门o:控制当前时刻的记忆单元状态有多少信息需要输出给外部状态。

形象的例子让我们更好的理解LSTM的原理:

假设你是一个梦想远大的学生,你想通过学习一门课程获得更多的知识。在学习过程中,LSTM模型帮助你,它就像是一个老师,它的遗忘门就像是老师的提醒,它让你挑出不用的知识,以保持你对重要知识的清晰记忆。它的输入门就像是老师的指导,它会重新审视你学习过的知识,按照自己的逻辑把知识结合起来,进化出更多有用的知识。最后,它的输出门就像老师的监督,它会确保你学习到了有用的知识,不要浪费时间去学习无用的知识。

二、LSTM网络运用-预测上证指数走势

# 使用LSTM预测沪市指数
import numpy as np
import pandas as pd
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers import Dropout
from pandas import DataFrame
from pandas import concat
from itertools import chain
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['SimHei']

# 转化为可以用于监督学习的数据
def get_train_set(data_set, timesteps_in, timesteps_out=1):
    train_data_set = np.array(data_set)
    reframed_train_data_set = np.array(series_to_supervised(train_data_set, timesteps_in, timesteps_out).values)

    train_x, train_y = reframed_train_data_set[:, :-timesteps_out], reframed_train_data_set[:, -timesteps_out:]
    # 将数据集重构为符合LSTM要求的数据格式,即 [样本数,时间步,特征]

    train_x = train_x.reshape((train_x.shape[0], timesteps_in, 1))
    return train_x, train_y

"""
将时间序列数据转换为适用于监督学习的数据
给定输入、输出序列的长度
data: 观察序列
n_in: 观测数据input(X)的步长,范围[1, len(data)], 默认为1
n_out: 观测数据output(y)的步长, 范围为[0, len(data)-1], 默认为1
dropnan: 是否删除NaN行
返回值:适用于监督学习的 DataFrame
"""
def series_to_supervised(data, n_in=1, n_out=1, dropnan=True):
    print(data.shape)
    n_vars = 1 if type(data) is list else data.shape[1]
    df = DataFrame(data)
    cols, names = list(), list()
    # input sequence (t-n, ... t-1)
    for i in range(n_in, 0, -1):
        cols.append(df.shift(i))
        names += [('var%d(t-%d)' % (j + 1, i)) for j in range(n_vars)]
    # 预测序列 (t, t+1, ... t+n)
    for i in range(0, n_out):
        cols.append(df.shift(-i))
        if i == 0:
            names += [('var%d(t)' % (j + 1)) for j in range(n_vars)]
        else:
            names += [('var%d(t+%d)' % (j + 1, i)) for j in range(n_vars)]
    # 拼接到一起
    agg = concat(cols, axis=1)
    agg.columns = names
    # 去掉NaN行
    if dropnan:
        agg.dropna(inplace=True)
    return agg

# 使用LSTM进行预测
def lstm_model(source_data_set, train_x, label_y, input_epochs, input_batch_size, timesteps_out):
    model = Sequential()
    
    # 第一层, 隐藏层神经元节点个数为128, 返回整个序列
    model.add(LSTM(128, return_sequences=True, activation='tanh', input_shape=(train_x.shape[1], train_x.shape[2])))
    # 第二层,隐藏层神经元节点个数为128, 只返回序列最后一个输出
    model.add(LSTM(128, return_sequences=False))
    model.add(Dropout(0.5))
    # 第三层 因为是回归问题所以使用linear
    model.add(Dense(timesteps_out, activation='linear'))
    model.compile(loss='mean_squared_error', optimizer='adam')

    # LSTM训练 input_epochs次数
    res = model.fit(train_x, label_y, epochs=input_epochs, batch_size=input_batch_size, verbose=2, shuffle=False)

    # 模型预测
    train_predict = model.predict(train_x)
    #test_data_list = list(chain(*test_data))
    train_predict_list = list(chain(*train_predict))

    plt.plot(res.history['loss'], label='train')
    plt.show()
    #print(model.summary())
    plot_img(source_data_set, train_predict)

# 呈现原始数据,训练结果,验证结果,预测结果
def plot_img(source_data_set, train_predict):
    plt.figure(figsize=(24, 8))
    # 原始数据蓝色
    plt.plot(source_data_set[:, -1], c='b',label = '标签')
    # 训练数据绿色
    plt.plot([x for x in train_predict], c='g')
    plt.legend()
    plt.show()

# 设置观测数据input(X)的步长(时间步),epochs,batch_size
timesteps_in = 3
timesteps_out = 3
epochs = 1000
batch_size = 100
data = pd.read_csv('./shanghai_index_1990_12_19_to_2019_12_11.csv')
data_set = data[['Price']].values.astype('float64')
# 转化为可以用于监督学习的数据
train_x, label_y = get_train_set(data_set, timesteps_in=timesteps_in, timesteps_out=timesteps_out)

print(train_x, label_y )
print(train_x.shape)
print(train_x.shape[1], train_x.shape[2])

# 使用LSTM进行训练、预测
lstm_model(data_set, train_x, label_y, epochs, batch_size, timesteps_out=timesteps_out)

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/361742.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CATCTF wife原型链污染

CATCTF wife原型链污染 原型链污染原理:https://drun1baby.github.io/2022/12/29/JavaScript-%E5%8E%9F%E5%9E%8B%E9%93%BE%E6%B1%A1%E6%9F%93/ 如下代码,prototype是newClass类的一个属性。newClass 实例化的对象 newObj 的 .__proto__ 指向 newClass…

数字证书不仅有SSL证书,数字证书类型科普大全

数字证书不仅有SSL证书,数字证书类型科普大全证书不单单有SSL证书,还有SSL证书、代码签名证书以及客户端证书等。具体分类跟着陕西CA一起来看看吧。SSL证书:单域名SSL证书、多域名SSL证书、通配符SSL证书代码签名证书:单位代码签名…

管理系统权限分析以及白屏处理

菜单权限的业务分析 超级管理员:首页、权限模块、商品模块 不同角色能看到的菜单是不一样的。 如何实现菜单的权限 登录时向服务器发请求,服务器会把用户相应的菜单的权限信息,返回给前端,可以根据服务器返回的数据&#xff0…

sql复习(set运算符、高级子查询)

一、set运算符 union:得到两个查询结果的并集,并且⾃动去掉重复⾏。不会排序 union all:得到两个查询结果的并集,不会去掉重复⾏。也不会排序 intersect:得到两个查询结果的交集,并且按照结果集的第⼀个列进…

JavaSEI学习day10 基础班知识点回顾

一. 注释 注释是对代码的一种解释,在程序 的指定位置, 添加的一些说明性信息被注释掉的内容, 不会参与程序的编译和运行. 分类:单行: // 注释信息多行: /* 注释信息 */文档: /** 注释信息 */ 二. 关键字 Java语言中有特殊含义的单词,就是关键字。在后面的课程中…

基于Django的员工管理系统

目录 一、新建项目 二、创建app 三、设计表结构 四、在MySQL中生成表 五、静态文件管理 六、添加页面 七、模板的继承 一、新建项目 django-admin startproject 员工管理系统 二、创建app startapp app01 三、设计表结构 app01/migrations/models.py from django.db impo…

类与类之间的关系有哪几种?

文章目录程序设计要素1.可读性2.健壮性3.优化4.复用性5.可扩展性设计类的关系遵循的原则1、 高内聚低耦合2、面向对象开发中 “针对接口编程优于针对实现编程”,”组合优于继承” 的总体设计类与类之间的关系(即事物关系) A is-a B 泛化&…

模拟用户登录-课后程序(JAVA基础案例教程-黑马程序员编著-第五章-课后作业)

【案例5-3】 模拟用户登录 【案例介绍】 1.任务描述 在使用一些APP时,通常都需要填写用户名和密码。用户名和密码输入都正确才会登录成功,否则会提示用户名或密码错误。 本例要求编写一个程序,模拟用户登录。程序要求如下: 用…

Redis02: Redis基础命令

一、基础命令 先启动redis服务,使用redis-cli客户端连到redis数据库里面 1. 获取符合规则的键: keys 要点: (1)keys 后面可以指定正则表达式 (2)在生产环境下建议禁用keys命令,因为这个命令会查…

为什么要经常阅读和分析计算机SCI期刊论文? - 易智编译EaseEditing

训练阅读与分析期刊论文的能力,可以增加中长期的学术竞争力。 只要能够充分掌握阅读与分析期刊论文的技巧,就可以水到渠成地轻松进行「创新」的工作。 所以,只要深入掌握到阅读与分析期刊论文的技巧,就可以掌握到大学生不曾研习过…

koa中间件的实现原理

koa中间件的实现原理如何?先来看一个例子。koa的执行顺序是这样的:const middleware asyncfunction (ctx, next) {console.log(1)await next()console.log(6) }const middleware2 asyncfunction (ctx, next) {console.log(2)await next()console.log(5…

GCN项目实战1-SimGNN

文章目录SimGNN:快速图相似度计算的神经网络方法1. 数据2. 模型2.1 python文件功能介绍2.2 重要函数和类的实现SimGNN:快速图相似度计算的神经网络方法 原论文名称:SimGNN: A Neural Network Approach to Fast Graph Similarity Computation…

2023年,即时配送迎来黄金年,其他玩家该如何“弯道超车”?

我们知道,面对如此宏达的快递行业,它的市场一直被许多人所看好。尤其近几年,随着电商、物流的发展,各互联网公司纷纷跻身这一领域。现在市面上为大家所熟知的三通一达、极兔以及顺丰等等。 除此之外,在一些细分领域&a…

AWS攻略——Peering连接VPC

文章目录创建IP/CIDR不覆盖的VPC创建VPC创建子网创建密钥对创建EC2创建Peering接受Peering邀请修改各个VPC的路由表修改美东us-east-1 pulic subnet的路由修改悉尼ap-southeast-2路由测试知识点我们回顾下《AWS攻略——VPC初识》中的知识: 一个VPC只能设置在一个Re…

劳特巴赫仿真测试工具Trace32的基本使用(cmm文件)

劳特巴赫 Trace32 调试使用教程 使用PRACTICE 脚本(.cmm) 在TRACE32 中使用PRACTICE 脚本(*.cmm)将帮助你: 在调试器启动时立即执行命令根据您的项目需求自定义TRACE32PowerView用户界面加载应用程序或符号使调试操作具有可重复性, 并可用于验证目的和回归测试 自动启动脚本…

Spring Boot系列--创建第一个Spring Boot项目

1.项目搭建 在IDEA中新建项目,选择Spring Initializr。 填写项目信息: 选择版本和Spring Web依赖: Spring Web插件能为项目集成Tomcat、配置dispatcherServlet和xml文件。此处选择的版本若为3.0.2的话会出现如下错误: java: …

【C++】文件IO流

一起来康康C中的文件IO操作吧 文章目录1.operator bool2.C文件IO流3.文件操作3.0 关于按位与的说明3.1 ifstream3.2 ofstream流插入文本3.3 ostringstream/istringstream3.4 stringstream3.5使用stringstream的注意事项结语1.operator bool 之前写OJ的时候,就已经用…

Python学习-----排序问题1.0(冒泡排序、选择排序、插入排序)

目录 前言: 1.冒泡排序 2.选择排序 3.插入排序 前言: 学过C语言肯定接触过排序问题,我们最常用的也就是冒泡排序、选择排序、插入排序……等等,同样在Python中也有排序问题,这里我也会讲解Python中冒泡排序、选择排…

MvvmLight框架入门

MvvmLight MvvmLight主要程序库 ViewModelLocator 在.netFramework环境下nuget安装MvvmLight会自动安装依赖项MvvmLightLibs库,如果在.net core环境下需要手动安装MvvmLightLibs库。 .netFramework环境下安装完成后,在ViewModel目录下具有一个 ViewMo…

page cache设计及实现

你好,我是安然无虞。 page cache的设计及实现 page cache 本质上也是一个哈希桶, 它是按照页的数量进行映射的. 当 central cache 向 page cache 申请内存时, page cache 先检查对应位置是否有span, 如果没有则向更大页去寻找一个span, 如果找到则分裂成两个. 比如…