[PyTorch][chapter 46][LSTM -1]

news2024/9/23 15:30:40

前言:

           长短期记忆网络(LSTM,Long Short-Term Memory)是一种时间循环神经网络,是为了解决一般的RNN(循环神经网络)存在的长期依赖问题而专门设计出来的。

目录:

  1.      背景简介
  2.      LSTM Cell
  3.      LSTM 反向传播算法
  4.      为什么能解决梯度消失
  5.       LSTM 模型的搭建


一  背景简介:

       1.1  RNN

         RNN 忽略o_t,L_t,y_t 模型可以简化成如下

      

       

          图中Rnn Cell 可以很清晰看出在隐藏状态h_t=f(x_t,h_{t-1})

            得到 h_t后:

              一方面用于当前层的模型损失计算,另一方面用于计算下一层的h_{t+1}

    由于RNN梯度消失的问题,后来通过LSTM 解决 

       1.2 LSTM 结构

        


二  LSTM  Cell

   LSTMCell(RNNCell) 结构

          

          前向传播算法 Forward

         2.1   更新: forget gate 忘记门

             f_t=\sigma(W_fh_{t-1}+U_{t}x_t+b_f)

             将值朝0 减少, 激活函数一般用sigmoid

             输出值[0,1]

         2.2 更新: Input gate 输入门

                i_t=\sigma(W_ih_{t-1}+U_ix_t+b_i)

                决定是不是忽略输入值

    

           2.3 更新: 候选记忆单元

                    a_t=\widetilde{c_t}=tanh(W_a h_{t-1}+U_ax_t+b_a)

           2.4 更新: 记忆单元

               c_t=f_t \odot c_{t-1}+i_t \odot a_t

             2.5  更新: 输出门

                决定是否使用隐藏值

                 o_t=\sigma(W_oh_{t-1}+U_ox_t+b_0)  

           2.6. 隐藏状态

                h_t=o_t \odot tanh(c_t)

           2.7  模型输出

                  \hat{y_t}=\sigma(Vh_t+b)

LSTM 门设计的解释一:

 输入门 ,遗忘门,输出门 不同取值组合的时候,记忆单元的输出情况


三  LSTM 反向传播推导

      3.1 定义两个\delta_t

             \delta_h^t=\frac{\partial L}{\partial h_t}

            \delta_c^t=\frac{\partial L}{\partial C_t}

    3.2  定义损失函数

            损失函数L(t)分为两部分: 

             时刻t的损失函数 l(t)

             时刻t后的损失函数L(t+1)

              L(t)=\left\{\begin{matrix} l(t)+L(t+1), if: t<T\\ l(t), if: t=T \end{matrix}\right.

      3.3 最后一个时刻\tau

              

 这里面要注意这里的o^{\tau}= Vh_{\tau}+c

    证明一下第二项,主要应用到微分的两个性质,以及微分和迹的关系:

   

   dl= tr((\frac{\partial L^{\tau}}{\partial h^{\tau}})^Tdh^{\tau})  ... 公式1: 微分和迹的关系

       =tr((\delta_h^{\tau})^Tdh^{\tau})

     因为

    h^{\tau}=o^{\tau} \odot tanh(c^{\tau})

   dh_T=o^{\tau}\odot(d(tanh (c^{\tau})))

           =o^{\tau} \odot (1-tanh^2(c^{\tau})) \odot dc^{\tau}

     带入上面公式1:

      dl= tr((\delta_h^{\tau})^T (o^{\tau}\odot(1-tanh^2(c^{\tau}))\odot dc^{\tau})

           =tr((\delta_h^{\tau} \odot o^{\tau} \odot(1-tanh^2(c^{\tau}))^Tdc^{\tau})

    所以

3.4   链式求导过程

       求导结果:

 

  这里详解一下推导过程:

  这是一个符合函数求导:先把h 写成向量形成

h=\begin{bmatrix} o_1*tanh(c_1)\\ o_2*tanh(c_2) \\ .... \\ o_n*tanh(c_n) \end{bmatrix}

 ------------------------------------------------------------   

 第一项: 

             

         h_{t+1}=o_{t+1}\odot tanh(c_{t+1})

         o_{t+1}=\sigma(W_oh_t+U_ox_{t+1}+b_0)

        设 a_{t+1}=W_oh_t+U_ox_{t+1}+b_0

           则    \frac{\partial h_{t+1}}{\partial h_{t}}=\frac{\partial h_{t+1}}{\partial o_{t+1}}\frac{\partial o_{t+1}}{\partial a_{t+1}}\frac{\partial a_{t+1}}{\partial h_{t}}

 

            其中:(利用矩阵求导的定义法 分子布局原理)

                    \frac{\partial h_{t+1}}{\partial o_{t+1}}=diag(tanh(c^{t+1})) 是一个对角矩阵

                  o=\begin{bmatrix} \sigma(a_1)\\ \sigma(a_2) \\ .... \\ \sigma(a_n) \end{bmatrix}

                 \frac{\partial o_{t+1}}{\partial a_{t+1}}=diag(o_{t+1}\odot(1-o_{t+1}))

                 \frac{\partial a_{t+1}}{\partial h_{t}}=W_o

                 几个连乘起来就是第一项

               

第二项

    c_{t+1}=f_{t+1}\odot c_t+i_{t+1}\odot a_{t+1}

   f_{t+1}=\sigma(W_fh_t+U_tx_{t+1}+b_f)

   i_{t+1}=\sigma(W_ih_t+U_i x_{t+1}+b_i)

  a_{t+1}=tanh(W_a h_t +U_ax_t +b_a)

参考:

   h=\begin{bmatrix} o_1*tanh(c_1)\\ o_2*tanh(c_2) \\ .... \\ o_n*tanh(c_n) \end{bmatrix}

其中:

\frac{\partial h_{t+1}}{\partial c^{t+1}}=diag(o^{t+1}\odot (1-tanh^2(c^{t+1}))

\frac{\partial h_{t+1}}{\partial h_{t}}=\frac{\partial h_{t+1}}{\partial c_{t+1}}\frac{\partial c_{t+1}}{\partial f_{t+1}}\frac{\partial f_{t+1}}{\partial h_{t}}

 \frac{\partial c_{t+1}}{\partial f_{t+1}}=diag(c^{t})

 \frac{\partial a_{t+1}}{\partial h_{t}}=diag(f_t \odot(1-f_t))W_f

其它也是相似,就有了上面的求导结果


四  为什么能解决梯度消失

    

     4.1 RNN 梯度消失的原理

                ,复旦大学邱锡鹏书里面 有更加详细的解释,通过极大假设:

在梯度计算中存在梯度的k 次方连乘 ,导致 梯度消失原理。

    4.2  LSTM 解决梯度消失 解释1:

            通过上面公式发现梯度计算中是加法运算,不存在连乘计算,

            极大概率降低了梯度消失的现象。

    4.3  LSTM 解决梯度 消失解释2:

              记忆单元c  作用相当于ResNet的残差部分.  

   比如f_{t}=1,\hat{c_t}=0 时候,\frac{\partial c_t}{\partial c_{t-1}}=1,不会存在梯度消失。

       


五 模型的搭建

   

    我们最后发现:

    O_t,C_t,H_t 的维度必须一致,都是hidden_size

    通过C_t,则 I_t,F_t,\tilde{c} 最后一个维度也必须是hidden_size

    

# -*- coding: utf-8 -*-
"""
Created on Thu Aug  3 15:11:19 2023

@author: chengxf2
"""

# -*- coding: utf-8 -*-
"""
Created on Wed Aug  2 15:34:25 2023

@author: chengxf2
"""

import torch
from torch import nn
from d21 import torch as d21


def normal(shape,devices):
    
    data = torch.randn(size= shape, device=devices)*0.01
    
    return data


def get_lstm_params(input_size, hidden_size,categorize_size,devices):
    


    
    #隐藏门参数
    W_xf= normal((input_size, hidden_size), devices)
    W_hf = normal((hidden_size, hidden_size),devices)
    b_f = torch.zeros(hidden_size,devices)
    
    #输入门参数
    W_xi= normal((input_size, hidden_size), devices)
    W_hi = normal((hidden_size, hidden_size),devices)
    b_i = torch.zeros(hidden_size,devices)
    

    
    #输出门参数
    W_xo= normal((input_size, hidden_size), devices)
    W_ho = normal((hidden_size, hidden_size),devices)
    b_o = torch.zeros(hidden_size,devices)
    
    #临时记忆单元
    W_xc= normal((input_size, hidden_size), devices)
    W_hc = normal((hidden_size, hidden_size),devices)
    b_c = torch.zeros(hidden_size,devices)
    
    #最终分类结果参数
    W_hq = normal((hidden_size, categorize_size), devices)
    b_q = torch.zeros(categorize_size,devices)
    
    
    params =[
        W_xf,W_hf,b_f,
        W_xi,W_hi,b_i,
        W_xo,W_ho,b_o,
        W_xc,W_hc,b_c,
        W_hq,b_q]
    
    for param in params:
        
        param.requires_grad_(True)
        
    return params

def init_lstm_state(batch_size, hidden_size, devices):
    
    cell_init = torch.zeros((batch_size, hidden_size),device=devices)
    hidden_init = torch.zeros((batch_size, hidden_size),device=devices)
    
    return (cell_init, hidden_init)


def lstm(inputs, state, params):
    [
        W_xf,W_hf,b_f,
        W_xi,W_hi,b_i,
        W_xo,W_ho,b_o,
        W_xc,W_hc,b_c,
        W_hq,b_q] = params    
    
    (H,C) = state
    outputs= []
    
    for x in inputs:
        
        #input gate
        I = torch.sigmoid((x@W_xi)+(H@W_hi)+b_i)
        F = torch.sigmoid((x@W_xf)+(H@W_hf)+b_f)
        O = torch.sigmoid((x@W_xo)+(H@W_ho)+b_o)
        
        C_tmp = torch.tanh((x@W_xc)+(H@W_hc)+b_c)
        C = F*C+I*C_tmp
        
        H = O*torch.tanh(C)
        Y = (H@W_hq)+b_q
        
        outputs.append(Y)
        
    return torch.cat(outputs, dim=0),(H,C)
        
    

def main():
    batch_size,num_steps =32, 35
    train_iter, cocab= d21.load_data_time_machine(batch_size, num_steps)

    

if __name__ == "__main__":
    
     main()


 参考

 

CSDN

https://www.cnblogs.com/pinard/p/6519110.html

57 长短期记忆网络(LSTM)【动手学深度学习v2】_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/834143.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

dubbo的高可用

1、zookeeper宕机与dubbo直连 现象&#xff1a;zookeeper注册中心宕机&#xff0c;还可以消费dubbo暴露的服务。 原因&#xff1a; 健壮性 &#xff08;1&#xff09;监控中心宕掉不影响使用&#xff0c;只是丢失部分采样数据. &#xff08;2&#xff09;数据库宕掉后&#x…

三、线性工作流

再生产的各个环节&#xff0c;正确使用gamma编码及gamma解码&#xff0c;使得最终得到的颜色数据与最初输入的物理数据一致。如果使用gamma空间的贴图&#xff0c;在传给着色器前需要从gamma空间转到线性空间。 如果不在线性空间下进行渲染&#xff0c;会产生的问题&#xff1a…

SQL Server安装配置

又得装数据库...头秃 报错 其他信息: 在与 SQL Server 建立连接时出现与网络相关的或特定于实例的错误。未找到或无法访问服务器。请验证... 1是可能主机名不一致导致的&#xff0c;这种换主机名&#xff0c;二是没开sql 服务。 第二种先打开SQL Server 资源配置管理器 打…

【腾讯云Cloud Studio实战训练营】React 快速构建点餐页面

前言&#xff1a; Cloud Studio是一个在线的云集成开发环境&#xff08;IDE&#xff09;&#xff0c;可以让开发人员在浏览器中轻松地开发、测试、调试和部署应用程序。它提供了基于云的计算资源和工具&#xff0c;例如代码编辑器、编译器、调试器、版本控制系统和项目管理工具…

8.4算法

#include <stdio.h> #include <stdlib.h> #include <string.h>// 1.求1/阶乘累加 int main() {int n;int i;double temp 1;double result 0;printf("Input:");scanf("%d", &n);for (i 1; i < n; i) {temp * i;//1,2,6,24resul…

Kafka的配置和使用

目录 1.服务器用docker安装kafka 2.springboot集成kafka实现生产者和消费者 1.服务器用docker安装kafka ①、安装docker&#xff08;docker类似于linux的软件商店&#xff0c;下载所有应用都能从docker去下载&#xff09; a、自动安装 curl -fsSL https://get.docker.com | b…

Jenkins工具系列 —— 插件 钉钉发送消息

文章目录 安装插件 Ding TalkJenkins 配置钉钉机器人钉钉APP配置项目中启动钉钉通知功能 安装插件 Ding Talk 点击 左侧的 Manage Jenkins —> Plugins ——> 左侧的 Available plugins Jenkins 配置钉钉机器人 点击 左侧的 Manage Jenkins &#xff0c;拉到最后 钉…

微前端中的 CSS

本文为翻译 本文译者为 360 奇舞团前端开发工程师原文标题&#xff1a;CSS in Micro Frontends 原文作者&#xff1a;Florian Rappl 原文地址&#xff1a;https://dev.to/florianrappl/css-in-micro-frontends-4jai 我被问得最多的问题之一是如何在微前端中处理 CSS。毕竟&…

QT图形视图系统 - 使用一个项目来学习QT的图形视图框架 -第一篇

文章目录 QT图形视图系统介绍开始搭建MainWindow框架设置scene的属性缩放功能的添加加上标尺 QT图形视图系统 介绍 详细的介绍可以看QT的官方助手&#xff0c;那里面介绍的详细且明白&#xff0c;需要一定的英语基础&#xff0c;我这里直接使用一个开源项目来介绍QGraphicsVi…

AI大模型之花,绽放在鸿蒙沃土

随着生成式AI日益火爆&#xff0c;大语言模型能力引发了越来越多对于智慧语音助手的期待。 我们相信&#xff0c;AI大模型能力加持下的智慧语音助手一定会很快落地&#xff0c;这个预判不仅来自对AI大模型的观察&#xff0c;更来自对鸿蒙的了解。鸿蒙一定会很快升级大模型能力&…

macOS 虚拟桌面黑屏(转)

转自&#xff1a;macOS重置虚拟桌面、macOS 虚拟桌面黑屏 有几次出现如图的情况&#xff0c;以为是iTerm的问题&#xff0c;但是在关闭软件&#xff0c;重启之后&#xff0c;依旧无效。 后面经过网友告知&#xff0c;才知道是虚拟桌面的问题。 为了清理这个问题&#xff0c;有以…

看redisson是如何解决锁超时问题

看redisson是如何解决锁超时问题 什么是锁超时问题&#xff1f; 比如利用redis实现的分布式锁会设置一定的过期时间&#xff0c;超过该时间&#xff0c;缓存自动删除&#xff0c;锁被释放。这是防止因程序宕机等原因导致锁一直被占用。 但存在一定的问题&#xff0c;如果是该…

ADC模拟看门狗

如果被ADC转换的模拟电压低于低阀值或高于高阀值&#xff0c;AWD模拟看门狗状态位被设置。阀值位 于ADC_HTR和ADC_LTR寄存器的最低12个有效位中。通过设置ADC_CR1寄存器的AWDIE位 以允许产生相应中断。通过以下函数可以进行配置 void ADC_AnalogWatchdogCmd(ADC_TypeDef* ADCx…

wolfSSL5.6.3 虚拟机ubuntu下编译运行记录(踩坑填坑)

网上相关教程很多(包括wolfSSL提供的手册上也是如此大而化之的描述)&#xff0c;大多类似如下步骤&#xff1a; ./configure //如果有特殊的要求的话可以在后面接上对应的语句&#xff0c;比如安装目录、打开或关闭哪些功能等等 make make install 然后结束&#xff0c;大体…

秒杀业务场景的处理方案

秒杀的处理方案 秒杀技术实现核心思想是运用缓存减少数据库瞬间的访问压力。在秒杀时&#xff0c;首先会将数据库的秒杀商品同步到缓存中&#xff0c;用户从缓存中查询秒杀商品&#xff0c;抢购商品时减少缓存中的库存数量。产生的秒杀订单先写到缓存&#xff0c;付款成功后再…

【TypeScript】安装的坑!

TypeScript安装 安装TypeScript安装时候可能报错这样开头的数据&#xff08;无法枚举容器中的对象&#xff09;——原因&#xff1a;没权限先解决没权限的问题如果发现无法修改-高级-修改继续安装想使用tsc-发现&#xff0c;tsc不能用解决方法&#xff1a;配置环境变量最后的最…

选读SQL经典实例笔记17_最多和最少

1. 问题4 1.1. 最多选修两门课程的学生&#xff0c;没有选修任何课程的学生应该被排除在外 1.2. sql select distinct s.*from student s, take twhere s.sno t.snoand s.sno not in ( select t1.snofrom take t1, take t2, take t3where t1.sno t2.snoand t2.sno t3.sno…

云原生之使用Docker部署homer静态主页

云原生之使用Docker部署homer静态主页 一、homer介绍1.1 homer简介1.2 homer特点 二、本地环境介绍2.1 本地环境规划2.2 本次实践介绍 三、本地环境检查3.1 检查Docker服务状态3.2 检查Docker版本3.3 检查docker compose 版本 四、下载homer镜像五、部署homer静态主页5.1 创建挂…

Kubernetes详细概述

这里写目录标题 一&#xff1a;Kubernetes 概述1、K8S 是什么&#xff1f;2、为什么要用 K8S?2.1.nsenter 是k8s容器抓包工具 3、Kubernetes 集群架构与组件4.核心组件4.1 Master 组件4.1.1.Kube-apiserver4.1.2.Kube-controller-manager4.1.3.Kube-scheduler 4.2 配置存储中心…

三、Java NIO编程

目录 3.1 Java NIO基本介绍3.2 BIO 和 NIO的比较3.3 NIO三大核心 selector、channel、buffer之间的关系3.4 缓冲区&#xff08;Buffer&#xff09;3.4.1 基本介绍3.4.2 Buffer类及其子类 3.5 通道3.5.0 channel基本介绍3.5.1 FileChannel 类3.5.2 应用实例1 - 本地文件写数据3.…