Transformer详解encoder

news2025/1/21 15:39:07

目录

1. Input Embedding

2. Positional Encoding

3. Multi-Head Attention

4. Add & Norm

5. Feedforward + Add & Norm

6.代码展示

(1)layer_norm

(2)encoder_layer=1


最近刚好梳理了下transformer,今天就来讲讲它~

        Transformer是谷歌大脑2017年在论文attention is all you need中提出来的seq2seq模型,它的本质就是由编码器和解码器组成,今天的主角则是其中的编码器(在BERT预训练模型中也只用到了编码器部分)如下图所示,这个模块的输入为 𝑋 (每一行代表一个句子,batchsize有多大就有多少行),我们将从输入到隐藏层按照从1到4的顺序逐层来看一下各个维度的变化。

1. Input Embedding

        所谓的Embedding其实就是查字典或者叫查表,也就是将一个句子里的每一个字转化为一个维度为embedding dimension的向量来表示,因此 𝑋 经过嵌入后变成 𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔 ,三个维度分别表示一个批次的句子数,每个句子的字数,每个字的嵌入维度。

2. Positional Encoding

        位置编码,按照字面意思理解就是给输入的位置做个标记,简单理解比如你就给一个字在句子中的位置编码1,2,3,4这样下去,高级点的比如作者用的正余弦函数

𝑃𝐸(𝑝𝑜𝑠,2𝑖)=𝑠𝑖𝑛(𝑝𝑜𝑠/100002𝑖/𝑑𝑚𝑜𝑑𝑒𝑙)

𝑃𝐸(𝑝𝑜𝑠,2𝑖+1)=𝑐𝑜𝑠(𝑝𝑜𝑠/100002𝑖/𝑑𝑚𝑜𝑑𝑒𝑙)

 

        其中pos表示字在句子中的位置,i指的词向量的维度。经过位置编码,相当于能够得到一个和输入维度完全一致的编码数组 𝑋𝑝𝑜𝑠 ,当它叠加到原来的词嵌入上得到新的词嵌入

𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔=𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔+𝑋𝑝𝑜𝑠

        此时的维度为:一个批次的句子数 × 一个句子的词数 × 一个词的嵌入维度

3. Multi-Head Attention

        注意力机制,其实可以理解为就是在计算相关性,很自然的想法就是去更多地关注那些相关更大的东西。这里首先要引入Query,Key和Value的概念,Query就是查询的意思,Key就是键用来和你要查询的Query做比较,比较得到一个分数(相关性或者相似度)再乘以Value这个值得到最终的结果。

        那么这个Q,K,V从哪里来呢,这里采用的是self-attention的方式,也就是从输入自己 𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔 来产生,即做线性映射产生Q,K,V:

𝑄=𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔∗𝑊𝑄𝐾=𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔∗𝑊𝐾𝑉=𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔∗𝑊𝑉

        这里三个权重矩阵均为维度为Embedding的方阵,也就是说Q,K,V的维度和 𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔 是一致的。

        接下来考虑什么叫做multi-head(多头)呢,本质上就是从embedding的维度上将矩阵切分为多份,每一份就是一个头,比如之前的Q,K,V切完后的维度就是一个批次的句子数 × 一个句子的词数 × 头数 × (词嵌入维度/头数)这个多头的切分体现在最后两个维度:词嵌入维度=数 × (词嵌入维度/头数)为了便于计算,通常会将第二第三维度进行转置,即最终的维度为一个批次的句子数 × 头数 × 一个句子的词数 × (词嵌入维度/头数)

        接下来说说注意力机制的计算,假设Q,K,V为切分完后的矩阵(其中一个头),根据两个向量的点积越大越相似,我们通过 𝑄𝐾𝑇 求出注意力矩阵,再根据注意力矩阵来给Value进行加权,即

𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛(𝑄,𝐾,𝑉)=𝑠𝑜𝑓𝑡𝑚𝑎𝑥(𝑄𝐾𝑇𝑑𝑘)𝑉

        其中 𝑑𝑘 是为了把注意力矩阵变成标准正态分布,softmax进行归一化,使每个字与其他字的注意力权重之和为1。这一操作使得每一个字的嵌入都包含当前句子内所有字的信息,注意Attention(Q,K,V)的维度和 𝑉 的维度保持一致。

4. Add & Norm

这里主要做了两个操作

  • 一个是残差连接(或者叫做短路连接),说得直白点就是把上一层的输入 𝑋 和上一层的输出加起来 𝑆𝑢𝑏𝐿𝑎𝑦𝑒𝑟(𝑋) ,即 𝑋+𝑆𝑢𝑏𝐿𝑎𝑦𝑒𝑟(𝑋) ,举例说明,比如在注意力机制前后的残差连接:

𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔+𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛(𝑄,𝐾,𝑉)

  • 一个是LayerNormalization(作用是把神经网络中隐藏层归一为标准正态分布,加速收敛),具体操作是将每一行每一个元素减去这行的均值, 再除以这行的标准差, 从而得到归一化后的数值。

5. Feedforward + Add & Norm

前馈网络也就是简单的两层线性映射再经过激活函数一下,比如

𝑋ℎ𝑖𝑑𝑑𝑒𝑛=𝑅𝑒𝑙𝑢(𝑋ℎ𝑖𝑑𝑑𝑒𝑛∗𝑊1∗𝑊2)

残差操作和层归一化同步骤3.


上述的1,2,3,4就构成Transformer中的一个encoder模块,经过1,2,3,4后得到的就是encode后的隐藏层表示,可以发现它的维度其实和输入是一致的!即:一个批次中句子数 × 一个句子的字数 × 字嵌入的维度

6.代码展示

(1)layer_norm

bs=2,seq=3,dim=5

import torch

batch_size = 2
seq = 3
fea_dim = 5
X = torch.rand(batch_size,seq,fea_dim)
layer_norm = torch.nn.LayerNorm(fea_dim)
out = layer_norm(X)
print(out)
print('-'*30)

mean = torch.mean(X,dim=-1,keepdim=True)
std = torch.sqrt(torch.var(X,unbiased=False,dim=-1,keepdim=True) + 1e-5)
weight = layer_norm.state_dict()['weight']
bias = layer_norm.state_dict()['bias']
my_norm = ((X - mean)/std) * weight + bias
print(my_norm)

(2)encoder_layer=1

bs=1,seq=1,dim=6,head=1

import torch

seq = 1
dim = 6
heads = 1
batch_size = 1
value = torch.rand(batch_size,seq,dim)

encoder_layer = torch.nn.TransformerEncoderLayer(dim,heads,dropout=0.0,batch_first=True)
out = encoder_layer(value)
print(out)

# 多头自注意力
def my_scaled_dot_product(query,key,value):
    qk_T = torch.mm(query,key.T)
    qk_T_scale = qk_T / torch.sqrt(torch.tensor(value.shape[1]))
    qk_exp = torch.exp(qk_T_scale)
    qk_exp_sum = torch.sum(qk_exp,dim=1,keepdim=True)
    qk_softmax = qk_exp / qk_exp_sum
    v_attn = torch.mm(qk_softmax,value)
    return v_attn,qk_softmax

in_proj_weight = encoder_layer.state_dict()['self_attn.in_proj_weight']
in_proj_bias = encoder_layer.state_dict()['self_attn.in_proj_bias']

out_proj_weight = encoder_layer.state_dict()['self_attn.out_proj.weight']
out_proj_bias = encoder_layer.state_dict()['self_attn.out_proj.bias']

batch_V_output = torch.empty(batch_size,seq,dim)
for i in range(batch_size):
    in_proj = torch.mm(value[i],in_proj_weight.T) + in_proj_bias
    Qs,Ks,Vs = torch.split(in_proj,dim,dim=-1)
    head_Vs = []
    attn_weight = torch.zeros(seq,seq)
    for Q,K,V in zip(torch.split(Qs,dim//heads,dim=-1),torch.split(Ks,dim//heads,dim=-1),torch.split(Vs,dim//heads,dim=-1)):
        head_v,_ = my_scaled_dot_product(Q,K,V)
        head_Vs.append(head_v)
    V_cat = torch.cat(head_Vs,dim=-1)
    V_ouput = torch.mm(V_cat,out_proj_weight.T) + out_proj_bias
    batch_V_output[i] = V_ouput

# 第一次加
first_Add = value + batch_V_output

# 第一次layer_norm
norm1_mean = torch.mean(first_Add,dim=-1,keepdim=True)
norm1_std = torch.sqrt(torch.var(first_Add,unbiased=False,dim=-1,keepdim=True) + 1e-5)
norm1_weight = encoder_layer.state_dict()['norm1.weight']
norm1_bias = encoder_layer.state_dict()['norm1.bias']
norm1 = ((first_Add - norm1_mean)/norm1_std) * norm1_weight + norm1_bias

# feed forward
linear1_weight = encoder_layer.state_dict()['linear1.weight']
linear1_bias = encoder_layer.state_dict()['linear1.bias']
linear2_weight = encoder_layer.state_dict()['linear2.weight']
linear2_bias = encoder_layer.state_dict()['linear2.bias']
linear1 = torch.matmul(norm1,linear1_weight.T) + linear1_bias
linear1_relu = torch.nn.functional.relu(linear1)
linear2 = torch.matmul(linear1_relu,linear2_weight.T) + linear2_bias

# 第二次加
second_Add = norm1 + linear2

# 第二次layer_norm
norm2_mean = torch.mean(second_Add,dim=-1,keepdim=True)
norm2_std = torch.sqrt(torch.var(second_Add,unbiased=False,dim=-1,keepdim=True) + 1e-5)
norm2_weight = encoder_layer.state_dict()['norm2.weight']
norm2_bias = encoder_layer.state_dict()['norm2.bias']
norm2 = ((second_Add - norm2_mean)/norm2_std) * norm2_weight + norm2_bias
print(norm2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1880454.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深入理解PHP命名空间

在PHP项目中,命名空间(namespace)是一个非常重要的特性。它不仅帮助开发者组织代码,还能避免类、函数、常量等命名冲突问题。本文将详细介绍PHP命名空间的概念、使用方法和最佳实践。 一、什么是命名空间? 命名空间…

LeetCode:经典题之2、445 题解及延伸

系列目录 88.合并两个有序数组 52.螺旋数组 567.字符串的排列 643.子数组最大平均数 150.逆波兰表达式 61.旋转链表 160.相交链表 83.删除排序链表中的重复元素 389.找不同 1491.去掉最低工资和最高工资后的工资平均值 896.单调序列 206.反转链表 92.反转链表II 141.环形链表 …

github主页这样优化,让人眼前一亮

我的主页(一之十六) 1. 创建与账户ID同名的仓库 注意:记得勾选Add a README file 2. markdown语法自定义README.md 3. 辅助工具 优秀profile:https://zzetao.github.io/awesome-github-profile/动态文字:https://r…

pytest测试框架pytest-cov插件生成代码覆盖率

Pytest提供了丰富的插件来扩展其功能,本章介绍下pytest-cov插件,用于生成测试覆盖率报告,帮助开发者了解哪些部分的代码被测试覆盖,哪些部分还需要进一步的测试。 pytest-cov 支持多种报告格式,包括纯文本、HTML、XML …

修复vcruntime140.dll方法分享

修复vcruntime140.dll方法分享 最近在破解typora的时候出现了缺失vcruntime140.dll文件的报错导致软件启动失败。所以找了一番资料发现都不是很方便的处理,甚至有的dll处理工具还需要花钱????,我本来就是为…

前端学习 Vue 插槽如何实现组件内容分发?

目录 一、Vue.js框架介绍二、什么是Vue 插槽三、Vue 插槽的应用场景四、Vue 插槽如何实现组件内容分发 一、Vue.js框架介绍 Vue.js是一个用于构建用户界面的渐进式JavaScript框架。它设计得非常灵活,可以轻松地被集成到现有的项目中,也可以作为一个完整…

新型发电系统——光伏行业推动能源转型

一、发展背景 “十四五”期间,随着“双碳”目标提出及逐步落实,本就呈现出较好发展势头的分布式光伏发展有望大幅提速。就“十四五”光伏发展规划,国家发改委能源研究所可再生能源发展中心副主任陶冶表示,“双碳”目标意味着国家…

轻松解锁电脑强悍性能,4000MHz的玖合星舞 DDR4 内存很能打

轻松解锁电脑强悍性能,4000MHz的玖合星舞 DDR4 内存很能打 哈喽小伙伴们好,我是Stark-C~ 很多有经验的电脑玩家在自己DIY电脑选购内存条的时候,除了内存总容量,最看重的参数那就是频率了。内存频率和我们常说的CPU主频一样&…

计网之IP

IP IP基本认识 不使用NAT时,源IP地址和目的IP地址不变,只要源MAC和目的MAC地址在变化 IP地址 D类是组播地址,E类是保留地址 无分类地址CIDR 解决直接分类的B类65536太多,C类256太少a.b.c.d/x的前x位属于网路号,剩…

pytest测试框架pytest-random-order插件随机执行用例顺序

Pytest提供了丰富的插件来扩展其功能,本章介绍下pytest-random-order插件,随机设置pytest测试用例的运行顺序,并对随机性进行一些控制。 官方文档: https://pytest-cov.readthedocs.io/en/latest/index.html 适配版本说明&#x…

ComfyUI局部重绘的四种方式 (附件工作流在最后)

前言 局部重绘需要在图片中选择重绘区域,点击图片右击选择Open in MaskEditor(在蒙版编辑器中打开),用鼠标描绘出需要重绘的区域 方式一:重绘编码器 这种方式重绘比较生硬,需要额外搭配使用才行 方式二&…

ThreadPoolExecutor基于ctl变量的声明周期管理

个人博客 ThreadPoolExecutor基于ctl变量的声明周期管理 | iwts’s blog 总集 想要完整了解下ThreadPoolExecutor?可以参考: 基于源码详解ThreadPoolExecutor实现原理 | iwts’s blog ctl字段的应用 线程池内部使用一个变量ctl维护两个值&#xff…

逆变器--学习笔记(一)

并网: 逆变器中的“并网”指的是逆变器将其产生的交流电与电网同步,并输送到公共电网中。并网逆变器通常用于太阳能发电系统和其他分布式发电系统,将其产生的电能输送到电网供其他用户使用。 THD谐波失真总量: 逆变器的THD(Tot…

【TB作品】温度DS18B20读取,温控风扇,ATMEGA128单片机,Proteus仿真

读取温度; PWM风扇控制; 蜂鸣器控制。 写博客介绍这个基于ATmega128的作品时,可以从以下几个方面展开描述: 概述 介绍项目的背景和目的,说明使用ATmega128的原因以及项目的整体架构。 硬件设计 主要元件 详细列出…

180Kg大载重多旋翼无人机技术详解

一、机体结构与材料 180Kg大载重多旋翼无人机在机体结构上采用了高强度轻量化设计。其主体框架采用航空铝合金材料,既保证了机体的结构强度,又减轻了整体重量。同时,关键部位如连接件、旋翼支撑臂等则采用碳纤维复合材料,以进一步…

主流电商平台API接口(天猫获得淘宝商品详情,获得淘宝app商品详情原数据 ,获得淘口令真实url API,按图搜索淘宝商品(拍立淘) API )

主流电商平台商品接口在电商企业中具有重要应用价值。通过商品接口,电商企业可以实现商品同步功能: 商品信息同步:通过接口可以实时同步主流电商平台上的商品信息,包括商品标题、价格、库存、销量等数据,确保企业在自…

微服务中的Docker详细学习

Docker的个人理解 首先我对于Docker的理解分为两部分,第一是对名字上的理解,我们都知道docker的英文翻译是“码头工人”的意思,所以我们也可以理解为docker是码头上的一个个集装箱的使用。这也与他的图标很相似。其次我是对于其功能上的理解&…

Excel显示/隐藏批注按钮为什么是灰色?

在excel中,经常使用批注来加强数据信息的提示,有时候会把很多的批注显示出来,但是再想将它们隐藏起来,全选工作表后,“显示/隐藏批注”按钮是灰色的,不可用。 二、可操作方法 批注在excel、WPS表格中都是按…

002-基于Sklearn的机器学习入门:回归分析(上)

本节及后续章节将介绍机器学习中的几种经典回归算法,所选方法都在Sklearn库中聚类模块有具体实现。本节为上篇,将介绍基础的线性回归方法,包括线性回归、逻辑回归、多项式回归和岭回归等。 2.1 回归分析概述 回归(Regression&…

【知识学习】Unity3D中Scriptable Render Pipeline的概念及使用方法示例

Unity3D中的Scriptable Render Pipeline(SRP)是一种高度可定制的渲染管线框架,允许开发者完全控制渲染流程,以适应不同的渲染需求和硬件平台。SRP使得开发者可以编写自己的渲染逻辑,包括摄像机管理、渲染设置、光照处理…