使用 PyTorch 构建 LSTM 股票价格预测模型

news2025/2/25 6:11:29

目录

      • 引言
      • 准备工作
      • 1. 训练模型(`train.py`)
      • 2. 模型定义(`model.py`)
      • 3. 测试模型和可视化(`test.py`)
      • 使用说明
      • 模型调整
      • 结论

引言

在金融领域,股票价格预测是一个重要且具有挑战性的任务。随着深度学习的发展,长短期记忆网络(LSTM)因其在处理时间序列数据方面的出色表现而受到关注。本篇博客将指导你如何使用PyTorch构建一个LSTM模型来预测股票价格,我们将逐步介绍数据预处理、模型训练和结果可视化的完整流程。

准备工作

  1. 安装依赖
    确保你已经安装了以下 Python 库:

    pip install pandas numpy torch matplotlib scikit-learn
    
  2. 下载数据
    使用 yfinance 库下载你感兴趣的股票的历史数据,并保存为 CSV 文件。我们这里使用 Apple(AAPL)过去五年的数据,文件命名为 AAPL_5y_data.csv。以下是一个下载数据的代码示例:

    import yfinance as yf
    
    # 下载Apple股票过去5年的数据
    data = yf.download('AAPL', start='2019-01-01', end='2024-01-01')
    data.to_csv('AAPL_5y_data.csv')
    

1. 训练模型(train.py

在这个脚本中,我们将读取 CSV 文件,归一化数据,并使用 LSTM 模型进行训练。

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import MinMaxScaler
from model import LSTM  # 导入LSTM类

# 设置随机种子
torch.manual_seed(42)

# 读取CSV文件
file_path = 'AAPL_5y_data.csv'  # 替换为你的CSV文件路径
data = pd.read_csv(file_path)

# 确保日期列是 datetime 类型
data['Date'] = pd.to_datetime(data['Date'])
data.set_index('Date', inplace=True)

# 选择多特征:'Close', 'Open', 'High', 'Low', 'Volume'
features = data[['Close', 'Open', 'High', 'Low', 'Volume']].values

# 数据归一化
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(features)

# 准备训练和测试数据
train_size = int(len(scaled_data) * 0.8)
train_data = scaled_data[:train_size]
test_data = scaled_data[train_size:]

def create_dataset(data, time_step=1):
    X, y = [], []
    for i in range(len(data) - time_step - 1):
        a = data[i:(i + time_step)]
        X.append(a)
        y.append(data[i + time_step, 0])  # 预测收盘价
    return np.array(X), np.array(y)

# 创建数据集
time_step = 50  # 时间步长
X_train, y_train = create_dataset(train_data, time_step)

# 转换为PyTorch张量
X_train = torch.from_numpy(X_train).float()
y_train = torch.from_numpy(y_train).float().view(-1, 1)

# 初始化模型、损失函数和优化器
model = LSTM()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# 训练模型
num_epochs = 300
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# 保存模型
torch.save(model.state_dict(), 'lstm_model.pth')
print("模型已保存为 'lstm_model.pth'")

2. 模型定义(model.py

在这个文件中定义 LSTM 模型结构。

import torch
import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=5, hidden_size=100, num_layers=2, batch_first=True)
        self.fc = nn.Linear(100, 1)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :])  # 取最后时间步的输出
        return out

3. 测试模型和可视化(test.py

在这个脚本中,我们将加载训练好的模型,并使用测试数据进行预测和可视化。

import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from model import LSTM  # 导入LSTM类

# 设置字体为SimHei,用于显示中文
plt.rcParams['font.family'] = 'SimHei'

# 读取CSV文件
file_path = 'AAPL_5y_data.csv'  # 替换为你的CSV文件路径
data = pd.read_csv(file_path)

# 确保日期列是 datetime 类型
data['Date'] = pd.to_datetime(data['Date'])
data.set_index('Date', inplace=True)

# 选择多特征:'Close', 'Open', 'High', 'Low', 'Volume'
features = data[['Close', 'Open', 'High', 'Low', 'Volume']].values

# 数据归一化
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(features)

# 准备训练和测试数据
train_size = int(len(scaled_data) * 0.8)
train_data = scaled_data[:train_size]
test_data = scaled_data[train_size:]

def create_dataset(data, time_step=1):
    X, y = [], []
    for i in range(len(data) - time_step - 1):
        a = data[i:(i + time_step)]
        X.append(a)
        y.append(data[i + time_step, 0])  # 预测收盘价
    return np.array(X), np.array(y)

# 创建测试数据集
time_step = 50  # 时间步长
X_test, y_test = create_dataset(test_data, time_step)

# 转换为PyTorch张量
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test).float().view(-1, 1)

# 加载模型
model = LSTM()
model.load_state_dict(torch.load('lstm_model.pth'))
model.eval()

# 测试模型
with torch.no_grad():
    test_outputs = model(X_test)
    # test_outputs 是预测的收盘价,将其重新归一化为原始价格
    test_outputs = scaler.inverse_transform(np.concatenate((test_outputs.numpy(), np.zeros((test_outputs.shape[0], 4))), axis=1))[:, 0]  # 反归一化收盘价
    y_test_inverse = scaler.inverse_transform(np.concatenate((y_test.numpy(), np.zeros((y_test.shape[0], 4))), axis=1))[:, 0]

# 可视化结果
plt.figure(figsize=(14, 7))
plt.plot(data.index[-len(y_test):], y_test_inverse, label='真实价格', color='blue')
plt.plot(data.index[-len(test_outputs):], test_outputs, label='预测价格', color='red')
plt.title('股票价格预测')
plt.xlabel('日期')
plt.ylabel('价格')
plt.legend()
plt.show()

使用说明

  1. 保存脚本

    • 将训练脚本代码保存为 train.py
    • 将模型定义代码保存为 model.py
    • 将测试脚本代码保存为 test.py
  2. 运行训练

    • 在命令行中运行训练脚本:
      python train.py
      
    • 训练完成后,模型将保存为 lstm_model.pth
  3. 运行测试和可视化

    • 在命令行中运行测试脚本:

      python test.py
      
    • 这将加载已训练的模型,并可视化预测结果。
      在这里插入图片描述
      这只是一个演示,模型的预测效果还有待进一步优化。

模型调整

如果预测的价格和真实价格差距较大,可能是由于以下几个原因:

  1. 数据规模不足

    • 如果训练数据不足,模型可能无法学到市场的长期趋势。
    • 改进:使用更多的历史数据,尽量包括多年的数据。可以尝试增加数据的时间跨度。
  2. 数据预处理问题

    • 数据没有正确归一化,或归一化范围过窄。
    • 改进:检查 MinMaxScaler 的应用。你可以尝试不同的归一化范围,例如 (0, 1)(-1, 1),也可以使用其他标准化方法(例如 StandardScaler)。
  3. 模型复杂度不足

    • 模型的层数或隐藏单元数量可能不足以捕捉数据的复杂性。
    • 改进:增加 LSTM 的隐藏层数量或隐藏单元数量。你还可以考虑添加其他类型的层,例如卷积层(CNN)或全连接层,以提高模型的表达能力。
  4. 超参数调整

    • 学习率、批大小和时间步长等超参数可能需要调整以优化模型性能。
    • 改进:尝试不同的学习率(例如,0.001、0.0001 等)、不同的批大小(如 16、32、64)和时间步长(如 30、60)。
  5. 更改损失函数

    • 在某些情况下,使用不同的损失函数可能有助于模型的收敛。
    • 改进:可以尝试使用其他损失函数,例如 Huber 损失函数(nn.SmoothL1Loss)或自定义损失函数,以更好地适应数据。

结论

通过使用 PyTorch 构建 LSTM 模型,我们成功地实现了股票价格的预测。在这个过程中,我们学习了如何处理时间序列数据,构建和训练深度学习模型,以及如何评估和可视化预测结果。尽管模型的性能可能需要进一步的优化和调整,但这个示例为未来的工作奠定了基础。

希望这篇博客能够帮助你在股票价格预测方面取得更好的成果。欢迎分享你的成果和经验,或者提出你的问题!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2213377.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【观察】超聚变:跨越智能算力“四座大山”,全方位重构“智算底座”

毫无疑问,今天在人工智能的推动下,企业数智化转型已进入规模化“倍增创新”的阶段,尤其是以大模型为代表的AI技术加速演进,以及应用场景的不断拓展加深,都让各类AI创新应用如雨后春笋般涌现,并加速惠及千行…

C++中cout的一些扩展

需要添加<iomanip>头文件 cout有许多扩展功能&#xff0c;比如一直很麻烦的保留小数数位的问题。 这里用几个问题来引入 cout实现保留小数数位 #include<iostream> #include<iomanip> using namespace std; int main(){double x123.345;double y342.324…

【未公开0day】金和OAC6 SignUpload SQL注入漏洞【附poc】

免责声明&#xff1a;本文仅用于技术学习和讨论。请勿使用本文所提供的内容及相关技术从事非法活动&#xff0c;若利用本文提供的内容或工具造成任何直接或间接的后果及损失&#xff0c;均由使用者本人负责&#xff0c;所产生的一切不良后果均与文章作者及本账号无关。 fofa语…

2024/10/14 英语每日一段

Advocates of the working pattern—100% of the work done in 80% of the time for 100% pay—claim the shorter working week boosts productivity, public health and builds a society “where we work to live, rather than live to work”. But the TaxPayers’ Allianc…

threejs-UV

一、简介 1.介绍 1.什么是UV映射&#xff1f; UV映射是一种将二维纹理映射到三维模型表面的技术。在这个过程中&#xff0c;3D模型上的每个顶点都会被赋予一个二维坐标&#xff08;U, V&#xff09;。U和V分别表示纹理坐标的水平和垂直方向。这些坐标用于将纹理图像上的像素与…

SQL优化,我就用了这几招

先赞后看&#xff0c;Java进阶一大半 阿里巴巴社区博客最近发表了一篇探究MySQL索引策略的博客&#xff0c;下图是一条查询SQL的执行过程。 我是南哥&#xff0c;相信对你通关面试、拿下Offer有所帮助。 敲黑板&#xff1a;本文总结了MySQL语句优化、索引优化常见的面试题&…

景区卫生间智能刷脸取纸机,灵活设置取纸长度、取纸间隔时间

在旅游景区&#xff0c;卫生间的服务质量直接影响着游客的体验。景区卫生间智能刷脸取纸机的出现&#xff0c;为解决游客用纸需求、提高资源利用效率以及提升景区管理水平带来了创新性的解决方案。 一、智能刷脸取纸机功能 1. 精准取纸&#xff1a;能够根据游客的实际需求&…

“链动2+1+消费增值:用户留存新策略“

大家好&#xff0c;我是吴军&#xff0c;目前在一家以创新为核心的软件开发公司担任产品经理。今天&#xff0c;我将深入探讨一个经受住了时间考验且依然充满活力的商业模式——“链动21”模式&#xff0c;并通过一个实例及相关数据展示它如何巧妙应对用户留存与复购的挑战。 首…

每日OJ题_牛客_HJ63DNA序列_滑动窗口_C++_Java

目录 牛客_HJ63DNA序列_滑动窗口 题目解析 C代码 Java代码 牛客_HJ63DNA序列_滑动窗口 孩子们的游戏(圆圈中最后剩下的数)_牛客题霸_牛客网 描述&#xff1a; 一个 DNA 序列由 A/C/G/T 四个字母的排列组合组成。 G 和 C 的比例&#xff08;定义为 GC-Ratio &#xff09;是…

[SQL] 数据库图形化安装和使用

一 安装 1.1 图形化安装 下载DataGrip安装包 点击此处一直下一步即可。点击免费使用。 进去界面后,选择新建一个项目 点击加号&#xff0c;创建一个Mysql连接。输入Mysql的连接信息。点击DownLoad下载Mysql的驱动 接下来点击创建的mysq项目中后面的三个点&#xff0c;选择…

Facebook的全球影响力:跨文化交流与信息共享的前沿

引言 在数字化时代&#xff0c;社交媒体已成为全球沟通的重要平台。自2004年成立以来&#xff0c;Facebook迅速发展成为拥有超过20亿活跃用户的巨头。其强大的影响力使其成为跨文化交流与信息共享的前沿平台。 跨文化交流的促进 Facebook的多语言支持让来自不同文化背景的用户…

如何在mkdocs-material文档主题下设置多版本文档系统?

引言 前一段时间&#xff0c;参与了PaddleOCR开源项目的文档站点搭建工作&#xff0c;基于mkdocs工具&#xff0c;采用mkdocs-material主题&#xff0c;基于Github Pages来搭建整个文档站点。目前该站点已经搭建完毕&#xff0c; 支持多语言、文档搜索等诸多功能。 最近得知&…

【C++ 算法进阶】算法提升二

算法提升二 最大分配工资问题 &#xff08;贪心&#xff09;题目分析代码详解 数组有序问题 &#xff08;贪心&#xff09;题目分析代码详解 消息流问题 &#xff08;数据结构设计&#xff09;题目分析代码详解 可乐问题 &#xff08;Coding能力&#xff09;题目分析代码详解 司…

YOLOv9下载安装运行

1、进入GitHub的YOLOv9官网 https://github.com/WongKinYiu/yolov92、clone或下载项目 https://github.com/WongKinYiu/yolov9.githttps://codeload.github.com/WongKinYiu/yolov9/zip/refs/heads/main2.1、进入控制台下载项目 git clone https://github.com/WongKinYiu/yol…

在线培训知识库+帮助中心:教育行业智慧学习的创新桥梁

在数字化转型的浪潮中&#xff0c;教育行业正经历着前所未有的变革。为了应对日益增长的学习需求&#xff0c;提升教育质量&#xff0c;构建一个集在线培训知识库与帮助中心于一体的智慧学习环境&#xff0c;已成为教育行业转型升级的重要方向。这一创新模式不仅优化了学习资源…

无人机飞手执照培训费用较高原因分析

无人机飞手执照培训费用较高的原因可以归结为多个方面&#xff0c;以下是对这些原因的具体分析&#xff1a; 一、课程内容的全面性和专业性 无人机飞手执照培训涵盖了从无人机基础知识到高级飞行技巧、从组装调试到故障维修的多个方面。这种全面性和专业性要求培训机构提供高…

猎板PCB测试大讲堂:让你测试的明明白白

在电子研发领域&#xff0c;PCB&#xff08;印刷电路板&#xff09;的检测是确保产品质量的关键环节。主要的检测方式包括飞针测试和测试架测试。以下是这两种技术的详细介绍&#xff0c;旨在为电子研发工程师提供技术资料。 PCB飞针测试&#xff08;Flying Probe Test&#x…

麦克风哪个品牌音质最好,无线领夹麦克风十大品牌推荐

随着科技的进步&#xff0c;无线领夹麦克风的技术也在不断革新。从传统的模拟信号传输到如今的数字信号传输&#xff0c;再到智能降噪、自适应增益控制等先进技术的应用&#xff0c;无线领夹麦克风的录音品质得到了显著提升。然而&#xff0c;市场上仍有一些产品采用过时的技术…

vue2使用pdfjs-dist实现pdf预览(iframe形式,不修改pdfjs原来的ui和控件)

前情提要 在一开始要使用pdf预览的时候&#xff0c;第一次选的是vue-pdf&#xff0c;但是vue-pdf支持的功能太少&#xff0c;缺少了项目中需要的一项-复制粘贴功能 之后我一顿搜搜搜&#xff0c;最终貌似只有pdfjs能用 但是网上支持text-layer的貌似都是用的2.09那个版本。 使…

嵌入式AI博客目录

文章目录 环境搭建ubuntu下载安装c版opencv4.7.0和4.5.0 & 安装opencv4.5.0报错及解决方法ubuntu系统vscode配置c版opencv & 编译运行c播放视频代码&#xff08;包含&#xff1a;vscode使用copencv&#xff0c;创建CmakeList.txt&#xff0c;创建编译项目&#xff09;u…