统计软件与数据分析Lesson17----利用pytorch构建LSTM预测股票收益率详细教程

news2024/9/22 9:57:57

利用pytorch构建LSTM预测股票收益率详细教程

  • 1. 整体实现思路
  • 2.代码编写
    • 2.1 step1:导入所需的库
    • 2.2 step2: 读取数据、构建训练样本
    • 2.3 step3: 定义部分辅助函数
    • 2.4 step4:LSTM模型构建
    • 2.5 step5:模型训练
    • 2.6 step6:模型预测和评估
  • 3. 小结

1. 整体实现思路

step1:导入所需的库

  • torch:用于构建神经网络和进行模型训练
  • torch.nn:包含用于构建神经网络的类和函数
  • numpy:用于数据处理和数组操作
  • pandas:用于数据读取和处理
  • matplotlib.pyplot:用于数据可视化
  • sklearn:用于数据预处理和评估模型性能
  • torchviz:用于绘制动态计算图

step2:读取数据、构建训练样本

  • 从文件中读取股票收益率数据
  • 进行必要的数据预处理,如时间序列排序和数据划分
  • 定义一个函数,将收益率序列转换为训练样本
  • 使用滑动窗口方法,将收益率序列划分为输入特征和目标值
  • 将输入特征和目标值转换为PyTorch张量

step3:定义部分辅助函数

  • 定义模型动态计算图绘制函数
  • 定义时序数据序列化处理函数

step4:LSTM模型构建

  • 定义一个LSTM类,继承自torch.nn.Module
  • 在构造函数中定义LSTM层和全连接层
  • 实现forward方法,定义前向传播过程

step5:模型训练

  • 设置超参数,包括隐藏层大小、学习率、迭代次数等
  • 创建LSTM模型实例、损失函数和优化器
  • 迭代训练模型,计算损失并进行反向传播更新模型参数
  • 记录训练过程中的损失值,并绘制损失变化的折线图
  • 绘制模型动态计算图并保存为png

step6:模型预测和评估

  • 使用训练好的模型对测试数据进行预测
  • 计算预测结果与实际值之间的均方误差(MSE)
  • 保存预测结果与真实值到指定文件中
  • 可视化预测结果和实际值的对比

下面对各个步骤进行详细的code编写和实现。

2.代码编写

2.1 step1:导入所需的库

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torchviz import make_dot

# 设置随机种子
seed = 1234
torch.manual_seed(seed)
np.random.seed(seed)

2.2 step2: 读取数据、构建训练样本

601318.csv数据下载

# 导入数据
df = pd.read_csv('./data/601318.csv')
df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d') 
df.set_index('trade_date', inplace=True) 
df.sort_index(inplace=True)
pct_chg = df['pct_chg'].values.astype(float)

# 定义训练集和测试集的比例
train_ratio = 0.8
train_size = int(len(pct_chg) * train_ratio)

# 划分数据集
train_data = pct_chg[:train_size]
test_data = pct_chg[train_size:]
#查看数据基本特征
print(df[['pct_chg']].describe())
# 绘制收益率曲线
df['pct_chg'].plot(figsize=(10,6))
plt.title('601318 daily return')
plt.xlabel('trade_date')
plt.ylabel('return')
plt.show()
输出:
   pct_chg
count  2006.000000
mean      0.038926
std       1.907410
min     -10.000000
25%      -0.974750
50%      -0.030000
75%       0.913225
max      10.020000

17-1

2.3 step3: 定义部分辅助函数

记得在自己代码的同级目录下创建一个名为result的文件夹用于保存相关的结果和可视化图

# 生成训练样本
def generate_samples(data, sequence_length):
    X = []
    y = []
    for i in range(len(data) - sequence_length):
        X.append(data[i:i+sequence_length])
        y.append(data[i+sequence_length])
    return torch.tensor(X).unsqueeze(2), torch.tensor(y).unsqueeze(1)

# 使用torchviz生成动态计算图
def save_model_graph(model, input_size, hidden_size, output_size):
    # 创建一个随机输入张量
    example_input = torch.randn(1, input_size) 
    dot = make_dot(model(example_input), params=dict(model.named_parameters()))
    # 保存计算图为图片
    dot.render('./result/lstm_model_graph', format='png')

    print("模型的动态计算图已保存为LSTM_Model_graph.png") 

2.4 step4:LSTM模型构建

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :])
        return out

2.5 step5:模型训练

# 设置超参数和其他参数
input_size = 1
hidden_size = 32
output_size = 1
sequence_length = 10
num_epochs = 5000
learning_rate = 0.0003
batch_size = 100
patience = 20
best_loss = float('inf')
early_stop_counter = 0

# 创建模型实例
model = LSTM(input_size, hidden_size, output_size)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 将训练数据转换为样本
train_input, train_target = generate_samples(train_data, sequence_length)
train_input = train_input.float()
train_target = train_target.float()
train_data_size = len(train_input)

loss_history = []
# 开始训练
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch_start in range(0, train_data_size, batch_size):
        batch_end = batch_start + batch_size
        if batch_end > train_data_size:
            batch_end = train_data_size

        batch_input = train_input[batch_start:batch_end]
        batch_target = train_target[batch_start:batch_end]

        outputs = model(batch_input)
        loss = criterion(outputs, batch_target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / (train_data_size // batch_size)
    loss_history.append(avg_loss)
    
    # 更新最佳损失值并进行early-stop判断
    if avg_loss < best_loss:
        best_loss = avg_loss
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break
    # 每隔100个epoch 打印一次当前的loss
    if (epoch + 1) % 100 == 0:
        print(f'Epoch: {epoch + 1}/{num_epochs}, Loss: {avg_loss}')
        
# 保存模型
torch.save(model.state_dict(), f'./result/lstm_model_{num_epochs}.pth')

# 绘制损失值变化的折线图
plt.plot(figsize=(12,6))
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title(f'Training Loss (epochs={num_epochs})')
plt.savefig('./result/loss_plot.png')
plt.show()

# 绘制动态计算图并保存为图片
inputs = torch.tensor(train_input, dtype=torch.float32)
dot = make_dot(model(inputs), params=dict(model.named_parameters()))
dot.render('./result/LSTM_dynamic_graph', format='png')

17-2
17-3

2.6 step6:模型预测和评估

# 加载模型
model.load_state_dict(torch.load(f'./result/lstm_model_{num_epochs}.pth'))

# 测试集数据转换为张量
test_input, test_target = generate_samples(test_data, sequence_length)
test_input = test_input.float()
test_target = test_target.float()

# 关闭梯度追踪
with torch.no_grad():
    model.eval()
    # 进行预测
    predicted = model(test_input)
    # 计算均方误差
    mse = criterion(predicted, test_target[-1].view(-1, 1, 1))
    print(f'Test MSE: {mse.item():.4f}')

# 可视化预测结果
predicted = predicted.view(-1).numpy()

# 保存测试集的真实值和预测值到csv
test_result = pd.DataFrame({'actual':pct_chg[train_size+ sequence_length:], 'predicted': predicted})
test_result.to_csv(f'./result/test_result{num_epochs}.csv', index=False)

# 绘制预测结果
plt.plot(figsize=(12,6))
plt.plot(range(train_size, len(pct_chg)), pct_chg[train_size:], label='Actual')
plt.plot(range(train_size + sequence_length, len(pct_chg)), predicted, label='Predicted')
plt.xlabel('Time')
plt.ylabel('Daily Return')
plt.title('Actual vs Predicted in TestData')
plt.legend()
plt.savefig(f'./result/test_prediction_plot{num_epochs}.png')
plt.show()

17-4

3. 小结

完整代码LSTM_ReturnPrediction.py下载

思考如何将完整代码模块化,并在命令行窗口运行含参数设置的py文件

1. 总结模型的性能和训练过程中的变化:

  • 1.可在不设置early_stop条件时,设置不同的训练次数num_epochs,分析训练过程中损失函数的变化情况,观察是否存在收敛和过拟合的现象。
  • 2.可设置不同的学习率,观察模型的训练速度以及最优的迭代次数。
  • 3.可比较预测结果与实际值之间的均方误差(MSE),评估模型在测试集上的预测性能。
  • 4.可以绘制预测结果与实际值之间的对比图,以直观了解模型的预测效果。

2. 探讨可能的改进方法:

  • 1.调整模型结构:
    • 增加/减少LSTM层的数量,改变隐藏层的大小。
    • 尝试不同的激活函数,如ReLU、Tanh等。
    • 添加正则化技术,如Dropout层,以防止过拟合。
  • 2.超参数调优:
    • 对学习率、迭代次数、滑动窗口大小等超参数进行调优,以提高模型性能。
    • 可以使用网格搜索或随机搜索等方法来搜索最佳超参数组合。
  • 3.数据预处理和特征工程:
    • 考虑对输入数据进行更复杂的特征工程,如技术指标的计算或时间序列的平滑处理。
    • 尝试不同的数据标准化或归一化方法,以改善模型的训练效果。
  • 4.模型集成:
    • 考虑使用模型集成方法,如投票、堆叠等,结合多个不同的LSTM模型,以提高预测性能和鲁棒性。

通过以上改进方法的尝试和实验,可以进一步提升模型的性能和泛化能力。在实践中,可以进行多次迭代和实验,根据实际情况进行调整和优化,以获得更好的预测结果。尝试优化模型以提高预测精度~~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/594154.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

对抗样本攻击

目录 一、对抗样本攻击的基本原理 1.1 什么是对抗样本攻击和对抗样本 1.2 对抗样本攻击的基本思路 1.3 对抗样本攻击的分类 1.3.1 按攻击效果分类 1.3.2 按攻击者能力分类 1.3.3 按攻击环境分类 1.4 对抗扰动的衡量 二、对抗样本攻击方法 一、对抗样本攻击的基本原理 …

华为OD机试真题B卷 Java 实现【最少交换次数】,附详细解题思路

一、题目描述 给出数字K&#xff0c;请输出所有小于K的整数组合到一起的最小交换次数。 组合一起是指满足条件的数字相邻&#xff0c;不要求相邻后在数组中的位置。 取值范围&#xff1a; -100 < K < 100 -100 < 数组中的数值 < 100 二、输入描述 第一行输入…

网络安全合规-ISO 27001(一)

实施ISO27001认证的步骤 在长期实践过程中&#xff0c;总结创新了一套高效可行的ISO27001/ISMS项目实施的规范流程。 一、现状调研分析&#xff1a;我方派咨询师去企业了解基本情况&#xff1b;本阶段主要是前期的准备和计划工作&#xff0c;包括明确评估目标&#xff0c;确定…

如何远程控制电脑,远程控制电脑的设置方法

很多人无论是在工作还是生活中使用电脑的时候都需要用到远程控制&#xff0c;因为它可以方便我们解决很多需要到现场操作的问题&#xff0c;在很大方面提升了我们的工作效率&#xff0c;下面来跟大家分享一下&#xff0c;如何远程控制电脑&#xff0c;远程控制电脑的设置方法 …

Web应用技术(第十五周/持续更新)

本次练习基于how2j和课本&#xff0c;进行SSM的初步整合&#xff0c;理解SSM整合的原理、好处。 SSM整合应用 1.简单的实例项目&#xff1a;2.原理分析&#xff1a;3.浅谈使用SSM框架化&#xff1a; 1.简单的实例项目&#xff1a; how2j 2.原理分析&#xff1a; 具体见流程图…

【网络】基础知识1

目录 网络发展 独立模式 网络互联 局域网LAN 广域网WAN 什么是协议 初识网络协议 协议分层 OSI七层模型 TCP/IP四层&#xff08;或五层&#xff09;模型 OSI和TCP/IP对比 网络传输流程 什么是报头 局域网通信原理 同网段的主机通讯 跨网段的主机通讯 数据包封装…

Kali搭建GVM完整版-渗透测试模拟环境(7)

上一篇:OpenVAS、GSA配置验证-渗透测试模拟环境(6)_luozhonghua2000的博客-CSDN博客 在bt5上面进行了安装,调试等配置验证,这篇在kali上面继续安装调试卸载等配置验证,中途版本问题,依赖问题,脚本编写都一一解决。 特别是因网络原因造成的rsync: [Receiver] safe_read f…

Sinkhorn-Knopp算法

Sinkhorn-Knopp是为了解决最优传输问题所提出的。 Sinkhorn算法原理 最优运输问题的目标就是以最小的成本将一个概率分布转换为另一个概率分布。即将概率分布 c 以最小的成本转换到概率分布 r&#xff0c;此时就要获得一个分配方案 P ∈ R n m 其中需满足以下条件&#xff1…

数据分析应该怎么学习?适合什么人学?

先来分享下适合学习数据分析的人群&#xff1a; 数据爱好者&#xff1a;对数据比较感兴趣&#xff0c;喜欢从数据中发现问题&#xff0c;有一定的见解&#xff0c;那么数据分析可以让这类小伙伴能够更好的理解和解释数据。市场营销、运营、业务分析&#xff1a;这类小伙伴学习…

SAP从入门到放弃系列之MRP区域

注&#xff1a;MRP AREA&#xff0c;本文中MRP范围或MRP区域都是指MRP AREA。另外MRP组和MRP区域是两个概念。 目录 MRP区域-库位层级 MRP区域-分包 其他事项 MRP区域-库位层级 除了在单个工厂级别、物料级别或产品组级别运行 MRP 之外&#xff0c;如果业务需要为以下运行 …

NLPChatGPTLLMs技术、源码、案例实战210课

NLP&ChatGPT&LLMs技术、源码、案例实战210课 超过12.5万行NLP/ChatGPT/LLMs代码的AI课程 讲师介绍 现任职于硅谷一家对话机器人CTO&#xff0c;专精于Conversational AI 在美国曾先后工作于硅谷最顶级的机器学习和人工智能实验室 CTO、杰出AI工程师、首席机器学习工程…

【机器学习】浅析过拟合

过度拟合 我们来想象如下一个场景&#xff1a;我们准备了10000张西瓜的照片让算法训练识别西瓜图像&#xff0c;但是这 10000张西瓜的图片都是有瓜梗的&#xff0c;算法在拟合西瓜的特征的时候&#xff0c;将西瓜带瓜梗当作了一个一般性的特征。此时出现一张没有瓜梗的西瓜照片…

探索Java面向对象编程的奇妙世界(七)

⭐ 字符串 String 类详解⭐ 阅读 API 文档⭐ String 类常用的方法⭐ 字符串相等的判断⭐ 内部类 ⭐ 字符串 String 类详解 String 是最常用的类&#xff0c;要掌握 String 类常见的方法&#xff0c;它底层实现也需要掌握好&#xff0c;不然在工作开发中很容易犯错。 &#x…

UI设计师必备的远程软件有哪些?

远程工作时&#xff0c;选择高效的远程软件非常重要。以下是3款提高工作效率的远程软件&#xff0c;希望对你有所帮助&#xff01; 1、即时设计协同设计 是国内首款集合原型、设计、交付、协作和资源管理于一体的高效远程设计软件。它提供实时在线协作功能&#xff0c;使用户…

14肖特基二极管

目录 一、介绍 二、结构 三、关键参数 1、导通压降VF 2、反向饱和漏电流IR 3、额定电流Io/IF 4、最大浪涌电流IFSM 5、最大反向峰值电压VRM 6、最大直流反向电压VR 7、最高工作频率fM 8、反向恢复时间Trr 9、最大耗散功率P 四、特点 1、反向恢复时间 2、缺点 五…

vue router 拆分路由 自动导入

目录 目录结构&#xff1a;拆分路由&#xff1a;自动导入&#xff1a;配置路由&#xff1a; 不求甚解&#xff0c;直接照搬就行了。 目录结构&#xff1a; 拆分路由&#xff1a; // danweiRouter.js export default {path: /danwei,name: danwei,component: () > import(.…

详解RGB和YUV色彩空间转换

前言 首先指出本文中的RGB指的是非线性RGB&#xff0c;意思就是经过了伽马校正&#xff0c;按照行业规矩应当写成RGB&#xff0c;但是为了书写方便&#xff0c;仍写成RGB。关于YUV有多种叫法&#xff0c;分别是YUV&#xff0c;YPbPr&#xff0c;YCbCr。因此本文将首先指出他们之…

这 13 种职业用AI提效的 40 类场景盘点

随着人工智能技术的发展&#xff0c;职业领域出现了诸如我们“小蜜蜂助手Beezy”等神奇的工具&#xff0c;大幅度提升了各行各业里从业人员的工作效率。 笔者今天将详述13种常见职业&#xff0c;分别是如何利用这些工具在实际工作过程中来帮助自己提升效率的。大量干货和私藏宝…

2419286-92-1,Sulfo-Cy5.5 NHS ester,磺酸基Cyanine5.5-活性酯,用于标记抗体

Sulfo-Cyanine5.5 NHS ester&#xff0c;sulfo Cy5.5(Et) NHS&#xff0c;sulfo Cy5.5 SE&#xff0c;磺酸基Cy5.5-活性酯 &#xff08;文章资料汇总来源于&#xff1a;陕西新研博美生物科技有限公司小编MISSwu&#xff09;​ 产品结构式&#xff1a; 产品规格&#xff1a; 1…

Maven高级2-聚合与继承

1. 聚合 注意打包方式&#xff0c;不是默认的jar包形式&#xff0c;也不是web的war包形式&#xff0c;而是pom形式&#xff1b; <groupId>org.example</groupId> <artifactId>springmvc_08_parent</artifactId> <version>1.0-SNAPSHOT</versi…