深度学习之搭建LSTM模型预测股价

news2024/11/28 20:40:31

      大家好,我是带我去滑雪!

     本期利用Google股价数据集,该数据集中GOOG_Stock_Price_Train.csv为训练集,GOOG_Stock_Price_Test.csv为测试集,里面有开盘价、最高股价、最低股价、收盘价、调整后的收盘价、成交量,2021年11月以前,可以在美国Yahoo网站下载股价历史数据,但现在对中国已经禁用了,可以去其他地方进行下载。本次使用调整后的收盘价进行预测。

目录

1、导入相关模块和数据集

2、产生训练所需的特征和标签数据

3、转换数据为(样本数,时步、特征)的张量

4、定义LSTM模型

5、使用已经训练好的LSTM模型预测股价

6、绘制真实股价与预测股价的对比图


1、导入相关模块和数据集

import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from keras.models import Sequential
from keras.layers import Dense, Dropout,LSTM,SimpleRNN,GRU

# 载入Google股价数据集 
df_train = pd.read_csv(r'E:\工作\硕士\博客\博客37-\GOOG_Stock_Price_Train.csv',index_col="Date",parse_dates=True)
print(df_train)
df_test = pd.read_csv(r'E:\工作\硕士\博客\博客37-\GOOG_Stock_Price_Test.csv',index_col="Date",parse_dates=True)
print(df_test )

输出结果:

            Open        High         Low       Close   Adj Close  \
Date                                                                     
2012-01-03  324.360352  331.916199  324.077179  330.555054  330.555054   
2012-01-04  330.366272  332.959412  328.175537  331.980774  331.980774   
2012-01-05  328.925659  329.839722  325.994720  327.375732  327.375732   
2012-01-06  327.445282  327.867523  322.795532  322.909790  322.909790   
2012-01-09  321.161163  321.409546  308.607819  309.218842  309.218842   
...                ...         ...         ...         ...         ...   
2016-12-23  790.900024  792.739990  787.280029  789.909973  789.909973   
2016-12-27  790.679993  797.859985  787.656982  791.549988  791.549988   
2016-12-28  793.700012  794.229980  783.200012  785.049988  785.049988   
2016-12-29  783.330017  785.929993  778.919983  782.789978  782.789978   
2016-12-30  782.750000  782.780029  770.409973  771.820007  771.820007   

              Volume  
Date                  
2012-01-03   7400800  
2012-01-04   5765200  
2012-01-05   6608400  
2012-01-06   5420700  
2012-01-09  11720900  
...              ...  
2016-12-23    623400  
2016-12-27    789100  
2016-12-28   1153800  
2016-12-29    742200  
2016-12-30   1770000  

[1258 rows x 6 columns]
                  Open        High         Low       Close   Adj Close  \
Date                                                                     
2017-01-03  778.809998  789.630005  775.799988  786.140015  786.140015   
2017-01-04  788.359985  791.340027  783.159973  786.900024  786.900024   
2017-01-05  786.080017  794.479980  785.020020  794.020020  794.020020   
2017-01-06  795.260010  807.900024  792.203979  806.150024  806.150024   
2017-01-09  806.400024  809.966003  802.830017  806.650024  806.650024   
...                ...         ...         ...         ...         ...   
2017-04-24  851.200012  863.450012  849.859985  862.760010  862.760010   
2017-04-25  865.000000  875.000000  862.809998  872.299988  872.299988   
2017-04-26  874.229980  876.049988  867.747986  871.729980  871.729980   
2017-04-27  873.599976  875.400024  870.380005  874.250000  874.250000   
2017-04-28  910.659973  916.849976  905.770020  905.960022  905.960022   

             Volume  
Date                 
2017-01-03  1657300  
2017-01-04  1073000  
2017-01-05  1335200  
2017-01-06  1640200  
2017-01-09  1272400  
...             ...  
2017-04-24  1372500  
2017-04-25  1672000  
2017-04-26  1237200  
2017-04-27  2026800  
2017-04-28  3219500  

[81 rows x 6 columns]

2、产生训练所需的特征和标签数据

X_train_set = df_train.iloc[:,4:5].values 
#数据归一化
sc = MinMaxScaler() 
X_train_set = sc.fit_transform(X_train_set)
 
def create_dataset(ds, look_back=1):
    X_data, Y_data = [],[]
    for i in range(len(ds)-look_back):
        X_data.append(ds[i:(i+look_back), 0])
        Y_data.append(ds[i+look_back, 0])
    return np.array(X_data), np.array(Y_data)
look_back = 60
print("回看天数:", look_back)
 
# 分割成特征数据和标签数据
X_train, Y_train = create_dataset(X_train_set, look_back)
 
X_train
Y_train

输出结果:

回看天数: 60

Out[5]:

array([0.08291369, 0.07626093, 0.0815312 , ..., 0.94758974, 0.94336851,
       0.92287887])

3、转换数据为(样本数,时步、特征)的张量

X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
X_train.shape 

输出结果:

(1198, 60, 1)

4、定义LSTM模型

      在编译模型中,损失函数为MSE,优化器为adam。在训练模型中,训练周期为100,批次尺寸为32。 

model = Sequential()
model.add(LSTM(50, return_sequences=True, 
               input_shape=(X_train.shape[1], 1)))
model.add(Dropout(0.2))
model.add(LSTM(50, return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(50))
model.add(Dropout(0.2))
model.add(Dense(1))
model.summary()  
#编译模型
model.compile(loss="mse", optimizer="adam") 
#训练模型
model.fit(X_train, Y_train, epochs=100, batch_size=32)

输出结果:

38/38 [==============================] - 2s 46ms/step - loss: 0.0013
Epoch 94/100
38/38 [==============================] - 2s 46ms/step - loss: 0.0013
Epoch 95/100
38/38 [==============================] - 2s 47ms/step - loss: 0.0012
Epoch 96/100
38/38 [==============================] - 2s 46ms/step - loss: 0.0013
Epoch 97/100
38/38 [==============================] - 2s 46ms/step - loss: 0.0013
Epoch 98/100
38/38 [==============================] - 2s 47ms/step - loss: 0.0013
Epoch 99/100
38/38 [==============================] - 2s 46ms/step - loss: 0.0012
Epoch 100/100
38/38 [==============================] - 2s 46ms/step - loss: 0.0013

5、使用已经训练好的LSTM模型预测股价

       测试集为2017年1月到3月的股价,因为使用的是前60天的股价数据,使用预测的是4月份股价 。

X_test_set = df_test.iloc[:,4:5].values
 
# 产生标签数据
_, Y_test = create_dataset(X_test_set, look_back)
 
#特征数据和标准化
X_test_s = sc.transform(X_test_set)
X_test,_ = create_dataset(X_test_s, look_back)
 
# 转换成(样本数, 时步, 特征)张量
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
X_test_pred = model.predict(X_test)
 
#  将预测值转换回股价
X_test_pred_price = sc.inverse_transform(X_test_pred)
X_test_pred_price

输出结果:

array([[814.5596 ],
       [819.2384 ],
       [821.1239 ],
       [823.5624 ],
       [824.0013 ],
       [822.3476 ],
       [819.3523 ],
       [816.00055],
       [813.82117],
       [812.62726],
       [812.6262 ],
       [812.9471 ],
       [817.2544 ],
       [821.539  ],
       [824.44244],
       [826.5891 ],
       [828.0157 ],
       [834.4217 ],
       [843.3087 ],
       [849.4051 ],
       [852.694  ]], dtype=float32)

6、绘制真实股价与预测股价的对比图

import matplotlib.pyplot as plt
plt.plot(Y_test, color="red", label="Real Stock Price")
plt.plot(X_test_pred_price, color="blue", label="Predicted Stock Price")
plt.title("2017 Google Stock Price Prediction")
plt.xlabel("Time")
plt.ylabel("Google Time Price")
plt.legend()
plt.savefig("E:\工作\硕士\博客\博客37-/squares1.png",
            bbox_inches ="tight",
            pad_inches = 1,
            transparent = True,
            facecolor ="w",
            edgecolor ='w',
            dpi=300,
            orientation ='landscape')

输出结果:

 


更多优质内容持续发布中,请移步主页查看。

   点赞+关注,下次不迷路!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/550642.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flutter项目webview加载没有HTTPS证书的网页在Android和iOS设备上无法显示的解决方案

一、问题描述 Flutter项目使用谷歌官方webview库 webview_flutter,加载自签名证书、证书失效、无证书等HTTPS网页地址时,在Android或pc浏览器中提示证书失效,在iOS设备上为空白页,为了加载自签名证书的网页,需要饶过i…

AVR单片机ATemga328P中断原理的介绍

1、一AVR单片机中断原理的介绍 ATmega328P微控制器具有两个外部中断引脚,分别是INT0和INT1。 外部中断0(INT0):它对应的引脚是PD2(数字引脚2)。INT0可以用于响应外部信号的边沿触发(上升沿、下…

【服务器】使用Nodejs搭建HTTP web服务器

Yan-英杰的主页 悟已往之不谏 知来者之可追 C程序员,2024届电子信息研究生 目录 前言 1.安装Node.js环境 2.创建node.js服务 3. 访问node.js 服务 4.内网穿透 4.1 安装配置cpolar内网穿透 4.2 创建隧道映射本地端口 5.固定公网地址 [TOC] 转载自内网穿透…

Unity Addressables学习笔记(1)---创建远程服务器加载资源

例子1:加载一个图片 1.首先创建一个UI Image,空白图片,资源打包方式选择真是部署的 2.修改远程发布和加载配置 Bulid Path选择RemoteBuildPath Load Path我选择了custom,地址是http://localhost:8080/WebGL/ 遇坑1 :最开始我选择的Build Path 是 Loca…

windows安装mysql 5.7.41

前言 要学mysql,肯定得本地装上一个玩一玩啦,下面一起来安装mysql吧 一、下载 https://downloads.mysql.com/archives/community/ 顺便说一下,下载按钮下方有个md5,可以验证下文件是否被篡改,理论上官网下载的应该问…

初识结构体

目录 结构体的声明 结构体的基础知识 结构体的声明 结构体成员的类型 结构体变量的定义和初始化 定义 初始化 结构体成员的访问 结构体变量访问成员 结构体指针访问指向变量的成员 结构体传参 传地址 传结构体 结论 结构体的声明 结构体的基础知识 数组&#xff…

【ChatGPT】IOS如何下载注册使用ChatGPT的APP(教学)

👉博__主👈:米码收割机 👉技__能👈:C/Python语言 👉公众号👈:测试开发自动化 👉专__注👈:专注主流机器人、人工智能等相关领域的开发、…

iptables 防火墙

iptables概述 Linux系统的防火墙:ip信息过滤系统,它实际上由两个组件netfilter和iptables组成。 主要工作在网络层,针对IP数据包。体现在对包内的IP地址、端口、协议等信息的处理上 netfilter / iptables关系: netfilter:属于…

Electron中如何创建模态窗口?

目录 前言一、模态窗口1.Web页面模态框2.Electron中的模态窗口3.区分父子窗口与模态窗口 二、实际案例使用总结 前言 模态框是一种常用的交互元素,无论是在 Web 网站、桌面应用还是移动 APP 中,都有其应用场景。模态框指的是一种弹出窗口,它…

【TES714】JFM7K325T(复旦微FPGA)+HI3531DV200(华为海思)的综合视频处理平台设计原理图及调试经验

板卡概述 TES714是自主研制的一款5路HD-SDI视频采集图像处理平台,该平台采用上海复旦微的高性能Kintex系列FPGA加上华为海思的高性能视频处理器HI3531DV200来实现。 华为海思的HI3531DV200是一款集成了ARM A53四核处理器性能强大的神经网络引擎,支持多种…

【运维知识进阶篇】集群架构-Nginx动静分离详解

我们先前将静态资源放到NFS,动态资源放到MySQL,一是为了提高我们Web服务器性能,减轻它的压力,另一面如果Web宕机了,我们的静态和动态资源还可以访问到。但是之前方式不管是静态还是动态文件,都是走的代码文…

ssl vpn 与 ipsec vpn 区别

VPN 安全协议有两种主要类型,IPsec 和 SSL,了解它们之间的区别对于确保客户的安全至关重要。在本文中,我们将解释IPsec 和 SSL VPN 协议之间的区别,以及如何选择合适的协议来满足客户的需求。了解更多SSL技术最新信息,…

Linux_证书_Openssl实现对称加密、非对称加密、CA颁布证书

文章目录 OpenSSLopenssl实现对称加密openssl实现非对称加密生成密钥对非对称加密数字签名小结 根据CA颁布证书生成ca私钥和ca证书根据ca生成证书 尾声 OpenSSL 常用证书生成工具包括三个:ssh-keygen、cfssl、openssl。这里介绍 OpenSSL , OpenSSL 是一个开源项目&…

【Python从入门到进阶】20、HTML页面结构的介绍

接上篇《19、Python异常处理》 上一篇我们学习了Python中有关异常(捕获异常、处理异常等)的知识。从本篇开始,我们进入Python的实战教程,学习爬虫的相关技术,本篇主要讲解要爬取的HTML页面的结构。 一、一个场景 假设…

Godot引擎 4.0 文档 - 入门介绍 - Godot 编辑器

本文为Google Translate英译中结果,DrGraph在此基础上加了一些校正。英文原版页面: First look at Godots editor — Godot Engine (stable) documentation in English Godot的编辑器 本页将为您简要介绍 Godot 的界面。我们将查看不同的主屏幕和停靠栏…

C语言:字符函数和字符串函数详解及部分函数的模拟实现(前篇)

文章目录 求字符串长度strlenstrlen函数的模拟实现: 长度不受限制的字符串函数strcpystrcatstrcmp总结 长度受限制的字符串函数介绍strncpystrncatstrncmp 前言: C语言中对字符和字符串的处理很是频繁,但是C语言本身是没有字符串类型的,字符串…

【LeetCode】382. 链表随机节点

382. 链表随机节点(中等) 方法一 思路 定义两个链表,一个origin,用于每次调用 getRandom() 时进行初始化,一个 l 用于每次调用 getRandom() 时进行遍历,找到随机选定的元素。首先在 Solution() 的时候&am…

SpringBoot原理——起步依赖与自动装配

文章目录 SpringBoot原理一、起步依赖二、自动配置2.1 概述2.2 工具类准备工作2.2.2 HeaderConfig2.2.3 HeaderGenerator2.2.4 HeaderParser2.2.5 MyImportSelector2.2.6 TokenParser2.2.7 pom.xml文件 2.3 自动配置原理2.3.1 引入工具类2.3.2 案例 : 访问第三方Bea…

GPT专业应用:撰写工作简报

●图片由Lexica 生成,输入:Workers working overtime 工作简报,作为一种了解情况、沟通信息的有效手段,能使上级机关和领导及时了解、掌握所属部门的政治学习、军事训练、行政管理等方面的最新情况;同时,能…

BERT输入以及权重矩阵形状解析

以下用形状来描述矩阵。对于向量,为了方便理解,也写成了类似(1,64)这种形状的表示形式,这个你理解为64维的向量即可。下面讲的矩阵相乘都是默认的叉乘。 词嵌入矩阵形状:以BERT_BASE为例,我们知道其有12层Encoder&…