d2l 里面GRU与Lstm实现

news2025/1/12 23:12:19

此二者的本质都是对rnn进行改良:关注当前多还是关注之前多。

在此详细讲一下。

目录

1.GRU门循环控制单元

1.1理论:

 1.2初始化参数

1.3定义网络

1.4训练命令行

1.5简洁实现

2.Lstm长短期记忆网络

2.1理论

 2.2加载参数

2.3定义lstm计算

2.4定义模型:

2.5训练命令行

3.强调


1.GRU门循环控制单元

1.1理论:

其参数多了两个,本质都是对H的计算进行了改进。

 

 1.2初始化参数

  与从零开始RNN的初始化参数类似,首先指定输入输出维度=len(vocab)
  构建一个均值=0,std=0.01的初始化tensor,传入的是尺寸
  将更新门、重置门、候选隐状态的参数都是3个,构造初始化辅助函数three,直接赋值即可得到相应的初始化参数(初始化需要的参数形式相同,故一样)
  传入的X尺寸为(bs,V),系数W的尺寸为(V,h)/(h,V),H的尺寸为(bs,h);与X或H相乘得到(bs,V),这其实就是Y的尺寸,再相应dim=0上叠加,得到最终一个T的outputs为(bs*T,V)

def get_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size
    
    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01
    
    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))
    
    W_xz, W_hz, b_z = three() # 更新⻔参数
    W_xr, W_hr, b_r = three() # 重置⻔参数
    W_xh, W_hh, b_h = three() # 候选隐状态参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

1.3定义网络

  初始化参数:

def init_gru_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device), )

  定义gru计算:

  注意: @这个符号是矩阵乘法,*是哈达玛积
              H的尺寸仍是(bs,h),注意W_hz;W_hr等的尺寸为(h,h)

  计算公式与开始的理论图里面的公式一致。

def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
        R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
        H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = H @ W_hq + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)

  与rnn不同的是:只要传入初始化参数,初始化state,以及forward如何计(定义的gru)丢尽RNNModel即可运算。

1.4训练命令行

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
                            init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

1.5简洁实现

  经过nn的直接调用的RNN或GRU,得到的将X(bs,V)送进去这个net得到的都是Y(T,bs,h),都需要额外添加Linear(h,V)得到最终的outputs(T*bs,V)
  再与y(bs,T)转置reshape得到的(T*bs)计算交叉熵loss

num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

2.Lstm长短期记忆网络

2.1理论

 

 2.2加载参数

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

2.3定义lstm计算

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size
    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01
    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))
    
    W_xi, W_hi, b_i = three() # 输⼊⻔参数
    W_xf, W_hf, b_f = three() # 遗忘⻔参数
    W_xo, W_ho, b_o = three() # 输出⻔参数
    W_xc, W_hc, b_c = three() # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
                b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

2.4定义模型:

def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))

有了c记忆元的存在,所以要初始化两个。

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
    W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)

2.5训练命令行

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

3.强调

  送入for X,T in train_iter里面的X,Y均为(bs,T)
  X直接送入net会先经过one-hot变成(T,bs,V)
  经过net后得到的y_hat为(T*bs,V)
  在net中,如果是调用nn.RNN(len(vocab),num_hiddens)或者gru或lstm,则通过调用层得到的Y尺寸都为(T,bs,hiddens),本质上是所有时间步的隐层,需要再接一个LInear(h,V)得到输出y_hat为(T*bs,V)。
  如果是从零实现,则得到的每一个Y为(bs,V),再通过dim=0的累加得到y_hat为(T*bs,V)
  然后在于y(经过转置为(T*bs))进行交叉熵计算loss。
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/429005.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

iTOP4412开发板Qt程序打包和部署

因为我们要把写好的程序发给用户来用,写好的源码也不方便给别人看,所以要把程序进行打包部署。 步骤一:点击左下角的电脑图标将 Debug 模式切换到 Release 模式。 release 模式:发布版本,不对源代码进行调试&#xff…

微信小程序:表格中更改输入框的值,实时获取表格全部数据,点击按钮更改数据库指定项数据

样例: 样式展示 数据库中原始第一条数据 修改表格第一行的数量: 数据库结果 核心代码 wxml ①wx:for:执行循环将数组数据展示出来 ②在某一单元格加上input样式 ③在input中绑定:文本框改变事件,并且绑定data-index便于知道…

网络编程,IO流

网络编程 计算机网络是指将地理位置不同的具有独立功能的多台计算机及其外部设备,通过通信线路连接起来,在网络操作系统,网络管理软件及网络通信协议的管理和协调下,实现资源共享和信息传递的计算机系统。 1.网络通信的要素 通信…

程序环境和预处理(下)——“C”

各位CSDN的uu们你们好呀,今天小雅兰的内容是程序环境和预处理的下篇知识点,那么,这篇博客写完后,C语言的知识点就到这里就结束啦,后续会专注于刷题和读书,也是关于C语言的,会写一些数据结构和C的…

本地Linux服务器安装宝塔面板,并内网穿透实现公网远程登录

文章目录前言1. 安装宝塔2. 安装cpolar内网穿透3. 远程访问宝塔4. 固定http地址5. 配置二级子域名6. 测试访问二级子域名转发自CSDN远程穿透的文章:Linux安装宝塔,并实现公网远程登录宝塔面板【内网穿透】 前言 宝塔面板作为建站运维工具,它…

尚融宝17-用户身份认证的三种模式

目录 1、单一服务器模式 2、SSO(Single Sign On)模式 3、Token模式 1、单一服务器模式 即只有一个服务器,用户通过输入账户和密码,提交表单后服务器拿到前端发送过来的数据查询数据库是否存在该用户,其一般流程如下…

【分享】体验微软Bing在线绘图功能

哈喽,大家好,我是木易巷~ 木易巷体验了一下子微软Bing在线绘图功能,快来看看吧~ 简单介绍 New Bing 不了解或者没有注册New Bing的小伙伴可以看看这一篇: 【教程】你现在还不知道微软的New Bing?你out了&#xff0…

【NestJs】使用MySQL关联查询

上一篇文章介绍了NestJs使用MySQL创建多个实体,接下来讲到的则是实体创建完毕之后,开始进行查询。里面可能涉及到nestjs使用语法,要是不知道的小伙伴可以先行了解,也可以模仿写,后面我会继续出nestjs的教程。也欢迎大家…

SpringMVC的基本使用-------基本注解RequestMapping、基本数据类型绑定、参数绑定、POJO类型绑定

SpringMVC的三层架构和MVC SpringMVC简介 三层架构概述: 一种是 C/S 架构,也就是客户端/服务器,另一种是 B/S 架构,也就是浏览器服务器。在 JavaEE 开发中,几乎全都是基于 B/S 架构的开发。那么在 B/S 架构中&#…

时间序列信号阈值降噪方法,有什么可以创新的地方吗?

可以换个空间,从图域的角度进行分析,比如图傅里叶变换,图小波变换等图时频分析方法。图小波阈值降噪的基本思想是通过将时间序列信号转换成路图信号,再利用图小波变换分解成尺度函数系数和一系列对应不同尺度的谱图小波系数&#…

VAE 理论推导及代码实现

VAE 理论推导及代码实现 熵、交叉熵、KL 散度的概念 熵(Entropy) 假设 p (x)是一个分布函数,满足在 x 上的积分为 1,那么 p(x)p(x)p(x)的熵定义为 H(p(x))H (p (x))H(p(x)),这里我们简写为 H(p)H(p)H(p) H(p)∫p(x)…

移动硬盘文件或目录损坏且无法读取,这样做就对了!

案例:移动硬盘文件或目录损坏且无法读取怎么办 【我的移动硬盘插入电脑后突然就显示文件损坏,遇到这种情况我应该怎么处理呀?感谢回答!】 移动硬盘是一种方便携带和存储数据的设备,然而,有时候可能会遇到…

UE中的channel

当我们需要处理碰撞矩阵,或者调用接口投射射线进行检测等,为了区分哪些对象可以被射线检测到,哪些对象忽略,就需要用到channel。 1.Channel 简介 在UE5中,一个对象的channel可以在Physics下查看: 设置成…

如何确保采购过程中的产品质量

在企业采购过程中,确保采购的产品质量是至关重要的。采购的质量直接关系到企业的生产和销售质量,影响企业的形象和利润。为了确保采购过程中的质量,企业需要采取一些措施来保证采购物料和商品的质量,以下是一些有效的方法&#xf…

Linux学习-----Chapter nine

使用 ssh 服务管理远程主机9.1 配置网络服务9.1.1 配置网卡参数9.1.2 创建网络会话9.1.3 绑定两块网卡1、创建一个bond网卡2、向bond0设备添加从属网卡3、配置bond0设备的网络信息4、启动它9.2 远程控制服务9.2.1 配置sshd服务9.2.2 安全密钥验证9.2.3 远程传输命令9.3 不间断会…

日撸 Java 三百行day25-26

文章目录说明day25 二叉树深度遍历的栈实现 (中序)1.具有通用性的对象栈2.栈实现中序遍历2.1 思路2.2 代码day26 二叉树深度遍历的栈实现 (前序和后序)1.前序遍历2.后序遍历说明 闵老师的文章链接: 日撸 Java 三百行(总述)_minfanphd的博客-…

Redis第二十八讲 Redis集群脑裂数据丢失问题与集群是否完整才能对外提供服务

集群脑裂数据丢失问题 所谓的脑裂,就是指在主从集群中,同时有两个主节点,它们都能接收写请求。而脑裂最直接的影响,就是客户端不知道应该往哪个主节点写入数据,结果就是不同的客户端会往不同的主节点上写入数据。而且,严重的话,脑裂会进一步导致数据丢失。 redis的集群…

银行数字化转型导师坚鹏:银行业同业竞争策略分析

《银行业同业竞争策略分析》 —数字化背景下银行转型发展创新思维 课程背景: 数字化背景下,很多银行存在以下问题: 不清楚国内领先银行的业务发展现状? 不清楚如何制定竞争策略? 不知道其他银行转型的成功做法&…

Matplotlib学习挑战第六关--散点图、柱形图、饼图

1、Matplotlib 散点图 我们可以使用 pyplot 中的 scatter() 方法来绘制散点图。 scatter() 方法语法格式如下: matplotlib.pyplot.scatter(x, y, sNone, cNone, markerNone, cmapNone,normNone, vminNone, vmaxNone, alphaNone, linewidthsNone, *, edgecolorsNo…

【RabbitMQ】RabbbitMQ的六种工作模式以及代码实现

目录 一、交换机类型 二、简单模式 1、介绍 2、代码实现 三、Work Queues工作队列模式 1、介绍 2、代码实现 四、Pub/Sub订阅模式 1、介绍 2、代码实现 五、Routing路由模式 1、介绍 2、代码实现 六、Topics通配符模式 1、介绍 2、代码实现 一、交换机类型 在…