【量化课程】08_2.深度学习量化策略基础实战

news2025/1/23 9:27:32

文章目录

    • 1. 深度学习简介
    • 2. 常用深度学习模型架构
      • 2.1 LSTM 介绍
      • 2.2 LSTM在股票预测中的应用
    • 3. 模块分类
      • 3.1 卷积层
      • 3.2 池化层
      • 3.3 全连接层
      • 3.4 Dropout层
    • 4. 深度学习模型构建
    • 5. 策略实现

1. 深度学习简介

深度学习是模拟人脑进行分析学习的神经网络。

2. 常用深度学习模型架构

  • 深度神经网络(DNN)
  • 卷积神经网络(CNN)
  • 马尔可夫链(MC)
  • 玻尔兹曼机(BM)
  • 生成对抗网络(GAN)
  • 长短期记忆网络(LSTM)

2.1 LSTM 介绍

长短期记忆网络(LSTM)是一种常用于处理序列数据的循环神经网络(RNN)的变体,被广泛应用于自然语言处理、语音识别、时间序列预测等任务中。

LSTM通过门控机制解决了传统RNN中的梯度问题,能够有效地处理序列数据,并在多个领域取得了显著的成果。

2.2 LSTM在股票预测中的应用

LSTM在量化预测股票方面被广泛应用。它可以利用历史股票价格和交易量等数据来学习股票价格的趋势和波动,从而进行未来的预测。

在股票预测中,LSTM可以接受时间序列数据作为输入,并通过递归地更新隐藏状态来捕获长期依赖关系。它可以通过学习历史价格和交易量等特征的模式,对未来的股票价格进行预测。

通过将股票历史数据作为训练样本,LSTM可以学习不同时间尺度上的模式,例如每日、每周或每月的波动情况。它还可以利用技术指标、市场情绪数据等辅助信息,以提高预测准确性。

在实际应用中,研究人员和投资者通过训练LSTM模型来预测股票的价格趋势、波动情况和交易信号。这些预测结果可以用于制定投资策略、风险管理和决策制定等方面。

需要注意的是,股票市场受到多种因素的影响,包括经济因素、政治事件和市场心理等。LSTM在股票预测中的应用并不是完全准确的,因此在实际应用中需要结合其他因素进行综合分析和决策。此外,过度依赖LSTM模型所做的预测结果也可能存在风险,投资者仍需谨慎分析和评估。

3. 模块分类

3.1 卷积层

卷积层是深度学习中的基本层之一,通过卷积操作对输入数据进行特征提取和特征映射,并利用参数共享和局部连接等机制提高模型的参数效率。

  • 一维卷积层
  • 二维卷积层
  • 三维卷积层

3.2 池化层

平均池化和最大池化是卷积神经网络中常用的池化操作,用于减少特征图的维度,并提取出重要的特征信息。

  • 平均池化
  • 最大池化

3.3 全连接层

全连接层是神经网络中的一种常见层类型。在全连接层中,每个输入神经元与输出层中的每个神经元都有连接。每个连接都有一个权重,用于调整输入神经元对于输出神经元的影响。全连接层的输出可以通过激活函数进行非线性变换。

3.4 Dropout层

Dropout层是一种正则化技术,用于在训练过程中随机丢弃一部分输入神经元,以减少过拟合的风险。Dropout层通过随机断开神经元之间的连接来实现丢弃操作。在每个训练迭代中,Dropout层会随机选择一些神经元进行丢弃,并在前向传播和反向传播过程中不使用这些丢弃的神经元。

4. 深度学习模型构建

  1. 通过模块堆叠将输入层、中间层、输出层连接,然后构建模块进行初始化
  2. 训练模型
  3. 模型预测

5. 策略实现

本部分将介绍如何在BigQuant实现一个基于LSTM的选股策略

from biglearning.api import M
from biglearning.api import tools as T
from bigdatasource.api import DataSource
from biglearning.module2.common.data import Outputs
from zipline.finance.commission import PerOrder
import pandas as pd
import math


# LSTM模型训练和预测
def m4_run_bigquant_run(input_1, input_2, input_3):
    df =  input_1.read_pickle()
    feature_len = len(input_2.read_pickle())
    
    df['x'] = df['x'].reshape(df['x'].shape[0], int(feature_len), int(df['x'].shape[1]/feature_len))
    
    data_1 = DataSource.write_pickle(df)
    return Outputs(data_1=data_1)


# LSTM模型训练和预测的后处理
def m4_post_run_bigquant_run(outputs):
    return outputs


# LSTM模型训练和预测
def m8_run_bigquant_run(input_1, input_2, input_3):
    df =  input_1.read_pickle()
    feature_len = len(input_2.read_pickle())
    
    df['x'] = df['x'].reshape(df['x'].shape[0], int(feature_len), int(df['x'].shape[1]/feature_len))
    
    data_1 = DataSource.write_pickle(df)
    return Outputs(data_1=data_1)


# LSTM模型训练和预测的后处理
def m8_post_run_bigquant_run(outputs):
    return outputs


# 模型评估和排序
def m24_run_bigquant_run(input_1, input_2, input_3):
    pred_label = input_1.read_pickle()
    df = input_2.read_df()
    df = pd.DataFrame({'pred_label':pred_label[:,0], 'instrument':df.instrument, 'date':df.date})
    df.sort_values(['date','pred_label'],inplace=True, ascending=[True,False])
    return Outputs(data_1=DataSource.write_df(df), data_2=None, data_3=None)


# 模型评估和排序的后处理
def m24_post_run_bigquant_run(outputs):
    return outputs


# 初始化策略
def m19_initialize_bigquant_run(context):
    # 从options中读取数据
    context.ranker_prediction = context.options['data'].read_df()
    # 设置佣金费率
    context.set_commission(PerOrder(buy_cost=0.0003, sell_cost=0.0013, min_cost=5))
    stock_count = 30
    # 根据股票数量设置权重
    context.stock_weights = T.norm([1 / math.log(i + 2) for i in range(0, stock_count)])
    context.max_cash_per_instrument = 0.9
    context.options['hold_days'] = 5


# 处理每个交易日的数据
def m19_handle_data_bigquant_run(context, data):
    # 获取当日的预测结果
    ranker_prediction = context.ranker_prediction[
        context.ranker_prediction.date == data.current_dt.strftime('%Y-%m-%d')]

    is_staging = context.trading_day_index < context.options['hold_days']
    cash_avg = context.portfolio.portfolio_value / context.options['hold_days']
    cash_for_buy = min(context.portfolio.cash, (1 if is_staging else 1.5) * cash_avg)
    cash_for_sell = cash_avg - (context.portfolio.cash - cash_for_buy)
    positions = {e.symbol: p.amount * p.last_sale_price
                 for e, p in context.perf_tracker.position_tracker.positions.items()}

    if not is_staging and cash_for_sell > 0:
        equities = {e.symbol: e for e, p in context.perf_tracker.position_tracker.positions.items()}
        instruments = list(reversed(list(ranker_prediction.instrument[ranker_prediction.instrument.apply(
                lambda x: x in equities and not context.has_unfinished_sell_order(equities[x]))])))

        for instrument in instruments:
            context.order_target(context.symbol(instrument), 0)
            cash_for_sell -= positions[instrument]
            if cash_for_sell <= 0:
                break

    buy_cash_weights = context.stock_weights
    buy_instruments = list(ranker_prediction.instrument[:len(buy_cash_weights)])
    max_cash_per_instrument = context.portfolio.portfolio_value * context.max_cash_per_instrument
    for i, instrument in enumerate(buy_instruments):
        cash = cash_for_buy * buy_cash_weights[i]
        if cash > max_cash_per_instrument - positions.get(instrument, 0):
            cash = max_cash_per_instrument - positions.get(instrument, 0)
        if cash > 0:
            context.order_value(context.symbol(instrument), cash)

# 准备工作
def m19_prepare_bigquant_run(context):
    pass

# 获取2020年至2021年股票数据
m1 = M.instruments.v2(
    start_date='2020-01-01',
    end_date='2021-01-01',
    market='CN_STOCK_A',
    instrument_list=' ',
    max_count=0
)

# 使用高级自动标注器获取标签
m2 = M.advanced_auto_labeler.v2(
    instruments=m1.data,
    label_expr="""
shift(close, -5) / shift(open, -1)-1

clip(label, all_quantile(label, 0.01), all_quantile(label, 0.99))

where(shift(high, -1) == shift(low, -1), NaN, label)
""",
    start_date='',
    end_date='',
    benchmark='000300.SHA',
    drop_na_label=True,
    cast_label_int=False
)

# 标准化标签数据
m13 = M.standardlize.v8(
    input_1=m2.data,
    columns_input='label'
)

# 输入特征
m3 = M.input_features.v1(
    features="""close_0/mean(close_0,5)
close_0/mean(close_0,10)
close_0/mean(close_0,20)
close_0/open_0
open_0/mean(close_0,5)
open_0/mean(close_0,10)
open_0/mean(close_0,20)"""
)

# 抽取基础特征
m15 = M.general_feature_extractor.v7(
    instruments=m1.data,
    features=m3.data,
    start_date='',
    end_date='',
    before_start_days=30
)

# 提取派生特征
m16 = M.derived_feature_extractor.v3(
    input_data=m15.data,
    features=m3.data,
    date_col='date',
    instrument_col='instrument',
    drop_na=True,
    remove_extra_columns=False
)

# 标准化基础特征
m14 = M.standardlize.v8(
    input_1=m16.data,
    input_2=m3.data,
    columns_input='[]'
)

# 合并标签和特征
m7 = M.join.v3(
    data1=m13.data,
    data2=m14.data,
    on='date,instrument',
    how='inner',
    sort=False
)

# 将特征转换成二进制数据
m26 = M.dl_convert_to_bin.v2(
    input_data=m7.data,
    features=m3.data,
    window_size=5,
    feature_clip=5,
    flatten=True,
    window_along_col='instrument'
)

# 使用m4_run_bigquant_run函数运行缓存模式
m4 = M.cached.v3(
    input_1=m26.data,
    input_2=m3.data,
    run=m4_run_bigquant_run,
    post_run=m4_post_run_bigquant_run,
    input_ports='',
    params='{}',
    output_ports=''
)

# 获取2021年至2022年股票数据
m9 = M.instruments.v2(
    start_date=T.live_run_param('trading_date', '2021-01-01'),
    end_date=T.live_run_param('trading_date', '2022-01-01'),
    market='CN_STOCK_A',
    instrument_list='',
    max_count=0
)

# 抽取基础特征
m17 = M.general_feature_extractor.v7(
    instruments=m9.data,
    features=m3.data,
    start_date='',
    end_date='',
    before_start_days=30
)

# 提取派生特征
m18 = M.derived_feature_extractor.v3(
    input_data=m17.data,
    features=m3.data,
    date_col='date',
    instrument_col='instrument',
    drop_na=True,
    remove_extra_columns=False
)

# 标准化基础特征
m25 = M.standardlize.v8(
    input_1=m18.data,
    input_2=m3.data,
    columns_input='[]'
)

# 将特征转换成二进制数据
m27 = M.dl_convert_to_bin.v2(
    input_data=m25.data,
    features=m3.data,
    window_size=5,
    feature_clip=5,
    flatten=True,
    window_along_col='instrument'
)

# 使用m8_run_bigquant_run函数运行缓存模式
m8 = M.cached.v3(
    input_1=m27.data,
    input_2=m3.data,
    run=m8_run_bigquant_run,
    post_run=m8_post_run_bigquant_run,
    input_ports='',
    params='{}',
    output_ports=''
)

# 构造LSTM模型的输入层
m6 = M.dl_layer_input.v1(
    shape='7,5',
    batch_shape='',
    dtype='float32',
    sparse=False,
    name=''
)

# 构造LSTM模型的LSTM层
m10 = M.dl_layer_lstm.v1(
    inputs=m6.data,
    units=32,
    activation='tanh',
    recurrent_activation='hard_sigmoid',
    use_bias=True,
    kernel_initializer='glorot_uniform',
    recurrent_initializer='Orthogonal',
    bias_initializer='Zeros',
    unit_forget_bias=True,
    kernel_regularizer='None',
    kernel_regularizer_l1=0,
    kernel_regularizer_l2=0,
    recurrent_regularizer='None',
    recurrent_regularizer_l1=0,
    recurrent_regularizer_l2=0,
    bias_regularizer='None',
    bias_regularizer_l1=0,
    bias_regularizer_l2=0,
    activity_regularizer='None',
    activity_regularizer_l1=0,
    activity_regularizer_l2=0,
    kernel_constraint='None',
    recurrent_constraint='None',
    bias_constraint='None',
    dropout=0,
    recurrent_dropout=0,
    return_sequences=False,
    implementation='0',
    name=''
)

# 构造LSTM模型的Dropout层
m12 = M.dl_layer_dropout.v1(
    inputs=m10.data,
    rate=0.2,
    noise_shape='',
    name=''
)

# 构造LSTM模型的全连接层1
m20 = M.dl_layer_dense.v1(
    inputs=m12.data,
    units=30,
    activation='tanh',
    use_bias=True,
    kernel_initializer='glorot_uniform',
    bias_initializer='Zeros',
    kernel_regularizer='None',
    kernel_regularizer_l1=0,
    kernel_regularizer_l2=0,
    bias_regularizer='None',
    bias_regularizer_l1=0,
    bias_regularizer_l2=0,
    activity_regularizer='None',
    activity_regularizer_l1=0,
    activity_regularizer_l2=0,
    kernel_constraint='None',
    bias_constraint='None',
    name=''
)

# 构造LSTM模型的Dropout层2
m21 = M.dl_layer_dropout.v1(
    inputs=m20.data,
    rate=0.2,
    noise_shape='',
    name=''
)

# 构造LSTM模型的全连接层2
m22 = M.dl_layer_dense.v1(
    inputs=m21.data,
    units=1,
    activation='tanh',
    use_bias=True,
    kernel_initializer='glorot_uniform',
    bias_initializer='Zeros',
    kernel_regularizer='None',
    kernel_regularizer_l1=0,
    kernel_regularizer_l2=0,
    bias_regularizer='None',
    bias_regularizer_l1=0,
    bias_regularizer_l2=0,
    activity_regularizer='None',
    activity_regularizer_l1=0,
    activity_regularizer_l2=0,
    kernel_constraint='None',
    bias_constraint='None',
    name=''
)

# 初始化LSTM模型
m34 = M.dl_model_init.v1(
    inputs=m6.data,
    outputs=m22.data
)

# 训练LSTM模型
m5 = M.dl_model_train.v1(
    input_model=m34.data,
    training_data=m4.data_1,
    optimizer='RMSprop',
    loss='mean_squared_error',
    metrics='mae',
    batch_size=256,
    epochs=5,
    n_gpus=0,
    verbose='2:每个epoch输出一行记录'
)

# 使用LSTM模型进行预测
m11 = M.dl_model_predict.v1(
    trained_model=m5.data,
    input_data=m8.data_1,
    batch_size=1024,
    n_gpus=0,
    verbose='2:每个epoch输出一行记录'
)

# 使用m24_run_bigquant_run函数运行缓存模式
m24 = M.cached.v3(
    input_1=m11.data,
    input_2=m18.data,
    run=m24_run_bigquant_run,
    post_run=m24_post_run_bigquant_run,
    input_ports='',
    params='{}',
    output_ports=''
)

# 执行交易
m19 = M.trade.v4(
    instruments=m9.data,
    options_data=m24.data_1,
    start_date='',
    end_date='',
    initialize=m19_initialize_bigquant_run,
    handle_data=m19_handle_data_bigquant_run,
    prepare=m19_prepare_bigquant_run,
    volume_limit=0.025,
    order_price_field_buy='open',
    order_price_field_sell='close',
    capital_base=1000000,
    auto_cancel_non_tradable_orders=True,
    data_frequency='daily',
    price_type='后复权',
    product_type='股票',
    plot_charts=True,
    backtest_only=False,
    benchmark='000300.SHA'
)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/877245.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DevOps系列文章 之 SpringBoot整合GitLab-CI实现持续集成

在企业开发过程中&#xff0c;我们开发的功能或者是修复的BUG都需要部署到服务器上去&#xff0c;而这部分部署操作又是重复且繁琐的工作&#xff0c;GitLab-CI 持续集成为我们解决了这一痛点&#xff0c;将重复部署的工作自动化&#xff0c;大大的节省了程序员们的宝贵时间。本…

群晖 NAS 十分精准的安装 Mysql 远程访问连接

文章目录 1. 安装Mysql2. 安装phpMyAdmin3. 修改User 表4. 本地测试连接5. 安装cpolar6. 配置公网访问地址7. 固定连接公网地址 转载自cpolar极点云文章&#xff1a;群晖NAS 安装 MySQL远程访问连接 群晖安装MySQL具有高效、安全、可靠、灵活等优势&#xff0c;可以为用户提供一…

C++初阶之一篇文章教会你queue和priority_queue(理解使用和模拟实现)

queue和priority_queue&#xff08;理解使用和模拟实现&#xff09; 什么是queuequeue的使用1.queue构造函数2.empty()3.size()4.front()5.back();6.push7.emplace8.pop()9.swap queue模拟实现什么是priority_queuepriority_queue的使用1.priority_queue构造函数1.1 模板参数 C…

如何切换goland之中的版本号(升级go 到1.20)

go 安装/版本切换_go 切换版本_云满笔记的博客-CSDN博客 用brew就行&#xff1a; echo export PATH"/opt/homebrew/opt/go1.20/bin:$PATH" >> ~/.zshrc

无涯教程-Perl - scalar函数

描述 此函数强制EXPR的判断在标量context中进行,即使它通常在列表context中也可以使用。 语法 以下是此函数的简单语法- scalar EXPR返回值 此函数返回标量。 例 以下是显示其基本用法的示例代码- #!/usr/bin/perl -wa (1,2,3,4); b (10,20,30,40);c ( a, b ); prin…

如何与 Anheuser-Busch 建立 EDI 连接?

Anheuser-Busch 于 1852 年创立&#xff0c;总部位于美国密苏里州圣路易斯市&#xff0c;出产的百威啤酒(Budweiser)名扬世界&#xff0c;深受各国消费者喜爱。推进数字化转型&#xff0c;科技赋能降费增效。在诸多举措之中&#xff0c;可以看到 EDI 的身影&#xff0c;借助 ED…

Java智慧工地APP源码带AI识别

智慧工地为建筑全生命周期赋能&#xff0c;用创新的可视化与智能化方法&#xff0c;降低成本&#xff0c;创造价值。 一、智慧工地APP概述 智慧工地”立足于互联网&#xff0c;采用云计算&#xff0c;大数据和物联网等技术手段&#xff0c;针对当前建筑行业的特点&#xff0c;…

Apipost接口自动化控制器使用详解

测试人员在编写测试用例以及实际测试过程中&#xff0c;经常会遇到两个棘手的问题&#xff1a; •稍微复杂一些的自动化测试逻辑&#xff0c;往往需要手动写代码才能实现&#xff0c;难以实现和维护 •测试用例编写完成后&#xff0c;需要手动执行&#xff0c;难以接入自动化体…

对比学习论文综述总结

第一阶段:百花齐放(18-19中) 有InstDisc(Instance Discrimination)、CPC、CMC代表工作。在这个阶段方法模型都还没有统一,目标函数也没有统一,代理任务也没有统一,所以说是一个百花齐放的时代 1 判别式代理任务---个体判别任务 1.1 Inst Dict---一个编码器+一个memory…

springboot结合element-ui实现增删改查,附前端完整代码

实现功能 前端完整代码 后端接口 登录&#xff0c;注册&#xff0c;查询所有用户&#xff0c;根据用户名模糊查询&#xff0c;添加用户&#xff0c;更新用户&#xff0c;删除用户 前端 注册&#xff0c;登录&#xff0c;退出&#xff0c;用户增删改查&#xff0c;导航栏&#…

关于单行文本input和多行文本textarea唤起自动完成功能

若不要自动完成功能&#xff0c;则增加 autocomplete"off" 属性到控件或窗体中&#xff0c;默认 autocomplete"on" 处于开启状态。 实测过程中&#xff0c;单行文本可以有自动完成功能&#xff0c;多行文本无论如何实验都不行。查了查资料&#xff0c;MDN…

算法提高-线段树

线段树 线段树和树状数组线段树的五个操作单点修改&#xff08;不需要懒标记&#xff09;要求的答案就是我们要维护的属性&#xff0c;不需要维护其他的属性帮助我们获得答案要求的答案还需要其他属性去维护 区间修改&#xff08;需要懒标记&#xff0c;pushdown&#xff09;有…

小爱同学今日起开启邀请测试, 小米Al 大模型团队整装待发

在小米雷军年度演讲中&#xff0c;小米宣布未来将在技术研发上投入超过200亿元人民币&#xff0c;并预计在2023年达到这一目标。除此之外&#xff0c;雷军还强调了5G技术在小米发展中的重要性&#xff0c;称其为必要标准&#xff0c;并预测小米将进入世界前十的位置。 雷军还透…

配置 yum/dnf 置您的系统以使用默认存储库

题目 给系统配置默认存储库&#xff0c;要求如下&#xff1a; YUM 的 两 个 存 储 库 的 地 址 分 别 是 &#xff1a; ftp://host.domain8.rhce.cc/dvd/BaseOS ftp://host.domain8.rhce.cc/dvd/AppStream vim /etc/yum.repos.d/redhat.repo [base] namebase baseurlftp:/…

釉面陶瓷器皿SOR/2016-175标准上架亚马逊加拿大站

亲爱的釉面陶瓷器皿和玻璃器皿制造商和卖家&#xff0c;亚马逊加拿大站将执行SOR/2016-175法规。这是一份新的法规&#xff0c;规定了含有铅和镉的釉面陶瓷器和玻璃器皿需要满足的要求。让我们一起来看一看&#xff0c;为什么要实行SOR/2016-175法规&#xff1f;这是一个保护消…

Unity游戏源码分享-中国象棋Unity5.6版本

Unity游戏源码分享-中国象棋Unity5.6版本 项目地址&#xff1a; https://download.csdn.net/download/Highning0007/88215699

linux系统服务学习(一)Linux高级命令扩展

文章目录 Linux高级命令&#xff08;扩展&#xff09;一、find命令1、find命令作用2、基本语法3、*星号通配符4、根据文件修改时间搜索文件☆ 聊一下Windows中的文件时间概念&#xff1f;☆ 使用stat命令获取文件的最后修改时间☆ 创建文件时设置修改时间以及修改文件的修改时间…

母牛的故事

一、题目 有一头母牛&#xff0c;它每年年初生一头小母牛。每头小母牛从第四个年头开始&#xff0c;每年年初也生一头小母牛。请编程实现在第n年的时候&#xff0c;共有多少头母牛&#xff1f; Input 输入数据由多个测试实例组成&#xff0c;每个测试实例占一行&#xff0c;包…

11 个 Python 编码习惯

让你成为糟糕程序员的 11 个 Python 编码习惯 简介 Python 因其简洁性和可读性而备受推崇&#xff0c;但即使是最有经验的程序员也可能会陷入影响代码质量的习惯中。 在本博客中&#xff0c;我们将探讨 10 种常见的编码习惯&#xff0c;它们会降低您作为 Python 程序员的效率。…

深入探析设计模式:工厂模式的三种姿态

深入探析设计模式&#xff1a;工厂模式的三种姿态 1. 简单工厂模式1.1 概念1.2 案例1.3 优缺点 2. 抽象工厂模式2.1 概念2.2 案例&#xff1a;跨品牌手机生产2.3 优缺点 3. 超级工厂模式3.1 概念3.2 案例&#xff1a;动物园游览3.3 优缺点 4. 总结 欢迎阅读本文&#xff0c;今天…