【机器学习实战入门】使用LSTM机器学习预测股票价格

news2025/2/23 23:20:30

在这里插入图片描述
在这里插入图片描述
机器学习在股票价格预测中有重要的应用。在这个机器学习项目中,我们将讨论如何预测股票的收益。这是一个非常复杂的任务,充满了不确定性。我们将会把这个项目分成两部分进行开发:

首先,我们将学习如何使用 LSTM 神经网络预测股票价格。
然后,我们将使用 Plotly Dash 构建一个用于股票分析的仪表板。
在这里插入图片描述

股票价格预测项目仪表板

股票价格预测项目
数据集
为了构建股票价格预测模型,我们将使用“印度国家证券交易所(NSE)TATA GLOBAL”数据集。这是来自印度国家证券交易所的 Tata 全球饮料有限公司的 Tata 饮料数据集:
为了构建股票分析的仪表板,我们将使用另一个包含多个股票(如苹果、微软、脸书)的数据集:
源代码
下载地址:链接: 源代码 及 Tata 饮料数据集 多个股票(如苹果、微软、脸书)的数据集

使用 LSTM 预测股票价格

  1. 导入:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib.pylab import rcParams
rcParams['figure.figsize']=20,10
from keras.models import Sequential
from keras.layers import LSTM,Dropout,Dense
from sklearn.preprocessing import MinMaxScaler
  1. 读取数据集:
df=pd.read_csv("NSE-TATA.csv")
df.head()

读取股票数据
在这里插入图片描述

  1. 从数据框中分析收盘价:
df["Date"]=pd.to_datetime(df.Date,format="%Y-%m-%d")
df.index=df['Date']
plt.figure(figsize=(16,8))
plt.plot(df["Close"],label='Close Price history')

分析股票价格
在这里插入图片描述

  1. 按日期时间排序并筛选“Date”和“Close”列:
data=df.sort_index(ascending=True,axis=0)
new_dataset=pd.DataFrame(index=range(0,len(df)),columns=['Date','Close'])
for i in range(0,len(data)):
    new_dataset["Date"][i]=data['Date'][i]
    new_dataset["Close"][i]=data["Close"][i]
  1. 对新的筛选数据集进行归一化:
scaler=MinMaxScaler(feature_range=(0,1))
final_dataset=new_dataset.values
train_data=final_dataset[0:987,:]
valid_data=final_dataset[987:,:]
new_dataset.index=new_dataset.Date
new_dataset.drop("Date",axis=1,inplace=True)
scaler=MinMaxScaler(feature_range=(0,1))
scaled_data=scaler.fit_transform(final_dataset)
x_train_data,y_train_data=[],[]
for i in range(60,len(train_data)):
    x_train_data.append(scaled_data[i-60:i,0])
    y_train_data.append(scaled_data[i,0])
    
x_train_data,y_train_data=np.array(x_train_data),np.array(y_train_data)
x_train_data=np.reshape(x_train_data,(x_train_data.shape[0],x_train_data.shape[1],1))
  1. 构建和训练 LSTM 模型:
lstm_model=Sequential()
lstm_model.add(LSTM(units=50,return_sequences=True,input_shape=(x_train_data.shape[1],1)))
lstm_model.add(LSTM(units=50))
lstm_model.add(Dense(1))
inputs_data=new_dataset[len(new_dataset)-len(valid_data)-60:].values
inputs_data=inputs_data.reshape(-1,1)
inputs_data=scaler.transform(inputs_data)
lstm_model.compile(loss='mean_squared_error',optimizer='adam')
lstm_model.fit(x_train_data,y_train_data,epochs=1,batch_size=1,verbose=2)
  1. 从数据集中抽取样本,利用 LSTM 模型进行股票价格预测:
X_test=[]
for i in range(60,inputs_data.shape[0]):
    X_test.append(inputs_data[i-60:i,0])
X_test=np.array(X_test)
X_test=np.reshape(X_test,(X_test.shape[0],X_test.shape[1],1))
predicted_closing_price=lstm_model.predict(X_test)
predicted_closing_price=scaler.inverse_transform(predicted_closing_price)
  1. 保存 LSTM 模型:
lstm_model.save("saved_model.h5")
  1. 真实股票成本与预测股票成本对比可视化:
train_data=new_dataset[:987]
valid_data=new_dataset[987:]
valid_data['Predictions']=predicted_closing_price
plt.plot(train_data["Close"])
plt.plot(valid_data[['Close',"Predictions"]])

可以看到,LSTM 模型预测的股票价格与实际股票价格相当接近。
在这里插入图片描述

使用 Plotly Dash 构建仪表板
在本节中,我们将构建一个仪表板用于分析股票。Dash 是一个 Python 框架,它在 Flask 和 React.js 之上提供了一层抽象,用于构建分析型 Web 应用程序。
在继续之前,你需要安装 Dash。在终端运行以下命令。

pip3 install dash
pip3 install dash-html-components
pip3 install dash-core-components

现在创建一个新的 Python 文件 stock_app.py 并粘贴以下脚本:

import dash
import dash_core_components as dcc
import dash_html_components as html
import pandas as pd
import plotly.graph_objs as go
from dash.dependencies import Input, Output
from keras.models import load_model
from sklearn.preprocessing import MinMaxScaler
import numpy as np

app = dash.Dash()
server = app.server
scaler = MinMaxScaler(feature_range=(0,1))
df_nse = pd.read_csv("./NSE-TATA.csv")
df_nse["Date"] = pd.to_datetime(df_nse.Date, format="%Y-%m-%d")
df_nse.index = df_nse['Date']
data = df_nse.sort_index(ascending=True, axis=0)
new_data = pd.DataFrame(index=range(0, len(df_nse)), columns=['Date', 'Close'])
for i in range(0, len(data)):
    new_data["Date"][i] = data['Date'][i]
    new_data["Close"][i] = data["Close"][i]
new_data.index = new_data.Date
new_data.drop("Date", axis=1, inplace=True)
dataset = new_data.values
train = dataset[0:987, :]
valid = dataset[987:, :]
scaler = MinMaxScaler(feature_range=(0,1))
scaled_data = scaler.fit_transform(dataset)
x_train, y_train = [], []
for i in range(60, len(train)):
    x_train.append(scaled_data[i-60:i, 0])
    y_train.append(scaled_data[i, 0])
    
x_train, y_train = np.array(x_train), np.array(y_train)
x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1], 1))
model = load_model("saved_model.h5")
inputs = new_data[len(new_data)-len(valid)-60:].values
inputs = inputs.reshape(-1, 1)
inputs = scaler.transform(inputs)
X_test = []
for i in range(60, inputs.shape[0]):
    X_test.append(inputs[i-60:i, 0])
X_test = np.array(X_test)
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
closing_price = model.predict(X_test)
closing_price = scaler.inverse_transform(closing_price)
train = new_data[:987]
valid = new_data[987:]
valid['Predictions'] = closing_price
df = pd.read_csv("./stock_data.csv")

app.layout = html.Div([
    html.H1("股票价格分析仪表板", style={"textAlign": "center"}),
    dcc.Tabs(id="tabs", children=[
        dcc.Tab(label='NSE-TATAGLOBAL 股票数据', children=[
            html.Div([
                html.H2("实际收盘价", style={"textAlign": "center"}),
                dcc.Graph(
                    id="Actual Data",
                    figure={
                        "data": [
                            go.Scatter(
                                x=train.index,
                                y=valid["Close"],
                                mode='markers'
                            )
                        ],
                        "layout": go.Layout(
                            title='散点图',
                            xaxis={'title': '日期'},
                            yaxis={'title': '收盘价'}
                        )
                    }
                ),
                html.H2("LSTM 预测收盘价", style={"textAlign": "center"}),
                dcc.Graph(
                    id="Predicted Data",
                    figure={
                        "data": [
                            go.Scatter(
                                x=valid.index,
                                y=valid["Predictions"],
                                mode='markers'
                            )
                        ],
                        "layout": go.Layout(
                            title='散点图',
                            xaxis={'title': '日期'},
                            yaxis={'title': '收盘价'}
                        )
                    }
                )                
            ])                
        ]),
        dcc.Tab(label='脸书股票数据', children=[
            html.Div([
                html.H1("脸书最高价与最低价对比", 
                        style={'textAlign': 'center'}),
                dcc.Dropdown(id='my-dropdown',
                             options=[{'label': '特斯拉', 'value': 'TSLA'},
                                      {'label': '苹果', 'value': 'AAPL'}, 
                                      {'label': '脸书', 'value': 'FB'}, 
                                      {'label': '微软', 'value': 'MSFT'}], 
                             multi=True, value=['FB'],
                             style={"display": "block", "margin-left": "auto", 
                                    "margin-right": "auto", "width": "60%"}),
                dcc.Graph(id='highlow'),
                html.H1("脸书市场交易量", style={'textAlign': 'center'}),
                dcc.Dropdown(id='my-dropdown2',
                             options=[{'label': '特斯拉', 'value': 'TSLA'},
                                      {'label': '苹果', 'value': 'AAPL'}, 
                                      {'label': '脸书', 'value': 'FB'},
                                      {'label': '微软', 'value': 'MSFT'}], 
                             multi=True, value=['FB'],
                             style={"display": "block", "margin-left": "auto", 
                                    "margin-right": "auto", "width": "60%"}),
                dcc.Graph(id='volume')
            ], className="container"),
        ])
    ])
])

@app.callback(Output('highlow', 'figure'),
              [Input('my-dropdown', 'value')])
def update_graph(selected_dropdown):
    dropdown = {"TSLA": "特斯拉", "AAPL": "苹果", "FB": "脸书", "MSFT": "微软"}
    trace1 = []
    trace2 = []
    for stock in selected_dropdown:
        trace1.append(
            go.Scatter(x=df[df["Stock"] == stock]["Date"],
                       y=df[df["Stock"] == stock]["High"],
                       mode='lines', opacity=0.7, 
                       name=f'高 {dropdown[stock]}', textposition='bottom center'))
        trace2.append(
            go.Scatter(x=df[df["Stock"] == stock]["Date"],
                       y=df[df["Stock"] == stock]["Low"],
                       mode='lines', opacity=0.6,
                       name=f'低 {dropdown[stock]}', textposition='bottom center'))
    traces = [trace1, trace2]
    data = [val for sublist in traces for val in sublist]
    figure = {'data': data,
              'layout': go.Layout(colorway=["#5E0DAC", '#FF4F00', '#375CB1', 
                                            '#FF7400', '#FFF400', '#FF0056'],
              height=600,
              title=f"随时间变化的高低价:{', '.join(str(dropdown[i]) for i in selected_dropdown)}",
              xaxis={"title": "日期",
                     'rangeselector': {'buttons': list([{'count': 1, 'label': '1M', 
                                                        'step': 'month', 
                                                        'stepmode': 'backward'},
                                                       {'count': 6, 'label': '6M', 
                                                        'step': 'month', 
                                                        'stepmode': 'backward'},
                                                       {'step': 'all'}])},
                     'rangeslider': {'visible': True}, 'type': 'date'},
              yaxis={"title": "价格(美元)"}}}
    return figure

@app.callback(Output('volume', 'figure'),
              [Input('my-dropdown2', 'value')])
def update_graph(selected_dropdown_value):
    dropdown = {"TSLA": "特斯拉", "AAPL": "苹果", "FB": "脸书", "MSFT": "微软"}
    trace1 = []
    for stock in selected_dropdown_value:
        trace1.append(
            go.Scatter(x=df[df["Stock"] == stock]["Date"],
                       y=df[df["Stock"] == stock]["Volume"],
                       mode='lines', opacity=0.7,
                       name=f'交易量 {dropdown[stock]}', textposition='bottom center'))
    traces = [trace1]
    data = [val for sublist in traces for val in sublist]
    figure = {'data': data, 
              'layout': go.Layout(colorway=["#5E0DAC", '#FF4F00', '#375CB1', 
                                            '#FF7400', '#FFF400', '#FF0056'],
              height=600,
              title=f"随时间变化的市场交易量:{', '.join(str(dropdown[i]) for i in selected_dropdown_value)}",
              xaxis={"title": "日期",
                     'rangeselector': {'buttons': list([{'count': 1, 'label': '1M', 
                                                        'step': 'month', 
                                                        'stepmode': 'backward'},
                                                       {'count': 6, 'label': '6M',
                                                        'step': 'month', 
                                                        'stepmode': 'backward'},
                                                       {'step': 'all'}])},
                     'rangeslider': {'visible': True}, 'type': 'date'},
              yaxis={"title": "交易量"}}}
    return figure

if __name__ == '__main__':
    app.run_server(debug=True)

现在运行此文件并打开浏览器中的应用:

python3 stock_app.py

股票价格预测项目仪表板
在这里插入图片描述

摘要
股票价格预测是一个适合机器学习初学者的项目;在本教程中,我们学习了如何开发股票价格预测模型以及如何构建用于股票分析的交互式仪表板。我们实现了基于 LSTM 模型的股市预测。另一方面,我们使用了 Python 的 Plotly Dash 框架来构建仪表板。

参考文献及资料链接

参考资料链接
股票价格预测基础https://example.com/ml-basics
LSTM 神经网络教程https://example.com/lstm-tutorial
TensorFlow 官方文档https://tensorflow.org/docs
Keras 官方文档https://keras.io/zh/
Scikit-learn 文档https://scikit-learn.org/stable/
NSE TATA GLOBAL 数据集https://example.com/tata-global-dataset
股票数据集https://example.com/stocks-dataset
运行 Flask 扩展https://flask.palletsprojects.com/en/2.3.x/extensions/
Plotly 官方网站https://plotly.com/python/
Plotly 冲浪式图表 (Dash) 官方文档https://dash.plotly.com/
Pandas 官方文档https://pandas.pydata.org/pandas-docs/stable/
Numpy 官方文档https://numpy.org/doc/stable/
LSTM 股票预测实践https://medium.com/@example_lstm_pred
Dash 股票分析仪表板案例https://blog.plotly.com/dash-stock-examples/
源代码与数据集介绍

股票价格预测项目

在这个机器学习项目中,我们将开发一个基于神经网络的股票预测模型,用于预测股票收益。

学习如何开发股票价格预测模型,并构建一个用于股票分析的交互式仪表板。我们使用 LSTM 模型实现股票市场预测,并使用 Plotly Dash Python 框架构建仪表板。

类别:机器学习、深度学习
编程语言:Python
工具与库:Plotly Dash、LSTM
IDE:Jupyter
前端:Plotly Dash(用于可视化)
后端:无
先决条件:Python、机器学习、深度学习、神经网络
目标受众:教育、开发人员、数据工程师、数据科学家

股票价格数据

该数据集包含关于塔塔全球饮料有限公司(Tata Global Beverages Limited)的股票价格记录。数据集中还包含按日期排列的股票价格,包括开盘价、收盘价、最高价和最低价,以及当天的交易量和成交额。

对于想要尝试数据可视化、数据分析以及多种形式的数据处理技术的人来说,这是一个极好的数据库。

示例数据
NSE 塔塔全球饮料有限公司

数据格式

  • Date:日期
  • Open:开盘价
  • High:最高价
  • Low:最低价
  • Last:最新价
  • Close:收盘价
  • Total Trade Quantity:总交易量
  • Turnover (Lacs):成交额(单位:十万卢比)
    在这里插入图片描述

股票价格数据

该历史数据集包含关于苹果(Apple)、微软(Microsoft)、脸书(Facebook)等多家公司股票价格的记录。数据集中还包含按日期排列的股票价格,包括开盘价、收盘价、最高价和最低价,以及当天的交易量。

对于想要尝试数据可视化、数据分析以及多种形式的数据处理技术的人来说,这是一个极好的数据库。

示例数据
股票数据集

数据格式

  • Date:日期
  • Open:开盘价
  • High:最高价
  • Low:最低价
  • Close:收盘价
  • Volume:交易量
  • OpenInt:未平仓合约(适用于期货和期权)
  • Stock:股票名称或代码

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2280038.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DilateFormer: Multi-Scale Dilated Transformer for Visual Recognition 中的空洞自注意力机制

空洞自注意力机制 文章目录 摘要1. 模型解释1.1. 滑动窗口扩张注意力1.2. 多尺度扩张注意力 2. 代码3. 流程图3.1. MultiDilatelocalAttention3.2. DilateAttention3.3. MLP 摘要 本文针对DilateFormer中的空洞自注意力机制原理和代码进行详细介绍,最后通过流程图梳…

大模型GUI系列论文阅读 DAY2续:《一个具备规划、长上下文理解和程序合成能力的真实世界Web代理》

摘要 预训练的大语言模型(LLMs)近年来在自主网页自动化方面实现了更好的泛化能力和样本效率。然而,在真实世界的网站上,其性能仍然受到以下问题的影响:(1) 开放领域的复杂性,(2) 有限的上下文长度&#xff…

Qt按钮美化教程

前言 Qt按钮美化主要有三种方式:QSS、属性和自绘 QSS 字体大小 font-size: 18px;文字颜色 color: white;背景颜色 background-color: rgb(10,88,163); 按钮边框 border: 2px solid rgb(114,188,51);文字对齐 text-align: left;左侧内边距 padding-left: 10…

云IDE:开启软件开发的未来篇章

敖行客一直致力于将整个研发协作流程线上化,从而打破物理环境依赖,让研发组织模式更加灵活、自由且高效,今天就来聊聊AT Work(一站式研发协作平台)的重要组成部分-云IDE。 在科技领域,历史常常是未来的风向…

AI agent 在 6G 网络应用,无人机群控场景

AI agent 在 6G 网络应用,无人机群控场景 随着 6G 时代的临近,融合人工智能成为关键趋势。借鉴 IT 行业 AI Agent 应用范式,提出 6G AI Agent 技术框架,包含多模型融合、定制化 Agent 和插件式环境交互理念,构建了涵盖四层结构的框架。通过各层协同实现自主环境感知等能力…

【Linux 重装】Ubuntu 启动盘 U盘无法被识别,如何处理?

背景 U盘烧录了 Ubuntu 系统作为启动盘,再次插入电脑后无法被识别 解决方案(Mac 适用) (1)查找 USB,(2)格式化(1)在 terminal 中通过 diskutil list 查看是…

【优选算法篇】2----复写零

---------------------------------------begin--------------------------------------- 这道算法题相对于移动零,就上了一点点强度咯,不过还是很容易理解的啦~ 题目解析: 这道题如果没理解好题目,是很难的,但理解题…

高效建站指南:通过Portainer快速搭建自己的在线网站

文章目录 前言1. 安装Portainer1.1 访问Portainer Web界面 2. 使用Portainer创建Nginx容器3. 将Web静态站点实现公网访问4. 配置Web站点公网访问地址4.1公网访问Web站点 5. 固定Web静态站点公网地址6. 固定公网地址访问Web静态站点 前言 Portainer是一个开源的Docker轻量级可视…

redis性能优化参考——筑梦之路

基准性能测试 redis响应延迟耗时多长判定为慢? 比如机器硬件配置比较差,响应延迟10毫秒,就认为是慢,机器硬件配置比较高,响应延迟0.5毫秒,就认为是慢。这个没有固定的标准,只有了解了你的 Red…

Python 入门教程(2)搭建环境 | 2.3、VSCode配置Python开发环境

文章目录 一、VSCode配置Python开发环境1、软件安装2、安装Python插件3、配置Python环境4、包管理5、调试程序 前言 Visual Studio Code(简称VSCode)以其强大的功能和灵活的扩展性,成为了许多开发者的首选。本文将详细介绍如何在VSCode中配置…

Trimble三维激光扫描-地下公共设施维护的新途径【沪敖3D】

三维激光扫描技术生成了复杂隧道网络的高度详细的三维模型 项目背景 纽约州北部的地下通道网络已有100年历史,其中包含供暖系统、电线和其他公用设施,现在已经开始显露出老化迹象。由于安全原因,第三方的进入受到限制,在没有现成纸…

【强化学习】策略梯度(Policy Gradient,PG)算法

📢本篇文章是博主强化学习(RL)领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅…

Apache SeaTunnel 2.3.9 正式发布:多项新特性与优化全面提升数据集成能力

近日,Apache SeaTunnel 社区正式发布了最新版本 2.3.9。本次更新新增了Helm 集群部署、Transform 支持多表、Zeta新API、表结构转换、任务提交队列、分库分表合并、列转多行 等多个功能更新! 作为一款开源、分布式的数据集成平台,本次版本通过…

4 AXI USER IP

前言 使用AXI Interface封装IP,并使用AXI Interface实现对IP内部寄存器进行读写实现控制LED的demo,这个demo是非常必要的,因为在前面的笔记中基本都需哟PS端与PL端就行通信互相交互,在PL端可以通过中断的形式来告知PS端一些事情&…

B站评论系统的多级存储架构

以下文章来源于哔哩哔哩技术 ,作者业务 哔哩哔哩技术. 提供B站相关技术的介绍和讲解 1. 背景 评论是 B站生态的重要组成部分,涵盖了 UP 主与用户的互动、平台内容的推荐与优化、社区文化建设以及用户情感满足。B站的评论区不仅是用户互动的核心场所&…

电子科大2024秋《大数据分析与智能计算》真题回忆

考试日期:2025-01-08 课程:成电信软学院-大数据分析与智能计算 形式:开卷 考试回忆版 简答题(4*15) 1. 简述大数据的四个特征。分析每个特征所带来的问题和可能的解决方案 2. HDFS的架构的主要组件有哪些&#xff0…

多选multiple下拉框el-select回显问题(只显示后端返回id)

首先保证v-model的值对应options数据源里面的id <el-form-item prop"subclass" label"分类" ><el-select v-model"formData.subclass" multiple placeholder"请选择" clearable :disabled"!!formData.id"><e…

JavaWeb开发(十五)实战-生鲜后台管理系统(二)注册、登录、记住密码

1. 生鲜后台管理系统-注册功能 1.1. 注册功能 &#xff08;1&#xff09;创建注册RegisterServlet&#xff0c;接收form表单中的参数。   &#xff08;2&#xff09;service创建一个userService处理业务逻辑。   &#xff08;3&#xff09;RegisterServlet将参数传递给ser…

【MySQL系列文章】Linux环境下安装部署MySQL

前言 本次安装部署主要针对Linux环境进行安装部署操作,系统位数64 getconf LONG_BIT 64MySQL版本&#xff1a;v5.7.38 一、下载MySQL MySQL下载地址&#xff1a;MySQL :: Download MySQL Community Server (Archived Versions) 二、上传MySQL压缩包到Linuxx环境&#xff0c…

嵌入式硬件篇---基本组合逻辑电路

文章目录 前言基本逻辑门电路1.与门&#xff08;AND Gate&#xff09;2.或门&#xff08;OR Gate&#xff09;3.非门&#xff08;NOT Gate&#xff09;4.与非门&#xff08;NAND Gate&#xff09;5.或非门&#xff08;NOR Gate&#xff09;6.异或门&#xff08;XOR Gate&#x…