动手学深度学习(Pytorch版)代码实践 -循环神经网络- 56门控循环单元(`GRU`)

news2025/1/15 23:35:06

56门控循环单元(GRU

我们讨论了如何在循环神经网络中计算梯度, 以及矩阵连续乘积可以导致梯度消失或梯度爆炸的问题。 下面我们简单思考一下这种梯度异常在实践中的意义:

  • 我们可能会遇到这样的情况:早期观测值对预测所有未来观测值具有非常重要的意义。 考虑一个极端情况,其中第一个观测值包含一个校验和, 目标是在序列的末尾辨别校验和是否正确。 在这种情况下,第一个词元的影响至关重要。 我们希望有某些机制能够在一个记忆元里存储重要的早期信息。 如果没有这样的机制,我们将不得不给这个观测值指定一个非常大的梯度, 因为它会影响所有后续的观测值。
  • 我们可能会遇到这样的情况:一些词元没有相关的观测值。 例如,在对网页内容进行情感分析时, 可能有一些辅助HTML代码与网页传达的情绪无关。 我们希望有一些机制来跳过隐状态表示中的此类词元。
  • 我们可能会遇到这样的情况:序列的各个部分之间存在逻辑中断。 例如,书的章节之间可能会有过渡存在, 或者证券的熊市和牛市之间可能会有过渡存在。 在这种情况下,最好有一种方法来重置我们的内部状态表示。

门控循环单元与普通的循环神经网络之间的关键区别在于: 前者支持隐状态的门控。 这意味着模型有专门的机制来确定应该何时更新隐状态, 以及应该何时重置隐状态。 这些机制是可学习的,并且能够解决了上面列出的问题。 例如,如果第一个词元非常重要, 模型将学会在第一次观测之后不更新隐状态。 同样,模型也可以学会跳过不相关的临时观测。 最后,模型还将学会在需要的时候重置隐状态。

1.重置门和更新门
  • 重置门有助于捕获序列中的短期依赖关系。
  • 更新门有助于捕获序列中的长期依赖关系。

在这里插入图片描述

2.候选隐状态

在这里插入图片描述

3.隐状态

在这里插入图片描述

4.从零开始实现
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

# 定义批量大小和时间步数
batch_size, num_steps = 32, 35

# 使用d2l库的load_data_time_machine函数加载数据集
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

def get_params(vocab_size, num_hiddens, device):
    """
    初始化GRU模型的参数。
    参数:
        vocab_size (int): 词汇表的大小。
        num_hiddens (int): 隐藏单元的数量。
        device (torch.device): 张量所在的设备。

    返回:list of torch.Tensor: 包含所有参数的列表。
    """
    num_inputs = num_outputs = vocab_size  # 输入和输出的数量都等于词汇表大小

    def normal(shape):
        """
        使用均值为0,标准差为0.01的正态分布初始化张量。
        参数: shape (tuple): 张量的形状。
        返回:torch.Tensor: 初始化后的张量。
        """
        return torch.randn(size=shape, device=device) * 0.01

    def three():
        """
        初始化GRU门的参数。
        返回:tuple of torch.Tensor: 包含门的权重和偏置的元组。
        """
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xz, W_hz, b_z = three()   # 更新门参数
    W_xr, W_hr, b_r = three()   # 重置门参数
    W_xh, W_hh, b_h = three()   # 候选隐藏状态参数
    
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    
    # 将所有参数收集到一个列表中
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
   
    for param in params: # 启用所有参数的梯度计算
        param.requires_grad_(True)
    
    return params

def init_gru_state(batch_size, num_hiddens, device):
    """
    初始化GRU的隐藏状态。

    参数:
        batch_size (int): 批量大小。
        num_hiddens (int): 隐藏单元的数量。
        device (torch.device): 张量所在的设备。

    返回:tuple of torch.Tensor: 初始隐藏状态。
    """
    return (torch.zeros((batch_size, num_hiddens), device=device), )

def gru(inputs, state, params):
    """
    定义GRU的前向传播。

    参数:
        inputs (torch.Tensor): 输入数据。
        state (tuple of torch.Tensor): 隐藏状态。
        params (list of torch.Tensor): GRU的参数。

    返回:
        torch.Tensor: GRU的输出。
        tuple of torch.Tensor: 更新后的隐藏状态。
    """
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state  # 获取隐藏状态
    outputs = []  # 存储输出的列表
    for X in inputs:  # 遍历每一个输入时间步
        # 计算更新门Z
        Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
        # 计算重置门R
        R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
        # 计算候选隐藏状态H_tilda
        H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
        # 更新隐藏状态H
        H = Z * H + (1 - Z) * H_tilda
        # 计算输出Y
        Y = H @ W_hq + b_q
        outputs.append(Y)  # 将输出添加到列表中
    return torch.cat(outputs, dim=0), (H,)  # 返回连接后的输出和更新后的隐藏状态

# 获取词汇表大小、隐藏单元数量和设备
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
# 定义训练的轮数和学习率
num_epochs, lr = 500, 1
# 初始化GRU模型
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru)
# 使用d2l库的train_ch8函数训练模型
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.1, 38557.3 tokens/sec on cuda:0
# time traveller for so it will be convenient to speak of himwas e

在这里插入图片描述

5.简洁实现
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

# 定义批量大小和时间步数
batch_size, num_steps = 32, 35
# 使用d2l库的load_data_time_machine函数加载数据集
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

num_epochs, lr = 500, 1
# # 获取词汇表大小、隐藏单元数量和设备
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens) # 定义一个GRU层,输入大小为num_inputs,隐藏单元数量为num_hiddens
model = d2l.RNNModel(gru_layer, len(vocab)) # 使用GRU层和词汇表大小创建一个RNN模型
model = model.to(device)
# 该函数需要模型、训练数据迭代器、词汇表、学习率、训练轮数和设备作为参数
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.0, 248342.8 tokens/sec on cuda:0
# time travelleryou can show black is white by argument said filby

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1904245.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Windows电脑下载、安装VS Code的方法

本文介绍Visual Studio Code(VS Code)软件在Windows操作系统电脑中的下载、安装、运行方法。 Visual Studio Code(简称VS Code)是一款由微软开发的免费、开源的源代码编辑器,支持跨平台使用,可在Windows、m…

C++ 模版进阶

目录 前言 1. 非类型模版参数 1.1 概念与讲解 1.2 array容器 2. 模版的特化 2.1 概念 2.2 函数模版特化 2.3 类模版特化 2.3.1 全特化 2.3.2 偏特化 3.模版的编译分离 3.1 什么是分离编译 3.2 模版的分离编译 3.3 解决方法 4. 模版总结 总结 前言 本篇文章主要…

在mac下 Vue2和Vue3并存 全局Vue2环境创建Vue3新项目(Vue cli2和Vue cli4)

全局安装vue2 npm install vue-cli -g自行在任意位置创建一个文件夹vue3,局部安装vue3,注意不要带-g npm install vue/cli安装完成后,进入目录,修改vue为vue3 找到vue3/node-moudles/.bin/vue,把vue改成vue3。 对环境变量进行配置…

【6】图像分类部署

【6】图像分类部署 文章目录 前言一、将pytorch模型转为ONNX二、本地终端部署2.1. ONNX Runtime部署2.2. pytorch模型部署(补充) 三、使用flask的web网页部署四、微信小程序部署五、使用pyqt界面化部署总结 前言 包括将训练好的模型部署在本地终端、web…

在Linux操作系统中去修复/etc/fstab文件引起的系统故障。

如果/etcfstab文件中发生错误,有可能导致系统无法正常启动。 比如:系统里的一块磁盘被删除,但是/etc/fstab中关于这块磁盘的信息依然被保存在文件/etc/fstab中。 主要看倒数后两行,系统提示,敲ctrlD或者是直接输入密码…

LeetCode 744, 49, 207

目录 744. 寻找比目标字母大的最小字母题目链接标签思路代码 49. 字母异位词分组题目链接标签思路代码 207. 课程表题目链接标签思路代码 744. 寻找比目标字母大的最小字母 题目链接 744. 寻找比目标字母大的最小字母 标签 数组 二分查找 思路 本题比 基础二分查找 难的一…

redhat7.x 升级openssh至openssh-9.8p1

1.环境准备: OS系统:redhat 7.4 2.备份配置文件: cp -rf /etc/ssh /etc/ssh.bak cp -rf /usr/bin/openssl /usr/bin/openssl.bak cp -rf /etc/pam.d /etc/pam.d.bak cp -rf /usr/lib/systemd/system /usr/lib/systemd/system.bak 3.安装…

阿里云存储应用

如何做好权限控制 小浩在梳理门户网站静态资源时,发现有些资源是仅内部员工可访问,有些资源是特定的注册客户可访问,还有些资源是匿名客户也可以访问。针对不同场景、不同用户,小浩该如何规划企业门户网站静态资源的权限控制呢&a…

解析商场智能导视系统背后的科技:AR导航与大数据如何助力商业运营

在布局复杂的大型商场中,顾客常常面临寻找特定店铺的挑战。商场的规模庞大,店铺众多,使得顾客在享受购物乐趣的同时,也不得不面对寻路的难题。维小帮商场智能导航导视系统的电子地图、AR导航营销能为顾客提供更加便捷的购物体验。…

Linux—网络设置

目录 一、ifconfig——查看网络配置 1、查看网络接口信息 1.1、查看所有网络接口 1.2、查看具体的网络接口 2、修改网络配置 3、添加网络接口 4、禁用/激活网卡 二、hostname——查看主机名称 1、查看主机名称 2、临时修改主机名称 3、永久修改主机名称 4、查看本…

2023年了,还在手动px转rem吗?

px-to-rem 使用amfe-flexible和postcss-pxtorem在webpack中配置px转rem npm i amfe-flexible -Snpm i postcss-pxtorem -D在main.js中 import flexible from amfe-flexible Vue.use(flexible);index.html中 <meta name"viewport" content"widthdevice-w…

用 Echarts 画折线图

https://andi.cn/page/621503.html

Floyd判圈算法——环形链表(C++)

Floyd判圈算法(Floyd Cycle Detection Algorithm)&#xff0c;又称龟兔赛跑算法(Tortoise and Hare Algorithm)&#xff0c;是一个可以在有限状态机、迭代函数或者链表上判断是否存在环&#xff0c;求出该环的起点与长度的算法。 …

Java语言程序设计篇一

Java语言概述 Java语言起源编程语言最新排名名字起源Java语言发展历程Java语言的特点Java虚拟机垃圾回收Java语言规范Java技术简介Java程序的结构Java程序注意事项&#xff1a;注释编程风格练习 Java语言起源 1990年Sun公司提出一项绿色计划。1992年语言开发成功最初取名为Oak…

vue3使用方式汇总

1、引入iconfont阿里图库图标&#xff1a; 1.1 进入阿里图标网站&#xff1a; iconfont阿里&#xff1a;https://www.iconfont.cn/ 1.2 添加图标&#xff1a; 1.3 下载代码&#xff1a; 1.4 在vue3中配置代码&#xff1a; 将其代码复制到src/assets/fonts/目录下&#xff1…

大众点评2024年全球必吃榜清单

大众点评2024年全球必吃榜清单共2797家&#xff0c;奇怪的是官方并没有发布详细清单&#xff0c;只发布了新闻通稿介绍大概情况。这里做一些整理。 按城市分布情况&#xff0c;数量如下 上海 144 北京 137 成都 96 重庆 93 广州 81 深圳 79 武汉 69 苏州 67 杭州 61 …

应急响应--网站(web)入侵篡改指南

免责声明:本文... 目录 被入侵常见现象: 首要任务&#xff1a; 分析思路&#xff1a; 演示案例: IIS&.NET-注入-基于时间配合日志分析 Apache&PHP-漏洞-基于漏洞配合日志分析 Tomcat&JSP-弱口令-基于后门配合日志分析 (推荐) Webshell 查杀-常规后门&…

17_VGG深度学习图像分类算法

1.1 简介 VGG网络&#xff0c;全称为Visual Geometry Group网络&#xff0c;是由牛津大学的Visual Geometry Group和谷歌DeepMind的研究人员共同提出的深度卷积神经网络模型。这一模型因在2014年ILSVRC&#xff08;ImageNet大规模视觉识别挑战赛&#xff09;中取得图像分类任务…

昇思25天学习打卡营第4天|MindSpore数据集和数据变换

# 打卡 目录 # 打卡 Dateset&#xff1a;Pipeline 的起始 具体步骤 数据处理 Pipeline 代码例子 内置数据集的情况 自定义数据集的情况 可迭代的数据集 生成器 Transforms&#xff1a;数据预处理 代码例子 通用变换Compose 文本变换 Text Lambda变换 Dateset&…

STM32芯片系列与产品后缀解读

一. 产品系列 STM32单片机是一系列基于ARM Cortex-M内核的32位微控制器&#xff0c;广泛应用于嵌入式系统中。 STM32系列由STMicroelectronics&#xff08;意法半导体&#xff09;开发和生产&#xff0c;并凭借其灵活的设计、丰富的外设和强大的生态系统&#xff0c;成为嵌入式…