transformer代码注解

news2024/9/25 15:24:19

其中代码均来自李沐老师的动手学pytorch中。

class PositionWiseFFN(nn.Module):
    '''
    ffn_num_inputs 4
    ffn_num_hiddens 4
    ffn_num_outputs 8
    '''
    def __init__(self,ffn_num_inputs,ffn_num_hiddens,ffn_num_outputs):
        super(PositionWiseFFN,self).__init__()
        self.dense1 = nn.Linear(ffn_num_inputs,ffn_num_hiddens)#4*4
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens,ffn_num_outputs)#4*8
    def forward(self,X):
        return self.dense2(self.relu(self.dense1(X)))
positionWiseFFN = PositionWiseFFN(4,4,8)
positionWiseFFN.eval()
positionWiseFFN(torch.ones(size=(2,3,4)))[0]

上面的代码为前馈神经网络结构,其实也就是一个全连接层。

class AddNorm(nn.Module):
    def __init__(self,normalized_shape,dropout):
        super(AddNorm, self).__init__()
        self.dropout=nn.Dropout(dropout)
        self.layer_norm=nn.LayerNorm(normalized_shape=normalized_shape)
    def forward(self,x,y):
        return self.layer_norm(self.dropout(y)+x)
#比如[3, 4]或torch.Size([3, 4]),则会对网络最后的两维进行归一化,且要求输入数据的最后两维尺寸也是[3, 4]
add_norm = AddNorm(normalized_shape=[3,4],dropout=0.5)
add_norm.eval()
add_norm(torch.ones(size=(2,3,4)),torch.ones(size=(2,3,4)))

这里实现的是残差化和规范化。nn.LayerNorm(normalized_shape=normalized_shape)为layer规范化,其中normalized_shape为[3, 4],对网络最后的两维进行归一化。

class MultiHeadAttention(nn.Module):
    def __init__(self,query_size,key_size,value_size,num_hiddens,num_heads,dropout,bias=False):
        super(MultiHeadAttention, self).__init__()
        self.num_heads=num_heads
        #用独立学习得到的 ℎ 组不同的线性投影(linear projections)来变换查询、键和值
        self.attention=d2l.torch.DotProductAttention(dropout)
        self.W_q=nn.Linear(query_size,num_hiddens,bias=bias)
        self.W_k=nn.Linear(key_size,num_hiddens,bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
#总之就是:我们的Q,K,V的embedding,怎么拆分成k个头的数据,然后放到一个大头中,一遍算出multi_head的值
#这里是一组QKV乘一组W 直接生成特征大小的结果,在切分成8份,放到batch里等价于并行计算
    def forward(self,queries,keys,values,valid_lens):
        # print('----')
        # print(queries)
        queries = transpose_qkv(self.W_q(queries),self.num_heads)
        # print(queries)
        # print('----')
        keys = transpose_qkv(self.W_k(keys),self.num_heads)
        values = transpose_qkv(self.W_v(values),self.num_heads)
        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,
            # 然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(valid_lens,repeats=self.num_heads,dim=0)
        #valid_lens tensor([3, 3, 3, 3, 3, 2, 2, 2, 2, 2])
        # output的形状:(batch_size*num_heads,查询的个数,
        # num_hiddens/num_heads)
        # print(queries.shape)
        # print(keys.shape)
        '''
        queries-->torch.Size([10, 4, 20])
        keys----->torch.Size([10, 6, 20])
        两个批次,每次五个多头注意力,就一共会有十个注意力需要做。得出的矩阵为10*4*6,表示为10次注意力
        每个注意力query和key的矩阵为4*6
              keys  keys  keys  keys  keys  keys
        Query
        Query
        Query
        Query
        在经过mask时 需要将10*4*6的矩阵,转为二维矩阵,就是40*6。
        valid_lens首先会在上面的代码中,扩展至num_heads,然后会在masked_softmax中扩至40大小。
        
        '''
        output = self.attention(queries,keys,values,valid_lens)
        # print('-----')
        # print(output)
        # print('-----')
        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        output_concat = transpose_output(output,self.num_heads)
       # print(output_concat.shape)torch.Size([2, 4, 100])
        return self.W_o(output_concat)
def transpose_qkv(X,num_heads):
    # 2,6,100 2,4,100
    X = X.reshape(X.shape[0],X.shape[1],num_heads,-1)
    # 2,5,6,20  2,5,4,20
    # 输出X的形状: (batch_size,num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    #最终输出的形状: (batch_size * num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
    #10,6,20  10,6,20
    return X.reshape(-1, X.shape[2], X.shape[3])
def transpose_output(X,num_heads):
    """逆转transpose_qkv函数的操作"""
    X = X.reshape(-1,num_heads,X.shape[1],X.shape[2])
    X = X.permute(0,2,1,3)
    return X.reshape(X.shape[0],X.shape[1],-1)
#在这里,我们设置head为5,也就是一共有5次self-attention。
num_hiddens,num_heads = 100,5
multiHeadAttention = MultiHeadAttention(num_hiddens,num_hiddens,num_hiddens,num_hiddens,5,0.5)
multiHeadAttention.eval()

batch_size,num_queries = 2,4
num_kvpairs,valid_lens = 6,torch.tensor([3,2])
#2,6,100 批次 句子长度 embedsize
Y = torch.ones(size=(batch_size,num_kvpairs,num_hiddens))
#2,4,100
X = torch.ones(size=(batch_size,num_queries,num_hiddens))
print(multiHeadAttention(X,Y,Y,valid_lens).shape)

首先我们设置head为5。num_hiddens可以理解为query或者key的大小,num_kvpairs表示每次注意力中key的数量,num_queries表示每次注意力中query的数量。value的数量与key的数量一样。另外一种理解就是,将X理解为批次*句子长度(单词的数量)*embedding size。每个单词对应一次查询。随后就是__init__,创建几个全连接层,对query、key、value进行变换,不同注意力的query、key和value,均不一样。
111111
主要是实现图中红色部分。然后会调用forward函数,transpose_qkv函数进行切分,假定原本的输入为2 * 6 * 100,因为大小为两个批次,每个批次需要做五个注意力机制,每个注意力机制的key的数量为6,所以将输入为2 * 6 * 100,转换为10 * 6 * 20。意思就是10次注意力,每个注意力中的key为6个,每个key由20维度的向量表示。query同理。因为我们要并行计算,这样使用torch.bmm可以直接进行计算,计算得出Query和key矩阵。在上面的例子中,计算得出的为10 * 4 * 6大小的矩阵。
在训练时刻的mask中,首先会将结果转变为二维矩阵40 * 6,其中的每一行代表了query与不同key计算的结果,有时候query只能和部分key进行计算,比如:第二个词的query只能计算第一个词与第二个词的key,而之后key需要进行mask。我们会给定一个valid_lens 代表需要保留的计算结果。其中mask部分会调用以下代码:

 mask = torch.arange((maxlen), dtype=torch.float32,device=X.device)[None, :] < valid_len[:, None]
 X[~mask] = value#value为极小值。

torch.arange((maxlen)会生成从0到5的矩阵,valid_len在之前会经过扩展为140大小的矩阵,然后转换为40 * 1的矩阵。最终的mask会变成40 * 6大小矩阵就像以下形式:
[True,True,True,False,False]
而最后两个False是需要进行mask的,X[~mask] = value将最后两个Fasle,变为负极小值,再经过softmax之后,结果将趋近于0,从而将其mask。然后与value相乘,得出结果为10 * 4
20矩阵大小的结果,在经过变换,变为2 * 4* 100矩阵,最后再经过最后一次全连接层,然后输出结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/805789.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

3ds MAX绘制简单动画

建立一个长方体和茶壶&#xff1a; 在界面右下角点击时间配置&#xff1a; 这是动画制作的必要步骤 选择【自动】&#xff0c;接下来&#xff0c;我们只要在对应的帧改变窗口中图形的位置&#xff0c;就能自动记录该时刻的模样 这就意味着&#xff0c;我们通过电脑记录某几个…

工业平板电脑优化汽车工厂的生产流程

汽车行业一直是自动化机器人系统的早期应用领域之一。通过使用具有高负载能力和远程作用的大型机械臂&#xff0c;汽车装配工厂可以实现点焊、安装挡风玻璃、安装车轮等工作&#xff0c;而较小的机械手则用于焊接和安装子组件。使用机器人系统不仅提高了生产效率&#xff0c;还…

工业智能化的关键之二:集成监控和分析能力

将监控和分析能力集成到工厂运营的日常中是工业智能化发展的关键步骤。随着科技的进步和数字化技术的广泛应用&#xff0c;工厂正在逐步实现从传统的人工操作到智能化的转变。这种转变不仅提高了工厂的生产效率和产品质量&#xff0c;还极大地提升了工厂的安全性和可靠性。 1.…

Flutter 调试工具篇 | 壹 - 使用 Flutter Inspector 分析界面

theme: cyanosis 1. 前言 很多朋友可能在布局过程中、或者组件使用过程中&#xff0c;会遇到诸如颜色、尺寸、约束、定位等问题&#xff0c;可能会让你抓耳挠腮。俗话说&#xff0c;磨刀不误砍柴工&#xff0c;会使用工具是非常重要的&#xff0c;其实 Flutter 提供了强大的调试…

axios使用异步方式无感刷新token,简单,太简单了

文章目录 &#x1f349; 废话在前&#x1f357; 接着踩坑&#x1f969; 解决思路&#x1f353; 完整代码 &#x1f349; 废话在前 写vue的或帮们无感刷新token相信大家都不陌生了吧&#xff0c;刚好&#xff0c;最近自己的一个项目中就需要用到这个需求&#xff0c;因为之前没…

Fluentbit

Fluent Bit&#xff08;常简称为Fluent-Bit 或 Fluentbit&#xff09;是一个开源的、轻量级的日志数据收集器&#xff08;log collector&#xff09;和 转发器&#xff08;log forwarder&#xff09;&#xff0c;旨在高效地收集、处理和转发日志数据。它是Fluentd项目的一个子项…

山东农业大学图书馆藏书《乡村振兴战略下传统村落文化旅游设计》

山东农业大学图书馆藏书《乡村振兴战略下传统村落文化旅游设计》

数字化时代,企业研发效能跃升之道丨IDCF

本文节选自新书《数字化时代研发效能跃升方法与实践》 作者&#xff1a;冬哥 研发效能是近年的热词&#xff0c;企业言必谈效能&#xff0c;但究竟什么是研发效能&#xff0c;落地具体应该如何进行&#xff0c;相信每个人都会有无数的问题浮现。 什么是效能&#xff1f; 效能…

Element-plus侧边栏踩坑

问题描述 el-menu直接嵌套el-menu-item菜单&#xff0c;折叠时不会出现文字显示和小箭头无法隐藏的问题&#xff0c;但是实际开发需求中难免需要把el-menu-item封装为组件 解决 vue3项目中嵌套两层template <template><template v-for"item in list" :k…

内网隧道代理技术(十三)之内网代理介绍

前言 什么?你问我内网隧道代理技术怎么突然就第十三篇了,第十二篇呢?这个,因为某些不可抗拒力量,第十二篇博客无法发表,如果想要查阅,请加内网渗透qq群:838076210 内网代理介绍 内网代理介绍 内网资产扫描这种场景一般是进行内网渗透才需要的代理技术,如果你不打内…

公共字段的填充

方式1&#xff0c;通过mybatis-plus提供的MetaObjectHandler进行填充 import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler;import com.sky.context.BaseContext; import lombok.extern.slf4j.Slf4j; import org.apache.ibatis.reflection.MetaObject; import o…

【ROS第一讲】一、创建工作空间

【ROS第一讲】一、创建工作空间 一、工作空间1.src&#xff1a;2.build&#xff1a;3.devel&#xff1a;4.install: 二、创建工作空间1.工作空间的编译2.配置环境变量&#xff1a; 三、创建功能包 一、工作空间 1.src&#xff1a; 放置所有功能包源码的空间 2.build&#xf…

Unity XML3——XML序列化

一、XML 序列化 ​ 序列化&#xff1a;把对象转化为可传输的字节序列过程称为序列化&#xff0c;就是把想要存储的内容转换为字节序列用于存储或传递 ​ 反序列化&#xff1a;把字节序列还原为对象的过程称为反序列化&#xff0c;就是把存储或收到的字节序列信息解析读取出来…

再见 MyBatis-Plus !

一、Mybatis-Flex是什么&#xff1f; Mybatis-Flex 是一个优雅的 Mybatis 增强框架&#xff0c;它非常轻量、同时拥有极高的性能与灵活性。我们可以轻松的使用 Mybaits-Flex 链接任何数据库&#xff0c;其内置的 QueryWrapper^亮点 帮助我们极大的减少了 SQL 编写的工作的同时…

odoo16-domain

odoo16-domain 参考:https://blog.csdn.net/u013250491/article/details/86699928 domain的使用注意以下几点: 是在py文件中使用还是在xml中使用,py文件是在后端使用可以利用orm, 而xml是在前端渲染,使用的是js,没有办法使用orm如果在xml中使用,domain的格式建议为[[]], 二维…

LeetCode32.Longest-Valid-Parentheses<最长有效括号>

题目&#xff1a; 思路&#xff1a; 遍历括号.遇到右括号然后前一个是左括号 那就res2,然后重定位 i 的值 并且长度减少2; 但是问题在于无法判断最长的括号.只能得到string内的全部括号长度. 错误代码: 写过一题类似的,那题是找括号数.记得是使用的栈,但是死活写不出来. 看完…

【Visual Studio Code】加载saved_model.pb时报错缺失‘cudart64_110.dll‘等

如果报错Could not load dynamic library cudart64_110.dll; dlerror: cudart64_110.dll not found&#xff0c; 将对应的cudart64_110.dll复制到C:\Windows\System32下即可 如果VScode仍继续报错&#xff0c;重新启动软件即解决问题。 同理&#xff0c;若仍有相同报错 Cou…

ios私钥证书的创建方法

ios私钥证书是苹果公司为ios开发者打包app&#xff0c;推出的一种数字证书&#xff0c;只有同一个苹果开发者账号生成的ios私钥证书打的包&#xff0c;才能上架同一个开发者账号的app store。因此不要指望别人给你共享私钥证书和描述文件&#xff0c;因为别人的证书和描述文件打…

Ubuntu Server版 之 apache系列 安装、重启、开启,版本查看

安装之前首先要检测是否安装过 apt list --installed | grep tool tool&#xff1a;要检测的名称&#xff0c;如mysql、apache 、ngnix 等 安装 apache sudo apt install apache2 安装apache 默认是开启的 可以通过浏览器 检测一下 service apache stop # apache 停止服务…

道本科技||全面建立国有企业合规管理体系

为全面深化国有企业法治建设&#xff0c;不断加强合规管理&#xff0c;防控合规风险&#xff0c;保障企业稳健发展&#xff0c;近日&#xff0c;市国资委印发《常州市市属国有企业合规管理办法&#xff08;试行&#xff09;》&#xff08;以下简称《办法》&#xff09;&#xf…