用循环神经网络预测股价

news2024/11/19 13:19:27

循环神经网络可以用来对时间序列进行预测,之前我们在介绍循环神经网络RNN,LSTM和GRU的时候都用到了正弦函数预测的例子,其实这个例子就是一个时间序列。而在众多的时间序列例子中,最普遍的就是股价的预测了,股价序列是一种很明显的时间序列,价格随时间变化,每天都有一个收盘价。本文就打算使用简单循环神经网络RNN和长短期记忆网络LSTM来对股价进行一下预测。

我们打算利用前N天的股票收盘价来预测下一日的股票收盘价,所以首先需要获取股票数据,这里我使用akshare接口来获取数据,个人觉得比tushare好用。

虽然变量名取了df_hs300,但我没有用沪深300指数,我选择了浙大网新这个股票,毕竟是自己学校下面的企业,支持一下:)

df_hs300 = ak.stock_zh_a_hist(symbol="600797", period="daily", start_date="20210101", end_date=datetime.datetime.today().strftime("%Y%m%d"), adjust="")

获取了从2021年1月1日到当前的股票数据,我们可以输出这个数据看一下:

而我这里只需要收盘价以及日期两个字段,并把收盘价进行归一化处理,更便于训练:

close_list = df_hs300['收盘'].values
date_list = df_hs300['日期'].values
close_list_norm=[price/max(close_list) for price in close_list]

可以打印出来看一下

%matplotlib inline
import matplotlib.pyplot as plt

plt.plot(close_list_norm)
plt.title('hs_300')
plt.xlabel('date')
plt.ylabel('colse price')
plt.show()

下面,根据这个数据集定义一个Dataset和DataLoader,我选择用前10天的收盘价来预测下一个交易日的收盘价,所以时间步选择了10,并用前700个数据作为训练数据集。

from torch.utils.data import Dataset, DataLoader  

class StockDataset(Dataset):
    def __init__(self, data_list, time_step = 10, transform=None):
        self.data = data_list
        self.features = []
        self.targets = []
        for i in range(len(self.data)-time_step):
            feature = [x for x in self.data[i:i+time_step]]
            y = self.data[i+time_step]
            #feature = torch.Tensor(feature)
            #feature = feature.unsqueeze(1)
            y = torch.tensor(y)  
            y = y.reshape(-1)
            self.features.append(feature)
            self.targets.append(y)
        self.features = torch.tensor(self.features)
        self.features = self.features.reshape(-1, time_step, 1)
        
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]
        
transform = transforms.Compose([transforms.ToTensor()])
dataset = StockDataset(close_list_norm[:700], time_step=10, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=True)

我先用简单循环神经网络来训练一下该数据集,下面定义了这个RNN模型以及一些初始化参数:

time_step = 10
batch_size = 1
#设计网络(单隐藏层Rnn)
input_size,hidden_size,output_size=1,20,1
#Rnn初始隐藏单元hidden_prev初始化
hidden_prev=torch.zeros(1,batch_size,hidden_size).cuda()
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.rnn=nn.RNN(
            input_size=input_size,    #输入特征维度,当前特征为股价,维度为1
            hidden_size=hidden_size,  #隐藏层神经元个数,或者也叫输出的维度
            num_layers=1,
            batch_first=True
        )
        self.linear=nn.Linear(hidden_size,output_size)

    def forward(self,X,hidden_prev):
        out,ht=self.rnn(X,hidden_prev)

        batch_size, seq,  hidden_size = out.shape

        out = self.linear(out[:, -1, :])  # 其实就是取出输出的序列长度中的最后一个去进行线性运算,得到输出
        return out

定义一个训练方法:

model=Net()
model=model.cuda()
criterion=nn.MSELoss()
learning_rate,epochs=0.01,500
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)
for epoch in range(epochs):
    losses = []
    for X,y in dataloader:
        X = X.cuda()
        y = y.cuda()
        y=y.to(torch.float32)
        X=X.to(torch.float32)
        #print("X.shape: ",X.shape)
        #print("y.shape: ",y.shape)
        optimizer.zero_grad()
        yy=model(X,hidden_prev)
        yy=yy.cuda()
        #print("yy.shape: ",yy.shape)
        #print(yy)
        #print(y)
        loss = criterion(y, yy)
        model.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    epoch_loss=sum(losses)/len(losses)
    if epoch%50==0:   #保留验证集损失最小的模型参数
        print("epoch:{},loss:{:.8f}".format(epoch+1,epoch_loss))
torch.save(model, "model2.pt")
# 输出:
epoch:1,loss:0.00685027
epoch:51,loss:0.00065118
epoch:101,loss:0.00120512
epoch:151,loss:0.00215360
epoch:201,loss:0.00149827
epoch:251,loss:0.00173493
epoch:301,loss:0.00188238
epoch:351,loss:0.00167589
epoch:401,loss:0.00165730
epoch:451,loss:0.00160637

我们把训练后的模型在验证数据集上测试一下,首先定义验证数据集,训练数据集选择所有数据的前700个数据,验证数据集就选择700个以后的数据作为验证数据集。

transform = transforms.Compose([transforms.ToTensor()])
val_dataset = StockDataset(close_list_norm[700:], time_step=10, transform=transform)
val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=1, shuffle=False)

定义验证方法

model = torch.load('model2.pt')

Val_y,Val_predict=[],[]
#将归一化后的数据还原
Val_max_price=max(close_list) 
for X,y in val_dataloader:
    with torch.no_grad():
        X = X.cuda()
        y=y.to(torch.float32)
        X=X.to(torch.float32)
        print("X: ",X)
        predict=model(X,hidden_prev)
        y=y.cpu()
        predict=predict.cpu()
        print("y: ",y)
        print("predict: ",predict)
        # 把股价还原为归一化之前的股价
        Val_y.append(y[0][0]*Val_max_price) 
        Val_predict.append(predict[0][0]*Val_max_price)

fig=plt.figure(figsize=(8,5),dpi=80)
# 红色表示真实值,绿色表示预测值
plt.plot(Val_y,linestyle='--',color='r')
plt.plot(Val_predict,color='g')
plt.title('stock price')
plt.xlabel('time')
plt.ylabel('price')
plt.show()

我们可以看到,总体趋势是一致的,但是真实值和预测值之间的差距确实有点大,那么我们接下来看一下LSTM网络的模型表现如何:

class Net_LSTM(nn.Module):
    def __init__(self):
        super(Net_LSTM,self).__init__()
        self.lstm=nn.LSTM(
            input_size=input_size,    #输入特征维度,当前特征为股价,维度为1
            hidden_size=hidden_size,  #隐藏层神经元个数,或者也叫输出的维度
            num_layers=1,
            batch_first=True
        )
        self.linear=nn.Linear(hidden_size,output_size)

    def forward(self,X):
        out,ht=self.lstm(X)      
        batch_size, seq,  hidden_size = out.shape
        out = self.linear(out[:, -1, :])  # 其实就是取出输出的序列长度中的最后一个去进行线性运算,得到输出
        return out

因为训练和验证的函数和用RNN训练和验证的函数基本是一致的,我就不赘述了,我们来看看利用LSTM进行训练后的模型,在验证集上的表现如何:

可以看到,这个效果比起用简单循环神经网络RNN好上了很多,可见LSTM的效果确实比简单RNN要提高了不少。这只是一个例子而已,不建议根据这个结果去进行投资,因为预测结果在细节上和原始数据还是有不少差别的,而且只能验证下一个交易日的情况,如果预测时间稍微拉长,效果就会急剧下降,并当前预测都是在前期的数据集的基础上进行预测,如果有突发事件的发生,模型是捕捉不到的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1710115.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【PG16】后 EL 7 时代,PG 16 如何在 CentOS 7 上运行

↑ 关注“少安事务所”公众号,欢迎⭐收藏,不错过精彩内容~ ★ 本文写于 2023-09-29 PostgreSQL 16 Released 9/14, PostgreSQL 16 正式发布。从发布公告^1 和 Release Notes^2 可以看到 PG16 包含了诸多新特性和增强改进。 性能提升,查询计划…

ssm超市管理系统java超市进销存管理系统jsp项目

文章目录 超市进销存管理系统一、项目演示二、项目介绍三、系统部分功能截图四、七千字项目文档五、部分代码展示六、底部获取项目源码和七千字项目文档(9.9¥带走) 超市进销存管理系统 一、项目演示 超市进销存管理系统 二、项目介绍 角色分…

Dynadot API调整一览

关于Dynadot Dynadot是通过ICANN认证的域名注册商,自2002年成立以来,服务于全球108个国家和地区的客户,为数以万计的客户提供简洁,优惠,安全的域名注册以及管理服务。 Dynadot平台操作教程索引(包括域名邮…

算法设计第七周(应用哈夫曼算法解决文件归并问题)

一、【实验目的】 (1)进一步理解贪心法的设计思想 (2)掌握哈夫曼算法的具体应用 (3)比较不同的文件归并策略,探讨最优算法。 二、【实验内容】 设S{f1,…,fn}是一组不同的长度的有序文件构…

vue脚手架与创建vue项目

一、前言 vue脚手架的安装与创建vue项目需要先行安装配置node与npm,详情可以看node、npm的下载、安装、配置_node 下载安装-CSDN博客 二、vue脚手架的使用 1、vue与vue脚手架的版本 Vue脚手架(Vue CLI)是Vue.js官方提供的一个命令行工具&…

打乱一维数组中数据(小练习)

int[] tempArr{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}; 要求:打乱一维数组的数据,并按照4个一组的方式添加到二维数组中。 package chengyu4; import java.util.Random; public class Test{public static void main(String args[]) {int[] temArr {1…

CM2038A 3W 双通道立体声音频功率放大器芯片IC

功能说明: CM2038A是一双路音频功率放大器,它能够在5V 电源电压下给一个4Ω负载提供THD小于10%、最大平均值为3W的输出功率。另外,在驱动立体声耳机时耳机输入引脚可以使放 大器工作在单边模式。 CM2038A是为提供高保真音频输出而专门设计…

【每日一坑】KiCAD 覆铜区域约束

【每日一坑】 1.螺丝孔周围不想要要铜皮; 2、首先在CTRLshiftK;画一个区域,比如铺一个GND; 3、选择CUTOUT; 4、画线,画好闭合图形;如下图 5、就是这样了,就是还没有画圆或者异形的;

如何制定一个有效的现货黄金投资策略(EEtrade)

制定一个有效的现货黄金投资策略涉及多方面的考量。以下是几个步骤和考虑因素,可以帮助您建立一个坚实的投资策略: 1. 设立清晰的投资目标 决定您投资现货黄金的主要目的。是否是为了短期利润,长期保值增值,还是为了投资组合的多…

如何通过网络性能监控和流量回溯分析提升网络效率?

目录 网络性能监控的重要性 什么是网络性能监控? 为什么需要网络性能监控? 流量回溯分析的应用 什么是流量回溯分析? 流量回溯分析的优势 实现网络性能监控和流量回溯分析的方法 使用高性能的分析工具 部署网络监控系统 结论 在当今…

Windows内核函数 - 创建关闭注册表

在驱动程序的开发中,经常会用到对注册表的操作。与Win32的API不同,DDK提供另外一套对注册表操作的相关函数。首先明确一下注册表里的几个概念,避免在后面混淆。 图1 注册表概念 有5个概念需要重申一下: * 注册表项: 注…

关于c++的通过cin.get()维持黑框的思考

1.前言 由于本科没有学过c语言,研究生阶段接触c上手有点困难,今天遇到关于通过cin.get()来让黑框维持的原因。 2.思考 cin.get()维持黑框不消失的原因一言蔽之就是等待输入。等待键盘的输入内容并回车(一般是回车)后cin.get()才…

2024下半年BRC-20铭文发展趋势预测分析

自区块链技术诞生以来,其应用场景不断扩展,代币标准也在不断演进。BRC-20铭文作为基于比特币区块链的代币标准,自其推出以来,因其安全性和去中心化特性,受到了广泛关注和使用。随着区块链技术和市场环境的不断变化&…

NFTScan 正式上线 Mint NFTScan 浏览器和 NFT API 数据服务

2024 年 5 月 20 号,NFTScan 团队正式对外发布了 Mint NFTScan 浏览器,将为 Mint 生态的 NFT 开发者和用户提供简洁高效的 NFT 数据搜索查询服务。NFTScan 作为全球领先的 NFT 数据基础设施服务商,Mint 是继 Bitcoin、Ethereum、BNBChain、Po…

maven聚合工程整合springboot+mybatisplus遇到的问题

前言(可以直接跳过看下面解决方法) 项目结构 两个module: yema-terminal-boot 是springboot项目,子包有:controller、service、dao 等等。属于经典三层架构。那么,该module可以理解为是一个单体项目&…

python打造自定义汽车模块:从设计到组装的全过程

新书上架~👇全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一、引言 二、定义汽车模块与核心类 三、模拟汽车组装过程 四、抽象与封装 五、完整汽车…

液氮罐内部会污染吗

液氮罐是一种常见的存储液态氮的设备,广泛应用于科研、生物医药、食品冷冻等领域。但是,人们对于液氮罐内部是否会产生污染一直存在疑问。 我们来看液氮罐内部可能的污染源。液氮罐内部主要存在以下几种潜在的污染来源:气体污染、杂质污染、…

飞睿智能高精度、低功耗测距,无线室内定位UWB芯片如何改变智能家居

在数字化和智能化快速发展的今天,定位技术已经成为我们日常生活中不可或缺的一部分。然而,传统的GPS定位技术在室内环境中往往束手无策,给我们的生活带来了诸多不便。幸运的是,随着科技的不断进步,一种名为UWB&#xf…

Octo:伯克利开源机器人开发框架

【摘要】在各种机器人数据集上预先训练的大型策略有可能改变机器人学习:这种通用机器人策略无需从头开始训练新策略,只需使用少量领域内数据即可进行微调,但具有广泛的泛化能力。然而,为了广泛应用于各种机器人学习场景、环境和任…

FM1800隧道广播插播控制器

隧道广播插播控制器是一款群载波&应急广播插播控制器采用SDR软件无线电技术,产生独立的插播信号与“群载波”信号,本设备可通过软件无线电技术将音频信号调制成调频载波或“群载波”信号,分别送入插播主机,实现隧道广播远端机…