机器学习笔记:门控循环单元的建立

news2024/9/24 15:28:10

目录

介绍

结构

模型原理

重置门与更新门

候选隐状态

输出隐状态

模型实现

引入数据

初始化参数

定义模型

训练与预测

简洁实现GRU

思考


介绍

门控循环单元(Gated Recurrent Unit,简称GRU)是循环神经网络一种较为复杂的构成形式,其用途也是处理时序数据,相比具有单隐藏状态的RNN,GRU具有忘记的能力,可以忘记无用的数据。

结构

与传统RNN相比,GRU的结构引入了的概念,比RNN复杂许多,不过可以看出,其输入仍然是X_t和上一时间步隐状态H_{t-1},输出仍然是本时间步隐状态H_t。区别在于“细胞”内部结构,RNN只需要将H和X分别处理,之后结合在一起,激活函数激活后将其输出即可。而GRU内部处理十分复杂。

模型原理

我们以处理的顺序来依次讲解各个组成部分的模型原理。

重置门与更新门

首先介绍重置门(reset gate)R_t更新门(update gate)Z_t。 我们把它们设计成(0,1)区间中的向量。 重置门允许我们控制“可能还想记住”的过去状态的数量; 更新门将允许我们控制新状态中有多少个是旧状态的副本。后面还会再提到两个门的具体作用。

重置门和更新门的计算公式如下所示,由于使用sigmoid函数,R_tZ_t的值在(0,1)区间内。

R_t= \sigma (X_t \cdot W_{xr} + H_{t-1} \cdot W_{hr}+b_r)

Z_t= \sigma (X_t \cdot W_{xz} + H_{t-1} \cdot W_{hz}+b_z)

候选隐状态

候选隐状态的计算公式如下,是RNN中计算公式的升级版。(\bigodot是哈达玛积)

\tilde{H}_t = tanh(X_t\cdot W_{xh}+(H_{t-1}\bigodot R_t)\cdot W_{hh}+b_h)

当重置门R的值接近1时,则候选隐状态的计算与RNN一致,当重置门R的值接近0时,则候选隐状态计算时会完全“忘记”之前的值。

输出隐状态

输出隐状态需要更新门,候选隐状态和上一阶段隐状态共同计算得到。

H_t=Z_t \bigodot H_{t-1}+ (1-Z_t)\bigodot \tilde{H_t}

由公式可以看出,当Z_t接近0时,隐状态即为候选隐状态,当Z_t接近1时,隐状态即为上一阶段隐状态,更新门决定隐状态中有多少部分进行更新。

模型实现

引入数据

我们从零开始实现一个GRU,首先引入相关的库,并定义相关的一系列超参数。

from mxnet import np, npx
from mxnet.gluon import rnn
from d2l import mxnet as d2l

npx.set_np()

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化参数

将需要学习的参数进行初始化。

def get_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return np.random.normal(scale=0.01, size=shape, ctx=device)

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                np.zeros(num_hiddens, ctx=device))

    W_xz, W_hz, b_z = three()  # 更新门参数
    W_xr, W_hr, b_r = three()  # 重置门参数
    W_xh, W_hh, b_h = three()  # 候选隐状态参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = np.zeros(num_outputs, ctx=device)
    # 附加梯度
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.attach_grad()
    return params

定义模型

定义门控循环单元模型, 模型的架构与基本的循环神经网络单元是相同的, 只是权重更新公式更为复杂。

def init_gru_state(batch_size, num_hiddens, device):
    return (np.zeros(shape=(batch_size, num_hiddens), ctx=device), )
def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = npx.sigmoid(np.dot(X, W_xz) + np.dot(H, W_hz) + b_z)
        R = npx.sigmoid(np.dot(X, W_xr) + np.dot(H, W_hr) + b_r)
        H_tilda = np.tanh(np.dot(X, W_xh) + np.dot(R * H, W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = np.dot(H, W_hq) + b_q
        outputs.append(Y)
    return np.concatenate(outputs, axis=0), (H,)

训练与预测

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
                            init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

运行结果如下:

perplexity 1.1, 10510.3 tokens/sec on gpu(0)
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby

简洁实现GRU

mxnet框架中自带GRU的API,可以直接调用。GRU唯一需要的参数就是隐藏单元的数量。

接下来根据上一篇文章中定义好的train_ch8进行反向计算更新参数并进行预测即可。

gru_layer = rnn.GRU(num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

运行结果如下:

perplexity 1.1, 183591.3 tokens/sec on gpu(0)
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

思考

  1. 如果仅仅实现门控循环单元的一部分,例如,只有一个重置门或一个更新门会怎样?

  2. 比较rnn.RNNrnn.GRU的不同实现对运行时间、困惑度和输出字符串的影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2033669.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

轻量级的灰度配置平台|得物技术

一、前言 随着近几年得物的业务和技术的快速发展,我们不管是在面向C端场景还是B端供应链;业务版本的迭代更新,技术架构的不断升级;不管是业务稳定性还是架构稳定性,业务灰度的能力对我们来说都是一项重要的技术保障&a…

x264 编码器 PSNR算法源码分析

PSNR PSNR(Peak Signal-to-Noise Ratio,峰值信噪比)是一种常用的图像质量评价指标,用于衡量图像或视频的清晰度和质量。PSNR是基于信号的最大可能功率与影响信号的噪声功率之间的比率。在图像处理领域,PSNR通常用来评估图像压缩或图像增强算法的效果。 PSNR的计算公式是…

思科CCNP最新考证流程

CCNP CCNP全称思科网络高级工程师认证(Cisco Certified Network Professional),是Cisco思科认证中的中级认证。获得ccnp证书表示着资深网络工程师具有对100个节点到超过500个节点的融合局域网和广域网进行安装、配置和故障排除的能力。能够管…

LeetCode257 二叉树的所有路径

前言 题目: 257. 二叉树的所有路径 文档: 代码随想录——二叉树的所有路径 编程语言: C 解题状态: 没思路,简单题强度好高… 思路 本题利用了递归加回溯的思路。 这道题目要求从根节点到叶子的路径,所以需…

一个Indie Hacker的微SaaS技术栈

如今,可用的技术非常多,我们每个月都会看到各种新的 JS 框架发布,有时,如果你一开始没有选择正确的技术堆栈,以后扩展起来就会很困难。因此,在今天的文章中,我将与你分享我用于开发微型 SaaS 的…

vue使用富文本编辑器+自由伸缩图片

首先要下载依赖,下方是本人使用的package.json,下载完依赖如果有启动项目失败的情况,建议将依赖版本降低或使用和下方一样的版本 package.json代码 {"name": "l","version": "0.1.0","privat…

Linux中线程常用接口(创建,等待,退出,取消)

pthread_create #include <pthread.h> int pthread_create(pthread_t *thread, const pthread_attr_t *attr, void *(*start_routine) (void *), void *arg); Compile and link with -pthread. 编译时应注意。 #include<iostream> #in…

使用Playwright解决reCAPTCHA的分步指南

您是否在您的网络爬虫中遇到过CAPTCHA&#xff1f;许多网站使用CAPTCHA系统&#xff08;最常见的是reCAPTCHA&#xff09;来防止自动化访问。但是&#xff0c;本文将指导您使用Playwright&#xff08;一种强大的浏览器自动化工具&#xff09;和CapSolver&#xff08;一个设计用…

# 利刃出鞘_Tomcat 核心原理解析(二)

利刃出鞘_Tomcat 核心原理解析&#xff08;二&#xff09; 一、 Tomcat专题 - Tomcat架构 - HTTP工作流程 1、Http 工作原理 HTTP 协议&#xff1a;是浏览器与服务器之间的数据传送协议。作为应用层协议&#xff0c;HTTP 是基于 TCP/IP 协议来传递数据的&#xff08;HTML文件…

AI 的偏见来自数据集,而数据集的偏见来自人类 | Open AGI Forum

作者 | Annie Xu 采访、责编 | Eric Wang 出品丨GOSIM 开源创新汇 Richard Vencu&#xff0c;现任 Stability AI 机器学习运维负责人、LAION 工程负责人兼创始人&#xff0c;他的人生可谓十分精彩。 已过知天命之年的他是个中国通&#xff0c;极其热爱中国的武术、茶叶、诱人…

BugKu CTF Misc:被勒索了 disordered_zip simple MQTT 请攻击这个压缩包

前言 BugKu是一个由乌云知识库&#xff08;wooyun.org&#xff09;推出的在线漏洞靶场。乌云知识库是一个致力于收集、整理和分享互联网安全漏洞信息的社区平台。 BugKu旨在提供一个实践和学习网络安全的平台&#xff0c;供安全爱好者和渗透测试人员进行挑战和练习。它包含了…

03. 剑指offer刷题-二叉树篇(第二部分)

class Solution { public:TreeNode* Convert(TreeNode* pRootOfTree) {if(pRootOfTree nullptr) return nullptr;vector<TreeNode*> cur traversal(pRootOfTree);return cur[0];}// 这道题需要用到「分解问题」的思维&#xff0c;想把整棵链表&#xff0c;可以先把左右…

[upload]-做题笔记

项目下载地址&#xff1a;https://github.com/c0ny1/upload-labs 第一关 查看源代码&#xff0c;可以看到是前端js限制上传jpg,png,gif后缀文件 function checkFile() {var file document.getElementsByName(upload_file)[0].value;if (file null || file "") …

Unity读取Android外部文件

最近近到个小需求,需要读Android件夹中的图片.在这里做一个记录. 首先读写部分,这里以图片为例子: 一读写部分 写入部分: 需要注意的是因为只有这个地址支持外部读写,所以这里用到的地址都以 :Application.persistentDataPath为地址起始. private Texture2D __CaptureCamera…

促进服务消费高质量发展虽好,但不能缺钱

近日&#xff0c;国务院印发《关于促进服务消费高质量发展的意见》&#xff0c;提出6方面20项重点任务。 百度图片&#xff1a;2024讲党课ppt国务院关于促进服务消费高质量发展​ 一是挖掘餐饮住宿、家政服务、养老托育等基础型消费潜力&#xff1b; 二是激发文化娱乐、旅游、…

Upload 上传图标不显示

el-upload如果在使用 Element UI 的 <el-upload> 组件时上传图标不显示&#xff0c;可能是由几个不同的原因造成的。以下是一些排查和解决这个问题的步骤&#xff1a; 如果在使用 Element UI 的 <el-upload> 组件时上传图标不显示&#xff0c;可能是由几个不同的原…

antd react echarts地图组件及使用

地图组件&#xff1a; import { useRef, useEffect } from "react"; import * as echarts from "echarts"; import chinaJson from ./chinaJson;const MapIndex ({option,width "100%",height "100%", }) > {const ref useRef…

08:【stm32】中断二:EXTI(外部中断)

EXTI&#xff08;外部中断&#xff09; 1、EXTI简介2、EXTI的内部结构2.1、EXTI通道2.2、内部寄存器 3、EXTI的编写程序3.1、EXTI的编程接口3.1.1、EXTI_Init 4、编写实验 1、EXTI简介 外部中断控制器&#xff0c;能够检测外部输入信号的变化边沿并由此产生中断。通过检测上升沿…

BugKu CTF Misc:密室逃脱 铁子,来一道 想要种子吗 哥哥的秘密

前言 BugKu是一个由乌云知识库&#xff08;wooyun.org&#xff09;推出的在线漏洞靶场。乌云知识库是一个致力于收集、整理和分享互联网安全漏洞信息的社区平台。 BugKu旨在提供一个实践和学习网络安全的平台&#xff0c;供安全爱好者和渗透测试人员进行挑战和练习。它包含了…

Sql语句出现ORA-00933: SQL command not properly ended的解决方法

目录 1. 问题所示2. 原理分析3. 解决方法1. 问题所示 执行sql语句的时候出现如下问题: ORA-00933: SQL command not properly ended截图如下所示: 2. 原理分析 ORA-00933: SQL command not properly ended 是 Oracle 数据库中的错误,指示 SQL 语句存在语法问题 MySQL 和…