GRU模块:nn.GRU层

news2024/11/16 9:53:45

摘要: 

       如果需要深入理解GRU的话,内部实现的详细代码和计算公式就比较重要,中间的一些过程及中间变量的意义需要详细关注。只有这样,才能准备把握这个模块的内涵和意义,设计初衷和使用方式等等。所以,仔细研究这个模块的实现还是非常有必要的。以此类推,对于其他的模块同样如此,只有把各个经典的模块内部原理、实现和计算调用都搞清楚了,才能更好的去设计和利用神经网络,建立内在的直觉和能力。

       本文中介绍GRU内部的代码实现与数学表达式一致,在实际使用中,一般是通过调用API来实现,即语句:self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout),只需要设定相应的参数即可,免除了重新实现的繁琐,并且类似于pytorch框架中的API还做了计算上的优化,使用起来高效方便。

       先从输入输出的角度看,即代码中的这一行:output, state = self.rnn(X) 。在 GRU(Gated Recurrent Unit)中,outputstate 都是由 GRU 层的循环计算产生的,它们之间有直接的关系。state 实际上是 output 中最后一个时间步的隐藏状态。 

代码示例

class Seq2SeqEncoder(d2l.Encoder):
"""⽤于序列到序列学习的循环神经⽹络编码器"""
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0, **kwargs):
       super(Seq2SeqEncoder, self).__init__(**kwargs)
        # 嵌⼊层
       self.embedding = nn.Embedding(vocab_size, embed_size)
       self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,
       dropout=dropout)

    def forward(self, X, *args):
    # 输出'X'的形状:(batch_size,num_steps,embed_size)
        X = self.embedding(X)
    # 在循环神经⽹络模型中,第⼀个轴对应于时间步
        X = X.permute(1, 0, 2)
    # 如果未提及状态,则默认为0
        output, state = self.rnn(X)
    # output的形状:(num_steps,batch_size,num_hiddens)
    # state的形状:(num_layers,batch_size,num_hiddens)
        return output, state

output:在完成所有时间步后,最后⼀层的隐状态的输出output是⼀个张量(output由编码器的循环层返回),其形状为(时间步数,批量⼤⼩,隐藏单元数)。

state:最后⼀个时间步的多层隐状态是state的形状是(隐藏层的数量,批量⼤⼩, 隐藏单元的数量)。

GRU模块的框图 

GRU 的基本公式

GRU 的核心计算包括更新门(update gate)和重置门(reset gate),以及候选隐藏状态(candidate hidden state)。数学表达式如下:

  1. 更新门 \( z_t \): \[ z_t = \sigma(W_z \cdot h_{t-1} + U_z \cdot x_t) \]
       其中,\( \sigma \) 是sigmoid 函数,\( W_z \) 和 \( U_z \) 分别是对应于隐藏状态和输入的权重矩阵,\( h_{t-1} \) 是上一个时间步的隐藏状态,\( x_t \) 是当前时间步的输入。

  2. 重置门 \( r_t \):
       \[ r_t = \sigma(W_r \cdot h_{t-1} + U_r \cdot x_t) \]
       \( W_r \) 和 \( U_r \) 是更新门中定义的相似权重矩阵。

  3. 候选隐藏状态 \( \tilde{h}_t \):
       \[ \tilde{h}_t = \tanh(W \cdot r_t \odot h_{t-1} + U \cdot x_t) \]
       这里,\( \tanh \) 是激活函数,\( \odot \) 表示元素乘法(Hadamard product),\( W \) 和 \( U \) 是隐藏状态的权重矩阵。

  4. 最终隐藏状态 \( h_t \):
       \[ h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \]

output 和 state 的关系

  • output:在 GRU 中,output 包含了序列中每个时间步的隐藏状态。具体来说,对于每个时间步 \( t \),output 的第 \( t \) 个元素就是该时间步的隐藏状态 \( h_t \)。

  • state:state 是 GRU 层最后一层的隐藏状态,也就是 output 中最后一个时间步的隐藏状态 \( h_{T-1} \),其中 \( T \) 是序列的长度。

数学表达式

如果我们用 \( O \) 表示 output,\( S \) 表示 state,\( T \) 表示时间步的总数,那么:

\[ O = [h_0, h_1, ..., h_{T-1}] \]
\[ S = h_{T-1} \]

因此,state 实际上是 output 中最后一个元素,即 \( S = O[T-1] \)。

在 PyTorch 中,output 和 state 都是由 GRU 层的 `forward` 方法计算得到的。`output` 是一个三维张量,包含了序列中每个时间步的隐藏状态,而 `state` 是一个二维张量,仅包含最后一个时间步的隐藏状态。

GRU的内部实现

上面一节的代码示例,是通过调用API实现的,即self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)。那么,GRU内部是如何实现的呢?

分为模型、模型参数初始化和隐状态初始化三个部分:

模型定义(模型定义与数学表示式一致,也可以参考上图):

def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
        R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
        H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = H @ W_hq + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)

  模型参数初始化(权重是从标准差0.01的高斯分布中提取的,超参数num_hiddens定义隐藏单元的数量,偏置项设置为0):

def get_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01
    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xz, W_hz, b_z = three() # 更新⻔参数
    W_xr, W_hr, b_r = three() # 重置⻔参数
    W_xh, W_hh, b_h = three() # 候选隐状态参数
# 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
# 附加梯度
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

隐状态初始化函数(此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值都为0

def init_gru_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device), )

最后由一个函数统一起来,实现模型:

model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru)

小结 

       总体上说,内部的代码实现与数学表达式一致,在实际使用中,一般是通过调用API来实现,即self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout),只需要设定相应的参数即可,免除了重新实现的繁琐,并且类似于pytorch框架中的API还做了计算上的优化,使用起来高效方便。但是,如果需要深入理解GRU的话,那么内部实现的详细代码和计算公式就比较重要,中间的一些过程和变量的意义需要详细关注,只有这样,才能准备把握这个模块的内涵和意义,设计初衷和使用方式等等,所以,仔细研究这个模块的实现还是非常有必要的。对于其他的模块同样如此,只有把各个经典的模块内部原理、实现和计算调用都搞清楚了,才能更好的去设计和利用神经网络,建立内在的直觉和能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1656453.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

药品管理系统的设计与实现

互联网发展至今,无论是其理论还是技术都已经成熟,而且它广泛参与在社会中的方方面面。它让信息都可以通过网络传播,搭配信息管理工具可以很好地为人们提供服务。针对药品信息管理混乱,出错率高,信息安全性差&#xff0…

学习笔记:【QC】Android Q : telephony-phone 模块

一、phone init 流程图 高清的流程图参考:【高清图,保存后可以放大看】 二、phone MO 流程图 高清的流程图参考:【高清图,保存后可以放大看】 三、phone MT 流程图 高清的流程图参考:【高清图,保存后可以…

C 深入指针(3)

目录 一、关于数组名 1 数组名的理解 2 数组名 与 &数组名 的区别 二、使用数组访问指针 三、一维数组传参的本质 四、二级指针 五、指针数组 六、指针数组模拟二维数组 一、关于数组名 1 数组名的理解 //VS2022 x64 #include <stdio.h> int main() {int a…

Cesium 问题:billboard 加载未出来

文章目录 问题分析问题 接上篇 Cesium 展示——图标的依比例和不依比例缩放,使用加载 billboard 时,怀疑是路径的原因导致未加载成功 分析 原先

知名专业定制线缆知名智造品牌品牌推荐-精工电联:如何实现清扫机器人线缆产品的精益求精

在科技日新月异的今天&#xff0c;智能清扫机器人已经融入我们的日常生活。然而&#xff0c;其背后不可或缺的一部分&#xff0c;就是那些被称为机器人血管的精密线缆。精工电联作为高科技智能化产品及自动化设备专用连接线束和连接器配套服务商&#xff0c;致力于通过精益求精…

用户行为分析与内容创新:Kompas.ai的数据驱动策略

在数字化营销的今天&#xff0c;用户行为数据分析已成为内容创新和策略调整的核心。通过深入理解用户的行为模式和偏好&#xff0c;品牌能够创造出更具吸引力和相关性的内容&#xff0c;从而实现精准营销。本文将探讨用户行为数据分析在内容创新和策略调整中的价值&#xff0c;…

Langchain实战

感谢阅读 LangChain介绍百度文心API申请申请百度智能云创建应用 LLMChain demo以及伪幻觉问题多轮对话的实现Sequential ChainsSimpleSequentialChainSequentialChainRouter Chain Documents ChainStuffDocumentsChainRefineDocumentsChainMapReduceDocumentsChainMapRerankDoc…

JavaEE 初阶篇-深入了解 HTTP 协议

&#x1f525;博客主页&#xff1a; 【小扳_-CSDN博客】 ❤感谢大家点赞&#x1f44d;收藏⭐评论✍ 文章目录 1.0 HTTP 协议概述 2.0 HTTP 请求协议 2.1 请求方式的具体体现 3.0 HTTP 响应协议 3.1 常见的状态码及描述 3.2 常见的响应头 4.0 HTTP 协议解析 4.1 简单实现服务器响…

win11环境下,Idea 快捷键shift+f6重命名无法使用

微软自带输入法导致 单击win弹出搜索框

Java数据结构---栈和队列

目录 栈&#xff08;Stack&#xff09; 队列&#xff08;Queue&#xff09; 循环队列 栈&#xff08;Stack&#xff09; 栈&#xff1a;一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除操作元素操作。进行数据插入和删除操作的一端称为栈顶&#xff0c;另一…

【启明智显分享】7寸HMI智能屏应用方案:Model4芯片实现展会多媒体互动展示

所谓&#xff1a;“技术是底层逻辑&#xff0c;用户是中心动力”。技术水平不断提升&#xff0c;人们对展会展览的形式和体验不断提出更多的要求&#xff0c;展会展现形式也正在不断地寻求突破。 传统的展示形式往往只能提供静态显示和有限的互动体验&#xff0c;在这个公众注意…

因表别名引用错误导致查询SQL执行时间长未出结果

问题描述&#xff1a; 项目组人员反馈在执行一条提取数据SQL时执行很慢&#xff0c;每次执行一段时间就报超时&#xff0c;要求帮忙提取下。 解决过程&#xff1a; 项目组人员发来SQL后&#xff0c;看了下SQL&#xff0c;没什么问题&#xff0c;就在客户端上执行了下&#xff0…

Antd Table组件,state改变,但是render并不会重新渲染

背景 在table上面&#xff0c;当鼠标放在cell上面的时候&#xff0c;需要去请求接口拉取数据&#xff0c;然后setList(res.result)后&#xff0c;希望render中的traceIds也能够实时更新渲染。 const [traceIds, setTraceIds] useState() // 需要展示在popover上面的数据&…

社媒营销中的截流获客是怎么一回事?

如果你要问&#xff0c;现在做社媒营销是通过哪些方式进行引流的&#xff0c;那么必然有一种是截流&#xff0c;顾名思义也就是分取别人的流量&#xff0c;方法其实很简单&#xff0c;主要分为两种&#xff1a;&#xff08;1&#xff09;抓取别人的粉丝出来进行群发私信&#x…

2024上海国际地下空间工程与技术展览会

2024上海国际地下空间工程与技术展览会 2024 Underground Space Project and Technology Expo 时间&#xff1a;2024年10月30日-11月1日 地点&#xff1a;上海世博展览馆 主办单位&#xff1a;联合国人居署 上海市住房与城乡建设管理委员会 城博会背景 上海国际城市与建筑…

mysql数据库---操作数据库跟表的命令总结

前言 欢迎来到我的博客 个人主页:北岭敲键盘的荒漠猫-CSDN博客 本文着重整理mysql管理库跟表的指令。 不涉及增删查改等指令 其实本篇主要是我做好笔记格式 用的时候直接复制粘贴的 所以排版大多是为了快速找功能来排的 方便大家快速找目标语法 数据库的简介 一个数据库系…

(代码以上传,超级详细)面试必备算法题----Leeecode290单词规律

文章目录 概要题目要求测试and提交结果技术细节 概要 来自Leecode ​ 代码已上传&#xff09;仓库&#xff0c;需要测试实例和其他题型解决&#xff0c;可以去自行浏览 点击这里进入仓库领取代码喔&#xff01;顺便点个star给原子加油吧&#xff01; ​ 题目要求 使用哈希表 …

国内外主流大模型都具备有哪些特点?

文章目录 ⭐ 火爆全网的大模型起点⭐ 国外主流LLM及其特点⭐ 国内主流LLM及其特点⭐ 全球大模型生态的发展 该章节呢&#xff0c;我们主要是看一下关于国内外主流的大语言模型&#xff0c;通过它们都具备哪些特点&#xff0c;来达成对多模型有一个清晰的认知。对于 “多模型” …

tensorflow报错

参考 TensorFlow binary is optimized to use available CPU instructions in performance-critical operations._this tensorflow binary is optimized to use availab-CSDN博客 解决Python中cuBLAS插件无法注册问题_unable to register cudnn factory: attempting to re-CS…

获取两个时间之间的月份

工具类 public static List<String> getMonthBetweenDate(Date startDate, Date endDate) {ArrayList<String> result new ArrayList<String>();SimpleDateFormat sdf new SimpleDateFormat("yyyy.MM");//格式化&#xff0c;调整为自己需要的格…