循环神经网络(RNN)实现股票预测

news2024/11/24 16:56:25

文章目录

  • 一、前言
  • 二、前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
    • 2. 导入数据
  • 四、数据预处理
    • 1.归一化
    • 2.设置测试集训练集
  • 五、构建模型
  • 六、激活模型
  • 七、训练模型
  • 八、结果可视化
    • 1.绘制loss图
    • 2.预测
    • 3.评估

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别

来自专栏:机器学习与深度学习算法推荐

二、前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")

2. 导入数据

import os,math
from tensorflow.keras.layers import Dropout, Dense, SimpleRNN
from sklearn.preprocessing   import MinMaxScaler
from sklearn                 import metrics
import numpy             as np
import pandas            as pd
import tensorflow        as tf
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
data = pd.read_csv('./datasets/SH600519.csv')  # 读取股票文件

data
Unnamed: 0dateopenclosehighlowvolumecode
0742010-04-2688.70287.38189.07287.362107036.13600519
1752010-04-2787.35584.84187.35584.68158234.48600519
2762010-04-2884.23584.31885.12883.59726287.43600519
3772010-04-2984.59285.67186.31584.59234501.20600519
4782010-04-3083.87182.34083.87181.52385566.70600519
242124952020-04-201221.0001227.3001231.5001216.80024239.00600519
242224962020-04-211221.0201200.0001223.9901193.00029224.00600519
242324972020-04-221206.0001244.5001249.5001202.22044035.00600519
242424982020-04-231250.0001252.2601265.6801247.77026899.00600519
242524992020-04-241248.0001250.5601259.8901235.18019122.00600519

2426 rows × 8 columns

training_set = data.iloc[0:2426 - 300, 2:3].values  
test_set = data.iloc[2426 - 300:, 2:3].values  

四、数据预处理

1.归一化

sc           = MinMaxScaler(feature_range=(0, 1))
training_set = sc.fit_transform(training_set)
test_set     = sc.transform(test_set) 

2.设置测试集训练集

x_train = []
y_train = []

x_test = []
y_test = []

"""
使用前60天的开盘价作为输入特征x_train
    第61天的开盘价作为输入标签y_train
    
for循环共构建2426-300-60=2066组训练数据。
       共构建300-60=260组测试数据
"""
for i in range(60, len(training_set)):
    x_train.append(training_set[i - 60:i, 0])
    y_train.append(training_set[i, 0])
    
for i in range(60, len(test_set)):
    x_test.append(test_set[i - 60:i, 0])
    y_test.append(test_set[i, 0])
    
# 对训练集进行打乱
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
"""
将训练数据调整为数组(array)

调整后的形状:
x_train:(2066, 60, 1)
y_train:(2066,)
x_test :(240, 60, 1)
y_test :(240,)
"""
x_train, y_train = np.array(x_train), np.array(y_train) # x_train形状为:(2066, 60, 1)
x_test,  y_test  = np.array(x_test),  np.array(y_test)

"""
输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]
"""
x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))
x_test  = np.reshape(x_test,  (x_test.shape[0], 60, 1))

五、构建模型

model = tf.keras.Sequential([
    SimpleRNN(80, return_sequences=True), #布尔值。是返回输出序列中的最后一个输出,还是全部序列。
    Dropout(0.2),                         #防止过拟合
    SimpleRNN(80),
    Dropout(0.2),
    Dense(1)
])

六、激活模型

# 该应用只观测loss数值,不观测准确率,所以删去metrics选项,一会在每个epoch迭代显示时只显示loss值
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
              loss='mean_squared_error')  # 损失函数用均方误差

七、训练模型

history = model.fit(x_train, y_train, 
                    batch_size=64, 
                    epochs=20, 
                    validation_data=(x_test, y_test), 
                    validation_freq=1)                  #测试的epoch间隔数

model.summary()
Epoch 1/20
33/33 [==============================] - 6s 123ms/step - loss: 0.1809 - val_loss: 0.0310
Epoch 2/20
33/33 [==============================] - 3s 105ms/step - loss: 0.0257 - val_loss: 0.0721
Epoch 3/20
33/33 [==============================] - 3s 85ms/step - loss: 0.0165 - val_loss: 0.0059
Epoch 4/20
33/33 [==============================] - 3s 85ms/step - loss: 0.0097 - val_loss: 0.0111
Epoch 5/20
33/33 [==============================] - 3s 90ms/step - loss: 0.0099 - val_loss: 0.0139
Epoch 6/20
33/33 [==============================] - 3s 105ms/step - loss: 0.0067 - val_loss: 0.0167
Epoch 7/20
33/33 [==============================] - 3s 86ms/step - loss: 0.0067 - val_loss: 0.0095
Epoch 8/20
33/33 [==============================] - 3s 91ms/step - loss: 0.0063 - val_loss: 0.0218
Epoch 9/20
33/33 [==============================] - 3s 99ms/step - loss: 0.0052 - val_loss: 0.0109
Epoch 10/20
33/33 [==============================] - 3s 99ms/step - loss: 0.0043 - val_loss: 0.0120
Epoch 11/20
33/33 [==============================] - 3s 92ms/step - loss: 0.0044 - val_loss: 0.0167
Epoch 12/20
33/33 [==============================] - 3s 89ms/step - loss: 0.0039 - val_loss: 0.0032
Epoch 13/20
33/33 [==============================] - 3s 88ms/step - loss: 0.0041 - val_loss: 0.0052
Epoch 14/20
33/33 [==============================] - 3s 93ms/step - loss: 0.0035 - val_loss: 0.0179
Epoch 15/20
33/33 [==============================] - 4s 110ms/step - loss: 0.0033 - val_loss: 0.0124
Epoch 16/20
33/33 [==============================] - 3s 95ms/step - loss: 0.0035 - val_loss: 0.0149
Epoch 17/20
33/33 [==============================] - 4s 111ms/step - loss: 0.0028 - val_loss: 0.0111
Epoch 18/20
33/33 [==============================] - 4s 110ms/step - loss: 0.0029 - val_loss: 0.0061
Epoch 19/20
33/33 [==============================] - 3s 104ms/step - loss: 0.0027 - val_loss: 0.0110
Epoch 20/20
33/33 [==============================] - 3s 90ms/step - loss: 0.0028 - val_loss: 0.0037
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
simple_rnn (SimpleRNN)       (None, 60, 80)            6560      
_________________________________________________________________
dropout (Dropout)            (None, 60, 80)            0         
_________________________________________________________________
simple_rnn_1 (SimpleRNN)     (None, 80)                12880     
_________________________________________________________________
dropout_1 (Dropout)          (None, 80)                0         
_________________________________________________________________
dense (Dense)                (None, 1)                 81        
=================================================================
Total params: 19,521
Trainable params: 19,521
Non-trainable params: 0
_________________________________________________________________

八、结果可视化

1.绘制loss图

plt.plot(history.history['loss']    , label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()

2.预测

predicted_stock_price = model.predict(x_test)                       # 测试集输入模型进行预测
predicted_stock_price = sc.inverse_transform(predicted_stock_price) # 对预测数据还原---从(0,1)反归一化到原始范围
real_stock_price = sc.inverse_transform(test_set[60:])              # 对真实数据还原---从(0,1)反归一化到原始范围

# 画出真实数据和预测数据的对比曲线
plt.plot(real_stock_price, color='red', label='Stock Price')
plt.plot(predicted_stock_price, color='blue', label='Predicted Stock Price')
plt.title('Stock Price Prediction by K同学啊')
plt.xlabel('Time')
plt.ylabel('Stock Price')
plt.legend()
plt.show()

在这里插入图片描述

3.评估

MSE   = metrics.mean_squared_error(predicted_stock_price, real_stock_price)
RMSE  = metrics.mean_squared_error(predicted_stock_price, real_stock_price)**0.5
MAE   = metrics.mean_absolute_error(predicted_stock_price, real_stock_price)
R2    = metrics.r2_score(predicted_stock_price, real_stock_price)

print('均方误差: %.5f' % MSE)
print('均方根误差: %.5f' % RMSE)
print('平均绝对误差: %.5f' % MAE)
print('R2: %.5f' % R2)
均方误差: 1833.92534
均方根误差: 42.82435
平均绝对误差: 36.23424
R2: 0.72347

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1236980.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

防雷接地+防雷工程施工综合方案

一、地凯科技防雷工程接地概述 防雷接地工程是指在建筑物或其他设施上安装防雷装置,以防止雷电对人员、设备和建筑物造成危害的工程。防雷装置主要包括避雷针(网)、引下线、接地体(网)等部分,其中接地体&a…

Zookeeper初识及安装配置

文章目录 写在前面一、Zookeeper概念二、下载安装2.1 环境准备2.2 下载上传2.3 解压 三、配置启动3.1 配置zoo.cfg3.2 启动Zookeeper 写在前面 最近接受了一个比较老的分布式项目,用的Zookeeper协调服务,所以虽然相关的服务注册等功能有很多可能更好的代…

技术分享|电商数据接口|淘宝天猫京东商品API接口之数据同步

常见的数据同步/集成场景多发生于不同的存储系统、不同的存储格式,如从 mysql 同步数据至数仓、excel 或 csv 导入数据库中,但是众多数据同步解决方案很少涉及从 http 接口同步数据。 如淘宝、拼多多等电商平台,平台内部不同团队之间的数据打…

4.3、Linux进程(2)

个人主页:Lei宝啊 愿所有美好如期而遇 通过系统调用创建进程--fork函数 结果是什么呢? 为什么会出来三个打印呢? 就是因为父进程调用了fork函数创建出了子进程的task_struct,但是一个进程不止task_struct,还有代码和数据,他们…

JVM基础- 垃圾回收器

基本介绍 Java虚拟机(JVM)中的垃圾回收器是用来自动管理内存的关键组件。它负责识别并回收不再使用的内存,从而防止内存泄漏。不同的JVM实现提供了多种垃圾回收器,每种回收器都有其特定的使用场景和性能特点。以下是一些常见的JV…

Rust生态系统:探索常用的库和框架

大家好!我是lincyang。 今天我们来探索Rust的生态系统,特别是其中的一些常用库和框架。 Rust生态系统虽然相比于一些更成熟的语言还在成长阶段,但已经有很多强大的工具和库支持各种应用的开发。 常用的Rust库和框架 Serde:一个…

Tesco EDI需求分析

Tesco,成立于1919年,是一家全球领先的综合性零售企业,总部位于英国。公司致力于提供高质量、多样化的商品和服务,以满足客户的需求。Tesco的使命是通过创新和卓越的客户服务,为客户创造更美好的生活。多年来&#xff0…

51单片机LED灯渐明渐暗实验

51单片机LED灯渐明渐暗实验 1.概述 这篇文章介绍使用单片机控制两个LED彩灯亮度渐明渐暗效果,详细介绍了操作步骤以及完整的程序代码,动手就能制作的小实验。 2.操作步骤 2.1.硬件搭建 1.硬件准备 名称型号数量单片机STC12C2052AD1LED彩灯无2晶振1…

抖音预约服务小程序开发:前端与后端技术的完美融合

开发抖音预约服务小程序成为了一种有趣而又实用的尝试。本篇文章,小编会与大家共同探讨抖音预约服务小程序开发的前端与后端技术融合的关键要点。 一、前端技术选择与设计 1.小程序框架 开发抖音预约服务小程序的前端,首先需要选择一个适用的小程序框…

系统试运行方案

系统试运行的目的: 试运行目的通过既定时间段的试运行,全面考察项目建设成果。并通过试运行发现项目存在的问题,从而进一步完善项目建设内容,确保项目顺利通过竣工验收并平稳地移交给运行管理单位。通过实际运行中系统功能与性能的…

系列三、ThreadLocal vs synchronized

一、ThreadLocal vs synchronized 虽然ThreadLocal与synchronized关键字都能用于处理多线程并发访问变量的问题,但是两者处理问题的角度和思路是不一样的。区别如下: 小总结:虽然上一篇中的案例都实现了线程隔离,但是使用ThreadLo…

计算机网络之概述

一、概述 1.1因特网概述 定义 网络(Network)由若干结点(Node)和连接这些结点的链路(Link)组成。多个网络还可以通过路由器互连起来,这样就构成了一个覆盖范围更大的网络,即互联网(或互连网)因此,互联网是“网络的网络…

【C++】类与对象(中)

一、类的默认成员函数 如果一个类中什么成员都没有,简称为空类。 空类中真的什么都没有吗?并不是,任何类在什么都不写时,编译器会自动生成以下6个默认成员函数。 默认成员函数:用户没有显式实现,编译器会自…

最新红盟云卡个人自动发卡开源系统源码+全开源无加密+虚拟商品在线售卖平台

源码简介: 最新红盟云卡个人自动发卡开源系统源码全开源无加密虚拟商品在线售卖平台,支持多个接口的个人免签功能。 红盟云卡系统是一款基于PHP和MySQL开发的虚拟商品在线售卖平台。它具备美观且功能丰富的发卡网站特性,并可与社区进行无缝…

微信小程序点击图片放大预览,新页面中全屏预览图片

第一步&#xff1a;在wxml中定义image组件&#xff0c;并设置绑定事件。 <image src"{{priceUrl}}" bindtap"imgClick"></image>第二步&#xff1a;在js中设置需要预览图片的URL数组&#xff0c;切记一定要是数组&#xff0c;即使一张图也要是…

排序算法--插入排序

实现逻辑 ① 从第一个元素开始&#xff0c;该元素可以认为已经被排序 ② 取出下一个元素&#xff0c;在已经排序的元素序列中从后向前扫描 ③如果该元素&#xff08;已排序&#xff09;大于新元素&#xff0c;将该元素移到下一位置 ④ 重复步骤③&#xff0c;直到找到已排序的元…

DataFunSummit:2023年智能风控技术峰会-核心PPT资料下载

一、峰会简介 智能风控的技术体系涉及多个方面&#xff0c;包括数据架构、数据类型、风控算法、团队建设等。随着风控事件发生愈发频繁和规模化&#xff0c;风控数据架构将持续优化具有高并发、高吞吐特点的实时计算场景&#xff0c;带来更加及时的响应。 由于风控事件以人与…

为什么录屏没声音?实用技巧大放送!

录屏已成为我们在数字时代记录和分享内容的重要方式之一。但有时&#xff0c;您可能会遇到录制视频却没有声音的问题。这个问题可能出现在不同的录屏软件中&#xff0c;导致许多人感到疑惑。在本文中&#xff0c;我们将探讨为什么录屏没声音&#xff0c;并提供两种解决方案&…

【狂神说】CSS3详解

目录 CSS概述什么是CSSCSS发展史快速入门CSS的三种导入方式 2 选择器2.1 基本选择器标签选择器类选择器id选择器 2.2 层次选择器2.3 结构伪类选择器2.4 属性选择器&#xff08;常用&#xff09; 3 美化网页元素3.1 为什么要美化网页3.2 字体样式3.3 文本样式 视频课程见链接&am…