d2l Transformer

news2025/1/10 14:14:17

终于到变形金刚了,他的主要特征在于多头自注意力的使用,以及摒弃了rnn的操作。

目录

1.原理

2.多头注意力

3.逐位前馈网络FFN

4.层归一化

5.残差连接

6.Encoder

7.Decoder

8.训练

9.预测


1.原理

主要贡献:1.纯使用attention的Encoder-Decoder;2.Encoder与Decoder都有n个transformer块;3.每个块使用多头自注意力、层归一化、逐位前馈网络

原理图如下图所示:

 Decoder掩蔽多头注意力:Decoder输出预测时,不应该考虑该元素之后的元素(模拟真实预测),计算Xi输出是,假设当前序列长度为i。对于(Xi+1,Xi+1)...的k-v忽略。

2.多头注意力

通过FC将qvk映射到不同的dimension,使用n个独立的注意力池化层,再合并各个头的输出,在经过FC拿到想要的最终维数

 

#@save
class MultiHeadAttention(nn.Module):
    """多头注意⼒"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
        num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        
    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形状:
        # (batch_size,查询或者“键-值”对的个数,num_hiddens)
        # valid_lens 的形状:
        # (batch_size,)或(batch_size,查询的个数)
        # 经过变换后,输出的queries,keys,values 的形状:
        # (batch_size*num_heads,查询或者“键-值”对的个数,
        # num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        
        if valid_lens is not None:
            # 在轴0,将第⼀项(标量或者⽮量)复制num_heads次,
            # 然后如此复制第⼆项,然后诸如此类。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)
            
        # output的形状:(batch_size*num_heads,查询的个数,
        # num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)
        
        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

  避免有几个多头就写n个for-loop:把原本(bs,q,h)拆成(bs*n,q,h/n),本质是原本单个h的self.attention复制成了n个h/n的self.attention
  W_q(queries)--(bs,q,h)--qkv--(bs*n,q,h/n)
  映射输入q,k,v的都一样,都是通过grad的反向梯度下降玄学分工。类似于GoogLenet里面的分很多块最后再concat。
   3维可直接送到attention中P290,注意self.attention的输出形状与query一致。
  下面的transpose_output是你想qkv操作,最终返回(bs,q,h),本质是将n个h/n进行concat,实现多个多注意头拼接操作。

#@save
def transpose_qkv(X, num_heads):
    """为了多注意⼒头的并⾏计算⽽变换形状"""
    # 输⼊X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
    # num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    
    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    
    # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])


#@save
def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

3.逐位前馈网络FFN

  torch中的Linear只会对最后一个维度当作是特征维进行计算,所以输入的是三维,输出的改变也只有在最后一维改变,所以叫ffn,其本质其实就是mlp。

#@save
class PositionWiseFFN(nn.Module):
    """基于位置的前馈⽹络"""
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
        **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
        
    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

4.层归一化

  因为T的有效长度不一,所以用层归一化更稳定。见下图:

 b是bs,d是序列维数,可以理解为有效步长T的valid_len

ln = nn.LayerNorm(2)
bn = nn.BatchNorm1d(2)
X = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)
# 在训练模式下计算X的均值和⽅差
print('layer norm:', ln(X), '\nbatch norm:', bn(X))

'''
layer norm: tensor([[-1.0000,  1.0000],
        [-1.0000,  1.0000]], grad_fn=<NativeLayerNormBackward0>) 
batch norm: tensor([[-1.0000, -1.0000],
        [ 1.0000,  1.0000]], grad_fn=<NativeBatchNormBackward0>)
'''

  如上,可见层归一化是对每个样本(行)变成u=0;std=1。对同一个样本example,同一个feature,不同的diemsion做归一化。

5.残差连接

  Y为transformer的输出,X为原始输入。

#@save
class AddNorm(nn.Module):
    """残差连接后进⾏层规范化"""
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)
    
    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

  注意,残差相加连接,必须维度与里面的维数全都一样才行,且相加也不会导致形状变化!!

6.Encoder

  编码块

#@save
class EncoderBlock(nn.Module):
    """transformer编码器块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = d2l.MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout,
            use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(
            ffn_num_input, ffn_num_hiddens, num_hiddens) # ffn_num_output设置的是num_hiddens
        self.addnorm2 = AddNorm(norm_shape, dropout)
        
    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

  总结一下,transformer的encoder的输出与输入的形状一致,不会改变输入的形状,容易使用n块叠加
  原因:1.self.attention里面的qkv都是X,输出与query一致,所以不会改变形状,还是X;2.addnorm是加法操作,tensor只有维度与维数一摸一样才能相加,且不会改变形状;3.ffn里面的ffn_num_output最后设置的是num_hiddens,与输入的X相一致,所以ffn也不会改变形状

#@save
class TransformerEncoder(d2l.Encoder):
    """transformer编码器"""
    def __init__(self, vocab_size, key_size, query_size, value_size,
                num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
                num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                EncoderBlock(key_size, query_size, value_size, num_hiddens,
                            norm_shape, ffn_num_input, ffn_num_hiddens,
                            num_heads, dropout, use_bias))
        
    def forward(self, X, valid_lens, *args):
        # 因为位置编码值在-1和1之间,
        # 因此嵌⼊值乘以嵌⼊维度的平⽅根进⾏缩放,
        # 然后再与位置编码相加。
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self.attention_weights = [None] * len(self.blks)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[
                i] = blk.attention.attention.attention_weights
        return X

7.Decoder

  解码块:有两个attention,一个是掩蔽多头自注意力,,一个是编码器-解码器注意力,以及逐位前馈网络ffn。

class DecoderBlock(nn.Module):
    """解码器中第i个块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                dropout, i, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        self.i = i
        self.attention1 = d2l.MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = d2l.MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
                                    num_hiddens)
        self.addnorm3 = AddNorm(norm_shape, dropout)
        
    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        # 训练阶段,输出序列的所有词元都在同⼀时间处理,
        # 因此state[2][self.i]初始化为None。
        # 预测阶段,输出序列是通过词元⼀个接着⼀个解码的,
        # 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表⽰
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), axis=1)
            state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            # dec_valid_lens的开头:(batch_size,num_steps),
            # 其中每⼀⾏是[1,2,...,num_steps]
            dec_valid_lens = torch.arange(
                1, num_steps + 1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None
            
        # ⾃注意⼒
        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)
        # 编码器-解码器注意⼒。
        # enc_outputs的开头:(batch_size,num_steps,num_hiddens)
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

  注意,Y2中的k-v来自encoder的输出。

​
class TransformerDecoder(d2l.AttentionDecoder):
    def __init__(self, vocab_size, key_size, query_size, value_size,
                num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
                num_heads, num_layers, dropout, **kwargs):
        super(TransformerDecoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                DecoderBlock(key_size, query_size, value_size, num_hiddens,
                             norm_shape, ffn_num_input, ffn_num_hiddens,
                             num_heads, dropout, i))
        self.dense = nn.Linear(num_hiddens, vocab_size)
            
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
    
    def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)
            # 解码器⾃注意⼒权重
            self._attention_weights[0][
                i] = blk.attention1.attention.attention_weights
            # “编码器-解码器”⾃注意⼒权重
            self._attention_weights[1][
                i] = blk.attention2.attention.attention_weights
        return self.dense(X), state
    
        @property
        def attention_weights(self):
            return self._attention_weights

​

补充:经过dec_x经过解码器形状也不会改变

8.训练

重要的两个参数:h、p(num_heads)

num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4
key_size, query_size, value_size = 32, 32, 32
norm_shape = [32]

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)

encoder = TransformerEncoder(
    len(src_vocab), key_size, query_size, value_size, num_hiddens,
    norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    num_layers, dropout)
decoder = TransformerDecoder(
    len(tgt_vocab), key_size, query_size, value_size, num_hiddens,
    norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    num_layers, dropout)

net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

   训练时,k-v就是dec_x的输入x---自注意力

9.预测

原代码有错,要对predict_seq2seq改一下,把net.decoder直接改成Decoder的class即可。

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation, dec_attention_weight_seq = predict_seq2seq(
        net, eng, src_vocab, tgt_vocab, num_steps, device, True)
    print(f'{eng} => {translation}, ',
        f'bleu {d2l.bleu(translation, fra, k=2):.3f}')

'''
go . => va !,  bleu 1.000
i lost . => j'ai perdu .,  bleu 1.000
he's calm . => il est <unk> .,  bleu 0.658
i'm home . => je suis .,  bleu 0.432
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/453399.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaFX与Liberica JDK,搭建,运行,打包,放弃Eclipse

1、官网 JavaFX中文官方网站、Oracle官方文档 2、教程 JavaFX中文基础教程视频合集 JavaFX实战教程 3、VSCode/Eclipse VSCode(写HelloWorld用)、VSCode的Java扩展 Eclipse&#xff0c;跳至第9段 4、Liberica JDK安装 Liberica JDK官网下载 依次选择&#xff0c;All ve…

压力测试防踩坑指南,压测中要注意的那些事儿

对于一些高频访问接口&#xff0c;压力测试必不可少&#xff0c;本文主要叙述了自己在压测过程中遇到的问题&#xff0c;在此分享&#xff0c;希望能帮助大家避免踩坑&#xff0c;提高效率。 1.pod数量 现象&#xff1a;服务器资源充足&#xff0c;tps上不去&#xff0c;检查发…

OneData 共享同一套数据技术和资产

一、什么是 OneData 体系? 官方&#xff1a;阿里云OneData数据中台解决方案基于大数据存储和计算平台为载体&#xff0c;以OneModel统一数据构建及管理方法论为主干&#xff0c;OneID核心商业要素资产化为核心&#xff0c;实现全域链接、标签萃取、立体画像&#xff0c;以数据…

ASEMI代理ADI亚德诺ADAU1701JSTZ-RL车规级芯片

编辑-Z ADAU1701JSTZ-RL芯片参数&#xff1a; 型号&#xff1a;ADAU1701JSTZ-RL 模拟电源电压&#xff1a;3.3 V 数字电源电压&#xff1a;1.8 V 输入/输出电压&#xff1a;3.3 V 环境温度&#xff1a;25 C 主时钟输入&#xff1a;12.288 MHz 满刻度模拟输入&#xff1…

彻底掌握FreeRTOS中的务通知(Task Notifications)

​在之前的文章中已经讲解了很多种用于任务件通信的机制&#xff0c;包括队列、事件组和各种不同类型的信号量。使用这些机制都需要创建一个通信对象。 事件和数据不会直接发送到接收任务或接收ISR&#xff0c;而是发送到通信对象&#xff08;也就是发送到队列、事件组、信号量…

2023软考中级《软件设计师》(备考冲刺版) | 操作系统

目录 1.操作体统相关概念 1.1 操作系统的功能 1.2 特殊的操作系统 2.进程管理 2.1进程的概念 2.1.1 线程的概念 2.1.2 进程的状态 2.2 进程调度 2.2.1 PV操作的概念 2.2.2 信号量和PV操作 2.2.3 前趋图与PV操作 3.存储管理 3.1 页式存储 3.2 段式存储 3.3 段页式…

智慧安防小区管控系统解决方案(ppt可编辑)

本资料来源公开网络&#xff0c;仅供个人学习&#xff0c;请勿商用&#xff0c;如有侵权请联系删除 智慧安防小区-建设思路及目标 智慧安防小区管控子系统是&#xff0c;按照“数据向上集中、服务向下延伸”的思路&#xff0c;对相关要素进行重点采集&#xff0c;实现社区态势…

【JAVAEE】网络原理之网络通信基础

目录 1. &#x1f48b;IP地址 1.1 &#x1f35f;IP地址的格式 1.2 &#x1f381;特殊IP地址 2. ✨端口号 2.1 &#x1f383;端口号的格式 3. &#x1f618;网络协议 3.1 &#x1f3a8;为什么需要网络协议&#xff1f; 3.2 &#x1f49b;网络协议的概念与组成 3.3 &am…

答题积分小程序云开发实战-界面交互篇:首页页面布局样式与逻辑交互开发

微信小程序云开发实战-答题积分赛小程序 界面交互篇:首页页面布局样式与逻辑交互开发 首页效果图 布局思路 5行布局,即5个块级元素,轮播图、通告栏、个人信息、功能区、版权。

将服务器select模型设置为非阻塞,处理更多业务

timeval结构体在头文件为sys/time.h中&#xff0c;定义如下&#xff1a; struct timeval {long tv_sec; /* seconds */long tv_usec; /* and microseconds */ }; 其中tv_sec是秒&#xff0c;tv_usec是微秒&#xff08;microsecond &#xff09;&#xff0…

[单片机框架][bsp层][cx32l003][bsp_tim] Baes TIM 基础定时器配置和使用

文章目录 一、基础定时器介绍二、功能描述(1) Buzzer 功能 三、示例代码(PWM) 一、基础定时器介绍 基础定时器 Base Timer 包含两个定时器 TIM10/11。TIM10/11 功能完全相同。TIM10/11 是同步定时/计数器&#xff0c;可以作为 16/32 位自动重装载功能的定时/计数器&#xff0c…

VS2022配置GDAL

GDAL&#xff08;Geospatial Data Abstraction Library&#xff09;是一个用于处理地理空间数据的开源库。它提供了一组功能丰富的API&#xff0c;用于读取、写入、转换和处理各种地理空间数据格式&#xff0c;包括栅格数据&#xff08;如卫星图像、数字高程模型&#xff09;和…

Jupyter创建Anaconda多个虚拟环境教程

这里写目录标题 1.1界面化创建虚拟环境1.2命令行创建虚拟环境2.查看是否创建成功3.激活虚拟环境pylessonppt4.更改工作目录5.删除6.查看是否删除成功 1.1界面化创建虚拟环境 1.2命令行创建虚拟环境 conda create -n myenv——name pythonx.xmyenv-name:自己定义的环境名称 pyt…

fastjson反序列化漏洞复现

fastjson反序列化漏洞复现 一.影响版本: Fastjson<1.2.24二.实验过程图三.实验步骤四&#xff0c;实验结果以及参考链接 一.影响版本: Fastjson<1.2.24 二.实验过程图 (踩坑) rmijndi环境&#xff1a;java.sql.SQLException: JdbcRowSet (连接) JNDI 无法连接 2、ldapjn…

上海无纺布制造商【盈兹】申请纳斯达克IPO上市,募资1100万美元

来源&#xff1a;猛兽财经 作者&#xff1a;猛兽财经 猛兽财经获悉&#xff0c;来自上海的无纺布制造商【盈兹】&#xff0c;近期已向美国证券交易委员会&#xff08;SEC&#xff09;提交招股书&#xff0c;申请在纳斯达克IPO上市&#xff0c;股票代码为&#xff08;ETZ&#…

Invalid bound statement (not found)的原因以及解决方法

相信我们在学习Mybatis的时候都出现过 Invalid bound statement (not found) 这个错误&#xff0c;一般由以下几种可能导致这个错误 一&#xff1a;mapper方法名 和 mapper.xml id名不对应 例如&#xff1a; mapper&#xff1a; 对应的mapper.xml 这里建议小伙伴们下载一个插…

Linux中的YUM源仓库和NFS文件共享服务

这里写目录标题 一 、YUM仓库源的介绍和相关信息1.1yum相关介绍1.2 Linux系统各家厂商用的安装源1.3 yum下载方式 二 、 yum 仓库源的三种搭建2.1yum 配置本地源2.2创建ftp源2.3 配置http源2.4 配置yum在线源 三 、NFS的简介3.1 什么是NFS3.2 linux中要使用NFS需要下载的软件包…

User Diverse Preference Modeling by Multimodal Attentive Metric Learning

BACKGROUND 现有模型通常采用一个固定向量去表示用户偏好&#xff0c;在假设——特征向量每一个维度都代表了用户的一种特性或者一个方面&#xff0c;这种方式似乎不妥&#xff0c;因为用户对于不同物品的偏好是不一样的&#xff0c;例如因演员喜欢一部电影&#xff0c;而因特…

C++中的vector容器

文章目录 vector的介绍vector的使用vector的定义vector初始化vector iterator的使用vector空间增长问题vector增删改查vector迭代器失效问题 vector的介绍 vector是封装动态数组的顺序容器。   就像数组一样&#xff0c;vector也采用的连续存储空间来存储元素。这也就意味着我…

Java核心技术 卷1-总结-15

自己实现的hashCode方法应该与equals方法兼容 Java核心技术 卷1-总结-15 视图与包装器子范围不可修改的视图同步视图受查视图 并发线程状态新创建线程可运行线程被阻塞线程和等待线程被终止的线程 视图与包装器 子范围 可以为很多集合建立子范围&#xff08;subrange&#x…