门控循环神经网络学习笔记

news2024/10/5 14:01:05

在介绍门控循环神经网络之前,先简单介绍循环神经网络的基本计算方式

在这里插入图片描述

循环神经网络之称之为“循环”,因为其隐藏状态是循环利用的:

上一次输入计算出的隐藏状态与当前的输入结合,得到当前隐藏状态。

cur_output, cur_state = rnn(cur_X, last_state)

隐状态中保留了之前输入的特征和结构(对应句子的词元和结构)。

接下来介绍门控循环神经网络的几个方面:功能、计算方式、完整实现

(一)门控循环神经网络的功能:

门控循环神经网络和常规的循环神经网络有什么不同呢?

门控循环神经网络相比于常规的循环神经网络,可以有选择性地保留词元间的长期依赖关系。

这种描述或许有点抽象,所以我们通过两种不同的情况来理解一下其含义:

<1> 当早期的观测值对于接下来的观测具有重要意义时:

举一个具体的例子,当你看一篇文章或者一个句子,开头给出了时间或者地点,这个预测信息可能会影响到之后所有的观测值。

如果小说开头交代了一个年代信息,那么之后的事件都会发生在这个年代,不会出现这个年代不该出现的东西。

这时长期依赖关系对于我们的预测有着重要意义,所以应该选择性加以保留。

<2> 当一些观测值与我们接下来的观测没有联系时:

同样是一篇小说,我们不能根据它描述的一个人的发色来判断这个人的心情。

这时长期依赖关系对于我们的观测没有意义,应该选择性加以丢弃。

(二)门控循环神经网络的计算方式:

在这里插入图片描述

我们把门控循环神经网络的计算方式分为三步:

第一步:由cur_Xlast_state计算得到重置门和更新门:

门控循环神经网络有两种门:重置门和更新门。

重置门负责的是如何将过去的信息与新的输入相结合,保留可能还想留下的旧记忆。(之所以是可能,是因为是否保留还取决于更新门)

它有助于捕获短期依赖关系。

更新门负责帮助模型决定到底传递多少过去的信息到未来,也就是更新记忆。

它有助于捕获长期依赖关系。

这两项作用分别在第二步和第三步中有所体现。

我们先来看看第一步的更新方式:

然后是计算代码:

Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)  # 更新门
R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)  # 重置门

其中W_xz, W_hz, b_zW_xr, W_hr, b_r分别是更新门和重置门的可学习参数。

(2)第二步:用重置门去结合过去的隐状态与新的输入

还是先来看一下更新方式:

请添加图片描述

然后是计算代码:

H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)  # 候选隐状态

其中W_xh, W_hh, b_h是候选隐状态的可学习参数

(3)第三步:用更新门获取当前隐状态

先来看一下第三步的更新方式:

请添加图片描述

然后是计算代码:

cur_state = Z * H + (1 - Z) * H_tilda  # 更新

最后我们就可以根据得到的隐状态计算输出了:

cur_output = H @ W_ho + b_o  # 输出

其中W_ho, b_o是输出的可学习参数

(三)门控循环神经网络的完整实现:

import torch


class GRU:
    def __init__(self, vocab_size, hidden_size, device):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.device = device
        self.parameters = self.get_params()
    
    def get_params(self):
        """获取参数"""
        num_inputs = num_outputs = self.vocab_size
    
        def normal(shape):
            return torch.randn(size=shape, device=self.device)
    
        def three():
            return (
                normal((num_inputs, self.hidden_size)),  # 输入参数
                normal((self.hidden_size, self.hidden_size)),  # 隐状态参数
                torch.zeros(self.hidden_size, device=self.device)  # 偏移量
            )
    
        W_xz, W_hz, b_z = three()  # 更新门(z)参数(x, h, b)
        W_xr, W_hr, b_r = three()  # 重置门(r)参数(x, h, b)
        W_xh, W_hh, b_h = three()  # 候选隐状态(h)参数(x, h, b)
        W_ho, b_o = normal((self.hidden_size, num_outputs)),\
                    torch.zeros(num_outputs, device=self.device)  # 输出(o)参数(h, b)
    
        #  附加梯度
        params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_ho, b_o]
        for param in params:
            param.requires_grad_(True)
    
        return params
    
    
    def init_state(batch_size, num_hiddens, device):
        """初始化隐状态"""
        return torch.zeros((batch_size, num_hiddens), device=device)
    
    
    def __call__(inputs, state, params):
        W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_ho, b_o = params
        H = state
        outputs = []
    
        for X in inputs:  # 输入为独热编码
            Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)  # 更新门
            R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)  # 重置门
            H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)  # 候选隐状态
            H = Z * H + (1 - Z) * H_tilda  # 更新
            Y = H @ W_ho + b_o  # 输出
            outputs.append(Y)
    
        return torch.cat(outputs, dim=0), H  # 返回输出和更新后的隐状态

到这里门控循环神经网络的介绍就结束了,我们这里给出门控神经网络的简洁实现:

循环神经网络层的实现:

rnn = torch.nn.GRU(input_size, hidden_size, layers, dropout=dropout)

其中input_size是输入特征维度,hidden_size是隐藏层维度,layers是循环网络层数,dropout是暂退层超参数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/445565.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【嵌入式笔/面试】嵌入式软件基础题和真题总结——操作系统

在学习的时候找到几个十分好的工程和个人博客&#xff0c;先码一下&#xff0c;内容都摘自其中&#xff0c;有些重难点做了补充&#xff01; 才鲸 / 嵌入式软件笔试题汇总 嵌入式与Linux那些事 阿秀的学习笔记 小林coding 百问网linux 嵌入式软件面试合集 2022年春招实习十四面…

电脑丢失的dll文件怎么一键修复?修复dll方法分享

电脑丢失的dll文件怎么一键修复&#xff1f;电脑状况常常让人遇到各种问题&#xff0c;其中“DLL文件丢失”是最常见的问题之一。在这篇文章中&#xff0c;我们会介绍为何会出现DLL文件丢失的问题&#xff0c;以及提供一种简单、快捷的DLL文件修复方法。 一.为何会出现DLL文件丢…

vue使用vue-mapvgl实现烟台市各区县行政区绘制、三维柱状图

一、效果展示 二、地图组件&#xff1a; vue-mapvgl https://docs.guyixi.cn/vue-mapvgl/#/ 三、代码 main.js //vue-mapvGL import VueBMap from vue-bmap-gl; import vue-bmap-gl/dist/style.css import VueMapvgl from vue-mapvgl; Vue.use(VueBMap); Vue.use(VueMapvg…

c++算法——算法章节-时间空间复杂度

算法开章咯 这次是csp-j组算法 枚举法常用排序合集hash一维前缀和vector结构体queuestack贪心-简单贪心区间递归二分setmap二叉树图的遍历-邻接矩阵迷宫问题-dfs-深度优先搜素bfs-广度优先搜索动态规划-简单动态规划-01背包动态规划-背包-多重背包二分答案 算法是什么嘛&…

腾讯云轻量4核8G12M应用服务器带宽、月流量详细性能评测

腾讯云轻量4核8G12M应用服务器带宽&#xff0c;12M公网带宽下载速度峰值可达1536KB/秒&#xff0c;折合1.5M/s&#xff0c;每月2000GB月流量&#xff0c;折合每天66GB&#xff0c;系统盘为180GB SSD盘&#xff0c;地域节点可选上海、广州或北京&#xff0c;4核8G服务器网来详细…

0基础同学如何快速入门学Python

转自&#xff1a;https://www.zhihu.com/question/596253606/answer/2994169972 想学Python的小伙伴&#xff0c;这里给你们汇总了&#xff1a;学习资源、平台、小白环境配置、相关课程、书籍资料&#xff01;并且&#xff0c;附送学习方法以及计划制定。 一、可以了解到Pyth…

Appuploader证书申请教程

转载&#xff1a;IOS证书制作教程 点击苹果证书 按钮 点击新增 输入证书密码&#xff0c;名称 这个密码不是账号密码&#xff0c;而是一个保护证书的密码&#xff0c;是p12文件的密码&#xff0c;此密码设置后没有其他地方可以找到&#xff0c;忘记了只能删除证书重新制作&…

还在精神内耗?还在焦虑?可以看看这个

作为一个即将毕业的本科生&#xff0c;总是会不由自主的焦虑。因为不考研&#xff0c;所以显得和同学们格格不入&#xff0c;每天都在进行精神内耗&#xff0c;但是我不经意间看到了一个东西-《邓宁克鲁格效应》 上述的四个阶段刻画出了一条典型的“大师养成之路”。但大师毕竟…

华为三层交换机命令集合,已经分好类了,网工建议收藏!

你好&#xff0c;这里是网络技术联盟站。 本文给大家带来的是华为三层交换机的命令集合&#xff0c;我已经分好类&#xff0c;大家可以收藏备用&#xff01; 一、系统管理命令 1.1 查看版本信息 display version此命令用于查看交换机的版本信息&#xff0c;包括交换机的软件…

【AI理论学习】深入理解Prompt Learning和Prompt Tuning

深入理解Prompt Learning和Prompt Tuning 背景Prompt Learning简介1. Prompt是什么&#xff1f;2. 为什么要使用Prompt&#xff1f;3. Prompt Learning的形式&#xff08;举例&#xff09;4. 有哪些Pre-training language model&#xff1f;5. 常见的Prompt Learning的方法 Pro…

WebRTC 源码分析——Android 视频硬件编码

作者&#xff1a;DevYK 1. 简介 本文将重点介绍在 Android 平台上&#xff0c;WebRTC 是如何使用 MediaCodec 对视频数据进行编码&#xff0c;以及在整个编码过程中 webrtc native 与 java 的流程交互。 本篇开始会先回顾一下 Andorid MediaCodec 的概念和基础使用&#xff0…

Node【Global全局对象】之【Process】

文章目录 &#x1f31f;前言&#x1f31f;Process&#x1f31f;process属性&#x1f31f;process.env &#x1f31f;process方法&#x1f31f;process事件&#x1f31f;uncaughtException &#x1f31f;写在最后 &#x1f31f;前言 哈喽小伙伴们&#xff0c;新的专栏 Node 已开…

VSCode + GCC编译器(MinGW)开发环境中文字符乱码问题踩坑与解决办法

文章目录 问题背景问题描述测试代码测试结果现象描述问题分析 解决方案修改默认配置1. 已经存在的文件全部使用gbk编码重新保存。2. 在工程目录下新建.vscode目录&#xff0c;如果已存在则跳过此步骤。3. 在.vscode目录中新建settings.json&#xff0c;launch.json两个文件&…

SAP CAP篇二:为Service加上数据库支持

在篇一快速创建一个Service&#xff0c;基于Java的实现中&#xff0c;可见使用SAP CAP &#xff08;Cloud Programming Model&#xff09;确实可以提高开发效率。尤其是Java技术栈上&#xff0c;对比于之前使用Olingo框架来实现oData&#xff0c;使用SAP CAP真的可以做到指数级…

Hightopo应邀参加 2023 第十届中国工业数字化论坛

3 月 30 日&#xff0c;以“加快数字化转型&#xff0c;助推高质量发展”为主题的第十届中国工业数字化论坛在北京隆重举行。厦门图扑软件科技有限公司&#xff08;以下简称“图扑软件”&#xff09;应邀参展&#xff0c;与诸位专家、领导、业界同仁共同研讨工业领域的数字化创…

红包算法关于---随机分发和平均分发

目录 群发普通红包 流程图 MainRedPacket类 Manager类 Member类 User类 群发普通红包 题目介绍 某软件有多名用户&#xff08;User类&#xff09;&#xff0c;某群聊中有群主&#xff08;Manager类&#xff09;和多名普通成员&#xff08;Member类&#xff09;&#x…

c++ 11 auto的概念和用法

目录 auto的概念&#xff1a; 使用auto声明变量的语法: auto关键字使用场景: 1.简化代码的书写和阅读 2.避免类型繁琐的重复定义 auto使用时的注意事项&#xff1a; auto的概念&#xff1a; 在C11标准中&#xff0c;auto是一种类型推导机制。它可以让编译器根据右值表达式…

代码随想录训练营day52|300、最长递增子序列;674、最长连续递增序列;718、最长重复子数组

300、最长递增子序列 给你一个整数数组 nums &#xff0c;找到其中最长严格递增子序列的长度。 子序列是由数组派生而来的序列&#xff0c;删除&#xff08;或不删除&#xff09;数组中的元素而不改变其余元素的顺序。例如&#xff0c;[3,6,2,7] 是数组 [0,3,1,6,2,2,7] 的子…

【Spring Boot】SpringBoot 优雅整合Swagger Api 自动生成文档

文章目录 前言一、添加 Swagger 依赖二、创建接口类三、添加 Swagger 配置类四、访问 Swagger 页面五、整合一个更友好的UI接口文档 Knife4j1、添加 Knife4j 依赖2、添加 Knife4j 配置类3、访问 Knife4j 页面 总结 前言 Swagger 是一套 RESTful API 文档生成工具&#xff0c;可…

《选择》比努力更重要——C语言

目录 前言: 1.语句 2.选择语句 2.1小栗子 2.2选择结构 3.误导性else 3.1写法上的可读性和代码的稳健性&#xff1a; 3.2一些练习 4.switch选择语句 4.1嵌套的switch ❤博主CSDN:啊苏要学习 ▶专栏分类&#xff1a;C语言◀ C语言的学习&#xff0c;是为我们今后学习其…