LSTM MultiheadAttention 输入维度

news2025/1/11 6:11:49

最近遇到点问题,对于模块的输入矩阵的维度搞不清楚,这里在学习一下,记录下来,方便以后查阅。

LSTM & Attention 输入维度

  • LSTM
    • 记忆单元
    • 门控机制
    • LSTM结构
    • LSTM的计算过程
      • 遗忘门
      • 输入门
      • 更新记忆单元
      • 输出门
    • LSTM单元的pytorch实现
    • Pytorch中的LSTM
      • 参数
      • 输入Inputs: input, (h_0, c_0)
      • 输出Outputs: output, (h_n, c_n)
      • 参数解释
  • MultiheadAttention
    • Self Attention 计算过程
    • Multihead Attention 计算过程
    • MultiheadAttention单元的pytorch实现
    • Pytorch中的MultiheadAttention
    • 输入的矩阵维度
  • 参考资料

LSTM

LSTM是RNN的一种变种,可以有效地解决RNN的梯度爆炸或者消失问题。

在这里插入图片描述

记忆单元

LSTM引入了一个新的记忆单元 c t c_t ct,用于进行线性的循环信息传递,同时输出信息给隐藏层的外部状态 h t h_t ht。在每个时刻 t t t c t c_t ct记录了到当前时刻为止的历史信息。

门控机制

LSTM引入门控机制来控制信息传递的路径,类似于数字电路中的门,0即关闭,1即开启。

LSTM中的三个门为遗忘门 f t f_t ft,输入门 i t i_t it,和输出门 o t o_t ot

  • f t f_t ft控制上一个时刻的记忆单元 c t − 1 c_{t-1} ct1需要遗忘多少信息
  • i t i_t it控制当前时刻的候选状态 c ~ t \tilde{c}_t c~t有多少信息需要存储
  • o t o_t ot控制当前时刻的记忆单元 c t c_t ct有多少信息需要输出给外部状态 h t h_t ht

LSTM结构

如图一所示为LSTM的结构,LSTM网络由一个个的LSTM单元连接而成。

在这里插入图片描述

LSTM 的关键就是记忆单元,水平线在图上方贯穿运行。

记忆单元类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变会很容易。

LSTM的计算过程

遗忘门

在这里插入图片描述

在这一步中,遗忘门读取 h t − 1 h_{t-1} ht1 x t x_t xt,经由sigmoid,输入一个在0到1之间数值给每个在记忆单元 c t − 1 c_{t-1} ct1中的数字,1表示完全保留,0表示完全舍弃。

输入门

在这里插入图片描述
输入门将确定什么样的信息内存放在记忆单元中,这里包含两个部分。

  1. sigmoid层同样输出[0,1]的数值,决定候选状态 c ~ t \tilde{c}_t c~t有多少信息需要存储
  2. tanh层会创建候选状态 c ~ t \tilde{c}_t c~t

更新记忆单元

随后更新旧的细胞状态,将 c t − 1 c_{t-1} ct1更新为 c t c_t ct

在这里插入图片描述

首先将旧状态 c t − 1 c_{t-1} ct1 f t f_t ft相乘,遗忘掉由 f t f_t ft所确定的需要遗忘的信息,然后加上 i t ∗ c ~ t i_t*\tilde{c}_t itc~t,由此得到了新的记忆单元 c t c_t ct

输出门

结合输出门 o t o_t ot将内部状态的信息传递给外部状态 h t h_t ht。同样传递给外部状态的信息也是个过滤后的信息,首先sigmoid层确定记忆单元的那些信息被传递出去,然后,把细胞状态通过tanh层进行处理(得到[-1,1]的值)并将它和输出门的输出相乘,最终外部状态仅仅会得到输出门确定输出的那部分。

在这里插入图片描述

LSTM单元的pytorch实现

class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, cell_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size # 隐含状态h的大小,也即LSTM单元隐含层神经元数量
        self.cell_size = cell_size # 记忆单元c的大小
        # 门
        self.gate = nn.Linear(input_size+hidden_size, cell_size)
        self.output = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden, cell):
        # 连接输入x与h 
        combined = torch.cat((input, hidden), 1)
        # 遗忘门
        f_gate = self.sigmoid(self.gate(combined))
        # 输入门
        i_gate = self.sigmoid(self.gate(combined))
        z_state = self.tanh(self.gate(combined))
        # 输出门
        o_gate = self.sigmoid(self.gate(combined))
        # 更新记忆单元
        cell = torch.add(torch.mul(cell, f_gate), torch.mul(z_state, i_gate))
        # 更新隐藏状态h
        hidden = torch.mul(self.tanh(cell), o_gate)
        output = self.output(hidden)
        output = self.softmax(output)
        return output, hidden, cell
    
    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

    def initCell(self):
        return torch.zeros(1, self.cell_size)

Pytorch中的LSTM

在这里插入图片描述

参数

  • input_size – 输入特征维数
  • hidden_size – 隐含状态h hh的维数
  • num_layers – RNN层的个数:(在竖直方向堆叠的多个相同个数单元的层数),默认为1
  • bias – 隐层状态是否带bias,默认为true
  • batch_first – 是否输入输出的第一维为batchsize
  • dropout – 是否在除最后一个RNN层外的RNN层后面加dropout层
  • bidirectional –是否是双向RNN,默认为false
  • proj_size – 如果>0, 则会使用相应投影大小的LSTM,默认值:0

其中比较重要的参数就是hidden_size与num_layers,hidden_size所代表的就是LSTM单元中神经元的个数。num_layers所代表的含义,就是depth的堆叠,也就是有几层的隐含层。

在这里插入图片描述

这张图是以MLP的形式展示LSTM的传播方式(不用管左边的符号,输出和隐状态其实是一样的),方便理解hidden_size这个参数。其实hidden_size在各个函数里含义都差不多,就是参数W的第一维(或最后一维)。那么对应前面的公式,hidden_size实际就是以这个size设置所有W的对应维。

在这里插入图片描述

这张图非常便于理解参数num_layers。实际上就是个depth堆叠,每个蓝色块都是LSTM单元。只不过第一层输入是 x t , h t − 1 ( 0 ) , c t − 1 ( 0 ) x_t, h_{t-1}^{(0)}, c_{t-1}^{(0)} xt,ht1(0),ct1(0),中间层输入是 h t ( k − 1 ) , h t − 1 ( k ) , c t − 1 ( k ) h_{t}^{(k-1)}, h_{t-1}^{(k)}, c_{t-1}^{(k)} ht(k1),ht1(k),ct1(k)

输入Inputs: input, (h_0, c_0)

  • input:当batch_first = False 时形状为(L,N,H_in),当 batch_first = True 则为(N, L, H_in​) ,包含批量样本的时间序列输入。该输入也可是一个可变换长度的时间序序列。
  • h_0:形状为(D∗num_layers, N, H_out),指的是包含每一个批量样本的初始隐含状态。如果模型未提供(h_0, c_0) ,默认为是全0矩阵。
    c_0:形状为(D∗num_layers, N, H_cell), 指的是包含每一个批量样本的初始记忆细胞状态。 如果模型未提供(h_0, c_0) ,默认为是全0矩阵。

输出Outputs: output, (h_n, c_n)

  • output: 当batch_first = False 形状为(L, N, D∗H_out​) ,当batch_first = True 则为 (N, L, D∗H_out​) ,包含LSTM最后一层每一个时间步长 的输出特征()。
  • h_n: 形状为(D∗num_layers, N, H_out​),包括每一个批量样本最后一个时间步的隐含状态。
  • c_n: 形状为(D∗num_layers, N, H_cell​),包括每一个批量样本最后一个时间步的记忆细胞状态。

参数解释

  • N = 批量大小
  • L = 序列长度
  • D = 2 如果模型参数bidirectional = 2,否则为1
  • H_in = 输入的特征大小(input_size)
  • H_cell = 隐含单元数量(hidden_size)
  • H_out = proj_size, 如果proj_size > 0, 否则的话 = 隐含单元数量(hidden_size)

MultiheadAttention

Self Attention 计算过程

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

Multihead Attention 计算过程

在这里插入图片描述

MultiheadAttention单元的pytorch实现

class Attention(nn.Module):
    '''
    Attention Module used to perform self-attention operation allowing the model to attend
    information from different representation subspaces on an input sequence of embeddings.
    The sequence of operations is as follows :-

    Input -> Query, Key, Value -> ReshapeHeads -> Query.TransposedKey -> Softmax -> Dropout
    -> AttentionScores.Value -> ReshapeHeadsBack -> Output

    Args:
        embed_dim: Dimension size of the hidden embedding
        heads: Number of parallel attention heads (Default=8)
        activation: Optional activation function to be applied to the input while transforming to query, key and value matrixes (Default=None)
        dropout: Dropout value for the layer on attention_scores (Default=0.1)

    Methods:
        _reshape_heads(inp) :- 
        Changes the input sequence embeddings to reduced dimension according to the number
        of attention heads to parallelize attention operation
        (batch_size, seq_len, embed_dim) -> (batch_size * heads, seq_len, reduced_dim)

        _reshape_heads_back(inp) :-
        Changes the reduced dimension due to parallel attention heads back to the original
        embedding size
        (batch_size * heads, seq_len, reduced_dim) -> (batch_size, seq_len, embed_dim)

        forward(inp) :-
        Performs the self-attention operation on the input sequence embedding.
        Returns the output of self-attention as well as atttention scores
        (batch_size, seq_len, embed_dim) -> (batch_size, seq_len, embed_dim), (batch_size * heads, seq_len, seq_len)

    Examples:
        >>> attention = Attention(embed_dim, heads, activation, dropout)
        >>> out, weights = attention(inp)
    '''
    def __init__(self, embed_dim, heads=8, activation=None, dropout=0.1):
        super(Attention, self).__init__()
        self.heads = heads
        self.embed_dim = embed_dim
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.softmax = nn.Softmax(dim=-1)
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'elu':
            self.activation = nn.ELU()
        else:
            self.activation = nn.Identity()
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, inp):
        # inp: (batch_size, data_aug, cha_tim_dim, embed_dim)
        batch_size, data_aug, cha_tim_dim, embed_dim = inp.size()
        assert embed_dim == self.embed_dim

        query = self.activation(self.query(inp))
        key   = self.activation(self.key(inp))
        value = self.activation(self.value(inp))

        # output of _reshape_heads(): (batch_size * heads, data_aug, cha_tim_dim, reduced_dim) | reduced_dim = embed_dim // heads
        query = self._reshape_heads(query)
        key   = self._reshape_heads(key)
        value = self._reshape_heads(value)

        # attention_scores: (batch_size * heads, data_aug, cha_tim_dim, cha_tim_dim) | Softmaxed along the last dimension
        attention_scores = self.softmax(torch.matmul(query, key.transpose(2, 3)))

        # out: (batch_size * heads, data_aug, cha_tim_dim, reduced_dim)
        out = torch.matmul(self.dropout(attention_scores), value)

        # output of _reshape_heads_back(): (batch_size, data_aug, cha_tim_dim, embed_dim)
        out = self._reshape_heads_back(out)

        return out, attention_scores

    def _reshape_heads(self, inp):
        # inp: (batch_size, data_aug, cha_tim_dim, embed_dim)
        batch_size, data_aug, cha_tim_dim, embed_dim = inp.size()

        reduced_dim = self.embed_dim // self.heads
        assert reduced_dim * self.heads == self.embed_dim
        out = inp.reshape(batch_size, data_aug, cha_tim_dim, self.heads, reduced_dim)
        out = out.permute(0, 3, 1, 2, 4)
        out = out.reshape(-1, data_aug, cha_tim_dim, reduced_dim)

        # out: (batch_size * heads, data_aug, cha_tim_dim, reduced_dim)
        return out

Pytorch中的MultiheadAttention

在这里插入图片描述

在这里插入图片描述

输入的矩阵维度

在这里插入图片描述

参考资料

LSTM详解

Pytorch LSTM模型 参数详解

[译] 理解 LSTM 网络

https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html?highlight=attention#torch.nn.MultiheadAttention

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/172336.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring Security in Action 第七章 配置授权:限制访问

本专栏将从基础开始,循序渐进,以实战为线索,逐步深入SpringSecurity相关知识相关知识,打造完整的SpringSecurity学习步骤,提升工程化编码能力和思维能力,写出高质量代码。希望大家都能够从中有所收获&#…

[leetcode 72] 编辑距离

题目 题目:https://leetcode.cn/problems/edit-distance/description/ 类似题目:[leetcode 583] 两个字符串的删除操作 解法 动态规划 这题应该是字符串dp的终极形态了吧🤣,不看答案完全不会…看了答案发现原来还是dp… 以例题…

未来的竞争是认知和执行力的竞争,只有认知高,强执行才能赚钱

之前很火的一句话是:你永远赚不到认知范围之外的钱所以只有持续不断地提升认知才能持续成长,持续提升,持续赚钱。未来的竞争从另一方面来说也是认知的竞争。不同的认知对待同一事物、信息有不同的理解;不同的认知对待同一事物、信…

固高科技在创业板提交注册:业绩开始下滑,实控人均为“学院派”

近日,固高科技股份有限公司(下称“固高科技”)在深圳证券交易所创业板递交注册。据贝多财经了解,固高科技于2021年12月在创业板递交上市申请,2022年8月17日获得上市委会议通过。 本次冲刺创业板上市,固高科…

【一道面试题】说一下Synchronized?

说一下Synchronized? Synchronized锁是Java中为了解决线程安全问题的一种方式,是一种悲观锁Synchronized可以用来修饰方法或者以代码块,用来保证线程执行方法或代码块时的原子性Java中任何一个类的对象都可以用来作为锁对象,但是…

docker-15-镜像Ubuntu20.04中安装python3.9

1 拉取并运行镜像 从docker hub 拉取镜像,以ubuntu20.04为例: docker pull ubuntu:20.04 docker run -it ubuntu:20.04 /bin/bash发现命令行变为root1234abcd5678:,这样就是进入docker容器里了。以下是docker常用的命令: # 以…

8086到80386汇编数据传送指令的扩展

80386及以上汇编的数据传送指令如下; MOV 传送字或字节. MOVSX 先符号扩展,再传送. MOVZX 先零扩展,再传送. PUSH 把字压入堆栈. POP 把字弹出堆栈. PUSHA 把AX,CX,DX,BX,SP,BP,SI,DI依次压入堆栈. POPA 把DI,SI,BP,SP,BX,DX,CX,A…

人大金仓数据库KSQL常用命令

第三章KSQL常用命令 登陆前显示ksql的帮助命令 Ksql --help 列出所有的SQL命令清单 test# \h 列出某个SQL命令语法大纲 \h <sql命令> 如&#xff1a;\h delect 查看ksql元命令的帮助 ..... 查看数据库列表 显示当前连接的数据库和登录用户 \c 显示当前test数据库的…

数学和统计方法

平均数&#xff0c;加权平均数&#xff0c;中位数&#xff0c;众数 1、平均数&#xff1a;所有数加在一起求平均 2、中位数&#xff1a;对于有限的数集&#xff0c;可以通过把所有观察值高低排序后找出正中间的一个作为中位数。如果观察值有偶数个&#xff0c;通常取最中间的 …

Spring Boot学习篇(十一)

Spring Boot学习篇(十一) shiro安全框架使用篇(三) 1.shiro过滤地址配置(部分地址必须要登录才能访问) 1.1 在controll包下创建CRUDController类(用于提供地址进行测试),其内容如下所示 package com.zlz.controller;import org.springframework.stereotype.Controller; imp…

回顾一次后台从war包启动到jar包启动的改造

一、背景描述 1.项目情况 有个项目后台一开始是war包部署到tomcat中部署的 配置文件放在项目中 考虑到这种部署方式相对spring boot项目内置tomcat部署不太便捷&#xff0c;配置也没有独立出来&#xff0c;考虑将原来的spring mvc项目稍微改造为spring boot项目。 2.要求 1&am…

Linux设备树的概念

一.设备树概念以及作用1.设备树概念设备树(Device Tree)&#xff0c;将这个词分开就是“设备”和“树”&#xff0c;描述设备树的文件叫做 DTS(DeviceTree Source)&#xff0c;这个 DTS 文件采用树形结构描述板级设备&#xff0c;也就是开发板上的设备信息&#xff0c;比如CPU …

flowable的Task使用

ReceiveTask UserTask ServiceTask ScriptTask ReceiveTask 执行到这个ReceiveTask会停下来&#xff0c;需要人工触发一下&#xff0c;才会继续执行 ClassPathResource classPathResource new ClassPathResource("processes/ReceiveTaskDemo.bpmn20.xml");String f…

C++——模板与STL标准模板库

目录 一、模板 1.1类型模板 1.2非类型模板 二、STL 2.1链表实现 2.2迭代器 2.3STL容器 2.4STL算法 三、模板特化的匹配规则 (1) 类模板的匹配规则 (2) 函数模板的匹配规则 一、模板 1.1类型模板 #include <stdio.h> #include <iostream>using namespac…

深度学习 GAN生成对抗网络-手写数字生成及改良

如果你有一定神经网络的知识基础&#xff0c;想学习GAN生成对抗网络&#xff0c;可以按顺序参考系列文章&#xff1a; 深度学习 自动编码器与生成模型 深度学习 GAN生成对抗网络-1010格式数据生成简单案例 深度学习 GAN生成对抗网络-手写数字生成 一、前言 在前面一篇文章&am…

877. 石子游戏

877. 石子游戏题目算法设计&#xff1a;奇偶算法设计&#xff1a;动态规划题目 算法设计&#xff1a;奇偶 最简单的情况&#xff0c;只有2堆石子&#xff08;石子奇数&#xff09;&#xff0c;先稳赢。 但是四堆情况不同了&#xff0c;如 [3 7 2 3]。 不能直接选最大的&…

2023年五大趋势预测 | 大数据分析、人工智能和云产业展望

随着我们迈入2023年&#xff0c;大数据分析、人工智能和云产业将迎来蓬勃的创新和发展阶段 以下是我们预测的&#xff0c;将对行业格局产生重大影响的五大趋势&#xff1a; 世界在剧变&#xff0c;我们需要尽快寻找行业中的方向&#xff0c;迅速重回轨道 2023年&#xff0c;全…

TryHackMe-NahamStore(常见web漏洞 大杂烩)

NahamStore 漏洞赏金web安全 NahamStore的创建是为了测试您在NahamSec的“漏洞赏金狩猎和Web应用程序黑客入门”Udemy课程中学到的知识。 部署计算机&#xff0c;获得 IP 地址后&#xff0c;进入下一步&#xff01; 写在前面 可能我的顺序&#xff0c;跟别人以及题目都不太一…

spring boot集成activemq(windows)

目录 1.环境配置 2.说明 3.服务启动 4.示例 导入依赖 配置文件 service层 配置类 监听器 5.总结 1.环境配置 下载地址&#xff1a;https://activemq.apache.org/components/classic/download/安装&#xff1a;解压缩即可注意每个版本对应的java版本不一样&#xff0c…

分享96个PHP源码,总有一款适合您

PHP源码 分享96个PHP源码&#xff0c;总有一款适合您 下面是文件的名字&#xff0c;我放了一些图片&#xff0c;文章里不是所有的图主要是放不下...&#xff0c; 96个PHP源码下载链接&#xff1a;https://pan.baidu.com/s/1B-tNZlbfjT_D3n_Y6ZwfDw?pwduq19 提取码&#xff…