pytorch-20_1 LSTM在股价数据集上的预测实战

news2025/1/27 13:00:00

LSTM在股价数据集上的预测实战

  • 使用完整的JPX赛题数据,并向大家提供完整的lstm流程。

导包

import numpy as np #数据处理
import pandas as pd #数据处理
import matplotlib as mlp
import matplotlib.pyplot as plt #绘图
from sklearn.preprocessing import MinMaxScaler #·数据预处理
from sklearn.metrics import mean_squared_error
import torch 
import torch.nn as nn #导入pytorch中的基本类
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
import torch.utils.data as data
# typing 模块提供了一些类型,辅助函数中的参数类型定义
from typing import Union,List,Tuple,Iterable
from sklearn.preprocessing import LabelEncoder,MinMaxScaler
from decimal import ROUND_HALF_UP, Decimal

一、数据加载与处理

# 一、数据加载与处理
# 1、查看数据集信息
stock= pd.read_csv('stock_prices.csv')          # (2332531,12) 
stock_list = pd.read_csv('stock_list.csv')      # (4417,16)

stock["SecuritiesCode"].unique().__len__()      #2000支股票

# 2、为了效率我们抽取其中的10支股票
selected_codes = stock['SecuritiesCode'].drop_duplicates().sample(n=10)
stock = stock[stock['SecuritiesCode'].isin(selected_codes)]     # (9833,12)
stock["SecuritiesCode"].unique().__len__()      #只有10支股票了

stock.isnull().sum() #查看缺失值

# 3、预处理数据集
#将Target名字修改为Sharpe Ratio
stock.rename(columns={'Target': 'Sharpe Ratio'}, inplace=True)

#将Close列添加到最后
close_col = stock.pop('Close')
stock.loc[:,'Close'] = close_col

#填补Dividend缺失值、删除具有缺失值的行
stock["ExpectedDividend"] = stock["ExpectedDividend"].fillna(0)
stock.dropna(inplace=True)

#恢复索引
stock.index = range(stock.shape[0])

二、数据分割与数据重组

# 二、数据分割与数据重组
# 1、数据分割
train_size = int(len(stock) * 0.67)
test_size = len(stock) - train_size
train, test = stock[:train_size], stock[train_size:] # train (6580,12) test(3242,12)

# 2、带标签滑窗
def create_multivariate_dataset_2(dataset, window_size, pred_len):  # 
    """
    将多变量时间序列转变为能够用于训练和预测的数据【带标签的滑窗】
    
    参数:
        dataset: DataFrame,其中包含特征和标签,特征从索引3开始,最后一列是标签
        window_size: 滑窗的窗口大小
        pred_len:多步预测的预测范围/预测步长
    """
    X, y, y_indices = [], [], []
    for i in range(len(dataset) - window_size - pred_len + 1):                      # (len-ws-pl+1) --> (6580-30-5+1) = 6546
        # 选取从第4列到最后一列的特征和标签
        feature_and_label = dataset.iloc[i:i + window_size, 3:].values              # (ws,fs_la) --> (30,9)
        # 下一个时间点的标签作为目标
        target = dataset.iloc[(i + window_size):(i + window_size + pred_len), -1]   # pred_len --> 5
        # 记录本窗口中要预测的标签的时间点
        target_indices = list(range(i + window_size, i + window_size + pred_len))   # pl*(len-ws-pl+1) --> 5*6546 = 32730 

        X.append(feature_and_label)
        y.append(target)
        #将每个标签的索引添加到y_indices列表中
        y_indices.extend(target_indices)
    
    X = torch.FloatTensor(np.array(X, dtype=np.float32))
    y = torch.FloatTensor(np.array(y, dtype=np.float32))
    
    return X, y, y_indices

# 3、数据重组
window_size = 30        #窗口大小
pred_len = 5            #多步预测的步数

X_train_2, y_train_2, y_train_indices = create_multivariate_dataset_2(train, window_size, pred_len)     # x(6546,30,9) y(6546,5) (32730,)
X_test_2, y_test_2, y_test_indices = create_multivariate_dataset_2(test, window_size, pred_len)         # x(3208,30,9) y(3208,5) (16040,)

三、网络架构与参数设置

# 三、网络架构与参数设置
# 1、定义架构
class MyLSTM(nn.Module):
    def __init__(self,input_dim, seq_length, output_size, hidden_size, num_layers):
        super().__init__()
        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x, _ = self.lstm(x)
        #现在我要的是最后一个时间步,而不是全部时间步了
        x = self.linear(x[:,-1,:])
        return x

# 2、参数设置
input_size = 9          #输入特征的维度
hidden_size = 20        #LSTM隐藏状态的维度
n_epochs = 2000         #迭代epoch
learning_rate = 0.001   #学习率
num_layers = 1          #隐藏层的层数
output_size = 5

#设置GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# 加载数据,将数据分批次 
loader = data.DataLoader(data.TensorDataset(X_train_2, y_train_2), shuffle=True, batch_size=8) 

# 3、实例化模型
model = MyLSTM(input_size, window_size, pred_len,hidden_size, num_layers).to(device)
optimizer = optim.Adam(model.parameters(),lr=learning_rate) #定义优化器
loss_fn = nn.MSELoss() #定义损失函数
loader = data.DataLoader(data.TensorDataset(X_train_2, y_train_2)
                         #每个表单内部是保持时间顺序的即可,表单与表单之间可以shuffle
                         , shuffle=True
                         , batch_size=8) #将数据分批次

四、实际训练流程

# 四、实际训练流程
# 初始化早停参数
early_stopping_patience = 3  # 设置容忍的epoch数,即在这么多epoch后如果没有改进就停止
early_stopping_counter = 0  # 用于跟踪没有改进的epoch数
best_train_rmse = float('inf')  # 初始化最佳的训练RMSE

train_losses = []
test_losses = []

for epoch in range(n_epochs):
    model.train()
    for X_batch, y_batch in loader:
        y_pred = model(X_batch.to(device))
        loss = loss_fn(y_pred, y_batch.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    #验证与打印
    if epoch % 10 == 0:
        model.eval()
        with torch.no_grad():
            y_pred = model(X_train_2.to(device)).cpu()
            train_rmse = np.sqrt(loss_fn(y_pred, y_train_2))
            y_pred = model(X_test_2.to(device)).cpu()
            test_rmse = np.sqrt(loss_fn(y_pred, y_test_2))
        print("Epoch %d: train RMSE %.4f, test RMSE %.4f" % (epoch, train_rmse, test_rmse))
        
        # 将当前epoch的损失添加到列表中
        train_losses.append(train_rmse)
        test_losses.append(test_rmse)
    
        # 早停检查
        if  train_rmse < best_train_rmse:
            best_train_rmse = train_rmse
            early_stopping_counter = 0  # 重置计数器
        else:
            early_stopping_counter += 1  # 增加计数器
            if early_stopping_counter >= early_stopping_patience:
                print(f"Early stopping triggered after epoch {epoch}. Training RMSE did not decrease for {early_stopping_patience} consecutive epochs.")
                break  # 跳出训练循环

结果显示:

Epoch 0: train RMSE 1470.9308, test RMSE 1692.0652
Epoch 5: train RMSE 1415.7896, test RMSE 1639.1147
Epoch 10: train RMSE 1364.8196, test RMSE 1590.2207
......
Epoch 100: train RMSE 654.3458, test RMSE 904.7958
Epoch 105: train RMSE 638.2536, test RMSE 886.3511
Epoch 110: train RMSE 625.7336, test RMSE 870.9800
......
Epoch 200: train RMSE 598.3364, test RMSE 820.4078
Epoch 205: train RMSE 598.3354, test RMSE 820.3406
Epoch 210: train RMSE 598.3349, test RMSE 820.2874
......
Epoch 260: train RMSE 598.3341, test RMSE 820.1312
Epoch 265: train RMSE 598.3341, test RMSE 820.1294
Early stopping triggered after epoch 265. Training RMSE did not decrease for 3 consecutive epochs.

五、可视化结果

# 五、可视化结果
# 1、损失曲线
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train RMSE')
plt.plot(test_losses, label='Test RMSE')
plt.xlabel('Epochs')
plt.ylabel('RMSE')
plt.title('Train and Test RMSE Over Epochs')
plt.legend()
plt.show()

在这里插入图片描述
结果分析:预测效果不是很好,考虑进行数据预处理和特征工程

【扩展】股票数据的数据预处理与特征工程(后续更新~)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1688908.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Unreal Engine5 Landscape地形材质无法显示加载

UE5系列文章目录 文章目录 UE5系列文章目录前言一、解决办法 前言 在使用ue5做地形编辑的时候&#xff0c;明明刚才就保存的Landscape地形完全消失不见&#xff0c;或者是地形的材质不见了。重新打开UE5发现有时候能解决&#xff0c;但大多数时候还是没有解决&#xff0c;我下…

有效的变位词

如果哈希表的键的取值范围是固定的&#xff0c;并且范围不是很大&#xff0c;则可以用数组来模拟哈希表。数组的下标和哈希表的键相对应&#xff0c;而数组的值和哈希表的值相对应。 英文小写字母只有26个&#xff0c;因此可以用一个数组来模拟哈希表。 class Solution {publi…

中国主要城市房价指数数据集(2011-2024)

数据来源&#xff1a;东方财富网 时间跨度&#xff1a;2011年1月 - 2024年4月 数据范围&#xff1a;中国主要城市 包含指标&#xff1a; 日期、城市 新建商品住宅价格指数-同比 新建商品住宅价格指数-环比 新建商品住宅价格指数-定基 二手住宅价格指数-环比 二手住宅价格指…

CS西电高悦计网课设——校园网设计

校园网设计 一&#xff0c;需求分析 所有主机可以访问外网 主机可以通过域名访问Web服务器 为网络配置静态或者动态路由 图书馆主机通过DHCP自动获取IP参数 为办公楼划分VLAN 为所有设备分配合适的IP地址和子网掩码&#xff0c;IP地址的第二个字节使用学号的后两位。 二…

学习平台|基于Springboot+vue的学习平台系统的设计与实现(源码+数据库+文档)

学习平台系统 目录 基于Springboot&#xff0b;vue的学习平台系统的设计与实现 一、前言 二、系统设计 三、系统功能设计 1系统功能模块 2管理员功能模块 3学生功能模块 4教师功能模块 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八…

基于STM32实现智能风扇控制系统

目录 文章主题环境准备智能风扇控制系统基础代码示例&#xff1a;实现智能风扇控制系统 PWM控制风扇速度温度传感器数据读取串口通信控制应用场景&#xff1a;智能家居与环境调节问题解决方案与优化收尾与总结 1. 文章主题与命名 文章主题 本教程将详细介绍如何在STM32嵌入式…

layui扩展件(xm-select)实现下拉框

layui扩展件&#xff08;xm-select&#xff09;实现下拉框 扩展组件 xm-select 效果图 html代码 <div class"layui-inline"><label class"layui-form-label">职位</label><div class"layui-input-inline" style"wid…

你以为的私域是真正的私域嘛??你的私域流量真的属于你嘛?

大家好 我是一个软件开发公司的产品经理 专注私域电商行业7年有余 您的私域流量是真正的属于你自己嘛&#xff1f; 私域的定义 私域的界定&#xff1a;一个互联网私有数据&#xff08;资产&#xff09;积蓄的载体。这个载体的数据权益私有&#xff0c;且具备用户规则制定权…

法那科机器人M-900iA维修主要思路

发那科工业机器人是当今制造业中常用的自动化设备之一&#xff0c;而示教器是发那科机器人操作和维护的重要组成部分。 一、FANUC机械手示教器故障分类 1. 硬件故障 硬件故障通常是指发那科机器人M-900iA示教器本身的硬件问题&#xff0c;如屏幕损坏、按键失灵、电源故障等。 2…

脆皮之“字符函数与字符串函数”宝典

hello&#xff0c;大家好呀&#xff0c;感觉我之前有偷偷摸鱼了&#xff0c;今天又开始学习啦。加油&#xff01;&#xff01;&#xff01; 文章目录 1. 字符分类函数2. 字符转换函数3. strlen的使用和模拟实现3.1 strlen 的使用3.1 strlen 的模拟1.计算器方法2.指针-指针的方…

【Spring Security + OAuth2】身份认证

Spring Security OAuth2 第一章 Spring Security 快速入门 第二章 Spring Security 自定义配置 第三章 Spring Security 前后端分离配置 第四章 Spring Security 身份认证 第五章 Spring Security 授权 第六章 OAuth2 1、用户认证信息 1.1、基本概念 在Spring Security框架中…

Axure RP 9 for Mac/win:重新定义交互原型设计的未来

在当今数字化时代&#xff0c;交互原型设计已成为产品开发中不可或缺的一环。Axure RP 9作为一款功能强大的交互原型设计软件&#xff0c;凭借其出色的性能和用户友好的界面&#xff0c;赢得了广大设计师的青睐。 Axure RP 9不仅支持Mac和Windows两大主流操作系统&#xff0c;…

PMP 学习笔记(增量更新中)

PMP 作为最流行的项目管理方法论&#xff0c;是项目管理领域的对话基础&#xff0c;了解它能帮助我理解术语和规范的管理过程&#xff0c;也许后面会考一个认证。感谢 B 站视频《 PMP 认证考试课程最新完整免费课程零基础一次通过项目管理 PMP 考试》的作者&#xff0c;我通过它…

【简单介绍下深度神经网络】

&#x1f3a5;博主&#xff1a;程序员不想YY啊 &#x1f4ab;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f917;点赞&#x1f388;收藏⭐再看&#x1f4ab;养成习惯 ✨希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出…

人工智能应用-实验7-胶囊网络分类minst手写数据集

文章目录 &#x1f9e1;&#x1f9e1;实验内容&#x1f9e1;&#x1f9e1;&#x1f9e1;&#x1f9e1;代码&#x1f9e1;&#x1f9e1;&#x1f9e1;&#x1f9e1;分析结果&#x1f9e1;&#x1f9e1;&#x1f9e1;&#x1f9e1;实验总结&#x1f9e1;&#x1f9e1; &#x1f9…

vue3+ts实战

目录 一、ts语法练习 1.1、安装 1.2、语法 二、vue3ts 2.1、项目创建 2.1.1、项目创建(建议node版本在16.及以上) 2.1.2、下载路由、axios 2.1.3、引入element-plus 2.1.4、报错解决 (1)文件路径下有红色波浪 (2)组件名称下有红色波浪 (3)引入模块下有红色波浪 2.…

快速幂算法6

eg: n10&#xff0c;10%20, 10/25, 5%21,4* 5/22, 2%20,4*256 0/20, 1024 递归算法 #include<iostream> using namespace std; long long quick_pow(int b,int e) {if(b0)return 0;if(e0)return 1;if(e%20){int tempquick_pow(b,e/2);return temp*temp;}if(e%2!0)…

大数据学习之安装并配置maven环境

什么是Maven Maven字面意&#xff1a;专家、内行Maven是一款自动化构建工具&#xff0c;专注服务于Java平台的项目构建和依赖管理。依赖管理&#xff1a;jar之间的依赖关系&#xff0c;jar包管理问题统称为依赖管理项目构建&#xff1a;项目构建不等同于项目创建 项目构建是一…

【SQL国际标准】ISO/IEC 9075:2023 系列SQL的国际标准详情

目录 &#x1f30a;1. 前言 &#x1f30a;2. ISO/IEC 9075:2023 系列SQL的国际标准详情 &#x1f30a;1. 前言 ISO&#xff08;国际标准化组织&#xff0c;International Organization for Standardization&#xff09;是一个独立的、非政府间的国际组织&#xff0c;其宗旨是…