【深度学习基础模型】门控循环单元 (Gated Recurrent Units, GRU)详细理解并附实现代码。

news2024/11/17 4:40:16

【深度学习基础模型】门控循环单元 (Gated Recurrent Units, GRU)

【深度学习基础模型】门控循环单元 (Gated Recurrent Units, GRU)


文章目录

  • 【深度学习基础模型】门控循环单元 (Gated Recurrent Units, GRU)
  • 1.门控循环单元 (Gated Recurrent Units, GRU) 原理详解
    • 1.1 GRU 概述
    • 1.2 GRU 的门控机制
    • 1.3 GRU 的优缺点
    • 1.4 GRU 的应用
  • 2.Python 实现 GRU 的实例
    • 2.1GRU 实现及应用实例
    • 2.2 代码解释
  • 3.总结


参考地址:https://www.asimovinstitute.org/neural-network-zoo/
论文地址:https://arxiv.org/pdf/1412.3555v1

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!

1.门控循环单元 (Gated Recurrent Units, GRU) 原理详解

---

1.1 GRU 概述

GRU 是 LSTM(长短期记忆网络)的变体。与 LSTM 类似,GRU 也是为了解决 RNN 中的 梯度消失 和 梯度爆炸 问题而设计的,但 GRU 相比 LSTM 结构更为简单。GRU 去除了 LSTM 中的输出门,并结合了输入门和遗忘门为一个更新门。这使得 GRU 在某些情况下比 LSTM 更高效。

1.2 GRU 的门控机制

GRU 有两个门:更新门 (update gate) 和 重置门 (reset gate)

  • 更新门 (update gate): 控制当前隐藏状态中保留多少信息,决定保留多少先前的状态,以及从当前输入中引入多少新信息。
  • 重置门 (reset gate): 决定如何将新信息与之前的记忆结合起来,类似于 LSTM 的遗忘门,但工作方式稍有不同。

GRU 的公式为:

  • 更新门:
    z t = σ ( W z x t + U z h t − 1 ) z_t=σ(W_zx_t+U_zh_{t-1}) zt=σ(Wzxt+Uzht1)
  • 重置门:
    r t = σ ( W r x t + U r h t − 1 ) r_t=σ(W_rx_t+U_rh_{t-1}) rt=σ(Wrxt+Urht1)
  • 候选隐藏状态:
    h ~ t = t a n h ( W h x t + U h ( r t ⊙ h t − 1 ) ) \widetilde{h}_t=tanh(W_hx_t+U_h(r_t⊙h_{t-1})) h t=tanh(Whxt+Uh(rtht1))
  • 隐藏状态更新:
    h t = z t ⊙ h t − 1 + ( 1 − z t ) ⊙ h ~ t h_t=z_t⊙h_{t-1}+(1-z_t)⊙\widetilde{h}_t ht=ztht1+(1zt)h t

其中:

  • z t z_t zt是更新门,控制先前状态和当前候选状态的平衡。
  • r t r_t rt是重置门,控制前一时刻隐藏状态的影响程度。
  • h ~ t \widetilde{h}_t h t是候选的隐藏状态,使用当前输入和前一时刻的隐藏状态生成。
  • h t h_t ht是当前的隐藏状态。

1.3 GRU 的优缺点

  • 优点: 结构更简单,计算量较小,比 LSTM 更快,适合不需要复杂表达能力的场景。
  • 缺点: 由于少了一个门控机制(没有输出门),在某些任务中表现略逊于 LSTM。

1.4 GRU 的应用

GRU 和 LSTM 类似,广泛应用于序列数据处理任务,包括:

  • 自然语言处理 (NLP):如机器翻译、文本生成等。
  • 语音识别:处理连续的语音数据。
  • 时间序列预测:用于预测未来的趋势,例如股票预测等。

2.Python 实现 GRU 的实例

我们使用 PyTorch 实现一个基于 GRU 的文本分类模型。与前面 RNN 实例类似,我们将训练一个二分类模型。

2.1GRU 实现及应用实例

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 构造简单的示例数据集
# 假设有两个类别的句子,分别标注为 0 和 1
X = [
    [1, 2, 3, 4],     # "I love machine learning"
    [5, 6, 7, 8],     # "deep learning is great"
    [1, 9, 10, 11],   # "I hate spam emails"
    [12, 13, 14, 15]  # "phishing attacks are bad"
]
y = [0, 0, 1, 1]  # 标签

# 转换为 Tensor 格式
X = torch.tensor(X, dtype=torch.long)
y = torch.tensor(y, dtype=torch.long)

# 定义数据集和数据加载器
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 定义 GRU 模型
class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.embedding = nn.Embedding(input_size, hidden_size)  # 嵌入层
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers, batch_first=True)  # GRU 层
        self.fc = nn.Linear(hidden_size, output_size)  # 全连接层

    def forward(self, x):
        # 初始化隐藏状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        
        # 嵌入层
        out = self.embedding(x)
        
        # 通过 GRU
        out, _ = self.gru(out, h0)
        
        # 取最后一个时间步的隐藏状态
        out = out[:, -1, :]
        
        # 全连接层进行分类
        out = self.fc(out)
        return out

# 模型参数
input_size = 16  # 假设词汇表有 16 个词
hidden_size = 8  # 隐藏层维度
output_size = 2  # 输出为二分类
num_layers = 1   # GRU 层数

# 创建模型
model = GRUModel(input_size, hidden_size, output_size, num_layers)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 训练模型
num_epochs = 20
for epoch in range(num_epochs):
    for data, labels in dataloader:
        # 前向传播
        outputs = model(data)
        loss = criterion(outputs, labels)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    if (epoch+1) % 5 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# 测试模型
with torch.no_grad():
    test_sentence = torch.tensor([[1, 2, 3, 4]])  # 测试句子 "I love machine learning"
    prediction = model(test_sentence)
    predicted_class = torch.argmax(prediction, dim=1)
    print(f'Predicted class: {predicted_class.item()}')

2.2 代码解释

1.定义 GRU 模型:

  • self.embedding = nn.Embedding(input_size, hidden_size):将输入的单词索引转换为高维向量表示。
  • self.gru = nn.GRU(hidden_size, hidden_size, num_layers, batch_first=True):定义 GRU 层,输入和输出维度为 hidden_sizebatch_first=True 表示输入序列按批次为第一维度。
  • self.fc = nn.Linear(hidden_size, output_size):全连接层将 GRU 输出映射为分类输出。

2.GRU 的前向传播:

  • h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device):初始化 GRU 的隐藏状态。
  • out, _ = self.gru(out, h0):通过 GRU 层,out 是每个时间步的输出。
  • out = out[:, -1, :]:取最后一个时间步的隐藏状态作为最终输出。
  • out = self.fc(out):通过全连接层进行分类。

3.数据集与加载器:

  • 使用简单的二分类文本数据,将其转换为 PyTorch 的 TensorDatasetDataLoader

4.训练与测试:

  • 使用 Adam 优化器和交叉熵损失函数训练模型,在每 5 个 epoch 打印一次损失。
  • 测试阶段输入测试句子,输出分类结果。

3.总结

GRU 是一种简化的循环神经网络,与 LSTM 类似,适用于处理时间序列数据或具有顺序依赖的任务。

相比于 LSTM,GRU 计算效率更高,但表达能力稍弱。在实际应用中,GRU 常用于自然语言处理、语音识别和时间序列预测等领域。通过 Python 和 PyTorch 实现的 GRU 模型,展示了其在文本分类中的应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2167355.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一个任务的一辈子

总览 孕育:这一步是生命的起源,对应"任务"就是:申办人因为办理业务而发起一个流程。这是任务产生的摇篮。 任务的使命就是为了完成业务;生产:这是新生命产生的过程,对应"任务"就是:任务…

IT运维挑战与对策:构建高效一体化运维管理体系

在当今数字化时代,IT运维作为企业运营的核心支撑,其重要性不言而喻。然而,随着业务规模的扩大和技术的不断革新,IT运维团队面临着前所未有的挑战。本文旨在深度剖析当前IT运维中存在的主要问题,并探索一体化解决方案&a…

1500PLC使用EPOS控制伺服电机

硬件配置与参数 硬件配置 名称 型号 数量 PLC 1512C-1 PN 1个 伺服放大器 V90 PN 1个 伺服电机 SIMOTICS 1个 V90 PN伺服驱动器: 伺服驱动器硬件参数 使用软件:V-ASSISTANT 软件连接时可选择USB连接或者Ethernet连接,根据实际…

【ComfyUI】生成图细节更清晰——Consistency_Decoder

原文:https://github.com/openai/consistencydecoder comfyui: https://github.com/gameltb/Comfyui_Consistency_Decoder_VAE 博文资料下载:https://pan.baidu.com/s/1SwfA4T6iMsA8IrRrGXm4sg?pwd0925 安装 【秋葉aaaki】comfyui一键运行包 夸克网盘…

Vue下载静态文件

1、需求:将静态文件放在本地,让用户进行下载。 2、文件位置: ① 原生js:直接将文件放在某个目录或者根目录下 ② Vue:将文件放在根目录的public文件夹下面 3、代码示例: const url "/模板.xlsx"…

音视频入门基础:AAC专题(9)——FFmpeg源码中计算AAC裸流每个packet的duration和duration_time的实现

音视频入门基础:AAC专题系列文章: 音视频入门基础:AAC专题(1)——AAC官方文档下载 音视频入门基础:AAC专题(2)——使用FFmpeg命令生成AAC裸流文件 音视频入门基础:AAC…

前言 动手学深度学习课程安排及介绍

前言 动手学深度学习课程安排及介绍 文章目录 前言 动手学深度学习课程安排及介绍课程预告课程安排深度学习介绍 课程预告 学习深度学习关键是动手。 深度学习是人工智能最热的领域核心是神经网络神经网络是一门语言应该像学习Python/C一样学习深度学习 课程安排 【动手学深…

Mysql 存储List类型的数据

python request 爬到的数据里面有一部分是List,一开始在建表时想当然地使用 create table if not exists demo (id TEXT, short_id TEXT, parent_ids LIST)结果报错syntax error,查半天才发现Mysql里没有LIST这个类型 所以存储一个List只能将List数据…

第十六章 模板与泛型编程

16.1 定义模板 模板是C泛型编程的基础。为模板提供足够的信息&#xff0c;就能生成特定的类或函数。 16.1.1 函数模板 在模板定义中&#xff0c;模板参数列表不能为空。 //T的实际类型在编译时根据compare的使用情况来确定 template <typename T> int compare(const …

乱篇弹(54)让子弹飞

创作者在知乎能挣到钱吗&#xff1f; 芝士平台的答案&#xff1a;“当然能&#xff0c;在知乎&#xff0c;无论是各领域的优秀回答者&#xff0c;还是拥有几百或几千关注者的潜力创作者&#xff0c;甚至是只在知乎创作过几篇回答的新人创作者&#xff0c;都有可能在知乎赚钱 。…

[Linux]从零开始的Linux的远程方法介绍与配置教程

一、为什么需要远程Linux 相信大家在学习Linux时&#xff0c;要么是使用Linux的虚拟机或者在物理机上直接安装Linux。这样确实非常方便&#xff0c;我们也能直接看到Linux的桌面或者终端。既然我们都能直接看到终端或者Linux的桌面了&#xff0c;那我们为什么还要远程Linux呢&a…

WebSocket消息防丢ACK和心跳机制对信息安全性的作用及实现方法

WebSocket消息防丢ACK和心跳机制对信息安全性的作用及实现方法 在现代即时通讯&#xff08;IM&#xff09;系统和实时通信应用中&#xff0c;WebSocket作为一种高效的双向通信协议&#xff0c;得到了广泛应用。然而&#xff0c;在实际使用中&#xff0c;如何确保消息的可靠传输…

ai智能抠图有哪些?我只告诉你这些

在广告、设计、摄影以及视频剪辑等创意领域&#xff0c;抠图技术就像是一把神奇的钥匙&#xff0c;能够将图片中的精彩瞬间或独特元素巧妙地分离出来&#xff0c;并融入到全新的背景之中&#xff0c;创造出无限的可能性。 当面对复杂图形的挑战时&#xff0c;使用高效的在线智…

RabbitMQ基础使用

1.MQ基础介绍 同步调用 OpenFeign的调用。这种调用中&#xff0c;调用者发起请求后需要等待服务提供者执行业务返回结果后&#xff0c;才 能继续执行后面的业务。也就是说调用者在调用过程中处于阻塞状态&#xff0c;因此我们称这种调用方式为同步调用 异步调用 异步调用通…

Lucene 倒排索引原理详解:深入探讨相关算法设计

引言 随着互联网的快速发展&#xff0c;数据量呈现爆炸性的增长&#xff0c;如何从海量数据中快速准确地获取所需信息成为了一项挑战。全文搜索引擎的出现极大地解决了这个问题&#xff0c;而 Lucene 正是一款优秀的开源全文搜索引擎库。本文将深入探讨 Lucene 的核心技术之一…

NtripShare测量机器人自动化监测系统测站更换仪器后重新设站

NtripShare测量机器人自动化监测系统投入商业运营已经很久了&#xff0c;在MosBox与自动优化网平差技术的加持下&#xff0c;精度并不让人担心&#xff0c;最近基于客户需求处理了两个比较大的问题。 1、增加对反射片和免棱镜的支持。 2、进一步优化测站更换仪器或重新整平后重…

顶点缓存对象(VBO)与顶点数组对象(VAO)

我们的顶点数组在CPU端的内存里是以数组的形式存在,想要GPU去绘制三角形,那么需要将这些数据传输给GPU。那这些数据在显存端是怎么存储的呢?VBO上场了,它代表GPU上的一段存储空间对象,表现为一个unsigned int类型的变量,GPU端内存对象的一个ID编号、地址、大小。一个VBO对…

Cpp内存管理(7)

文章目录 前言一、C/C内存区域划分二、C/C动态内存管理C语言动态内存管理C动态内存管理对于内置类型对于自定义类型 三、new和delete的底层实现四、new和delete的实现原理五、定位new六、malloc/free和new/delete的区别总结 前言 软件开发过程中&#xff0c;内存管理的重要性不…

vue3中echarts柱状图横轴文字太多放不下怎么解决

问题&#xff1a;在做数据展示的时候&#xff0c;使用的是echarts&#xff0c;遇到了个问题&#xff0c;就是数据过多&#xff0c;但是设置的x轴的文字名称又太长&#xff0c;往往左边第一个或右边最后一个的名称展示不全&#xff0c;只有半个。 从网上找到了几种办法&#xff…

进击J8:Inception v1算法实战与解析

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 一、实验目的&#xff1a; 了解并学习图2中的卷积层运算量的计算过程了解并学习卷积层的并行结构与1x1卷积核部分内容&#xff08;重点&#xff09;尝试根据模…