基于强化学习DQN的股票预测【股票交易】

news2024/10/6 14:33:18

强化学习笔记

第一章 强化学习基本概念
第二章 贝尔曼方程
第三章 贝尔曼最优方程
第四章 值迭代和策略迭代
第五章 强化学习实例分析:GridWorld
第六章 蒙特卡洛方法
第七章 Robbins-Monro算法
第八章 多臂老虎机
第九章 强化学习实例分析:CartPole
第十章 时序差分法
第十一章 值函数近似【DQN】
第十二章 基于强化学习DQN的股票预测


文章目录

  • 强化学习笔记
  • 一、DQN
  • 二、软更新
  • 三、实验
  • 四、参考资料


在金融决策问题中,如何制定有效的交易策略一直是一个重要且具有挑战性的问题。近年来,强化学习在这一领域的应用显示出了很大的潜力,比如,强化学习可以帮助我们在股票交易过程中进行决策。

在这里,我想先比较一下监督学习和强化学习在股票交易问题中的不同:

  1. 监督学习主要关注预测,即通过历史数据训练模型,然后对未来的数据进行预测。例如,我们可以通过监督学习预测股票的价格走势。如果要交易还得结合其他策略方法。
  2. 而强化学习不仅仅是预测,它可以进行交易决策。它不仅仅关注于预测未来的股票价格,更重要的是,它可以根据预测结果来制定买卖策略,以最大化我们的收益。

下图给出了强化学习在股票交易问题应用中的主要框架:

image-20240627140425224 其核心问题有以下几点:
  1. 如何定义奖励函数,即Reward如何设置?
  2. 采用强化学习中的哪种模型,DQN、PPO、A2C、DDPG……
  3. 状态空间如何定义?

一、DQN

本文我们介绍用深度强化学习中最经典的模型——DQN来进行建模,完整代码放在GitHub上——DQN-for-Stock-Trading。在DQN模型中,采用了多个全连接线性层,其模型结构如下:

class QNetwork(nn.Module):
    """QNetwork (Deep Q-Network), state is the input, 
        and the output is the Q value of each action.
    """
    def __init__(self, state_size, action_size, fc1_units=128, fc2_units=128, fc3_units=64):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size , fc1_units)
        self.fc2 = nn.Linear(fc1_units, fc2_units)
        self.fc3 = nn.Linear(fc2_units, fc3_units)
        self.fc4 = nn.Linear(fc3_units, action_size)
        self.dropout = nn.Dropout(0.1)  # Dropout with 20% probability

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

其中:

  1. 输入也就是状态 s s s,建模为股票过去几天的波动情况,也就是相邻两天的差值,输入的维数由给定的一个滑动窗口大小决定;
  2. 输出则是action,这里我设置action有三种0、1、2,分别代表买入,卖出或者不变.

DQN的一个核心思想是经验缓冲池,将数据都放入缓冲池内,训练网络时从这里面采样得到小批量数据,其主要代码如下:

class ReplayBuffer:
    def __init__(self, action_size, buffer_size, batch_size):
        self.action_size = action_size
        self.memory = deque(maxlen=buffer_size)  # initialize replay buffer
        self.batch_size = batch_size
        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])

    def add(self, state, action, reward, next_state, done):
        """Add a new experience to memory."""
        e = self.experience(state, action, reward, next_state, done)
        self.memory.append(e)

    def __len__(self):
        """Return the current size of internal memory."""
        return len(self.memory)

DQN另一个重要思想是用两个神经网络来交替更新参数,其代码如下:

class Agent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size

        # Q-Network
        self.qnetwork_local = QNetwork(state_size, action_size).to(device)
        self.qnetwork_target = QNetwork(state_size, action_size).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

二、软更新

在更新target network时,我们采用软更新的策略。软更新是一种在深度强化学习中更新目标网络参数的方法。目标网络(target network)用于稳定训练过程,其参数并不像本地网络(local network)那样在每一步都更新,而是以较慢的速率进行更新。软更新通过将目标网络的参数逐步向本地网络的参数靠拢来实现这种较慢的更新。具体来说,软更新的公式如下:
θ target ← τ θ local + ( 1 − τ ) θ target \theta_{\text{target}} \leftarrow \tau \theta_{\text{local}} + (1 - \tau) \theta_{\text{target}} θtargetτθlocal+(1τ)θtarget其中:

  • θ target \theta_{\text{target}} θtarget 是目标网络的参数。
  • θ local \theta_{\text{local}} θlocal 是本地网络的参数。
  • τ \tau τ 是软更新的比例系数,通常是一个非常小的值(例如 0.001)。

这个公式表示目标网络的参数是本地网络参数的 τ \tau τ 倍加上目标网络自身参数的 ( 1 − τ ) (1 - \tau) (1τ) 倍。因此,目标网络参数的变化是渐进的,而不是像硬更新(hard update)那样直接将本地网络的参数复制到目标网络。

在代码中,软更新通过 soft_update 方法实现:

def soft_update(self, local_model, target_model, tau):
    for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
        target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)

在DQN算法中,如果目标网络的参数频繁更新,会导致训练过程不稳定,因为目标网络用于计算目标值,而这些目标值需要在一段时间内保持相对稳定。因此,软更新通过缓慢调整目标网络的参数,能够有效地平滑训练过程,提高算法的收敛性和稳定性。

三、实验

在比较简单的环境设置下进行实验,不考虑交易成本,每次买入卖出都是1股股票,reward设置为卖出股票时赚的钱。下图是训练过程的累积收益,我们可以看到随着不断地学习,agent的决策确实使得我们在这只股票上挣钱了!

image-20240627142654564

下图是在训练数据上回测的结果,我们可以看到agent学到了一个简单的“低吸高抛”的策略。

image-20240627142617668

下图是在测试集上的实验,我们发现在没有训练的数据上用刚才的模型也能挣钱,并且策略仍然是低吸高抛.

image-20240627142728266

采用更复杂的交易环境,考虑交易成本,每次买入卖出的数量,奖励函数采用收益率,我们可以得到一个复杂的策略。下图图仍是在训练数据上的回测,我们可以看到相比前面的“低吸高抛”策略稍微复杂了一些,下面条形图表示持仓,可以看到学习的策略在股票价格最低时增大仓位,在股票价格高点时,抛售赚钱。

截屏2024-06-27 14.29.17

四、参考资料

  1. https://www.youtube.com/watch?v=05NqKJ0v7EE

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1869125.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

商家转账到零钱开通指南

商家转账到零钱功能是微信支付开发的一款商家可以直接向个人微信发放零钱的产品,商家可通过此功能手动或者自动向多个微信用户发起转账。不过因为人工审核门槛的问题,不少商家很难自主通过申请,以下是经过我们上万次开通操作的经验总结&#…

观成科技:证券行业加密业务安全风险监测与防御技术研究

摘要:解决证券⾏业加密流量威胁问题、加密流量中的应⽤⻛险问题,对若⼲证券⾏业的实际流量内容进⾏调研分析, 分析了证券⾏业加密流量⾯临的合规性⻛险和加密协议及证书本⾝存在的⻛险、以及可能存在的外部加密流量威 胁,并提出防…

缓冲区溢出

本文作者:杉木涂鸦智能安全实验室 前置知识点 栈 栈(Stack)是计算机中的一种数据结构,用于存储临时数据。它的特点是后入先出(LIFO),只能在栈顶添加或删除数据。在程序中,栈被用于…

【JavaScript】JS对象和JSON

目录 一、创建JS对象 方式一:new Object() 方式二:{属性名:属性值,...,..., 方法名:function(){ } } 二、JSON格式 JSON格式语法: JSON与Java对象互转: 三、JS常见对象 3.1数组对象API 3.2 其它对象API 一、创建JS对象 方式一:new…

创新前沿:Web3如何颠覆传统计算机模式

随着Web3技术的快速发展,传统的计算机模式正面临着前所未有的挑战和改变。本文将深入探讨Web3技术的定义、原理以及它如何颠覆传统计算机模式,以及对全球科技发展的潜在影响。 1. 引言:Web3技术的兴起与背景 Web3不仅仅是技术创新的一种&…

OpenAI 开启买买买模式:接连收购 Rockset 与 Multi,科技巨头创新布局

引言 最近,OpenAI 在科技领域引起了广泛关注,通过接连收购两家初创公司 Rockset 和 Multi,开启了所谓的“买买买模式”。这一战略举措不仅展现了 OpenAI 对于技术发展的深远布局,也预示着未来更多创新产品的推出。本文将详细探讨…

Dataease安装,配置Jenkins自动部署

Dataease安装,配置Jenkins自动部署 一.安装Dataease 安装前准备:1.Ubuntu20.04 LTS国内源安装指定版本Docker 2.docker-compose安装 下载离线安装的安装包,下载地址:https://community.fit2cloud.com/#/download/dataease/v1-…

检测故障电容器

去耦电容与旁路电容 “去耦电容”和“旁路电容”这两个术语经常互换使用,它们的功能重叠,容易造成混淆。实际上,它们的用途相似,但在电路中的应用可能会影响术语。 去耦电容 功能:去耦电容器主要用于通过为交流信号…

【人工智能学习之图像操作(一)】

【人工智能学习之图像操作(一)】 图像读写创建图片并保存视频读取色彩空间与转换色彩空间的转换通道分离理解HSV基本图形绘制 阀值操作OTSU二值化简单阀值自适应阀值 图像读写 图像的读取、显示与保存 import cv2 img cv2.imread(r"1.jpg")…

Wp-scan一键扫描wordpress网页(KALI工具系列三十)

目录 1、KALI LINUX 简介 2、Wp-scan工具简介 3、信息收集 3.1 目标IP(服务器) 3.2kali的IP 4、操作实例 4.1 基本扫描 4.2 扫描已知漏洞 4.3 扫描目标主题 4.4 列出用户 4.5 输出扫描文件 4.6 输出详细结果 5、总结 1、KALI LINUX 简介 Kali Linux 是一…

海外仓一件代发效率提升方案:拣货区规划策略

作为海外仓的核心业务,一件代发处理的效率和准确性,可以说直接影响了海外仓的经济效益。今天我们就会针对大家都比较头疼的一件代发效率问题,给大家分享一些实用建议。 提升一件代发效率要考虑的3个关键要素 对以一件代发为主要业务的海外仓…

【机器学习】机器学习重要方法——迁移学习:理论、方法与实践

文章目录 迁移学习:理论、方法与实践引言第一章 迁移学习的基本概念1.1 什么是迁移学习1.2 迁移学习的类型1.3 迁移学习的优势 第二章 迁移学习的核心方法2.1 特征重用(Feature Reuse)2.2 微调(Fine-Tuning)2.3 领域适…

C++身份证ocr识别、身份证二要素核验接口状态码返回

互联网时代,对个人进行身份证实名认证相信大家都不陌生,那么,对于实名认证功能是如何实现的大家有所了解么?对于开发人员而言,身份证实名认证接口返回的状态码又都代表着什么意思呢?今天,跟着翔…

2024 年最新 Python 基于火山引擎豆包大模型搭建 QQ 机器人详细教程(更新中)

豆包大模型概述 火山引擎官网:https://www.volcengine.com/ 字节跳动推出的自研大模型。通过字节跳动内部50业务场景实践验证,每日千亿级tokens大使用量持续打磨,提供多模态能力,以优质模型效果为企业打造丰富的业务体验。 模型…

【Python机器学习】自动化特征选择——基于模型的特征选择

基于模型的特征选择使用一个监督机器学习模型来判断每个特征的重要性,并且仅保留最重要的特征。用于特征学习的监督模型不需要与用于最终建模的模型相同。特征选择模型需要为每个特征提供某种重要性度量,以便用这个度量对特征进行排序。决策树和基于决策…

Potato(土豆)一款轻量级的开源文本标注工具

项目介绍: Potato 是一款轻量级、可移植的Web文本标注工具,被EMNLP 2022 DEMO赛道接受。它旨在帮助用户快速地从零开始创建和部署各种文本标注任务,无需复杂的编程或网页设计。只需简单配置,团队即可在几分钟内启动并运行标注项目…

互联网寒冬VS基建饱和:计算机专业会重蹈土木工程的覆辙吗?

随着高考落幕,考生和家长们开始着手专业选择与志愿填报,"热门"与"冷门"专业的话题引起了广泛关注。而计算机专业无疑是最受瞩目的专业领域之一。 在过去的十几年里,计算机专业以其出色的就业率和薪酬水平,一…

LAMP架构的源码编译环境下部署Discuz论坛

一、LAMP架构 LAMP架构是一种常见的用于构建动态网站的技术栈 组成功能Linux(操作系统)LAMP 架构的基础,用于托管 Web 服务器和应用程序Apache(Web服务器)接收和处理客户端请求,并将静态和动态内容发送给…

AMEYA360代理:村田电子使用小型振动传感器件,实现设备状态预知检测

株式会社村田制作所近日完成了贴片型振动传感器件“PKGM-200D-R”的商品化。该新产品已开始批量生产供应。 以往FA行业实施的是计划性维护和事后维护,近年来预测性维护逐步受到关注。预测性维护使用各类传感器信息等预测可能发生故障的时间,以便事先采取…

ABAP编程中的参数传递:使用EXPORT/IMPORT与SPA/GPA参数

在ABAP编程中,有效地在程序之间传递数据是实现功能的关键。本文档将介绍两种常用的数据传递方法:EXPORT/IMPORT和SPA/GPA参数,并提供实际示例。 1. 使用EXPORT/IMPORT数据(ABAP/4内存) EXPORT/IMPORT语句允许程序在ABA…