程序员学长 | 快速学会一个算法,RNN

news2024/11/24 6:16:56

本文来源公众号“程序员学长”,仅用于学术分享,侵权删,干货满满。

原文链接:快速学会一个算法,RNN

今天给大家分享一个超强的算法模型,RNN

循环神经网络(Recurrent Neural Network, RNN)是一种专门用于「处理序列数据的神经网络」

由于其能够处理不同长度的输入序列,并保持过去信息的能力,它广泛应用于自然语言处理、语音识别和时间序列预测等领域

RNN的算法原理

RNN 的核心思想是使用循环的连接结构来保持对之前处理过的信息的记忆。

这种记忆通过隐藏层的状态来表达,每个时间步的隐藏状态都依赖于前一时间步的隐藏状态和当前时间步的输入。这种结构使得 RNN 能够捕获时间序列数据中的动态变化特性。

RNN 的问题

循环神经网络(RNN)虽然在处理序列数据方面具有明显优势,但在实际应用中遇到了几个关键问题,特别是梯度消失和梯度爆炸问题。这些问题直接影响了网络的训练效率和性能,进而催生了长短时记忆网络(LSTM)和门控循环单元(GRU)这两种更为高效的RNN变体。

  1. 梯度消失

    在 RNN 中,当网络层较多或者处理的序列数据较长时,由于梯度在反向传播过程中反复乘以小于1的数(如激活函数的导数),导致梯度逐渐变小,最终接近于零。这会使得网络中的权重无法有效更新,特别是序列前端的权重,从而难以捕捉到序列中早期的重要信息。

  2. 梯度爆炸

    与梯度消失相反,梯度爆炸是指在反向传播过程中梯度逐渐变得非常大,这通常发生在权重值较大时。梯度爆炸会导致网络权重的大幅波动,使得训练过程变得不稳定,甚至导致数值计算溢出。

  3. 难以捕捉长期依赖

    由于梯度消失和梯度爆炸的问题,「标准的 RNN 在处理长序列时难以学习到输入序列中的长距离依赖关系」。这意味着网络难以记忆并利用序列中早期的信息来影响后续的输出,这对于许多需要理解整个输入序列上下文的任务来说是一个大问题,如语言翻译、文本生成等。

为了解决这些问题,研究者们开发了 LSTM 和 GRU 这两种特殊类型的RNN。

RNN 变体

LSTM

LSTM(Long Short-Term Memory)是一种特殊类型的循环神经网络(RNN),「专门设计用来解决传统 RNN 在处理序列数据时面临的长期依赖问题」

LSTM 的关键特征是其维持细胞状态的能力,「细胞状态充当可以存储长序列信息的记忆单元」。这使得 LSTM 能够随着时间的推移选择性地记住或忘记信息,使它们非常适合上下文和远程依赖性至关重要的任务。

图片

LSTM 的核心组件

LSTM 的关键在于其内部状态(cell state)和三个重要的门控机制:输入门、遗忘门和输出门。这些门控制着信息的流入、更新和流出,使 LSTM 能够在必要时保存信息跨越多个时间步,或者丢弃不再需要的信息。

GRU

门控循环单元(Gated Recurrent Unit, GRU)旨在简化长短时记忆网络(LSTM)的结构,同时保持对长期依赖信息的捕捉能力。

GRU 对比 LSTM 的主要区别在于「其结构更简单,参数更少,这使得 GRU 在某些情况下训练更快,计算效率更高。」

GRU的核心组件

GRU 将 LSTM 中的三个门控合并为两个门控,即更新门(update gate)和重置门(reset gate)。

这两个门控决定了信息是如何在单元间传递的,帮助网络捕捉时间序列中的长距离依赖。

案例分享

下面我们来使用 RNN、GRU 和 LSTM 进行苹果股价的预测。

import yfinance as yf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout,SimpleRNN,GRU

# 获取苹果公司的股票数据
data = yf.download('AAPL', start='2018-01-01', end='2023-01-01')

# 使用收盘价
close_prices = data['Close'].values.reshape(-1, 1)

# 数据归一化
scaler = MinMaxScaler(feature_range=(0, 1))
close_prices = scaler.fit_transform(close_prices)

# 划分数据集为训练集和测试集
split = int(0.8 * len(close_prices))
train = close_prices[:split]
test = close_prices[split:]

# 创建序列数据集
def create_dataset(data, steps):
    X, y = [], []
    for i in range(len(data) - steps):
        X.append(data[i:(i + steps), 0])
        y.append(data[i + steps, 0])
    return np.array(X), np.array(y)
steps = 60
X_train, y_train = create_dataset(train, steps)
X_test, y_test = create_dataset(test, steps)

# 重塑输入以符合 RNN 模型的期望格式 [样本数, 时间步, 特征数]
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))


# 构建 LSTM 模型
model = Sequential()
model.add(LSTM(units=50, return_sequences=True, input_shape=(X_train.shape[1], 1)))
model.add(LSTM(units=50))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mean_squared_error')
model.fit(X_train, y_train, epochs=50, batch_size=32, verbose=1)


# 构建 RNN 模型
model_rnn = Sequential()
model_rnn.add(SimpleRNN(units=50, return_sequences=True, input_shape=(X_train.shape[1], 1)))
model_rnn.add(SimpleRNN(units=50))
model_rnn.add(Dense(1))
model_rnn.compile(optimizer='adam', loss='mean_squared_error')
model_rnn.fit(X_train, y_train, epochs=50, batch_size=32, verbose=1)

# 构建 GRU 模型
model_gru = Sequential()
model_gru.add(GRU(units=50, return_sequences=True, input_shape=(X_train.shape[1], 1)))
model_gru.add(GRU(units=50))
model_gru.add(Dense(1))
model_gru.compile(optimizer='adam', loss='mean_squared_error')
model_gru.fit(X_train, y_train, epochs=50, batch_size=32, verbose=1)

接下来,我们来看一下预测的结果。

# LSTM 预测
predicted_stock_price_lstm = model.predict(X_test)
predicted_stock_price_lstm = scaler.inverse_transform(predicted_stock_price_lstm)

# RNN 预测
predicted_stock_price_rnn = model_rnn.predict(X_test)
predicted_stock_price_rnn = scaler.inverse_transform(predicted_stock_price_rnn)

# GRU 预测
predicted_stock_price_gru = model_gru.predict(X_test)
predicted_stock_price_gru = scaler.inverse_transform(predicted_stock_price_gru)

# 绘图
plt.figure(figsize=(14, 5))
plt.plot(real_stock_price, color='red', label='Real Apple Stock Price')
plt.plot(predicted_stock_price_lstm, color='blue', label='Predicted Apple Stock Price (LSTM)')
plt.plot(predicted_stock_price_rnn, color='green', label='Predicted Apple Stock Price (RNN)')
plt.plot(predicted_stock_price_gru, color='purple', label='Predicted Apple Stock Price (GRU)')
plt.title('Apple Stock Price Prediction Comparison')
plt.xlabel('Time')
plt.ylabel('Apple Stock Price')
plt.legend()
plt.show()

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1865581.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

怎么打印加密的pfd文件,有那些方法?

现在人们的保密意识越来越强了,越来越多的人在完成pdf文档后就会对文档进行保护,但有的PDF文档被添加了密码,限制了打印的权限,导致我们想打印PDF文档的时候就提示我们要输入密码。面对这种情况,我们要怎样才能把PDF文档打印出来呢…

atcoder abc 359

A count takahashi 问题: 思路&#xff1a;字符串比较 代码&#xff1a; #include <bits/stdc.h>using namespace std;int main() {int n;cin >> n;int ans 0;for(int i 1; i < n; i ) {string s;cin >> s;if(s[0] T) ans ;}cout << ans;re…

BenchmarkSQL 对 MySQL 测试时请注意隔离级别!

BenchmarkSQL 是一款经典的开源数据库测试工具&#xff0c;内含了TPC-C测试脚本&#xff0c;可支持 Oracle、MySQL、PostgreSQL、SQL Server以及一些国产数据库的基准测试。 作者&#xff1a;李彬&#xff0c;爱可生 DBA 团队成员&#xff0c;负责项目日常问题处理及公司平台问…

机器学习--概念理解

知识点 一、机器学习概述 人工智能 机器学习 深度学习 学习的范围&#xff1a;模式识别、数据挖掘、统计学习、计算机视觉、语音识别、自然语言处理 可以解决的问题&#xff1a;给定数据的预测问题 二、机器学习的类型 监督学习 分类 回归 无监督学习 聚类 降维 强化…

【Java Web】会话管理

目录 一、为什么需要会话管理&#xff1f; 二、会话管理机制 三、Cookie概述 四、HttpSession概述 4.1 HttpSession时效性 一、为什么需要会话管理&#xff1f; HTTP协议在设计之初就是无状态的&#xff0c;所谓无状态就是在浏览器和服务器之间的通信过程中&#xff0c;服务器并…

免费录屏软件哪个好?录屏软件,分享3款免费工具

在日常生活或者工作中&#xff0c;录屏软件已经成为我们的得力助手。无论是教学、演示、娱乐&#xff0c;录屏软件都能为我们带来极大的便利。然而&#xff0c;市面上有些录屏软件的价格却十分的昂贵&#xff0c;让人望而却步。那么市面上到底有没有免费的录屏软件&#xff1f;…

MySQL的jdbc、odbc驱动版本必须和MySQL版本一样吗?

MySQL的版本和JDBC&#xff0c;ODBC驱动版本大体一致就可以 比如说MySQL的版本是8.0.35&#xff0c;您可以用8.0.19版本的JDBC,ODBC。或者8.0.31的版本。 除此之外我也查看了其他资料&#xff0c;这个哥们总结的也不错&#xff0c;我把链接放到这里 MySQL JDBC驱动版本与数据…

电脑可以录屏吗?5个方法,珍藏分享

在数字化时代&#xff0c;电脑的录屏功能已经成为许多人工作和学习的必备工具。录屏可以帮助我们捕捉屏幕上的动态内容&#xff0c;记录下重要的瞬间。所以&#xff0c;录屏是一个非常实用的功能。那么&#xff0c;电脑可以录屏吗&#xff1f;答案是肯定的。本文将为您介绍5种电…

一文详细了解Bootloader

Bootloader是什么 bootloader是一个引导加载程序&#xff0c;它的主要作用是初始化硬件设备、设置硬件参数&#xff0c;并加载操作系统内核。在嵌入式系统中&#xff0c;bootloader是硬件启动后第一个被执行的程序&#xff0c;它位于操作系统和硬件之间&#xff0c;起到桥梁的…

加速科技Flash存储测试解决方案 全面保障数据存储可靠性

Flash存储芯片 现代电子设备的核心数据存储守护者 Flash存储芯片是一种关键的非易失性存储器&#xff0c;作为现代电子设备中不可或缺的核心组件&#xff0c;承载着数据的存取重任。这种小巧而强大的芯片&#xff0c;以其低功耗、可靠性、高速的读写能力和巨大的存储容量&…

聊聊AI在企业数字化转型中的作用

随着科技的飞速发展&#xff0c;人工智能&#xff08;AI&#xff09;已经深入到我们生活的方方面面&#xff0c;尤其在数字化转型的浪潮中&#xff0c;AI技术更是扮演着举足轻重的角色。数字化转型&#xff0c;简而言之&#xff0c;就是企业利用数字技术来改造其业务运营方式&a…

[AI MoneyPrinterTurbo] 一键成片,超级印钞机

今天&#xff0c;我们将踏上一段关于MoneyPrinterTurbo的探索之旅&#xff0c;这是一个文生视频工具&#xff0c;旨在让视频创作变得轻松而有趣。 故事的开始 想象一下&#xff0c;你只需要提供一个视频主题或关键词&#xff0c;剩下的——视频文案、素材、字幕、背景音乐&am…

转:关于征集第三批工业软件新场景新技术难题解决思路的公告

工业软件是先进工业知识与经验的凝炼&#xff0c;工业软件自身的先进性既来自对先进工业先进需求的汲取提炼&#xff0c;也来自对根技术新突破、新成果的高效采用。为增强根技术新成果提供方与工业软件厂家或最终用户方的连接&#xff0c;促进国产工业软件差异化竞争力的打造&a…

浅学JVM

一、基本概念 目录 一、基本概念 二、JVM 运行时内存 1、新生代 1.1 Eden 区 1.2. ServivorFrom 1.3. ServivorTo 1.4 MinorGC 的过程 &#xff08;复制- >清空- >互换&#xff09; 1.4.1&#xff1a;eden 、servicorFrom 复制到ServicorTo&#xff0c;年龄1 …

docker入门配置

1、创建配置镜像 由于国内docker连接外网速度慢&#xff0c;采用代理 vi /etc/docker/daemon.json添加以下内容 {"registry-mirrors": ["https://9cpn8tt6.mirror.aliyuncs.com","https://dockerproxy.com","https://hub-mirror.c.163.co…

用一个暑假|用AlGC-stable diffusion 辅助服装设计及展示,让你在同龄人中脱颖而出!

大家好&#xff0c;我是设计师阿威 Stable Diffusion是一款开源AI绘画工具&#xff0c; 用户输入语言指令&#xff0c;即可自动生成各种风格的绘画图片 Stable Diffusion功能强大&#xff0c;生态完整、使用方便。支持大部分视觉模型上传&#xff0c;且可自己定制模型&#x…

[深度学习] Transformer

Transformer是一种深度学习模型&#xff0c;最早由Vaswani等人在2017年的论文《Attention is All You Need》中提出。它最初用于自然语言处理&#xff08;NLP&#xff09;任务&#xff0c;但其架构的灵活性使其在许多其他领域也表现出色&#xff0c;如计算机视觉、时间序列分析…

程序设计语言前言

1.机器语言及特点 2.编译语言及特点 3.高级语言及特点 4.编译和解释 5.IPO编程方式 一、机器语言 机器语言&#xff0c;也被称为二进制代码语言&#xff0c;是计算机硬件能够直接识别的程序语言或指令代码。它是由一系列由0和1组成的二进制指令码构成&#xff0c;每一条指令码…

分页组件 vue/uniapp

失效如上图 1.父组件调用 <onion-pagination :page.sync="todusGameQuery.pageSize" @update:page="changeTodusLoadMore":pageSize="todusGameQuery.pageNum" :total="todusGameTotal"></onion-pagination> 2.组件封装…

ios18开发者预览,Beta 2升级新增镜像等功能

近日&#xff0c;苹果发布了 iOS 18 开发者预览版 Beta 2 升级&#xff0c;为 iPhone 用户带来了多项新功能。据了解&#xff0c;这些新功能包括 iPhone 镜像和 SharePlay 屏幕共享&#xff0c;以及其他新增功能。 据了解&#xff0c;iPhone镜像可以让Mac用户将iPhone屏幕镜像…