【transformer】自注意力源码解读和复杂度计算

news2024/11/26 17:46:58

Self-attention

1

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

其中, Q Q Q为查询向量, K K K V V V为键向量和值向量, d k d_k dk为向量的维度。 Q Q Q K K K V V V在一般情况下是相同的。公式中的softmax函数将分数归一化为概率,得到加权的值向量。这里的注意力机制是通过计算查询向量 Q Q Q和键向量 K K K之间的相似性,来为值向量 V V V分配不同的权重。如果两个向量越相似,则它们之间的权重应该越大,反之则越小。

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)  # 获取文本嵌入维度大小
    # 按照注意力机制的公式计算注意力分数
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    # 是否使用掩码
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    # 使用softmax对最后一个维度获得注意力张量
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    # 注意力张量与value相乘得到query的注意力表示
    return torch.matmul(p_attn, value), p_attn

一个形状为 N × M N\times M N×M 的矩阵,与另一个形状为 M × P M\times P M×P的矩阵相乘,其运算复杂度来源于乘法操作的次数,时间复杂度为 O ( N M P ) O(NMP) O(NMP)

Self-attention的公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V其中, Q Q Q为查询向量, K K K V V V为键向量和值向量, d k d_k dk为向量的维度。 Q Q Q K K K V V V在一般情况下是相同的。公式中的softmax函数将分数归一化为概率,得到加权的值向量。
Self-Attention的计算复杂度主要来自三个方面:查询矩阵、键矩阵和值矩阵的乘积、softmax 的计算、以及输出向量和值的加权平均。
对于一个由n个单词组成的输入序列,假设有d个维度的特征,那么查询矩阵、键矩阵和值矩阵的维度都将是 n × d。

  • 对于查询矩阵 Q 和键矩阵 K 的点积, n × d n\times d n×d d × n d\times n d×n计算复杂度是 O ( n 2 d ) O(n^2d) O(n2d)
  • 每行 softmax 的计算,计算复杂度为 O ( n ) O(n) O(n),对n行做softmax,复杂度为 O ( n 2 ) O(n^2) O(n2)
  • 对于值矩阵 V (维度 n × d n\times d n×d)和 softmax 后的结果(维度 n × n n\times n n×n)进行点积,得到每个查询向量的加权平均值,复杂度是 O ( n 2 d ) O(n^2d) O(n2d)

因此,总的计算复杂度是 O ( n 2 d ) + O ( n 2 ) + O ( n 2 d ) ≃ O ( n 2 d ) O(n^2d) + O(n^2) + O(n^2d) \simeq O(n^2d) O(n2d)+O(n2)+O(n2d)O(n2d)
由于这个复杂度是关于输入序列长度n的平方级别,因此Self-Attention在处理长序列时可能会面临计算上的挑战。

多头注意力

2
多头注意力的计算公式如下:
MultiHead ⁡ ( Q , K , V ) = Concat ⁡ ( head ⁡ 1 , … ,  head  h ) W O  where   head  i = A ( Q W i Q , K W i K , V W i V ) \begin{aligned} \operatorname{MultiHead}(Q, K, V) & =\operatorname{Concat}\left(\operatorname{head}_1, \ldots, \text { head }_{\mathrm{h}}\right) W^O \\ \text { where } \text { head }_{\mathrm{i}} & =A\left(Q W_i^Q, K W_i^K, V W_i^V\right) \end{aligned} MultiHead(Q,K,V) where  head i=Concat(head1,, head h)WO=A(QWiQ,KWiK,VWiV)其中, Q , K , V Q,K,V Q,K,V 分别表示查询、键和值, h h h 表示头数, h e a d i head_i headi 表示第 i i i 个注意力头, W O W^O WO 表示输出层的权重矩阵。

# 用于深度拷贝的copy工具包
import copy

# 首先需要定义克隆函数, 因为在多头注意力机制的实现中, 用到多个结构相同的线性层.
# 我们将使用clone函数将他们一同初始化在一个网络层列表对象中. 之后的结构中也会用到该函数.
def clones(module, N):
    """用于生成相同网络层的克隆函数, 它的参数module表示要克隆的目标网络层, N代表需要克隆的数量"""
    # 在函数中, 我们通过for循环对module进行N次深度拷贝, 使其每个module成为独立的层,
    # 然后将其放在nn.ModuleList类型的列表中存放.
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

# 我们使用一个类来实现多头注意力机制的处理
class MultiHeadedAttention(nn.Module):
    def __init__(self, head, embedding_dim, dropout=0.1):
        """在类的初始化时, 会传入三个参数,head代表头数,embedding_dim代表词嵌入的维度, 
           dropout代表进行dropout操作时置0比率,默认是0.1."""
        super(MultiHeadedAttention, self).__init__()

        # 在函数中,首先使用了一个测试中常用的assert语句,判断h是否能被d_model整除,
        # 这是因为我们之后要给每个头分配等量的词特征.也就是embedding_dim/head个.
        assert embedding_dim % head == 0

        # 得到每个头获得的分割词向量维度d_k
        self.d_k = embedding_dim // head

        # 传入头数h
        self.head = head

        # 然后获得线性层对象,通过nn的Linear实例化,它的内部变换矩阵是embedding_dim x embedding_dim,然后使用clones函数克隆四个,
        # 为什么是四个呢,这是因为在多头注意力中,Q,K,V各需要一个,最后拼接的矩阵还需要一个,因此一共是四个.
        self.linears = clones(nn.Linear(embedding_dim, embedding_dim), 4)

        # self.attn为None,它代表最后得到的注意力张量,现在还没有结果所以为None.
        self.attn = None

        # 最后就是一个self.dropout对象,它通过nn中的Dropout实例化而来,置0比率为传进来的参数dropout.
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        """前向逻辑函数, 它的输入参数有四个,前三个就是注意力机制需要的Q, K, V,
           最后一个是注意力机制中可能需要的mask掩码张量,默认是None. """

        # 如果存在掩码张量mask
        if mask is not None:
            # 使用unsqueeze拓展维度
            mask = mask.unsqueeze(0)

        # 接着,我们获得一个batch_size的变量,他是query尺寸的第1个数字,代表有多少条样本.
        batch_size = query.size(0)

        # 之后就进入多头处理环节
        # 首先利用zip将输入QKV与三个线性层组到一起,然后使用for循环,将输入QKV分别传到线性层中,
        # 做完线性变换后,开始为每个头分割输入,这里使用view方法对线性变换的结果进行维度重塑,多加了一个维度h,代表头数,
        # 这样就意味着每个头可以获得一部分词特征组成的句子,其中的-1代表自适应维度,
        # 计算机会根据这种变换自动计算这里的值.然后对第二维和第三维进行转置操作,
        # 为了让代表句子长度维度和词向量维度能够相邻,这样注意力机制才能找到词义与句子位置的关系,
        # 从attention函数中可以看到,利用的是原始输入的倒数第一和第二维.这样我们就得到了每个头的输入.
        query, key, value = \
           [model(x).view(batch_size, -1, self.head, self.d_k).transpose(1, 2)
            for model, x in zip(self.linears, (query, key, value))]

        # 得到每个头的输入后,接下来就是将他们传入到attention中,
        # 这里直接调用我们之前实现的attention函数.同时也将mask和dropout传入其中.
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)

        # 通过多头注意力计算后,我们就得到了每个头计算结果组成的4维张量,我们需要将其转换为输入的形状以方便后续的计算,
        # 因此这里开始进行第一步处理环节的逆操作,先对第二和第三维进行转置,然后使用contiguous方法,
        # 这个方法的作用就是能够让转置后的张量应用view方法,否则将无法直接使用,
        # 所以,下一步就是使用view重塑形状,变成和输入形状相同.
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.head * self.d_k)

        # 最后使用线性层列表中的最后一个线性层对输入进行线性变换得到最终的多头注意力结构的输出.
        return self.linears[-1](x)

在多头注意力中,假设有 h h h 个头,每个头的查询、键和值的维度是 d k d_k dk d k d_k dk d v d_v dv,一般情况 d q = d k = d v = d h d_q=d_k=d_v=\frac{d}{h} dq=dk=dv=hd, 输入序列的长度为 N N N

  • 输入线性映射的复杂度: n × d n\times d n×d d × d h d \times \frac{d}{h} d×hd,计算复杂度是 O ( n d 2 h ) O(\frac{nd^2 }{h}) O(hnd2)
  • 注意力计算:输入线性映射后的维度 n × d h n \times \frac{d}{h} n×hd n × d h n \times \frac{d}{h} n×hd d h × n \frac{d}{h}\times n hd×n计算复杂度是 O ( n 2 d h ) O(n^2\frac{d}{h}) O(n2hd)
  • 输出线性映射: 多个头的结果concat成一个 n × d n\times d n×d矩阵, n × d n\times d n×d d × d d \times d d×d,计算复杂度是 O ( n d 2 ) O(nd^2) O(nd2)

总时间复杂度 O ( n d 2 h + n 2 d h + n d 2 ) O(\frac{nd^2}{h}+n^2\frac{d}{h}+nd^2) O(hnd2+n2hd+nd2)


参考:
传智博客-Transformer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/976888.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Langchain使用介绍之 - 基于向量存储进行检索

Text Embedding Models 如何将一段Document转换成向量存储到向量数据库中,首先需要了解Langchain提供了哪些将文本转换成向量的model,langchian提供了很多将自然语言转换成向量的的模型,如下图所示,除了下图列举的model&#xff0…

java和js实现MD5加密

java import java.security.MessageDigest;public class Demo2 {public static void main(String[] args) {Demo2 demo2 new Demo2();String encry demo2.md5("admin");System.out.println("加密后:" encry);}/*** md5加密*/private static…

正则性能提升之-Matcher.appendReplacementappendTail使用(别再无脑用字符串替换啦)

首先是用法: appendReplacement是java中替换相应字符串的一个方法 appendReplacement(StringBuffer sb,String replacement) 将当前匹配子串替换为指定字符串,并且将替换后的子串以及其之前到上次匹配子串之后的字符串段添加到一个 StringBuffer 对象里…

oracle将一个用户的表复制到另一个用户

注:scott用户和scott用户下的源表(EMP)本身就有,无需另行创建。 GRANT SELECT ON SCOTT.emp TO BI_ODSCREATE TABLE ODS_EMP AS SELECT * FROM SCOTT.emphttp://www.bxcqd.com/news/77615.html

Java多线程基础(创建、使用,状态)——Java第九讲

前言 这一讲开始我们将进入java高级部分,包括多线程编程、数据结构、并发编程、设计模式等。本讲先介绍多线程,多线程编程是Java编程中的一个重要部分。它允许程序同时执行多个任务,这有助于提高程序的效率和性能。在Java中,可以通过实现Runnable接口或继承Thread类来创建线…

android studio安卓模拟器高德SDK定位网络连接异常

背景 使用了高德SDK创建了一个 project, 下面是运行界面: 点击 "开始定位"按钮, 结果并没有返回定位信息, 而是报错了: 根据错误提示打开这个网址: https://lbs.amap.com/api/android-location-sdk/guide/utilities/errorcode, 并且找到错误码 4 的信息, 显示的是网…

zabbix -- 新建主机

目录 一、新建主机 二、新建监控项 IP主机192.168.136.55zabbix控制端/服务端192.168.136.56zabbix被控端/客户端 一、新建主机 主机参数 名称、群组(每台主机必须属于某个主机组内)、ip、端口 创建完成,如果你的ZBX为灰色,代…

SOLIDWORKS倒角是什么?

在现代工程设计中,倒角是一项常见而重要的工艺。它不仅可以提升产品的外观美观度,还能改善产品的强度和耐用性。SOLIDWORKS作为一款广泛应用于3D建模和设计的软件,提供了强大的倒角功能,使工程师能够轻松地在设计过程中添加和编辑…

基于SSM的线上旅行信息管理系统

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用Vue技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

Python 中下划线详解(_、_xx、xx_、__xx、__xx__)

文章目录 1 概述2 演示2.1 _:不重要的变量2.2 _xx:声明私有,仅内部调用2.3 xx_:区分关键字2.4 __xx:声明私有,仅当前类中可用2.5 __xx__:内置函数 1 概述 2 演示 2.1 _:不重要的变…

兼容性测试基本原则是什么

兼容性测试是计算机软件测试过程中的一项重要活动,旨在验证软件在不同平台、操作系统、网络环境、硬件设备或软件版本之间的正确运行和兼容性。那么,兼容性测试的基本原则是什么?下面,就来看看具体介绍吧! 兼容性测试的基本原则:…

zookeeper/HA集群配置

1.zookeep配置 1.1 安装4台虚拟机 (1)按照如下设置准备四台虚拟机,其中三台作为zookeeper,配置每台机器相应的IP,hostname,下载vim,ntpdate配置定时器定时更新时间,psmisc&#xff…

11.3.1-使用Pythton抓取股票基金数据

文章目录 1. 哪些方式获取股票数据1.1. yifinance1.2. JoinQuant聚宽1.3. tushare1.4. 自己动手,丰衣足食 2. 使用python抓取数据2.1. 查看请求报文2.2. 解析返回报文2.3. 数据存储2.4. 开始python代码编写2.4.1. 构造时间区间2.4.2. requests调用2.4.3. 数据存储 2…

让API开发更高效——Apipost

作为一款专为API开发设计的工具,Apipost凭借其强大的功能和高效的特点,正逐渐受到越来越多开发者的欢迎。本文将向您详细介绍Apipost的独特优势以及如何让您的API开发更加高效。 Apipost适用于所有与API开发相关的从业者,包括但不限于前端工…

Nor flash 页写地址与数据大小的限制

厂商提供的flash手册如下 如果页写指令的地址不是256的整数倍,并且写入的数据量超过了当前地址所在页的边界,则超过的那些数据会重新写入当前页的首地址(即256的整数倍地址),所以,在进行页写的时候&#x…

Unity Shader着色器知识

学习3D开发技术的时候无可避免的要接触到Shader,那么Shader是个什么概念呢?其实对于开发同事来说还是比较难理解的,一般来说Shader是服务于图形渲染的一类技术,开发人员可以通过其shader语言来自定义显卡渲染页面的算法&#xff0…

Django学习

1、启动项目 python manage.py runserversettings.py

微信小程序新建页面文件

1、在app.json->pages中新增页面的存放路径 list文件夹之前是直接右键加上去,后面删掉了,利用上述操作新增,只出现了两个文件。暂时还不清楚需要怎样才能正式生成4个文件

【STM32】锁存器

问题背景 在学习FSMC控制外部NOR存储器时,看到在NOR复用接口模式下,AD信号[15:0]是复用的。也就是说,若不使用锁存器:当NADV为低时,ADx(x0…15)上出现地址信号Ax,当NADV变高时,ADx上出现数据信号Dx。若使用…

基于深度学习网络的火灾检测算法matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022a 3.部分核心程序 ................................................................................ load F…