深度学习之循环神经网络(RNN)实现股票预测

news2024/11/25 20:51:03

深度学习训练营之循环神经网络(RNN)实现股票预测

  • 原文链接
  • 环境介绍
  • 前置工作
    • 设置GPU
    • 数据加载
    • 划分数据集
  • 模型训练
    • 数据预处理
      • 归一化
      • 对样本进行构建
    • 构建模型
    • 激活模型
    • 对模型进行训练
  • 结果可视化
  • 预测
  • 模型评估

原文链接

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第P1周:实现mnist手写数字识别
  • 🍖 原作者:K同学啊|接辅导、项目定制

环境介绍

  • 语言环境:Python3.9.13
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2

前置工作

设置GPU

如果

# K同学啊深度学习练习
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")

数据加载

import os,math
from tensorflow.keras.layers import Dropout, Dense, SimpleRNN
from sklearn.preprocessing   import MinMaxScaler
from sklearn                 import metrics
import numpy             as np
import pandas            as pd
import tensorflow        as tf
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
data = pd.read_csv('SH600519.csv')  # 读取股票文件

data

把对应的文件放在一个目录当中
在这里插入图片描述

划分数据集

我调整了一下测试集和训练集的数量,因为我后续训练的结果的训练集存在一定的欠拟合问题

"""
前(2426-300=2126)天的开盘价作为训练集,表格从0开始计数,2:3 是提取[2:3)列,前闭后开,故提取出C列开盘价
后300天的开盘价作为测试集
"""
training_set = data.iloc[0:2426 - 450, 2:3].values  
test_set = data.iloc[2426 - 450:, 2:3].values  

模型训练

数据预处理

归一化

sc           = MinMaxScaler(feature_range=(0, 1))
training_set = sc.fit_transform(training_set)
test_set     = sc.transform(test_set) 

对样本进行构建

x_train = []
y_train = []

x_test = []
y_test = []

"""
使用前60天的开盘价作为输入特征x_train
    第61天的开盘价作为输入标签y_train
    
for循环共构建2426-450-60=1916组训练数据。
       共构建450-60=390组测试数据
"""
for i in range(60, len(training_set)):
    x_train.append(training_set[i - 60:i, 0])
    y_train.append(training_set[i, 0])
    
for i in range(60, len(test_set)):
    x_test.append(test_set[i - 60:i, 0])
    y_test.append(test_set[i, 0])
    
# 对训练集进行打乱
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
"""
将训练数据调整为数组(array)

调整后的形状:
x_train:(1916, 60, 1)
y_train:(1916,)
x_test :(390, 60, 1)
y_test :(390,)
"""
x_train, y_train = np.array(x_train), np.array(y_train) # x_train形状为:(2066, 60, 1)
x_test,  y_test  = np.array(x_test),  np.array(y_test)

"""
输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]
"""
x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))
x_test  = np.reshape(x_test,  (x_test.shape[0], 60, 1))

构建模型

model = tf.keras.Sequential([
    SimpleRNN(100, return_sequences=True), #布尔值。是返回输出序列中的最后一个输出,还是全部序列。
    Dropout(0.1),                         #防止过拟合
    SimpleRNN(100),
    Dropout(0.1),
    Dense(1)
])

激活模型

# 该应用只观测loss数值,不观测准确率,所以删去metrics选项,一会在每个epoch迭代显示时只显示loss值
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
              loss='mean_squared_error')  # 损失函数用均方误差

对模型进行训练

训练时间相对来说还是比较快的

history = model.fit(x_train, y_train, 
                    batch_size=64, 
                    epochs=20, 
                    validation_data=(x_test, y_test), 
                    validation_freq=1)                  #测试的epoch间隔数

model.summary()

model.summary()对模型结果进行描述

在这里插入图片描述

结果可视化

plt.plot(history.history['loss']    , label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss by 无你想你')
plt.legend()
plt.show()

如果按照原本的数据集划分的方法我就会得到这样的结果,存在较大的欠拟合问题,所以我改了一下数据集划分的方法请添加图片描述

在这里插入图片描述
更改数据集划分
请添加图片描述

预测

predicted_stock_price = model.predict(x_test)                       # 测试集输入模型进行预测
predicted_stock_price = sc.inverse_transform(predicted_stock_price) # 对预测数据还原---从(0,1)反归一化到原始范围
real_stock_price = sc.inverse_transform(test_set[60:])              # 对真实数据还原---从(0,1)反归一化到原始范围

# 画出真实数据和预测数据的对比曲线
plt.plot(real_stock_price, color='red', label='Stock Price')
plt.plot(predicted_stock_price, color='blue', label='Predicted Stock Price')
plt.title('Stock Price Prediction by 无你想你')
plt.xlabel('Time')
plt.ylabel('Stock Price')
plt.legend()
plt.show()

请添加图片描述

模型评估

"""
MSE  :均方误差    ----->  预测值减真实值求平方后求均值
RMSE :均方根误差  ----->  对均方误差开方
MAE  :平均绝对误差----->  预测值减真实值求绝对值后求均值
R2   :决定系数,可以简单理解为反映模型拟合优度的重要的统计量

详细介绍可以参考文章:https://blog.csdn.net/qq_38251616/article/details/107997435
"""
MSE   = metrics.mean_squared_error(predicted_stock_price, real_stock_price)
RMSE  = metrics.mean_squared_error(predicted_stock_price, real_stock_price)**0.5
MAE   = metrics.mean_absolute_error(predicted_stock_price, real_stock_price)
R2    = metrics.r2_score(predicted_stock_price, real_stock_price)

print('均方误差: %.5f' % MSE)
print('均方根误差: %.5f' % RMSE)
print('平均绝对误差: %.5f' % MAE)
print('R2: %.5f' % R2)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/336561.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Python案例实战】爬虫能做哪些很酷很有趣很有用的事情?爱了爱了,简直是神仙代码~(淘宝秒杀、VIP视频解析、wiwi破解等)

前言 🚀 作者 :“程序员梨子” 🚀 **文章简介 **:本篇文章主要是写了opencv的人脸检测、猫脸检测小程序。 🚀 **文章源码免费获取 : 为了感谢每一个关注我的小可爱💓每篇文章的项目源码都是无 偿…

Spring Cloud组件

1.服务治理 Spring Cloud Eureka 概念 Eureka提供了服务端组件,我们也称为注册中心。每个服务都向Eureka的服务注册中心,登记自己提供服务的元数据,包括服务的ip地址、端口号、版本号、通信协议等。 原理 服务注册中心,还会以心跳…

【DOCKER】容器概念基础

文章目录1.容器1.概念2.特点3.与虚拟机的对比2.docker1.概念2.命名空间3.核心概念3.命令1.镜像命令2.仓库命令1.容器 1.概念 1.不同的运行环境,底层架构是不同的,这就会导致测试环境运行好好的应用,到了生产环境就会出现bug(就像…

DevOps在项目交付场景下的应用

DevOps介绍 DevOps一词是由development和operation两个单词组合而来,代表着研发和交付运营的一体化。DevOps在2009年就被提出,但在学术界和工业界还没有一个广泛认可的定义,一些有代表性的总结,比如John Willis从文化、自动化、度…

策略编制解决方案

策略编制 NetIQ 使您能够将 AD 策略扩展到整个 IT 生态系统,并集中管理和监控策略与配置更改。 一、优点 提高 IT 管理员效率 集中管理策略控制。 借助可见性降低风险 借助实时监控减少漏洞。 审计和合规性报告 采取措施以满足安全条例和策略的合规要求。 二、…

应用部署初探:微服务的3大部署模式

在之前的文章中,我们已经充分了解了应用部署的4种常见模式(金丝雀部署、蓝绿部署、滚动部署及影子部署)。随着云原生技术逐步成熟,企业追求更为灵活和可扩展的系统,微服务架构大行其道。 微服务固然有诸多优点&…

计算组合数Cnk即从n个不同数中选出k个不同数共有多少种方法math.comb(n,k)

【小白从小学Python、C、Java】 【计算机等级考试500强双证书】 【Python-数据分析】 计算组合数Cnk 即从n个不同数中选出k个不同数共有多少种方法 math.comb(n,k) 以下python代码输出结果是? import math print("【执行】print(math.comb(3,1))") print(math.comb(…

一文了解kafka消息队列,实现kafka的生产者(Producer)和消费者(Consumer)的代码,消息的持久化和消息的同步发送和异步发送

文章目录1. kafka的介绍1.2 Kafka适合的应用场景1.2 Kafka的四个核心API2. 代码实现kafka的生产者和消费者2.1 引入加入jar包2.2 生产者代码2.3 消费者代码2.4 介绍kafka生产者和消费者模式3. 消息持久化4. 消息的同步和异步发送5. 参考文档1. kafka的介绍 最近在学习kafka相关…

Ubuntu20.04+cuda11.2+cudnn8.1+Anaconda3安装tensorflow-GPU环境,亲测可用

(1)安装nvidia显卡驱动注意Ubuntu20.04和Ubuntu16.04版本的安装方法不同,安装驱动前一定要更新软件列表和安装必要软件、依赖(必须)sudo apt-get update #更新软件列表sudo apt-get install gsudo apt-get install gccsudo apt-get install make查看GP…

4.5.1 泛型

文章目录1.概述2.泛型的具体表现形式3.泛型的作用4.泛型示例5.练习:泛型测试一6.练习:泛型测试二1.概述 泛型不是指一种具体的类型,而是说,这里有个类型需要设置,那么具体设置成什么类型,得看具体的使用; …

RabbitMQ-持久化

一、介绍如何保证RabbitMQ服务停掉以后生产者发送过来的消息不丢失。默认情况下RabbitMQ退出或由于某种原因崩溃时,他将忽视队列和消息,除非告知它不要这样做。确保消息不丢失需要做两件事情:将队列和消息都标记为持久化二、队列持久化再声明…

(1分钟速通面试) SLAM中的最小二乘问题

最小二乘拟合问题 求解超定方程首先写这篇博客之前说一个背景,这个最小二乘拟合问题是我在去年面试实习的时候被问到的,然后当时是非常的尴尬,没有回答上来里面的问题。Hhh 所以这篇博客来进行一个补充学习一下下。感觉这个最小二乘问题还是比…

根据报告20%的白领在一年内做过副业,你有做副业吗?

现在大部分人收入单一,收入都是来源于本职工作,当没有了工作就没有了收入的来源,而生活压力又很大,各种开支,各种消费。所以很多人想要增加收入来源,增加被动收入,同时通过副业提升自己的价值和…

LeetCode·每日一题·1223.掷骰子模拟·记忆化搜索

作者:小迅链接:https://leetcode.cn/problems/dice-roll-simulation/solutions/2103471/ji-yi-hua-sou-suo-zhu-shi-chao-ji-xiang-xlfcs/来源:力扣(LeetCode)著作权归作者所有。商业转载请联系作者获得授权&#xff0…

libxlsxwriter中文报错问题

libxlsxwriter库在windows系统下VS中存在中文输入报错问题。这在小白关于libxlsxwriter的第一篇博客libxlsxwriter初体验里有所阐述。当时小白给出的解决方案是将文件编码修改成不带签名的utf-8。后来在使用中,小白发现这样并没有完全解决问题。有的中文可以正常写入…

VHDL语言基础-时序逻辑电路-触发器

目录 触发器: D触发器: 触发器的VHDL描述: 触发器的仿真波形如下:​编辑 时钟边沿检测的三种方法: 方法一: 方法二: 方法三: 带有Q非的D触发器: 带有Q非的D触发器的描述&am…

微信小程序 Springboot高校课堂教学管理系统-java

小程序端 学生在小程序端进行注册并且进行登录。 填写自己的个人信息进行注册 登录成功后可以看到有首页、课程资源、测试、互动论坛、我的功能模块。 课程资源学生可以点击想要查看的资源进行观看。 课程分类学生可以按照自己想要的分类进行搜索并且进行观看。 互动论坛可以查…

四种方式的MySQL安装过程 数据库(2)

目录 1. 仓库安装: 1.1 卸载数据库软件: 2. 本地安装: 2.1 卸载数据库软件: 3. 容器安装: 4. 源码安装: 4.1 使用systemctl命令启动进程 1. 仓库安装: (1)查看版本…

超级详细GitBook和GitLab集成步骤【linux环境】

介绍 本文主要是在 gitlab 上集成 gitbook 实现提交时 gitbook 自动刷新部署 ,以及在 linux 环境上搭建 gitlab gitbook,集成 GitLab CI 实现一个企业级或个人的 Wiki 系统 环境准备 1.一台 linux 服务器 2.安装 node 以及 npm 环境 (这里注意 node 环境不要过高 不…

CS反制之批量伪装上线

分析原理: 我们利用Wireshark抓包工具分析一下Cobalt Strike的上线过程是怎么样的 点击木马,主机上线并抓包 查看数据包 可以看到cookie是一串非对称RSA加密类型,需要一个私钥Private Key才能对其进行解密 我们对Cookie解密看看&#xff…