第九章(1):循环神经网络与pytorch示例(RNN实现股价预测)

news2024/11/24 11:12:22

第九章(1):循环神经网络与pytorch示例(RNN实现股价预测)


作者:安静到无声 个人主页

作者简介:人工智能和硬件设计博士生、CSDN与阿里云开发者博客专家,多项比赛获奖者,发表SCI论文多篇。

Thanks♪(・ω・)ノ 如果觉得文章不错或能帮助到你学习,可以点赞👍收藏📁评论📒+关注哦! o( ̄▽ ̄)d

欢迎大家来到安静到无声的 《基于pytorch的自然语言处理入门与实践》,如果对所写内容感兴趣请看《基于pytorch的自然语言处理入门与实践》系列讲解 - 总目录,同时这也可以作为大家学习的参考。欢迎订阅,请多多支持!

目录标题

  • 第九章(1):循环神经网络与pytorch示例(RNN实现股价预测)
  • 1. 概述
  • 2. 简单的循环神经网络
  • 3. RNN实现股价预测
  • 参考

1. 概述

循环神经网络(Recurrent Neural Network,RNN)是一种基于神经网络的机器学习模型,主要用于处理序列数据。与传统的前馈神经网络不同,RNN引入了循环连接,使得模型能够捕捉到输入序列中的上下文信息和时间依赖关系。

假设给定一个序列, x 1 : T = ( x 1 , x 2 , … , x t , … , x T ) x_{1:T}=(x_{1},x_{2},\ldots,x_{t},\ldots,x_{T}) x1:T=(x1,x2,,xt,,xT),RNN神经网络通过下面公式更新带反馈边的隐藏层的活性值 h t h_t ht
h t = f ( h t − 1 , x t ) , \boldsymbol{h}_{t}=f(\boldsymbol{h}_{t-1},\boldsymbol{x}_{t}), ht=f(ht1,xt),

其中 h 0 = 0 {}_{\boldsymbol{h}_{0}=0} h0=0 f ( ⋅ ) f(\cdot) f()为一个非线性的函数,例如前馈神经网络。

下图给出了循环神经网络的示例,其中“延时器”为一个虚拟单元,记录神经元的最近一次或几次活性值。
在这里插入图片描述

从数学上讲,上文公式可以看成一个动力系统。隐藏层的活性值 h t h_t ht,在很多文献上也称为状态( State ) 或隐状态( Hidden State )。

2. 简单的循环神经网络

简单循环网络( Simple Recurrent Network,SRN)是一个非常简单的循环神经网络,只有一个隐藏层的神经网络. 在一个两层的前馈神经网络中,连接存在相邻的层与层之间,隐藏层的节点之间是无连接的。而简单循环网络增加了从隐藏层到隐藏层的反馈连接。

向量 x t ∈ R M {\boldsymbol{x}}_{t}\in\mathbb{R}^{M} xtRM表示在 t t t时刻的一个输入, h t ∈ R D h_t\in\mathbb{R}^D htRD表示一个隐藏层的状态,这时 h t h_t ht与当前时刻的 x t x_t xt有关系,而且也和上一时刻的 h t − 1 h_{t-1} ht1有关系,简单的循环神经网络在 t t t时刻的更新公式如下所示:
z t = U h t − 1 + W x t + b z_{t}=U\boldsymbol{h}_{t-1}+Wx_{t}+\boldsymbol{b} zt=Uht1+Wxt+b h t = f ( z t ) h_t=f(\boldsymbol{z}_t) ht=f(zt)其中 z t z_{t} zt为隐藏层的净输入, U ∈ R D × D U\in\mathbb{R}^{D\times D} URD×D为状态-状态的矩阵, W ∈ R D × M W\in\mathbb{R}^{D\times M} WRD×M为状态-输入的权重矩阵, b ∈ R D b\in\mathbb{R}^{D} bRD为偏置项, f ( ⋅ ) f(\cdot) f()为一个非线性的激活函数。上述公式可以和写为:
h t = f ( U h t − 1 + W x t + b ) . \boldsymbol{h}_{t}=f(\boldsymbol{Uh}_{t-1}+\boldsymbol{W}\boldsymbol{x}_{t}+\boldsymbol{b}). ht=f(Uht1+Wxt+b).

如果我们把每个时刻的状态都看作前馈神经网络的一层,循环神经网络可以看作在时间维度上权值共享的神经网络。下给出了按时间展开的循环神经网络。

在这里插入图片描述

3. RNN实现股价预测

建立RNN基于zgpa train.csv数据,建立RNN模型,预测股价。

  • 完成数据预处理,将序列数据转化为可用子RNN输入的数据
  • 对新数据zgpa_test.csv进行预测,可视化结果
    模型结构:RNN 输出有120个神经元,每次使用前8个数据预测第9个数据。
import pandas as pd
import numpy as np
data = pd.read_csv(r'./zgpa_train.csv')
data.head()

在这里插入图片描述

price = data.loc[:,'close']
price.head()
0    28.78
1    29.23
2    29.26
3    28.50
4    28.67
Name: close, dtype: float64
#归一化处理
price_norm = price/max(price)

%matplotlib inline
from matplotlib import pyplot as plt
fig1 = plt.figure(figsize=(8,5))
plt.plot(price)
plt.title('close price')
plt.xlabel('time')
plt.ylabel('price')
plt.show()

在这里插入图片描述

#define X and y
#define method to extract X and y
def extract_data(data,time_step):
    X = []
    y = []
    #0,1,2,3...9:10个样本;time_step=8;0,1...7;1,2...8;2,3...9三组(两组样本)
    for i in range(len(data)-time_step):
        X.append([a for a in data[i:i+time_step]])
        y.append(data[i+time_step])
    X = np.array(X)
    X = X.reshape(X.shape[0],X.shape[1],1)
    return X, y
time_step = 8
X,y = extract_data(price_norm,time_step)


# 转换为Tensor
import torch
input_data = torch.tensor(X, dtype=torch.float32)
target_data = torch.tensor(y, dtype=torch.float32).unsqueeze(-1)
# 定义RNN模型
import torch.nn as nn
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):  #分别是输入,隐藏和输出的维度
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.rnn(x)
        out = self.fc(out[:, -1, :])
        return out


input_size = 1
hidden_size = 120
output_size = 1

# 创建模型实例
rnn = RNN(input_size, hidden_size, output_size)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.01)

# 训练模型
num_epochs = 1000
for epoch in range(num_epochs):
    outputs = rnn(input_data)
    loss = criterion(outputs, target_data)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1) % 100 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.10f}")


Epoch [100/1000], Loss: 0.0004473718
Epoch [200/1000], Loss: 0.0006229825
Epoch [300/1000], Loss: 0.0003415935
Epoch [400/1000], Loss: 0.0002980608
Epoch [500/1000], Loss: 0.0002818418
Epoch [600/1000], Loss: 0.0002672504
Epoch [700/1000], Loss: 0.0002543367
Epoch [800/1000], Loss: 0.0002437856
Epoch [900/1000], Loss: 0.0002345658
Epoch [1000/1000], Loss: 0.0002437943
predict = rnn(input_data)
predict_out = []
for i in predict:
    predict_out.append(i)
    
fig2 = plt.figure(figsize=(8,5))
plt.plot(y,label='real price')
plt.plot(predict_out,label='predict price')
plt.title('close price')
plt.xlabel('time')
plt.ylabel('price')
plt.legend()
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-s6UNaWPo-1689238498868)(/imgs/2023-07-13/b7Gtk9DkbhFTkyjr.png)]

--------推荐专栏--------
🔥 手把手实现Image captioning
💯CNN模型压缩
💖模式识别与人工智能(程序与算法)
🔥FPGA—Verilog与Hls学习与实践
💯基于Pytorch的自然语言处理入门与实践

参考

邱锡鹏,神经网络与深度学习,机械工业出版社,https://nndl.github.io/, 2020.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/749427.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

1.4 MVP矩阵

MVP矩阵代表什么 MVP矩阵分别是模型(Model)、观察(View)、投影(Projection)三个矩阵。 我们的顶点坐标起始于局部空间(Local Space),在这里他成为局部坐标(L…

【PHP面试题35】什么是MVC,为什么要使用它

文章目录 一、前言二、MVC介绍2.1 模型(Model)2.2 视图(View)2.3 控制器(Controller) 三、MVC模式的优点四、总结 一、前言 本文已收录于PHP全栈系列专栏:PHP面试专区。 计划将全覆盖PHP开发领域…

『分割』 平面模型分割

PCL提供的几个常见模型: pcl::SACMODEL_PLANE:平面模型,用于拟合平面结构的点云数据。 pcl::SACMODEL_SPHERE:球体模型,适用于拟合球体结构的点云数据。 pcl::SACMODEL_CYLINDER:圆柱体模型,用…

一个四年Android程序猿的2023上半年总结

一晃就做了四年的Android开发了,时光飞逝啊~ 工作的时间飞快,感觉每一天都很充实,但是大多数都是重复的样子。 去年的目标达成: 去年的目标就是学习学习,涨薪涨薪。上家公司的同事氛围很不错&#xff0…

一篇文章了解Redis分布式锁

Redis分布式锁 什么是分布式锁? ​ redis分布式锁是一种基于redis实现的锁机制,它用于在多并发分布式环境下控制并发访问共享资源。在多个应用程序或是进程访问共享资源时,分布式锁可以确保只有一个进程可以访问该资源,不会发生…

采用555时基电路的简易/可调定时长延时电路设计

采用 555 时基电路的简易长延时电路 本电路和一般的定时电路相比是通过在 555 时基电路的 5 脚处加了一个二极管 VD1,使得定时时间延长的特点。 一、电路工作原理 电路原理如图 11 所示。 当按下按钮SB时,12V的电源通过电阻器Rt向电容器Ct充电&#…

弹性IP和公网IP有什么区别?哪个好

​  弹性IP和公网IP有什么区别?哪个好。IP是服务器重要的组成资源,一台云服务器实例一般分为公网IP和内网IP,公网IP指的是对外访问的IP地址,是针对公众用户的IP,这是网站绑定的服务器IP地址。而内网IP顾名思义就是内部的网络IP…

Android Monkey稳定性测试

l 命令样例: adb shell monkey -p packagename --ignore-timeouts --ignore-crashes -v -v --throttle 200 1000000 各个参数的意义如下: -p 用此参数指定一个或多个包(Package)。指定包之后,Monkey将只允许系统启…

cmake多文件、多文件夹编译(2)

一、同级文件夹下代码调用问题 目录如下: ./testCMake(根目录): /build: /MyClass: CMakeLists.txt MyClass.cpp MyClass.h /MyFunction: CMakeLists.txt MyFunction.cpp MyFunction.h CMakeLists.txt main.cpp 上述…

day35-Postman/ajax

0目录 1.postman 2.ajax 1.Postman 1.1 定义:postman用于测试http协议接口,无论是开发还是测试人员 1.2 Servlet中的doGet()/daPost…

基于JavasSwing+MySQL的医药销售管理系统

点击以下链接获取源码: https://download.csdn.net/download/qq_64505944/87987881?spm1001.2014.3001.5503 功能:管理员与普通用户两个角色登录,可以增删改查用户,增删改查药品等功能 JDK1.8 MySQL5.7

微信小程序——开发入门

注册小程序 微信公众平台 设置相关信息 设置好之后需要去获取appID和秘钥,后序开发需要用到。 下载开发工具并安装 微信开发者工具(稳定版 Stable Build)下载地址与更新日志 | 微信开放文档 创建项目 打开开发者工具创建一个新项目并如下…

使用 ONLYOFFICE 宏检索网站详细信息

在上一篇文章中,我们基于一位用户发送的 VBA 参考构建了一个功能完善的 ONLYOFFICE 宏。今天,我们想再进一步,为其添加一些 Whois API 功能。 什么是 ONLYOFFICE 宏 如果您是一名资深 Microsoft Excel 用户,那么相信您已对于 VBA…

Nacos(服务注册与发现)+SpringBoot+openFeign项目集成

📝 学技术、更要掌握学习的方法,一起学习,让进步发生 👩🏻 作者:一只IT攻城狮 ,关注我,不迷路 。 💐学习建议:1、养成习惯,学习java的任何一个技术…

分割1——图像分割的前世今生

首先讲讲:什么是计算机视觉? 计算机视觉是一门让计算机学会“看”的学科,研究如何自动理解图像和视频中的内容。 其次讲讲:计算机视觉有哪些任务?我们所要讲的图像分割位于什么地位? 计算机视觉的三大经典…

计算机体系结构基础知识介绍之使用动态调度、多重问题和推测来利用流水线

我们已经了解了动态调度、多发射和推测等单独的机制是如何工作的。(具体请参见本人前几篇博客) 现在我们把这三种机制结合起来,得到一种和现代微处理器非常相似的微架构。为了简单起见,我们只考虑每个时钟周期发射两条指令的情况…

《算法竞赛·快冲300题》每日一题:“窗户”

《算法竞赛快冲300题》将于2024年出版,是《算法竞赛》的辅助练习册。 所有题目放在自建的OJ New Online Judge。 用C/C、Java、Python三种语言给出代码,以中低档题为主,适合入门、进阶。 文章目录 题目描述题解C代码Java代码Python代码 “ 窗…

es6 数组操作个人总结

es6 数组操作个人总结 动机数组数组生成可枚举对象转数组箭头函数筛选判断所有元素枚举循环 小结 动机 es6 ,说白了,就是增强版本的 js 。。。。。嗯,说到底,还是原生 js 罢了,不过比原有的 js 多了一些属性、类型、指…

【c++修行之路】智能指针

文章目录 前言为什么用智能指针智能指针简单实现unique_ptrshared_ptr 循环引用和weak_ptr的引入循环引用weak_ptr 定制删除器 前言 大家好久不见,今天来学习有关智能指针的内容~ 为什么用智能指针 假如我们有如下场景: double Div() {int x, y;cin …

Clion 配置Mingw64的 c++开发环境

1、Mingw64的安装与环境变量的配置 Mingw64文件下载 Mingw64下载地址:https://sourceforge.net/projects/mingw-w64/files/ posix相比win32拥有C 11多线程特性,sjlj和seh对应异常处理特性,sjlj较为古老,所以选择seh 配置环境变…