【python量化】挖掘股价中的图关系:基于图注意力网络的股价预测模型

news2025/1/16 0:15:58

f135eb2240577efe3c9f84e77ae42d2c.png

写在前面

近些年,图神经网络在时间序列预测领域发挥了重要的作用。其中,图注意力网络(GAT)是一种基于注意力机制的图神经网络,能够捕捉图结构数据中节点之间的复杂关系,从而在许多领域中取得了突出的性能。在本文中,我们利用Pytorch以及PyG(PyTorch Geometric 是一个强大而灵活的图神经网络库)框架,实现一个将 GAT 应用于股票价格预测的简单例子

1

前言

随着金融市场的复杂性不断增加,对股价的预测成为了一项巨大的挑战。传统的时间序列分析方法虽然在某些场景下有效,但往往无法捕捉市场中的复杂相互作用和隐藏模式。为了解决这一问题,本文采用了一种全新的视角:将股价时间序列转化为图结构,通过图注意力网络(GAT)以图的视角来建模和分析。图注意力网络是一种强大的图神经网络结构,通过引入注意力机制,能够灵活捕捉图中节点间的相互关系。在股价预测的场景中,股票之间的相互作用和依赖关系可以被自然地建模为图结构,其中股票作为节点,它们之间的相互作用作为边。借助 PyTorch Geometric(PyG)这一先进的图神经网络库,本文展示了如何构建、训练和评估 GAT 模型来预测股票价格。

项目的核心是将股价时间序列数据转化为图结构,然后利用 GAT 的强大表征能力进行分析和预测。本项目旨在作为一个简单演示,展示如何使用图注意力网络来分析股价时间序列,仅作为学习和研究之用,不应用于实际的投资决策

2

环境配置

本地环境:

Python 3.7
IDE:Pycharm

库版本:

numpy 1.18.1
pandas 1.0.3 
torch-geometric 2.0.2
matplotlib 3.2.1
torch 1.10.1
tushare 1.2.60

3

代码实现

总体设计

该项目通过四个主要文件组织,展示了如何使用图注意力网络(GAT)进行股票价格预测。从获取和处理股票数据到构建和训练深度学习模型,再到最终的预测阶段,整个流程被清晰地实现。每个文件都专注于特定的任务,共同构建了一个完整的股价预测解决方案:

1. data_hander.py 数据处理模块

此模块负责从 Tushare 获取股票的收盘价并保存到文件中。主要功能包括:获取指定股票代码和日期范围的收盘价数据。保存和加载收盘价数据。绘制股价曲线图。构建股票间的邻接矩阵,用于图模型。

2. model.py - 模型模块

在这个模块中,定义了基于图注意力网络(GAT)的深度学习模型。主要组成部分包括构建GATPredictor 类,用于构建 GAT 模型,设置不同的隐藏层大小和注意力头数量。定义了模型的前向传播过程。

3. train.py - 训练模块

此文件包括了数据预处理和模型训练的相关函数:数据归一化和分割为训练和测试集。使用滑动窗口方法处理时间序列数据。定义训练和预测的主要函数。

4. main.py - 主程序模块

作为项目的入口点,此文件组织和调用上述三个文件中的功能:定义要预测的股票代码和日期范围。调用数据处理、模型构建、训练和预测的相关函数。定义了整个项目的主要流程和执行逻辑。

数据处理模块

数据处理模块负责从外部源获取股票价格数据,并进行适当的预处理以供模型使用。其中,需要将tushare API token替换到代码中的Your Token处。

def fetch_close_prices(stocks, start_date, end_date, file_name='close_prices.csv'):
    if os.path.exists(file_name):
        close_prices = pd.read_csv(file_name).values
    else:
        ts.set_token('Your Token')
        pro = ts.pro_api()
        close_prices_list = []


        for stock in stocks:
            df = pro.daily(ts_code=stock, start_date=start_date, end_date=end_date)
            close_prices_list.append(df['close'].values)


        close_prices = np.column_stack(close_prices_list)
        pd.DataFrame(close_prices).to_csv(file_name, index=False)


    return close_prices

为了使用图注意力网络(GAT),需要将股票数据转化为图结构。这里将每只股票的收盘价的Pearson系数作为构建邻接矩阵来表示股票之间的相关性的依据,并设定阈值来确定是否具有连边。

def build_adjacency_matrix(close_prices, threshold=0.5):
    N = close_prices.shape[1]
    adj_matrix = np.zeros((N, N))


    for i in range(N):
        for j in range(N):
            correlation = np.corrcoef(close_prices[:, i], close_prices[:, j])[0, 1]
            adj_matrix[i, j] = 1 if abs(correlation) > threshold else 0
    return adj_matrix

模型模块

这个模块负责定义和实现一个简单的基于图注意力网络的股价预测模型。通过构建包括图注意力层、隐藏层和输出层的结构,来实现模型的前向传递。其中,由于PyG的GAT层只能接受二维的数据,所以为了提升并行运算速度,这里将一个batch的数据合成一张大图进行运算,最后通过view转换后再通过linear层进行输出预测值。

class GATPredictor(nn.Module):
    def __init__(self, node_features, node_nums, hidden_size=32, num_heads=1):
        super(GATPredictor, self).__init__()
        self.node_nums = node_nums
        self.gat1 = GATConv(node_features, hidden_size, heads=num_heads)
        self.gat2 = GATConv(hidden_size * num_heads, hidden_size, concat=False)
        self.out = nn.Linear(hidden_size, 1)


    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        batch_size = torch.max(batch).item() + 1
        x = torch.relu(self.gat1(x, edge_index))
        x = torch.relu(self.gat2(x, edge_index))
        x = x.view(batch_size, self.node_nums, -1)
        x = self.out(x)
        return x.squeeze(2)

训练模块

训练模块主要负责处理数据到模型的指定输入形式以及模型训练的逻辑。其中实现了一些对数据进行归一化处理以及划分滑动窗口的函数,以及模型训练和测试的函数。

def normalize_and_split(closing_prices, test_size=0.2):
    scaler = MinMaxScaler()
    normalized_data = scaler.fit_transform(closing_prices)
    train_data, test_data = train_test_split(normalized_data, test_size=test_size, shuffle=False)
    return train_data, test_data, scaler


def sliding_window(data, window_size):
    windows = []
    for i in range(len(data) - window_size):
        x = data[i:i + window_size]
        y = data[i + window_size]
        windows.append((x, y))
    return window

构建节点之间的连边时,每个股票作为一个节点,其窗口的数据作为节点特征,并且对每个滑动窗口构建相同的图关系,最后将其转换为PyG的指定输入数据形式。

def create_graph_data(windows, adj_matrix):
    graph_data = []
    edge_index = torch.tensor(np.where(adj_matrix != 0), dtype=torch.long)


    for window in windows:
        x, y = window
        x_tensor = torch.tensor(x.T, dtype=torch.float)
        y_tensor = torch.tensor(y, dtype=torch.float)  


        data = Data(x=x_tensor, y=y_tensor, edge_index=edge_index)
        graph_data.append(data)


    return graph_data


def train(model, train_loader, optimizer, criterion, epochs):
    model.train()
    for epoch in range(epochs):
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch)
            y = batch.y.view(out.size())
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
        print(f'Epoch: {epoch+1}, Loss: {loss.item()}'

主程序模块

这个模块根据给定的股票代码和日期范围获取历史收盘价数据,这里简单取了上证2022年的几只股票。然后通过滑动窗口方法和邻接矩阵将其转化为图数据格式。接下来,初始化并训练一个基于图注意力网络的预测模型,最后使用该模型进行预测并将预测结果与真实值进行可视化对比。

# 股票代码和日期范围
stocks = ['000001.SZ', '000002.SZ', '000006.SZ', '000005.SZ', '000008.SZ', '000009.SZ']
start_date = "2022-01-01"
end_date = "2023-01-01"
window_size = 10


# 获取收盘价并创建邻接矩阵
close_prices = fetch_close_prices(stocks, start_date, end_date)
adj_matrix = build_adjacency_matrix(close_prices)


# 数据归一化和划分
train_data, test_data, scaler = normalize_and_split(close_prices)


# 滑动窗口数据准备
train_windows = sliding_window(train_data, window_size)
test_windows = sliding_window(test_data, window_size)


# 转换为图数据
train_dataset = create_graph_data(train_windows, adj_matrix)
test_dataset = create_graph_data(test_windows, adj_matrix)


print(adj_matrix)


train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


# 创建模型
model = GATPredictor(node_features=window_size, node_nums=len(stocks))


# 优化器和损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()


# 训练
train(model, train_loader, optimizer, criterion, epochs=50)


# 预测
predictions, groundtruth = predict(model, test_loader, scaler)


# 可视化预测结果
plot_results(predictions, groundtruth, stocks)

运行效果

接下来对模型进行训练跟测试,经过多个epoch的迭代,mse逐渐收敛。

Epoch: 1, Loss: 0.3193877637386322
Epoch: 2, Loss: 0.21741190552711487
Epoch: 3, Loss: 0.16308942437171936
Epoch: 4, Loss: 0.09306434541940689
Epoch: 5, Loss: 0.052657563239336014
Epoch: 6, Loss: 0.02670634351670742
Epoch: 7, Loss: 0.019804660230875015
Epoch: 8, Loss: 0.016364283859729767
Epoch: 9, Loss: 0.017054198309779167
Epoch: 10, Loss: 0.01848340593278408
...
Epoch: 44, Loss: 0.009028978645801544
Epoch: 45, Loss: 0.013259928673505783
Epoch: 46, Loss: 0.01599966734647751
Epoch: 47, Loss: 0.011484017595648766
Epoch: 48, Loss: 0.010215471498668194
Epoch: 49, Loss: 0.01579277403652668
Epoch: 50, Loss: 0.011635522358119488

之后对多个股票的预测结果进行可视化,可以看出有些股票的趋势拟合效果较好。由于这六只股票只是简单的选取,所以其关联性可能不会很强,所以也可以选择多个相同板块的股票进行实验。除此之外,模型的结构也很简单,并没有引入更多的股价特征,所以可以进一步改进的点还有很多。

f528cb1297400fff18ebf067d1a7e70a.png

4

总结


在现代金融领域,随着大数据和机器学习技术的日益成熟,传统的股价预测方法正面临着前所未有的挑战和机遇。传统的股价预测往往依赖于单一股票的历史数据,忽略了股票间的互动和影响。这篇文章探讨了如何利用图注意力网络 (GAT) 挖掘股票之间的潜在关系,为股价预测提供了新的视角。其中将邻接矩阵与滑动窗口方法相结合,将时间序列的股价数据转化为图数据格式,从而捕捉股票间的复杂相互作用。本文中的实验只是一个简单的demo,还存在许多可以改进的空间,感兴趣的读者可以进一步研究。

本文内容仅仅是技术探讨和学习,并不构成任何投资建议。

获取完整代码与数据以及其他历史文章完整源码与数据可加入《人工智能量化实验室》知识星球。

往期推荐阅读

WWW 2023 | 量化交易相关论文(附论文链接)

KDD 2023 | 量化交易相关论文(附论文链接)

AAAI 2022 | 量化交易相关论文(附论文链接)

IJCAI 2022 | 量化交易相关论文(附论文链接)

WWW 2022 | 量化交易相关论文(附论文链接)

KDD 2022 | 量化交易相关论文(附论文链接)

解读:ChatGPT在股票市场预测方面的应用

解读:通过挖掘概念间共享信息,实现股票趋势预测的图模型框架

解读:机器学习预测收益模型应该采取哪种度量指标

解读:基于订单流、技术分析与神经网络的期货短期走势预测模型

【python量化】基于backtrader的深度学习模型量化回测框架

【python量化】将Transformer模型用于股票价格预测

【python量化】搭建一个CNN-LSTM模型用于股票价格预测

【python量化】用python搭建一个股票舆情分析系统

【python量化】将Informer用于股价预测

【python量化】将DeepAR用于股票价格多步概率预测

9992a579714fc702da17a79279e8aa84.png

《人工智能量化实验室》知识星球

fd428cefb5057f46318f072a6b0b706b.png

加入人工智能量化实验室知识星球,您可以获得:(1)定期推送最新人工智能量化应用相关的研究成果,包括高水平期刊论文以及券商优质金融工程研究报告,便于您随时随地了解最新前沿知识;(2)公众号历史文章Python项目完整源码;(3)优质Python、机器学习、量化交易相关电子书PDF;(4)优质量化交易资料、项目代码分享;(5)跟星友一起交流,结交志同道合朋友。(6)向博主发起提问,答疑解惑。

8cf1db2f2f113a3b63d329aae0a835f5.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1016491.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

STM32窗口看门狗 WWDG

目录 1.什么是窗口看门狗? 2.窗口看门狗工作原理 3.WWDG框图 4.控制寄存器(WWDG_CR) 5.配置寄存器(WWDG_CFR) 6.状态寄存器(WWDG_SR) 7.超时时间计算 8.窗口看门狗实验 9.独立看门狗和…

Python:web框架之Tornado的Hello World示例

一、安装Tornado pip install tornado 安装完成后会看到显示tornado的版本号。 二、编写Hello World程序 import tornado.ioloop #导入tornado包 import tornado.web class MainHandle(tornado.web.RequestHandler):def get(self): #定义请求函数self.write("He…

从零基础学习PyQt5软件app开发

常见的GUI框架的梳理 GUI,全称为图形⽤户界⾯,⼜称为图形⽤户接⼝,是⼀种⼈与计算机通信的界⾯显示格式。 ⽤户打开应⽤程序或⽹站时看到的第⼀个东⻄,并与之交互。 ⽤户界⾯通常包括许多视觉元素,如图标、按钮、图形…

python学习随笔3

range的使用 range()在python很常用&#xff0c;可以进行初始化和遍历等。 # range(st,ed) # [st, ed)# range(st,ed,step) # range(st, ed, step) i,i step, i 2 * step ... () < ed切片 跟range类似。 ll[st:ed:step]容器 元组 python中的元组中内容不可以进行更…

浅谈C++|运算符重载

重载原因 C 中的运算符重载是一种特性&#xff0c;允许程序员定义自定义类类型的运算符操作。通过运算符重载&#xff0c;可以对类对象执行类似于内置类型的操作&#xff0c;例如加法、减法、乘法等。 运算符重载通过定义特定的成员函数或非成员函数来实现。成员函数的运算符重…

html怎么设置按钮返回顶部

在 HTML 中&#xff0c;我们可以通过一些代码和 CSS 样式来创建一个这样的按钮。 <button onclick"topFunction()" id"myBtn">返回顶部</button> <style> #myBtn { display: none; position: fixed; bottom: 20px; right: 30px; z-inde…

高性能 Python 编译器 -- Codon

众所周知&#xff0c;Python 是一门简单易学、具有强大功能的编程语言&#xff0c;在各种用户使用统计榜单中总是名列前茅。相应地&#xff0c;围绕 Python&#xff0c;研究者开发了各种便捷工具&#xff0c;以更好的服务于这门语言。 编译器充当着高级语言与机器之间的翻译官&…

一封来自江苏省电力设计院的表扬信

近日&#xff0c;中新赛克海睿思收到了一封来自江苏省电力设计院公司&#xff08;以下简称“江苏院”&#xff09;的表扬信。 海睿思与江苏院自达成合作以来&#xff0c;双方团队经过共同努力&#xff0c;克服了项目交付过程中的诸多困难。不仅通过数据工程的整体咨询帮助江苏院…

pt26django教程

admin 后台数据库管理 django 提供了比较完善的后台管理数据库的接口&#xff0c;可供开发过程中调用和测试使用 django 会搜集所有已注册的模型类&#xff0c;为这些模型类提拱数据管理界面&#xff0c;供开发者使用 创建后台管理帐号: [rootvm mysite2]# python3 manage.…

什么是函数重载?作用是什么?如何使用?

函数重载是指在同一个作用域内&#xff0c;允许存在多个同名函数&#xff0c;但这些函数的参数列表必须不同。根据传入的参数类型、数量或顺序的不同&#xff0c;编译器可以区分调用哪个函数。 函数重载的作用主要有以下几点&#xff1a; 提高代码的可读性和可维护性&#xff…

openlayers-17-卷帘对比

实现卷帘对比功能&#xff0c;没有进一步测试版本兼容问题&#xff0c;不错从ol的官网来看&#xff0c;ol6之前的版本的示例与ol6及其之后的版本示例并不相同 ol5 示例https://openlayers.org/en/v5.3.0/examples/layer-swipe.html?qlayerswipeol6示例 https://openlayers.org…

GIS跟踪监管系统

GIS跟踪监管系统 系统架构功能模块1. 基本功能2. 仓库管理3. 物资查询 系统采用B/S架构&#xff0c;前端使用的技术为HTMLCSSJavaScript&#xff08;Leaflet、jQuery、bootstrap等&#xff09;&#xff0c;后台采用.NET框架。 系统架构 救援物资跟踪监管系统的架构如图所示&am…

Matplotlib入门

基本使用 基本用法 import matplotlib.pyplot as plt import numpy as npxnp.linspace(-1,1,50) y2*x1plt.figure()#定义一个图像窗口 plt.plot(x,y)#画&#xff08;x&#xff0c;y&#xff09;曲线 plt.show()#显示图像figure图像 import matplotlib.pyplot as plt import …

nat的基础配置(动态nat,nat server)

目录 1.静态nat 2.动态nat &#xff08;1&#xff09;配置公网地址池 &#xff08;2&#xff09;配置acl&#xff0c;匹配做nat转换的源 &#xff08;3&#xff09;将源转换为公网地址&#xff0c;其中no-pat表示不做端口转化&#xff0c;只做一对一的地址转换 3.nat ser…

《向量数据库指南》——向量数据库Milvus Cloud为什么选择开源?

开源对我们来说是一种信仰。从最早开始研发向量数据库的时候&#xff0c;我们就相信应该让更多人了解并使用优秀的技术&#xff0c;这是我们选择做开源的原因。 无论是在 AI 领域还是其他领域&#xff0c;我们希望技术不会被少数大公司垄断。在向量数据库问世之前&#xff0c;阿…

python:优化一EXCEL统计用类封装一下

# encoding: utf-8 # 版权所有 2023 涂聚文有限公司 # 许可信息查看&#xff1a; # 描述&#xff1a; # Author : geovindu,Geovin Du 涂聚文. # IDE : PyCharm 2023.1 python 311 # Datetime : 2023/9/17 5:40 # User : geovindu # Product : PyCharm # Proj…

JSON和全局异常处理

目录 1️⃣JSON 一、什么是json&#xff1f; 二、与javascript的关系 三、语法格式 四、注意事项 五、总结 六&#xff0c;使用json 1导入pom.xml依赖 2.配置spring-mvc.xml 3. ResponseBody注解使用 创建一个web层控制器 编写ClazzBiz 实现接口 测试&#xff1a; …

C#,数值计算——Hashfn2的计算方法与源程序

1 文本格式 using System; using System.Collections; using System.Collections.Generic; namespace Legalsoft.Truffer { public class Hashfn2 { private static ulong[] hashfn_tab { get; set; } new ulong[256]; private ulong h { get; set;…

【2023年11月第四版教材】第13章《资源管理》(第三部分)

第13章《资源管理》&#xff08;第部分&#xff09; 4 管理过程4.1 数据表现★★★4.2 资源管理计划★★★4.2 团队章程★★★ 5 估算活动资源 4 管理过程 组过程输入工具和技术输出规划1.规划资源管理1.项目章程2.项目管理计划&#xff08;质量管理计划、范围基准&#xff09…

elasticsearch5-RestAPI操作

个人名片&#xff1a; 博主&#xff1a;酒徒ᝰ. 个人简介&#xff1a;沉醉在酒中&#xff0c;借着一股酒劲&#xff0c;去拼搏一个未来。 本篇励志&#xff1a;三人行&#xff0c;必有我师焉。 本项目基于B站黑马程序员Java《SpringCloud微服务技术栈》&#xff0c;SpringCloud…