跟李沐学AI:Transformer

news2024/10/17 19:46:47

Transformer架构

(图源:10.7. Transformer — 动手学深度学习 2.0.0 documentation) 

基于编码器-解码器架构来处理序列对

与使用注意力的seq2seq不同,Transformer纯基于注意力

多头注意力(Multi-Head Attention)

(图源:10.5. 多头注意力 — 动手学深度学习 2.0.0 documentation) 

对于同一QKV,我们希望抽取不同的信息。多头注意力机制相比于标准的注意力机制,允许模型捕捉一个序列的不同方面并对不同方面生成不同权重。

假设输入序列为: I love palying football.

假设embedding_size=3并且使用2个注意力头,序列可表示为一个(4, 3)的矩阵(sequence_length, embed_size)。

[ I ]         [ love ]       [ playing ]    [ football ]
[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9], [1.0, 1.1, 1.2]]  # Shape: (4, 3)

首先我们需要把embedding投影为Q,K,V。

权重矩阵的形状由输入特征数和头数共同决定,不同头的权重矩阵的形状是相同的。

W_Q^{i} \epsilon R^{d_{feature}\times d_{head}}

W_K^{i} \epsilon R^{d_{feature}\times d_{head}}

W_V^{i} \epsilon R^{d_{feature}\times d_{head}}

d_{feature}是输入的embedding_size,在本例中为3。

d_{head}=d_{feature}/h。其中h是注意力头的数量。

故本例中权重矩阵的形状为(3, 1),其中1为向下取整得到。

对于头1:

  • Q1 = Wq1 * Embedding
  • K1 = Wk1 * Embedding
  • V1 = Wv1 * Embedding

假设头1的权重为:

W_Q^{head1} =[[0.1],[0.2],[0.3]]

W_K^{head1} =[[0.4],[0.5],[0.6]]

W_V^{head1} =[[0.7],[0.8],[0.9]]

那么我们可以列出输入embedding变换至Q的计算过程

Q_I^{\text{head1}} = [0.1, 0.2, 0.3] * [[0.1], [0.2], [0.3]] = [0.14]

Q_{love}^{\text{head1}} = [0.32]

Q_{playing}^{\text{head1}} = [0.5]

Q_{football}^{\text{head1}} = [0.68]

矩阵表示为:

Q^{\text{head1}} = [[0.14], [0.32], [0.5], [0.68]]

同理可得:

K^{\text{head1}} = [[0.32], [0.77], [1.22], [1.67]]

V^{\text{head1}} = [[0.38], [0.92], [1.46], [2.0]]

对于头2,我们需要不同的权重,计算过程相同

随后对于每个头,对Q,K进行点积并进行softmax计算注意力权重,将权重与V相乘计算加权和。本例中,QK点积的结果为一个(4, 4)的矩阵。将注意力权重与V相乘得到(4, 1)的注意力输出。

计算了每个头的注意力后,对输出进行拼接(concatenate)操作。再将拼接后的结果输入另一个线性层以得到最后的输出。本例中两个注意力头拼接后得到(4, 2)的输出。随后使用一个线性层将结果转换为需要的维度(4, 3)。

过程总结:

  • 输入维度: (4, 3) (4 词元, 嵌入维度为 3).
  • 每个头的Q、K、V: (4, 1) (4 tokens, 1 dimension per head).
  • 每个头的输出:(4, 1).
  • 每个头拼接后:(4, 2).
  • 线性映射后的输出:(4, 3).

基于位置的前馈网络FFN

将输入形状由(b, n, d)变为(bn, d)

输出形状由(bn, d)变回(b, n, d)

FFN等价于两层kernel_size为1的一维卷积层

将输入由(b, n, d)变为(bn, d)有如下原因:

  • 全连接层要求输入为2维,一维为batch_size,二维表示特征。
  • 将(b, n, d)变为(bn, d)而不是(b, nd)是为了保留每个元素在序列中的位置信息(位置编码提供的位置信息)。
  • 可以有效地将所有批次序列中的每个位置视为具有d特征的单个样本。这允许密集层对序列中的每个位置应用相同的权重集合。
  • 同时保障了并行计算,(bn, d)可以并行处理多个序列,可以同时对整个序列进行操作,而不是按顺序处理每个元素,这大大加快了计算速度并减少了整体训练时间。

层归一化

什么是层归一化?

对每个样本的所有feature做归一化。意思是对每个样本独立地计算均值和方差。

假设输入x=[[1, 2], [3, 4], [5, 6]]为(3, 2)的矩阵,样本归一化对每行计算均值和方差:

对第一个样本:均值为(1+2)/2=1.5,方差为(1-1.5)^2+(2-1.5)^2=0.25

对每行计算并归一化后得到输出\hat{x}=[[-1, 1], [-1, 1], [-1, 1]]

对比BatchNorm

Batch normalization对输入的所有batch的每个feature批量计算均值和方差并进行归一化。对每个feature独立地计算均值和方差。

同样假设x=[[1, 2], [3, 4], [5, 6]],对于第一个特征计算均值和方差:

mean=(1+3+5)/3=3variance=(1-3)^2+(3-3)^3+(5-3)^2=4

对每列计算并归一化后的到输出:\hat{x}=[[-1, -1], [0, 0], [1, 1]]

为什么是有layer norm而不是batch norm?

Transform的输入一般为一个序列,但是每个序列的有效长度(valid length)有所不同,会导致batch norm不稳定,故不适合batch norm。

对每个样本的元素做归一化,在长度变化时更加稳定。

信息传递

编码器中的输出y_1,\dots,y_n作为解码器中第i个Transformer块中多头注意力的K和V,解码器的Q来自目标序列。

说明编码器和解码器中的Block个数和输出维度是相同的。

预测

预测时,解码器的掩码注意力和多头注意力的Q、K和V的来源分别如下:

  • Masked Attention:

    • Q: 解码器在时间步 t 的状态(之前生成的tokens的表示)。
    • K: 解码器在时间步 t 的状态(之前生成的tokens的表示)。
    • V: 解码器在时间步 t 的状态(之前生成的tokens的表示)。
  • Multi-Head Attention(交叉注意力):
    • Q: 解码器在时间步 ttt 的状态。
    • K: 编码器的最终输出(输入序列的上下文表示)。
    • V: 编码器的最终输出(输入序列的上下文表示)。

在训练时,解码器的Masked Attention和Multi-Head Attention的Q、K和V分别由以下组成:

  • Masked Attention:

    • Q: 解码器当前时间步的输入向量(包括所有已生成的tokens的表示)。
    • K: 解码器当前时间步的输入向量(包括所有已生成的tokens的表示)。
    • V: 解码器当前时间步的输入向量(包括所有已生成的tokens的表示)。
  • Multi-Head Attention(交叉注意力):

    • Q: 解码器当前时间步的输入向量(表示已生成的tokens)。
    • K: 编码器的最终输出(输入序列的上下文表示)。
    • V: 编码器的最终输出(输入序列的上下文表示)。

基于Pytorch的Transformer

MultiHeadAttention

class MultiHeadAttention(nn.Module):
    # num_hiddens: 输入的特征数
    def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        # heads的数量
        self.num_heads = num_heads
        # 点积注意力
        self.attention = d2l.DotProductAttention(dropout)
        # 计算Q的权重矩阵
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        # 计算K的权重矩阵
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        # 计算V的权重矩阵
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        # 对注意力输出进行变换的矩阵
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # QKV的shape:(batch_size, nums, num_hiddens)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads) 
        values = transpose_qkv(self.W_v(values), self.num_heads)
        
        # valid_lens的shape:(batch_size, ) or (batch_size, num_query)
        if valid_lens is not None:
            # 每个头都需要valid_lens的信息,因此需要复制num_heads次
            # repeats=self.num_heads, dim=0 表示沿着第一个维度(即 batch 维度)将 valid_lens 的每一项都重复 num_heads
            # 由于每个头都需要知道它正在处理的序列的有效长度,因此我们需要将 valid_lens 对应地复制给每一个头。
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)

        # output的shape:(batch_size*num_heads, num_query, num_hiddens/num_heads)
        # 将不同头的QKV合并成为一个大矩阵进行计算,提升计算并行度。即K=K1 concat K2 concat .. Kn
        output = self.attention(queries, keys, values, valid_lens)
        # output_concat的形状:(batch_size, num_query, num_hiddens)
        # 将不同头的结果拼接
        output_concat = transpose_output(output, self.num_heads)
        # 对注意力输出进行线性变换后输出
        return self.W_o(output_concat)
        
# 两个辅助函数
#@save
def transpose_qkv(X, num_heads):
    # 为了多注意头并行计算而变换形状
    # X的输入为:(batch_size, num_query, num_hiddens)
    # 输出为:(batch_szie, num_query, num_heads, num_hiddnes / num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    # 交换heads维和num_query维
    # 输出为:(batch_size, num_heads, num_query, num_hiddens / num_heads)
    X = X.permute(0, 2, 1, 3)
    # 将输出变为(batch_size * num_heads, num_query, num_hiddens / num_heads)用于并行计算
    return X.reshape(-1, X.shape[2], X.shape[3])

#@save
def transpose_output(X, num_heads):
    # 将合并的X还原num_heads维度
    # 输入为:(batch_size, num_heads, num_query, num_hiddens/num_heads)
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    # 转换为:(batch_size, num_query, num_heads, num_hiddens / num_heads)
    X = X.permute(0, 2, 1, 3)
    # 转换为原格式:(batch_size, num_query, num_hiddens)
    return X.reshape(X.shape[0], X.shape[1], -1)
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()

batch_size, num_queries = 2, 4
# valid_lens表示第一个batch的有效长度为3,第二个为2
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
'''
在编码器-解码器架构中,比如在Transformer模型里,解码器部分的每个时间步都会产生一个查询向量去与编码器的输出进行交互。
查询的数量通常是解码器当前时间步的数目,或者是整个目标序列的长度(如果一次性处理整个序列的话)。
对于自注意力机制(self-attention)(编码器中),每个位置上的元素都作为一个查询去与其他所有元素交互,所以这里的查询数量就是输入序列的长度。
'''
X = torch.ones((batch_size, num_queries, num_hiddens))
'''
在编码器中的自注意力层,键-值对的数量等同于输入序列的长度。
在解码器中,键-值对可以来自于编码器的输出,这时它们的数量就是编码器输出序列的长度;或者来自解码器自身的先前输出,
这时它们的数量将是解码器已经产生的序列长度。
'''
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
'''
Q与KV数量不相同,如何进行注意力计算?
Q.shape = (batch_size, num_queries, d_model) 
K.shape = (batch_size, num_kvpairs, d_model) 
Q @ K =  (batch_size, num_queries, num_kvpairs) = A
将结果进行softmax后再与V相乘
V.shape = (batch_size, num_kvpairs, d_model) 
A @ K = (batch_size, num_queries, d_model) 
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2211676.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MyBatis环境配置详细过程

在此之前我们需要知道Maven和依赖是什么。 什么是Maven? Maven 是一个项目管理和构建自动化工具,最初由Apache软件基金会开发,主要用于Java项目的构建、依赖管理、文档生成和发布。Maven使用一种基于XML的配置文件(pom.xml&…

vue后台管理系统从0到1(6)引入pinia实现折叠功能

文章目录 vue后台管理系统从0到1(6)引入pinia实现折叠功能分析:安装并使用 pinia vue后台管理系统从0到1(6)引入pinia实现折叠功能 分析: 首先,接着上一期,我们项目启动起来应该是…

【算法思想·二叉树】用「遍历」思维解题 II

本文参考labuladongsuanfa笔记[【强化练习】用「遍历」思维解题 II | labuladong 的算法笔记] 如果让你在二叉树中的某些节点上做文章,一般来说也可以直接用遍历的思维模式。 270. 最接近的二叉搜索树值 | 力扣 | LeetCode | 给你二叉搜索树的根节点 root 和一个目…

通信工程学习:什么是SDRAM同步动态随机存取存储器

SDRAM:同步动态随机存取存储器 SDRAM,全称为Synchronous Dynamic Random Access Memory,即同步动态随机存取存储器,是一种广泛应用于计算机和嵌入式系统中的内存技术。以下是对SDRAM的详细介绍: 一、SDRAM的定义与特点…

TimeGen3.2

一、安装 1.安装包下载 软件安装压缩包,点击链接下载,自取。 链接: https://pan.baidu.com/s/1kebJ2z8YPMhqyvDiHLKktw?pwd0000 提取码: 0000 二、解压安装 1.解压 2.安装软件 (1)双击timegen-pro-3.2.exe文件 &#xff…

[CTF夺旗赛] CTFshow Web13-14 详细过程保姆级教程~

前言 ​ CTFShow通常是指网络安全领域中的“Capture The Flag”(夺旗赛)展示工具或平台。这是一种用于分享、学习和展示信息安全竞赛中获取的信息、漏洞利用技巧以及解题思路的在线社区或软件。参与者会在比赛中收集“flag”,通常是隐藏在网络环境中的数据或密码形…

SHCTF-2024-week1-wp

文章目录 SHCTF 2024 week1 wpMisc[Week1]真真假假?遮遮掩掩![Week1]拜师之旅①[Week1]Rasterizing Traffic[Week1]有WiFi干嘛不用呢? web[Week1] 单身十八年的手速[Week1] MD5 Master[Week1] ez_gittt[Week1] jvav[Week1] poppopop[Week1] 蛐蛐?蛐蛐! SHCTF 2024…

一些自定义函数

目录 一.strcmp()函数 二.strstr()函数 三.memcpy函数 四.memmove函数 五.strncpy函数 六.strcat函数 七.atoi函数 八.strlen函数 一.strcmp()函数 strcmp 函数是用于比较两个字符串是否相等的函数。它通过逐个字符比较两个字符串的 ASCII 值,来判断它们的相…

QD1-P3 HTML 基础内容介绍

本节学习&#xff1a;HTML基础语法介绍。 本节视频 www.bilibili.com/video/BV1n64y1U7oj?p3 ‍ 一、运行HTML代码 在 HBuilderX编辑器中创建空项目&#xff0c;添加一个 html 文件 <!-- QD1-P3 HTML基础语法 --><!DOCTYPE html> <html><head>&l…

Java面试宝典-并发编程学习01

Java 并发编程学习 1、创建线程的有哪些方式&#xff1f; 创建线程的方式有以下几种&#xff1a; 1. 继承Thread类&#xff1a;创建一个类继承Thread类&#xff0c;并重写run()方法&#xff0c;然后通过创建类的实例来创建线程。 2. 实现Runnable接口&#xff1a;创建一个类实…

PH47框架下BBDB飞控基础开发平台极简教程

1 硬件准备 1.1 一块WeAct 的Stm32F411核心板 1.2 2个USB-TTL模块。 1.3 Stm32开发所必须的如STlink烧写器。 1.4 必要的线材。 2 软件准备 2.1 Stm32开发所必备的Keil开发环境。 2.2 PH47框架代码&#xff0c;下载链接 2.3 CSS及BBDB 控制站工程&#xff0c;下载链接 2.4…

鸿蒙面试的一些可能问到的点

页面跳转 router 鸿蒙中跳转主要有两种&#xff0c;一种是router&#xff0c;一种是Navigation&#xff0c;官方推荐使用Navigation。 Router适用于模块间与模块内页面切换&#xff0c;通过每个页面的url实现模块间解耦。模块内页面跳转时&#xff0c;为了实现更好的转场动效…

7.2-I2C的DMA中断

I2C的DMA中断 请先阅读完I2C的普通中断模式以后再阅读本教程 i2c的DMA模式 1.添加通道 &#xff0c;添加后的参数保持默认 2.可以看到自动给我们DMA添加了中断向量。 保存后只需要将下面_ IT改为_ DMA即可 运行代码 i2c1) { aht20State 4; } } /* USER CODE END 0 */ 以上就…

ssm基于java的网上手机销售系统

系统包含&#xff1a;源码论文 所用技术&#xff1a;SpringBootVueSSMMybatisMysql 免费提供给大家参考或者学习&#xff0c;获取源码请私聊我 需要定制请私聊 目 录 目 录 III 1 绪论 1 1.1 研究背景 1 1.2 目的和意义 1 1.3 论文结构安排 2 2 相关技术 3 2.1 SSM框…

yolov5环境GPU搭建 ,用GPU跑polov5算法

win10NVIDIA GeForce RTX 3050torch1.13.1torchaudio0.13.1torchvision 0.14.1 cuda11.7python3.8cudnn8.7.0 在环境搭建中踩了许多坑&#xff0c;yolov5环境的搭建需要依赖很多环境&#xff0c;用cpu跑很容易跑单张识别&#xff0c;用GPU跑却踩了很多坑&#xff0c;不过GPU环…

Mac 备忘录妙用

之前使用 Windows 的过程中&#xff0c;最痛苦的事是没有一款可以满足我快速进行记录的应用 基本都得先打开该笔记软件&#xff0c;然后创建新笔记&#xff0c;最后才能输入&#xff0c;这么多步骤太麻烦了 在切换到 MacOS 之后&#xff0c;让我惊喜的就是自带的备忘录&#…

【java面经thinking】一

目录 类加载过程 加载&#xff1a; 连接 初始化 GC回收机制&#xff08;垃圾回收&#xff09; 区域 判断对象是否存活 回收机制 HashMap 类加载器 加载标识 加载机制 缓存 自定义加载器&#xff1a; JVM内存结构 常量池 string设置成final 按下网址发生 类加…

C语言有关结构体的知识(后有通讯录的实现)

一、结构体的声明 1.1 结构体的定义 结构体是一些值的集合&#xff0c;这些值被称为成员变量。结构的每个成员可以是不同的类型 1.2 结构体的声明 这里以描述一个学生为例&#xff1a; struct stu {char name[10];//名字int age;//年龄char id[20];//学号char sex[5];//性别 }…

TIM定时器(标准库)

目录 一. 前言 二. 定时器的框图 三. 定时中断的基本结构 四. TIM定时器相关代码 五. 最终现象展示 一. 前言 什么是定时器&#xff1f; 定时器可以对输入的时钟进行计数&#xff0c;并在计数值达到设定值时触发中断。 TIM定时器不仅具备基本的定时中断功能&#xff0c;而且…

【LeetCode】708. 循环有序列表的插入

目录 一、题目二、解法完整代码 一、题目 给定循环单调非递减列表中的一个点&#xff0c;写一个函数向这个列表中插入一个新元素 insertVal &#xff0c;使这个列表仍然是循环非降序的。 给定的可以是这个列表中任意一个顶点的指针&#xff0c;并不一定是这个列表中最小元素的…