TensorFlow项目练手(三)——基于GRU股票走势预测任务

news2024/11/17 8:40:36

项目介绍

项目基于GRU算法通过20天的股票序列来预测第21天的数据,有些项目也可以用LSTM算法,两者主要差别如下:

  • LSTM算法:目前使用最多的时间序列算法,是一种特殊的RNN(循环神经网络),能够学习长期的依赖关系。主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。
  • GRU算法:是一种特殊的RNN。和LSTM一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。

一、准备数据

1、获取数据

  1. 通过命令行安装yfinance
  2. 通过api获取股票数据
  3. 保存到csv中方便使用
import pandas_datareader.data as web
import datetime
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
plt.rcParams['font.sans-serif']='SimHei' #图表显示中文

import yfinance as yf
yf.pdr_override() #需要调用这个函数

# 1、获取股票数据
#上海的股票代码+.SS;深圳的股票代码+.SZ :
stock = web.get_data_yahoo("601318.SS", start="2022-01-01", end="2023-07-17")
# 保存到csv中
pd.DataFrame(data=stock).to_csv('./stock.csv')

# 2、获取csv中的数据
features = pd.read_csv('stock.csv')
features = features.drop('Adj Close',axis=1)
features.head()

在这里插入图片描述

2、数据可视化

通过绘图的方式查看当前的数据情况

# 3、绘图看看收盘价数据情况
close=features["Close"]
# 计算20天和100天移动平均线:
short_rolling_close = close.rolling(window=20).mean()
long_rolling_close = close.rolling(window=100).mean()
# 绘制
fig, ax = plt.subplots(figsize=(16,9))   #画面大小,可以修改
ax.plot(close.index, close, label='中国平安')   #以收盘价为索引值绘图
ax.plot(short_rolling_close.index, short_rolling_close, label='20天均线')
ax.plot(long_rolling_close.index, long_rolling_close, label='100天均线')
#x轴、y轴及图例:
ax.set_xlabel('日期')
ax.set_ylabel('收盘价 (人民币)')
ax.legend()      #图例
plt.show()      #绘图

在这里插入图片描述

3、数据预处理

取出当前的收盘价,删除无用的日期元素

# 4、取出label值
labels = features['Close']
time = features['Date']
features = features.drop('Date',axis=1)
features.head()

在这里插入图片描述

进行数据的归一化

# 5、数据预处理
from sklearn import preprocessing
input_features = preprocessing.StandardScaler().fit_transform(features)
input_features

在这里插入图片描述

4、构建数据序列

由于RNN的算法要求我们要有一定的序列,来预测出下一个值,所以我们按照20天的数据作为一个序列

# 6、定义序列,[下标1-20天预测第21天的收盘价]
from collections import deque

x = []
y = []

seq_len = 20
deq = deque(maxlen=seq_len)
for i in input_features:
    deq.append(list(i))
    if len(deq) == seq_len:
        x.append(list(deq))

x = x[:-1] # 取少一个序列,因为最后个序列没有答案
y = features['Close'].values[seq_len: ] #从第二十一天开始(下标为20)
time = time.values[seq_len: ] #从第二十一天开始(下标为20)

x, y, time = np.array(x), np.array(y), np.array(time)
print(x.shape)
print(y.shape)
print(time.shape)

在这里插入图片描述

二、构建模型

1、搭建GRU模型

import tensorflow as tf
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras import layers

from keras.models import load_model
from keras.models import Sequential
from keras.layers import Dropout
from keras.layers.core import Dense
from keras.optimizers import Adam

# 7、搭建模型
model = tf.keras.Sequential()
model.add(layers.GRU(8,input_shape=(20,5), activation='relu', return_sequences=True,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.GRU(16, activation='relu', return_sequences=True,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.GRU(32, activation='relu', return_sequences=False,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.Dense(16,kernel_initializer='random_normal',kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.Dense(1))
model.summary()

在这里插入图片描述

2、优化器和损失函数

# 优化器和损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
              loss=tf.keras.losses.MeanAbsoluteError(), # 标签和预测之间绝对差异的平均
              metrics = tf.keras.losses.MeanSquaredLogarithmicError()) # 计算标签和预测

3、开始训练

25%的比例作为验证集,75%的比例作为训练集

# 开始训练
model.fit(x,y,validation_split=0.25,epochs=200,batch_size=128)

在这里插入图片描述

4、模型预测

# 预测
y_pred = model.predict(x)
fig = plt.figure(figsize=(10,5))
axes = fig.add_subplot(111)
axes.plot(time,y,'b-',label='actual')
# 预测值,红色散点
axes.plot(time,y_pred,'r--',label='predict')
axes.set_xticks(time[::50])
axes.set_xticklabels(time[::50],rotation=45)
 
plt.legend()
plt.show()

在这里插入图片描述

5、回归指标评估

from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score
from math import sqrt

#回归评价指标
# calculate MSE 均方误差
mse=mean_squared_error(y,y_pred)
# calculate RMSE 均方根误差
rmse = sqrt(mean_squared_error(y, y_pred))
#calculate MAE 平均绝对误差
mae=mean_absolute_error(y,y_pred)
print('均方误差: %.6f' % mse)
print('均方根误差: %.6f' % rmse)
print('平均绝对误差: %.6f' % mae)

在这里插入图片描述

源代码

  • 源码查看

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/816879.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JDK 8.x 微服务启动JVM参数调优实战

微服务启动JVM参数调优实战 1.1 配置JVM启动参数1.2 解释1.3 JVM参数优化思路1.3.1 调整堆内存大小1.3.2 年轻代大小1.3.3 Metaspace 大小1.3.4 栈大小1.3.5 垃圾回收器选择1.3.6 垃圾回收参数1.3.7 预分配内存 1.3.8 禁用 ResizePLAB2. 常用JVM参数 1.1 配置JVM启动参数 服务…

每日一题——重建二叉树

重建二叉树 题目描述 给定节点数为 n 的二叉树的前序遍历和中序遍历结果,请重建出该二叉树并返回它的头结点。 例如输入前序遍历序列{1,2,4,7,3,5,6,8}和中序遍历序列{4,7,2,1,5,3,8,6},则重建出如下图所示。 提示: 1.vin.length pre.length 2.pre 和…

颠倒二进制位,颠倒给定的 32 位无符号整数的二进制位。

题记: 颠倒给定的 32 位无符号整数的二进制位。 提示: 请注意,在某些语言(如 Java)中,没有无符号整数类型。在这种情况下,输入和输出都将被指定为有符号整数类型,并且不应影响您的…

ChatPaper全流程加速科研:论文阅读+润色+优缺点分析与改进建议+审稿回复

项目设计集合(人工智能方向):助力新人快速实战掌握技能、自主完成项目设计升级,提升自身的硬实力(不仅限NLP、知识图谱、计算机视觉等领域):汇总有意义的项目设计集合,助力新人快速实…

惊喜!1行Python代码,瞬间测你工作量,分享一个统计代码行数的神器

大家好,这里是程序员晚枫。 **你想不想知道一个项目中,自己写了多少行代码?**我用今天的工具统计了一下开源项目:python-office的代码行数,竟然有21w行! 我们一起看一下怎么用最简单的方法,统…

《吐血整理》进阶系列教程-拿捏Fiddler抓包教程(16)-Fiddler如何充当第三者再识AutoResponder标签-上

1.简介 Fiddler充当第三者,主要是通过AutoResponder标签在客户端和服务端之间,Fiddler抓包,然后改包,最后发送。AutoResponder这个功能可以算的上是Fiddler最实用的功能,可以让我们修改服务器端返回的数据&#xff0c…

Windows10系统还原操作

哈喽,大家好,我是雷工! 复制了下虚拟机的Win10系统,但其中有一些软件,想实现类似手机的格式化出厂操作,下面记录Windows10系统的还原操作。 一、系统环境: 虚拟机内的Windows10,64…

JavaWeb第三章:JavaScript的全面知识

目录 前言 一.JavaScript的简介 💖概念 💖学习内容 二.JavaScript的引入方式 💖内部脚本 💖外部脚本 三.JavaScript的基础语法 💖语法的书写 💖变量 ✨ 全局变量 ✨局部变量 ✨常量 &a…

vue表单筛选

目录 筛选 HTML scss* filterComp 排序 表格 自定义数据样式 inner-table 分页 删除 default-modal 自定义元素的插槽-占位符 .search-wrap {height: 60px;display: flex;align-items: center;overflow: hidden;padding: 0 20px;.selected-options-wrap {flex: 1;.…

PostgreSQL数据库中,查询时提示表不存在的解决办法

最近遇到一个奇怪的问题,以前从来没有遇到过,在postgres SCHEMA下执行select * from table1语句时,提示表不存在,而实际这个表确是存在的,只不过是在public SCHEMA下。在public SCHEMA下执行这个sql语句是没有问题的。…

主成分分析PCA算法

Principal Components Analysis 这个协方差矩阵是一个nXn的,且是对称矩阵,就会有n个特征值λ和特征向量v,每个特征向量也是n维的。第一行特征向量v对应特征值λ1 。 D(yk):表示主成分yk的方差。方差越大,说明携带的信…

如何在不使用脚本和插件的情况下手动删除 3Ds Max 中的病毒?

如何加快3D项目的渲染速度? 3D项目渲染慢、渲染卡顿、渲染崩溃,本地硬件配置不够,想要加速渲染,在不增加额外的硬件成本投入的情况下,最好的解决方式是使用渲云云渲染,在云端批量渲染,批量出结…

【迁移】Mysql数据库备份 迁移

【迁移】Mysql数据库备份 迁移 📔 千寻简笔记介绍 千寻简笔记已开源,Gitee与GitHub搜索chihiro-notes,包含笔记源文件.md,以及PDF版本方便阅读,且是用了精美主题,阅读体验更佳,如果文章对你有…

金蝶云星空任意文件读取漏洞复现(0day)

0x01 产品简介 金蝶云星空是一款云端企业资源管理(ERP)软件,为企业提供财务管理、供应链管理以及业务流程管理等一体化解决方案。金蝶云星空聚焦多组织,多利润中心的大中型企业,以 “开放、标准、社交”三大特性为数字…

【Linux】 UDP网络套接字编程

🍎作者:阿润菜菜 📖专栏:Linux系统网络编程 文章目录 一、网络通信的本质(port标识的进程间通信)二、传输层协议UDP/TCP认识传输层协议UDP/TCP网络字节序问题(规定大端) 三、socket编…

ClickHouse的安装启动

安装步骤 1.关闭防火墙 2.修改资源限制配置文件 2.1 路径:/etc/security/limits.conf 在末尾添加: * soft nofile 65536 #任何用户可以打开的最大的文件描述符数量,默认1024 这里的设置会限制tcp连接数 * hard nofile 65536 * soft nproc…

什么是架构 架构图

如何画架构图_个人渣记录仅为自己搜索用的博客-CSDN博客 什么是架构?要表达的到底是什么? Linus 03 年在聊到拆分和集成时有一个很好的描述: I claim that you want to start communicating between independent modules no sooner than you…

【指针三:穿越编程边界的超能力】

本章重点 9.指针和数组面试题的解析 10. 指针笔试题 九、指针和数组面试题的解析 1、一维数组的sizeof #include<stdio.h> int main() {int a[] { 1,2,3,4 };printf("%d\n", sizeof(a));printf("%d\n", sizeof(a 0));printf("%d\n", s…

探索运营商渠道佣金数字化运营

当前全球经济增长放缓&#xff0c;行业竞争持续加剧已是常态&#xff0c;用户需求越发苛刻、经营成本不断上升。内忧外患&#xff0c;企业经营如何突围&#xff1f;越来越多的企业发现&#xff0c;融合数字化技术的IT解决方案为企业提供了一种解决问题的可能。 数字化运营可以帮…

反转链表(JS)

反转链表 题目 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5] 输出&#xff1a;[5,4,3,2,1]示例 2&#xff1a; 输入&#xff1a;head [1,2] 输出&#xff1a;[2,1]示例 3&…