基于LSTM的A股股票价格预测系统(torch) :从数据获取到模型训练的完整实现

news2025/1/16 16:53:48

在这里插入图片描述

1. 项目简介

本文介绍了一个使用LSTM(长短期记忆网络)进行股票价格预测的完整系统。该系统使用Python实现,集成了数据获取、预处理、模型训练和预测等功能。
这个代码使用的是 LSTM (Long Short-Term Memory) 模型,这是一种特殊的循环神经网络 (RNN)

2. 技术栈

  • Python 3.x
  • PyTorch (深度学习框架)
  • AKShare (股票数据获取)
  • Pandas (数据处理)
  • NumPy (数值计算)
  • Scikit-learn (数据预处理)

3. 系统架构

3.1 数据获取模块

def get_stock_data(stock_code, start_date, end_date, stock_name):
    """获取股票历史数据"""
    print(f"正在获取 {stock_name}{stock_code})的数据...")
    try:
        df = ak.stock_zh_a_hist(symbol=stock_code, 
                               period="daily", 
                               start_date=start_date, 
                               end_date=end_date, 
                               adjust="qfq")  # 使用前复权数据
        # ... 数据处理代码
        return df
    except Exception as e:
        print(f"获取{stock_name}数据时发生错误:{str(e)}")
        return None

3.2 LSTM模型定义

class StockRNN(nn.Module):
    """股票预测的LSTM模型"""
    def __init__(self, input_size, hidden_size, num_layers):
        super(StockRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

3.3 数据预处理

def prepare_data(df, sequence_length):
    """准备训练数据"""
    scaler = MinMaxScaler()
    scaled_data = scaler.fit_transform(df[['close']].values)
    
    X, y = [], []
    for i in range(len(scaled_data) - sequence_length):
        X.append(scaled_data[i:(i + sequence_length)])
        y.append(scaled_data[i + sequence_length])
    
    return np.array(X), np.array(y), scaler

4. 主要功能

  1. 股票数据获取和分析
  2. 市场状态评估
  3. 数据预处理和归一化
  4. LSTM模型训练
  5. 股价预测

5. 使用方法

  1. 运行程序
  2. 输入股票代码(或使用默认值)
  3. 设置日期范围
  4. 等待模型训练
  5. 获取预测结果
# 示例使用
stock_code = "002830"  # 股票代码
start_date = "20230101"  # 起始日期
end_date = "20240120"   # 结束日期
predict_date = "20241209"  # 预测日期

6. 模型参数

  • 序列长度:10天
  • LSTM隐藏层大小:64
  • LSTM层数:2
  • 训练轮数:100
  • 学习率:0.001

7. 风险提示

  1. 预测结果仅供参考,不构成投资建议
  2. 长期预测的准确性会显著降低
  3. 股市受多种因素影响,模型无法预测突发事件

8. 可能的改进方向

  1. 增加更多特征(如交易量、技术指标等)
  2. 优化模型架构
  3. 添加更多市场分析指标
  4. 实现实时数据更新
  5. 添加可视化功能

9. 总结

本项目展示了如何使用深度学习技术进行股票价格预测。通过整合数据获取、预处理和模型训练等功能,为股票分析提供了一个完整的解决方案。虽然预测结果仅供参考,但项目的实现过程对理解金融数据分析和深度学习应用具有重要的学习价值。

10. 环境配置与安装

10.1 Python环境要求

  • Python 3.8+

10.2 依赖包安装

# 创建虚拟环境(推荐)
python -m venv myvenv
source myvenv/bin/activate  # Linux/Mac
# 或
myvenv\Scripts\activate  # Windows

# 安装依赖包
pip install akshare
pip install torch
pip install pandas
pip install numpy
pip install scikit-learn

11. 完整代码实现

11.1 股票预测主程序 (stock_prediction_akshare.py)

import akshare as ak  # 导入akshare库,用于获取股票数据
import pandas as pd   # 导入pandas库,用于数据处理
import numpy as np    # 导入numpy库,用于数值计算
import torch         # 导入PyTorch库,用于深度学习
import torch.nn as nn  # 导入神经网络模块
import torch.optim as optim  # 导入优化器模块
from sklearn.preprocessing import MinMaxScaler  # 导入数据归一化工具
import random  # 导入随机数模块

[此处是完整的 stock_prediction_akshare.py 代码,与之前提供的相同]

11.2 下跌五日监控程序 (downfive5.py)

# 如果有 downfive5.py 的代码,请提供给我,我会添加到这里

11.3 Web应用接口 (app.py)

# 如果有 app.py 的代码,请提供给我,我会添加到这里

12. 运行说明

  1. 克隆或下载代码到本地
git clone [repository_url]
cd stock-prediction-system
  1. 安装依赖
pip install -r requirements.txt
  1. 运行股票预测程序
python stock_prediction_akshare.py
  1. 按提示输入:
    • 股票代码(例如:002830)
    • 起始日期(格式:YYYYMMDD)
    • 结束日期(格式:YYYYMMDD)
    • 预测日期(格式:YYYYMMDD)

13. 常见问题解答

  1. 数据获取失败

    • 检查网络连接
    • 确认股票代码是否正确
    • 验证日期格式
  2. 模型训练时间过长

    • 可以减少训练轮数(epochs)
    • 缩短历史数据范围
    • 使用GPU加速(如果可用)
  3. 预测结果异常

    • 检查数据预处理步骤
    • 调整模型参数
    • 验证输入数据的质量

14. 维护与更新

本项目仍在持续改进中,计划添加的功能包括:

  1. 多股票同时预测
  2. 更多技术指标支持
  3. 预测结果可视化
  4. 实时数据更新
  5. Web界面支持

15. 贡献指南

感谢以下开源项目:

  • AKShare
  • PyTorch
  • Pandas
  • NumPy
  • Scikit-learn

注意:

  1. 本项目仅供学习和研究使用
  2. 股市投资有风险,预测结果仅供参考
  3. 实际投资决策需要考虑多种因素

15. 代码:

import akshare as ak  # 导入akshare库,用于获取股票数据
import pandas as pd   # 导入pandas库,用于数据处理
import numpy as np    # 导入numpy库,用于数值计算
import torch         # 导入PyTorch库,用于深度学习
import torch.nn as nn  # 导入神经网络模块
import torch.optim as optim  # 导入优化器模块
from sklearn.preprocessing import MinMaxScaler  # 导入数据归一化工具
import random  # 导入随机数模块

def set_random_seed(seed=42):
    """设置随机种子,确保实验结果可重复"""
    random.seed(seed)  # 设置Python随机数种子
    np.random.seed(seed)  # 设置NumPy随机数种子
    torch.manual_seed(seed)  # 设置PyTorch CPU随机数种子
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)  # 设置PyTorch GPU随机数种子

class StockRNN(nn.Module):
    """定义股票预测的RNN模型类"""
    def __init__(self, input_size, hidden_size, num_layers):
        """
        初始化模型参数
        input_size: 输入特征维度
        hidden_size: LSTM隐藏层大小
        num_layers: LSTM层数
        """
        super(StockRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # 定义LSTM层
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        # 定义全连接层,用于输出预测值
        self.fc = nn.Linear(hidden_size, 1)
    
    def forward(self, x):
        """
        定义前向传播过程
        x: 输入数据,形状为(batch_size, sequence_length, input_size)
        """
        # 初始化隐藏状态和细胞状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        
        # LSTM前向传播
        out, _ = self.lstm(x, (h0, c0))
        # 取最后一个时间步的输出进行预测
        out = self.fc(out[:, -1, :])
        return out

def get_stock_name(stock_code):
    """
    获取股票名称
    stock_code: 股票代码
    返回: 股票名称
    """
    try:
        # 获取A股实时行情数据
        df = ak.stock_zh_a_spot_em()
        # 查找对应股票代码的信息
        stock_info = df[df['代码'] == stock_code]
        if not stock_info.empty:
            return stock_info.iloc[0]['名称']
        
        # 如果没找到,尝试添加市场前缀
        if stock_code.startswith('6'):
            full_code = f"sh{stock_code}"
        else:
            full_code = f"sz{stock_code}"
        
        stock_info = df[df['代码'] == stock_code]
        if not stock_info.empty:
            return stock_info.iloc[0]['名称']
        
        return f"股票{stock_code}"
            
    except Exception as e:
        print(f"获取股票名称时发生错误:{e}")
        return f"股票{stock_code}"

def analyze_market_state(df, stock_name):
    """
    分析股票市场状态
    df: 股票数据DataFrame
    stock_name: 股票名称
    """
    # 计算日收益率
    daily_returns = df['close'].pct_change()
    
    # 计算年化波动率
    volatility = daily_returns.std() * np.sqrt(252)
    
    # 计算整体涨跌幅
    trend = (df['close'].iloc[-1] - df['close'].iloc[0]) / df['close'].iloc[0]
    
    # 计算最大回撤
    cummax = df['close'].cummax()  # 计算历史最高价
    drawdown = (cummax - df['close']) / cummax  # 计算回撤比例
    max_drawdown = drawdown.max()  # 计算最大回撤
    
    # 打印分析结果
    print(f"\n{stock_name}市场状态分析:")
    print(f"样本数量: {len(df)} 天")
    print(f"年化波动率: {volatility:.2%}")
    print(f"整体趋势: {trend:.2%}")
    print(f"最大回撤: {max_drawdown:.2%}")
    print(f"起始价格: {df['close'].iloc[0]:.2f}")
    print(f"结束价格: {df['close'].iloc[-1]:.2f}")

def get_stock_data(stock_code, start_date, end_date, stock_name):
    """
    获取股票历史数据
    stock_code: 股票代码
    start_date: 起始日期
    end_date: 结束日期
    stock_name: 股票名称
    """
    print(f"正在获取 {stock_name}{stock_code})的数据...")
    try:
        # 使用akshare获取股票历史数据
        df = ak.stock_zh_a_hist(symbol=stock_code, 
                               period="daily", 
                               start_date=start_date, 
                               end_date=end_date, 
                               adjust="qfq")  # 使用前复权数据
        
        # 重命名列
        df = df.rename(columns={'收盘': 'close', '日期': 'date'})
        df['date'] = pd.to_datetime(df['date'])  # 转换日期格式
        df.set_index('date', inplace=True)  # 设置日期为索引
        
        # 分析市场状态
        analyze_market_state(df, stock_name)
        
        return df
        
    except Exception as e:
        print(f"获取{stock_name}数据时发生错误:{str(e)}")
        return None

def prepare_data(df, sequence_length):
    """
    准备模型训练数据
    df: 股票数据DataFrame
    sequence_length: 序列长度(用多少天数据预测下一天)
    """
    # 数据归一化
    scaler = MinMaxScaler()
    scaled_data = scaler.fit_transform(df[['close']].values)
    
    # 创建序列数据
    X, y = [], []
    for i in range(len(scaled_data) - sequence_length):
        X.append(scaled_data[i:(i + sequence_length)])  # 输入序列
        y.append(scaled_data[i + sequence_length])      # 预测目标
    
    return np.array(X), np.array(y), scaler

def predict_next_day(model, last_sequence, scaler):
    """
    预测下一个交易日的价格
    model: 训练好的模型
    last_sequence: 最后一个序列数据
    scaler: 归一化器
    """
    with torch.no_grad():  # 不计算梯度
        # 准备输入数据
        last_sequence_tensor = torch.FloatTensor(last_sequence).unsqueeze(0)
        # 进行预测
        predicted_scaled = model(last_sequence_tensor)
        # 将预测结果转换回原始价格范围
        predicted_price = scaler.inverse_transform(predicted_scaled.numpy())
    return predicted_price[0][0]

def get_input_with_default(prompt, default_value):
    """
    获取用户输入,支持默认值
    prompt: 提示信息
    default_value: 默认值
    """
    user_input = input(f"{prompt} [默认: {default_value}]: ").strip()
    return user_input if user_input else default_value

def main():
    """主函数"""
    # 设置随机种子
    set_random_seed(42)
    
    try:
        # 设置默认参数
        default_stock_code = "002830"
        default_start_date = "20230101"
        default_end_date = "20240120"
        default_predict_date = "20241209"
        
        # 获取用户输入
        stock_code = get_input_with_default(
            "请输入股票代码", 
            default_stock_code
        )
        
        start_date = get_input_with_default(
            "请输入起始日期(格式:YYYYMMDD)", 
            default_start_date
        )
        
        end_date = get_input_with_default(
            "请输入结束日期(格式:YYYYMMDD)", 
            default_end_date
        )
        
        predict_date = get_input_with_default(
            "请输入预测日期(格式:YYYYMMDD)", 
            default_predict_date
        )
        
        # 获取股票名称
        stock_name = get_stock_name(stock_code)
        
        # 设置模型参数
        sequence_length = 10  # 使用10天数据预测下一天
        hidden_size = 64     # LSTM隐藏层大小
        num_layers = 2       # LSTM层数
        epochs = 100         # 训练轮数
        learning_rate = 0.001  # 学习率
        
        # 获取并分析数据
        df = get_stock_data(stock_code, start_date, end_date, stock_name)
        if df is None:
            return
            
        # 准备训练数据
        X, y, scaler = prepare_data(df, sequence_length)
        X_tensor = torch.FloatTensor(X)
        y_tensor = torch.FloatTensor(y)
        
        # 初始化模型
        model = StockRNN(input_size=1, hidden_size=hidden_size, num_layers=num_layers)
        criterion = nn.MSELoss()  # 使用均方误差损失
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # 使用Adam优化器
        
        # 训练模型
        for epoch in range(epochs):
            optimizer.zero_grad()  # 清空梯度
            outputs = model(X_tensor)  # 前向传播
            loss = criterion(outputs, y_tensor)  # 计算损失
            loss.backward()  # 反向传播
            optimizer.step()  # 更新参数
            
            # 每10轮打印一次损失
            if (epoch + 1) % 10 == 0:
                print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
        
        # 预测未来价格
        last_sequence = X[-1]
        predicted_price = predict_next_day(model, last_sequence, scaler)
        print(f"\n预测 {stock_name}{stock_code})在 {predict_date} 的收盘价: {predicted_price:.2f}")
        
        # 打印风险提示
        print("\n风险提示:")
        print(f"1. {stock_name}的预测是基于历史数据的模型预测,不构成投资建议")
        print("2. 长期预测(超过一周)的准确性会显著降低")
        print("3. 股市受多种因素影响,模型无法预测突发事件的影响")
        
    except Exception as e:
        print(f"程序运行出错:{str(e)}")

if __name__ == "__main__":
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2256552.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【ArcGIS微课1000例】0134:ArcGIS Earth实现二维建筑物的三维完美显示

文章目录 一、加载数据二、三维显示三、三维符号化一、加载数据 加载配套实验数据(0134.rar中的建筑物,2d或3d都可以),方法如下:点击添加按钮。 点击【Add Files】,在弹出的Open对话框中,选择建筑物,点击确定,完成添加。 默认二维显示: 二、三维显示 右键建筑物图层…

汽车EEA架构:发展历程

1.发展历程的基本逻辑 汽车电子电气的发展历程中,其使用的基本逻辑是IPO(Input-Processing-Output)模型,如下图1所示: 图 1 那什么是IPO模型了?我们从控制器的原理入手解释IPO模型,控制器的主要用途如下: 根据给定的逻…

python拆分Excel文件

按Sheet拆分Excel 或 按照某一列的不同值拆分Excel。文档样式如下: 结果:红色是按照Sheet名拆出的,蓝色和橙色是某个Sheet按照某列的不同值拆分的。 代码: # -*- coding: utf-8 -*- """ 拆分excel文件——按照…

存内架构IR-DROP问题详解-电容电导补偿

一、总述 电容、电导补偿作为大规模数字电路的关键设计理念,是 CIM 架构优化的核心技术。在 CIM 中,平衡电容或电导并实现计算的精准映射,对能效提升和计算精度保障具有关键作用。本文基于近期文献探讨电容、电导补偿在 CIM 中的具体补偿策…

汽车网络安全 -- IDPS如何帮助OEM保证车辆全生命周期的信息安全

目录 1.强标的另一层解读 2.什么是IDPS 2.1 IDPS技术要点 2.2 车辆IDPS系统示例 3.车辆纵深防御架构 4.小结 1.强标的另一层解读 在最近发布的国家汽车安全强标《GB 44495》,在7.2节明确提出了12条关于通信安全的要求,分别涉及到车辆与车辆制造商云平台通信、车辆与车辆…

【数字化】华为企业数字化转型-认知篇

导读:企业数字化转型的必要性在于,它能够帮助企业适应数字化时代的需求,提升运营效率,创新业务模式,增强客户互动,从而在激烈的市场竞争中保持领先地位并实现可持续发展。通过学习华为企业数字化转型相关理…

用C#开发程序进行ASCII艺术制作

我一直很喜欢 ASCII 艺术,而我对制作 ASCII 艺术的热情促使我探索 .NET 框架中的 GDI。在本文中, 我将向您展示如何通过三个简单的步骤从 JPEG/Bitmap 图像生成 ASCII 艺术。 1、加载并调整图像大小。 2、读取每个像素,获取其颜色并将其转换…

第23周:机器学习及文献阅读

目录 摘要 Abstract 一、理论知识 1、逻辑提升 2、分类任务 3、10倍交叉验证法 二、文献阅读 1、模型方法——MLT (1)特征选择 (2)决策树剪枝 2、分类任务——逻辑回归 3、实验部分 数据集的选取 代码实践 模型…

2020年国赛高教杯数学建模E题校园供水系统智能管理解题全过程文档及程序

2020年国赛高教杯数学建模 E题 校园供水系统智能管理 原题再现 校园供水系统是校园公用设施的重要组成部分,学校为了保障校园供水系统的正常运行需要投入大量的人力、物力和财力。随着科学技术的发展,校园内已经普遍使用了智能水表,从而可以…

React开发高级篇 - React Hooks以及自定义Hooks实现思路

Hooks介绍 Hooks是react16.8以后新增的钩子API; 目的:增加代码的可复用性,逻辑性,弥补无状态组件没有生命周期,没有数据管理状态state的缺陷。 为什么要使用Hooks? 开发友好,可扩展性强&#…

摩尔线程 国产显卡 MUSA 并行编程 学习笔记-2024/12/03

Learning Roadmap: Section 1: Intro to Parallel Programming & MUSA Deep Learning Ecosystem(摩尔线程 国产显卡 MUSA 并行编程 学习笔记-2024/11/30-CSDN博客)UbuntuDriverToolkitcondapytorchtorch_musa环境安装(2024/11/24-Ubunt…

如何使用Docker轻松搭建高颜值无广告音乐播放器SPlayer随时随地听歌

前言 在快节奏的生活环境中,音乐成为了许多人放松和享受的重要方式。本文将介绍如何在Linux Ubuntu系统中使用Docker快速部署一款高颜值无广告的某抑云音乐播放器——SPlayer,并结合Cpolar内网穿透工具实现出门在外也能远程访问本地服务,随时…

C# Decimal

文章目录 前言1. Decimal 的基本特性2. 基本用法示例3. 特殊值与转换4. 数学运算示例5. 精度处理示例6. 比较操作示例7. 货币计算示例8. Decimal 的保留小数位数9. 处理 Decimal 的溢出和下溢10. 避免浮点数计算误差总结 前言 decimal 是 C# 中一种用于表示高精度十进制数的关键…

【理论·专业课】第三次作业

第1题(存储管理_内存碎片) 请指出内部碎片与外部碎片的区别。 ANS: 内部碎片是分配给进程但未被进程使用且无法被其他进程利用的内存空间 外部碎片是内存中因进程分配释放内存形成的不连续小块,虽总和够但因不连续无…

最新的springboot 3.x的支持s3协议的2.x方法的minio上传文件方法

拉取镜像 docker pull registry.cn-hangzhou.aliyuncs.com/qiluo-images/minio:latest运行命令 docker run -d \--name minio \-p 10087:9000 \-p 10088:9001 \-e MINIO_ROOT_USERminioadmin \-e MINIO_ROOT_PASSWORDY6HYraaphfZ9k8Lv \-v /data/minio/data:/data \-v /data/…

cocos creator接入字节跳动抖音小游戏JSAPI敏感词检测(进行文字输入,但输入敏感词后没有替换为*号)

今天更新了某个抖音小游戏的版本,增加了部分剧情,半天过后一条短信审核未通过,emmm…抖音总是能给开发者惊喜…打开电脑看看这次又整什么幺蛾子… 首先是一脸懵逼,后端早已接入了官方的内容安全检测能力了(https://de…

Origin快速拟合荧光寿命、PL Decay (TRPL)数据分析处理-方法二

1.先导入数据到origin 2.导入文件的时候注意:名字短的这个是,或者你打开后看哪个里面有800,因为我的激光重频是1.25Hz(应该是,不太确定单位是KHz还是MHz),所以对应的时间是800s。 3.选中两列直接…

17. 面向对象的特征

一、面向对象的三大特征 面向对象的三大特征指的是 封装、继承、多态。 封装(encapsulation,有时称为数据隐藏)是处理对象的一个重要概念。从形式上看,封装就是将数据和行为组合在一个包中,并对对象的使用者隐藏具体的…

Apache Dolphinscheduler可视化 DAG 工作流任务调度系统

Apache Dolphinscheduler 关于 一个分布式易扩展的可视化 DAG 工作流任务调度系统。致力于解决数据处理流程中错综复杂的依赖关系,使调度系统在数据处理流程中开箱即用。 DolphinScheduler 的主要特性如下: 易于部署,提供四种部署方式&am…

第二部分:基础知识 6.函数 --[JavaScript 新手村:开启编程之旅的第一步]

JavaScript 函数是可重用的代码块,用于执行特定任务。函数可以接受参数(输入数据),并且可以返回一个值。JavaScript 提供了多种定义函数的方式,下面将详细介绍这些方式,并给出一些示例。 1. 函数声明 下面…