基于CNN的股票预测方法【卷积神经网络】

news2025/1/13 13:24:10

基于机器学习方法的股票预测系列文章目录

一、基于强化学习DQN的股票预测【股票交易】
二、基于CNN的股票预测方法【卷积神经网络】


文章目录

  • 基于机器学习方法的股票预测系列文章目录
  • 一、CNN建模原理
  • 二、模型搭建
  • 三、模型参数的选择
    • (1)探究`window_size`的影响
    • (2)探究`kernel_size`的影响
    • (3)探究探究模型结构的影响
    • (4) 模型拟合效果
  • 四、数据处理
    • (1)数据变换
    • (2)Kalman滤波
  • 五、参考资料


本文探讨了利用卷积神经网络(CNN)进行股票预测的建模方法,并详细介绍了模型的搭建、参数选择以及数据处理方法。尽管序列建模通常与递归神经网络(如LSTM和GRU)相关,但本文展示了如何使用CNN进行时间序列数据的预测,完整代码放在GitHub上——Stock-Prediction-Using-Machine-Learing.

一、CNN建模原理

深度学习背景下的序列建模主题主要与递归神经网络架构(如LSTM和GRU)有关,但事实上CNN也可以用于对序列数据的建模。与处理图像所用的二维卷积不同,处理时间序列可以使用一维卷积,用多个以前的数据序列预测下一时刻。如下图所示,Input_length是指定用几个以前的数据来预测下一天的股票价格,用一个一个卷积核来滑动提取特征,最后通过一个线性层得到输出的预测值,具体网络搭建见下一小节。

截屏2022-05-23 下午8.22.12

其中两个关键的参数是:

  1. Input_length: 用几个以前的数据作为输入,来预测下一时刻。(在后文称为Window_size)
  2. Kernel_size: 卷积核大小。

事实上也可以用二维的卷积和来建模,比如输入可以是多只股票,用二维卷积核对多只股票同时建模预测,或者将一只股票的多个特征同时建模预测,本文仅探究用股票的收盘价来预测未来的股票收盘价格,没有利用股票数据的其他技术指标。

二、模型搭建

基于Pytorch深度学习框架,搭建的CNN网络如下所示:

kernel_size=2   #一维卷积核大小

class CNNmodel(nn.Module):
    def __init__(self):
        super(CNNmodel, self).__init__()
        self.conv1 = nn.Conv1d(1, 64, kernel_size=kernel_size)   #1xkersize的卷积核 
        #self.conv2 = nn.Conv1d(64,128,1)
        self.relu = nn.ReLU(inplace=True)
        self.Linear1 = nn.Linear(64*(window_size-kernel_size+1), 10)
        self.Linear2 = nn.Linear(10, 1)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = x.view(-1)
        x = self.Linear1(x)
        x = self.relu(x)
        x = self.Linear2(x)
        return x

model = CNNmodel()
print(model)

选用relu函数作为激励函数,因为股票都是正数,而relu函数的性质,可以很好的避免模型输出值为负值。

三、模型参数的选择

(1)探究window_size的影响

调节CNN模型中window_size参数,并比较不同window_size下训练集与测试集的相对误差率,结果如下表所示:

window_size训练集相对误差率测试集相对误差率
52.22%2.18%
61.69%1.48%
72.30%2.27%
82.36%2.12%
151.65%1.58%
202.01%2.03%
502.37%2.55%
150 (kernel_size=40)3.21%2.72%

分析上表知:

  1. 不同window_size对结果有一定影响
  2. window_size比较大时,误差很大
  3. 在10左右,效果比较好,最终我们选择 window_size=6

(2)探究kernel_size的影响

调节CNN模型中kernel_size参数,并比较不同kernel_size下训练集与测试集的相对误差率,结果如下表所示(window_size=6):

Kernel_size训练集相对误差率测试集相对误差率
21.69%1.48%
31.76%1.58%
41.82%1.94%
51.90%1.73%

分析上表数据知较小的kernel_size能使相对误差率更小,最终我们选择kernel_size=2。

(3)探究探究模型结构的影响

调节CNN模型中模型结构,并比较不同模型结构下的平均误差和平均相对误差率,结果如下表所示:

模型结构平均误差平均相对误差率
两个卷积层0.522.46%
1个卷积层,线性层1000.462.29%
1个卷积层,线性层100.281.38%

由上表知,模型对学习率十分敏感;模型结构过于复杂,不容易学习,且容易过拟合。

(4) 模型拟合效果

通过以上探究得到的模型结构以及参数,以AAPL股票为例,采用原始数据进行训练,其预测结果如下图所示:

截屏2022-05-29 下午10.55.49

由上图知,以原始数据进行训练有不错的拟合效果,但滞后比较明显,神经网络会“偷懒”,这是因为数据序列中产生了变化趋势,而基于滑动时间窗口策略的对发生变化趋势的数据感知是滞后的。

对测试集进行预测:

截屏2022-05-29 下午10.34.21

四、数据处理

(1)数据变换

为了解决预测过程中出现的“滞后”问题,常常对原始数据进行一定的处理。常见的数据处理方法有:

  1. 数据归一化
  2. 不直接给出希望模型预测的未经处理的真实值,对输入样本进行非线性化的处理如,如:平方、开根号、ln等
  3. 差分,预测时间t和t-1处值的差异,而不是直接预测t时刻的值

如以AAPL股票数据为例,对其收盘价取其平方的对数进行训练,最终的预测效果如下图所示:

与上一小节的图对比知,“滞后”现象得到显著的减弱,模型的可信度更好。

(2)Kalman滤波

卡尔曼滤波(Kalman filtering)是一种利用线性系统状态方程,通过系统输入输出观测数据,对系统状态进行最优估计的算法。由于观测数据中包括系统中的噪声和干扰的影响,所以最优估计也可看作是滤波过程。

Kalman滤波原理及数理处理过程如下:

  1. 给定初始估计值、系统输入、初始协方差矩阵和误差的方差 Q Q Q, 首先要计算预测值、预测值和真实值之间误差协方差矩阵:

X ^ k ′ = A X ^ k − 1 + B u k − 1 P k ′ = A P k − 1 A T + Q \begin{aligned} &\hat{X}_{k}^{\prime}=A \hat{X}_{k-1}+B u_{k-1} \\ &P_{k}^{\prime}=A P_{k-1} A^{T}+Q \end{aligned} X^k=AX^k1+Buk1Pk=APk1AT+Q

  1. 然后根据 P k ′ P_{k}^{\prime} Pk 计算卡尔曼增益 K k K_{k} Kk :

K k = P k ′ H T ( H P k ′ H T + R ) − 1 K_{k}=P_{k}^{\prime} H^{T}\left(H P_{k}^{\prime} H^{T}+R\right)^{-1} Kk=PkHT(HPkHT+R)1

  1. 然后根据卡尔曼增益 K k K_{k} Kk X ^ k ′ \hat{X}_{k}{ }^{\prime} X^k 以及测量值 Z k Z_{k} Zk, 调和平均得到估计值:

X ^ k = X ^ k ′ + K k ( Z k − H X ^ k ′ ) \hat{X}_{k}=\hat{X}_{k}^{\prime}+K_{k}\left(Z_{k}-H \hat{X}_{k}^{\prime}\right) X^k=X^k+Kk(ZkHX^k)

  1. 最后还要计算估计值和真实值之间的误差协方差矩阵, 为下次递推做准备:

P k = ( I − K k H ) P k ′ P_{k}=\left(I-K_{k} H\right) P_{k}^{\prime} Pk=(IKkH)Pk

以AAPL股票数据为例,对其收盘价进行kalman滤波后,以CNN模型进行训练,结果如下图所示:

截屏2022-05-29 下午10.59.48

与图3对比可知,图3中平均误差为0.11,相对误差率为2.30%,采用kalman滤波后,平均误差为0.08,相对误差率为1.71%,效果变好。

通过前面2种不同数据处理方法对不同模型效果的影响,我们可以看到,不同数据处理方法对不同模型的影响不一样,但总的来说对数据进行相应的处理后,能够提升模型的性能。而通过实验我们发现Kalman滤波进行数据处理后,模型效果有显著的提升。

五、参考资料

  1. 王宇轩.基于卷积神经网络的股票预测[D].天津工业大学,2019.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1886239.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

8619 公约公倍

这个问题可以通过计算最大公约数 (GCD) 和最小公倍数 (LCM) 来解决。我们需要找到一个整数,它是 a, b, c 的 GCD 的倍数,同时也是 d, e, f 的 LCM 的约数。 以下是解决这个问题的步骤: 1. 计算 a, b, c 的最大公约数。 2. 计算 d, e, f 的最…

SAP MM模块的ATP检查

前面几篇文章都演示和说明ATP的一些设置和操作,通常情况下ATP的检查PP模块,SD模块用的相对来说是比较多的,但是实际上MM模块也会遵循ATP的可用性的检查规则。 当我们在做311、301等移动类型时,系统会根据相应的可用性检查规则&am…

大模型应用开发实战基础

大模型应用开发实战基础 1. 背景 大模型如日中天,各行各业都受它影响,但是作为程序员,除了让它翻译代码不知道用它干什么,就像是拿着锤子的木匠,找不到钉子在哪。一边听着别人说2024是AI元年,一边又不知所…

基于X86+FPGA的精密加工检测设备解决方案

应用场景 随着我国高新技术的发展和国防现代化发展,航空、航天等领域需 要的大型光电子器件,微型电子机械、 光 电信息等领域需要的微型器件,还有一些复杂零件的加工需求日益增加,这些都需要借助精密甚至超精密的加工检测设备 客…

Asp.NET identity以及Owin

》》》Identity是集成到Owin框架中中 ● Microsoft.AspNet.Identity.Core:Identity的核心类库,实现了身份验证的核心功能,并提供了拓展接口。● Microsoft.AspNet.Identity.EntityFramework:Identity数据持久化的EF实现。   ● …

强化学习的数学原理:最优贝尔曼公式

大纲 贝尔曼最优公式是贝尔曼公式的一个特殊情况,但其也非常重要。 本节课很重要的两个概念和一个工具: 工具不用多说,就是贝尔曼最优公式,概念则是 optimal state value(最优状态价值) 和 optimal polic…

文件中各个函数返回----EOF----NULL---非零值>>>>>区分

fopen 返回值:操作正常返回文件指针, 失败返回NULL fclose 返回值:操作正常返回 0 失败返回EOF 不关闭文件会丢失 fgetc 返回值: 成功读入字符 失败返回EOF fputc 返回值:成功输出的字符 失败返回EOF fgets …

香橙派OrangePi AIpro初体验:当小白拿到一块开发板第一时间会做什么?

文章目录 香橙派OrangePi AIpro初体验:当小白拿到一块高性能AI开发板第一时间会做什么前言一、香橙派OrangePi AIpro概述1.简介2.引脚图开箱图片 二、使用体验1.基础操作2.软件工具分析 三、香橙派OrangePi AIpro.测试Demo1.测试Demo1:录音和播音(USB接口…

MySQL的并发控制、事务、日志

目录 一.并发控制 1.锁机制 2.加锁与释放锁 二.事务(transactions) 1.事物的概念 2.ACID特性 3.事务隔离级别 三.日志 1.事务日志 2.错误日志 3.通用日志 4.慢查询日志 5.二进制日志 备份 一.并发控制 在 MySQL 中,并发控制是确…

pandas数据分析(5)

pandas使用Numpy的np.nan代表缺失数据,显示为NaN。NaN是浮点数标准中地Not-a-Number。对于时间戳,则使用pd.NaT,而文本使用的是None。 首先构造一组数据: 使用None或者np.nan来表示缺失的值: 清理DataFrame时&#xf…

ubuntu apt命令 出现红色弹框 Daemons using outdated libraries

1. 弹框没截图,是因为ubuntu22.04一个新特性导致的,由 needrestart 命令触发,默认情况是交互性质的,也就是会中断在这里需要手动要处理提示。 2. 修改/etc/needrestart/needrestart.conf 文件,将 #$nrconf{restart} …

APKDeepLens:一款针对Android应用程序的安全扫描工具

关于APKDeepLens APKDeepLens是一款针对Android应用程序的安全扫描工具,该工具基于Python开发,旨在扫描和识别Android应用程序(APK文件)中的安全漏洞。 APKDeepLens主要针对的是OWASP Top 10移动端安全漏洞,并为开发人…

力扣热100 哈希

哈希 1. 两数之和49.字母异位词分组128.最长连续序列 1. 两数之和 题目:给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两个 整数,并返回它们的数组下标。你可以假设每种输入只会对应一个答案。…

计算机缺少d3dcompiler_43.dll无法继续执行代码怎么修复

打开游戏或许软件程序时候,我们会经常遇到各式各样的问题,比如找不到d3dcompiler_43.dll无法继续执行代码就是非常常见的问题,今天我叫大家如何解决遇到d3dcompiler_43.dll丢失问题,也详细介绍d3dcompiler_43.dll文件是什么与丢失…

什么方法能快速分享视频给他人?视频二维码提供预览的制作技巧

现在想要分享一个或者多个视频时,很多人会选择将视频生成二维码的方法来展现视频内容,通过这种方式可以让多人同时扫码查看同一个视频,有效提升其他人获取内容的速度及视频传播的效率。那么视频转换成二维码的方法是什么样的呢? …

replace()方法——替换字符串

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 语法参考 replace()方法用于将某一字符串中一部分字符替换为指定的新字符,如果不指定新字符,那么原字符将被直接去除&#x…

数据库取出来的日期格式是数组格式,序列化日期格式

序列化前,如图所示: 解决方式,序列化日期(localdatetime)格式 步骤一、添加序列化类 package com.abliner.test.common.configure;import com.alibaba.fastjson.serializer.JSONSerializer; import com.alibaba.fas…

[图解]企业应用架构模式2024新译本讲解19-数据映射器1

1 00:00:01,720 --> 00:00:03,950 下一个我们要讲的就是 2 00:00:04,660 --> 00:00:07,420 数据映射器这个模式 3 00:00:09,760 --> 00:00:13,420 这个也是在数据源模式里面 4 00:00:13,430 --> 00:00:14,820 用得最广泛的 5 00:00:16,250 --> 00:00:19,170…

高编:进程间通信 IPC interprocess communicate

一、进程间三大类通信 1、古老的通信方式 无名管道 有名管道 信号 2、IPC对象通信 system v(5) BSD suse fedora kernel.org 消息队列(用的相对少,这里不讨论) 共享内存 信号量集(进程间做互斥与同步semaphore) 3、socket通信 网络通…

AD快速导入立创3D模型

在AD绘制PCB时,可以添加3D模型,在绘制完成PCB后就可以导出3D图给结构工程师核对,方便产品的开发。这里介绍一种可以比较快完成3D导入的方式。 一、PyCharm代码 打开PyCharm,在运行本代码时,需要安装第三方包codecs&a…