Attention Is All You Need原理与代码详细解读

news2025/1/25 9:21:54

文章目录

  • 前言
  • 一、Transformer结构的原理
    • 1、Transform结构
    • 2、位置编码公式
    • 3、transformer公式
    • 4、FFN结构
  • 二、Encode模块代码解读
    • 1、编码数据
    • 2、文本Embedding编码
    • 3、位置position编码
    • 4、Attention编码
    • 5、FFN编码
  • 三、Decode模块代码解读
    • 1、编码数据
    • 2、文本Embedding与位置编码
    • 3、mask编码
    • 4、Attention编码
      • self attention
      • cross attention
    • 5、FFN编码
  • 四、源码附件(源码有注释)
  • 总结


前言

目前,我研究大模型相关知识,常用到transformer结构,我想到NLP领域开篇之作Attention is all you need论文,论文实际提出transform结构,可与CNN并驾齐驱的结构,该结构利用Q/K/V模式整合全局信息,与CNN提取局部信息有所差别。介于此,我将一年前博客园更新笔记迁入该博客中,本文将介绍transform原理,也根据源码解读,深入介绍transforme经典典结构,并附有代码。


论文链接:点击这里

一、Transformer结构的原理

该部分主要介绍Attention is all you need 结构、模块、公式。暂时不介绍什么Q K V 什么Attention 什么编解码等,后面我将会根据代码解读介绍,让读者更容易理解。

1、Transform结构

Transformer由且仅由Attention和Feed Forward Neural Network(也称FFN)组成,其中Attention包含self Attention与Mutil-Head Attention,如下图:
示例:pandas 是基于NumPy 的一种工具,该工具是为了解决数据分析任务而创建的。
注:模型一般可有encode与decode组成,encode负责特征编码,decode负责解码。目前,也有论文不使用解码器decode,如swin-transform。

2、位置编码公式

位置编码公式(还有很多其它公式,该论文使用此公式),如下:

在这里插入图片描述

3、transformer公式

在这里插入图片描述

4、FFN结构

FFN是由nn.Linear线性和激活函数构成,后面代码详细说明。

二、Encode模块代码解读

1、编码数据

编码输入数据介绍:
enc_input = [
[1, 3, 4, 1, 2, 3],
[1, 3, 4, 1, 2, 3],
[1, 3, 4, 1, 2, 3],
[1, 3, 4, 1, 2, 3]]
编码使用输入数据,为4x6行,表示4个句子,每个句子有6个单词,包含标点符号。
注:至于文本如何表示数字,可参考这里

2、文本Embedding编码

文本嵌入embedding:

self.src_emb = nn.Embedding(vocab_size, d_model) # d_model=128

vocab_size:词典的大小尺寸,比如总共出现5000个词,那就输入5000。此时index为(0-4999)

d_model:嵌入向量的维度,即用多少维来表示一个词或符号

nn.Embedding()函数可使用torch调用,建议读者百度了解其功能。

随后可将输入x=enc_input,可将enc_outputs则表示嵌入成功,维度为[4,6,128]分别表示batch为4,词为6,用128维度描述词6

x = self.src_emb(x)  # 词嵌入

3、位置position编码

位置编码,使用上面公式嵌入,我将不再介绍,其代码如下:

 pe = torch.zeros(max_len, d_model)
         position = torch.arange(0., max_len).unsqueeze(1)
         div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))  # 偶数列
         pe[:, 0::2] = torch.sin(position * div_term) # 奇数列
         pe[:, 1::2] = torch.cos(position * div_term)
         pe = pe.unsqueeze(0)

将编码进行位置编码后,位置为[1,6,128]+输入编码的[4,6,128],相当于句子已经结合了位置编码信息,作为新新的输入,代码如下:

x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)  #torch.autograd.Variable 表示有梯度的张量变量

4、Attention编码

在介绍此之前,先普及一个知识,若X与Y相等,则为self attention 否则为cross-attention,因为解码时候X!=Y.
在这里插入图片描述

获取Q K V 代码,实际是一个线性变化,将以上输入x变成[4,6,512],然后通过head个数8与对应dv,dk将512拆分[8,64],随后移维度位置,变成[4,8,6,64]

 self.WQ = nn.Linear(d_model, d_k * n_heads)  # 利用线性卷积
 self.WK = nn.Linear(d_model, d_k * n_heads)
 self.WV = nn.Linear(d_model, d_v * n_heads)

变化后的q k v

 q_s = self.WQ(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)  # 线性卷积后再分组实现head功能
 k_s = self.WK(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
 v_s = self.WV(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
 attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)  # 编导对应的头

随后通过以上self公式,将其编码计算

scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
attn = nn.Softmax(dim=-1)(scores)
context = torch.matmul(attn, V)

以上编码将是encode编码得到结果,我们将得到结果进行还原:

context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)  # 将其还原
output = self.linear(context)  # 通过线性又将其变成原来模样维度
layer_norm(output + Q)  # 这里加Q 实际是对Q寻找

以上将重新得到新的输入x,维度为[4,6,128]

5、FFN编码

将以上的输出维度为[4,6,128]进行FFN层变化,实际类似线性残差网络变化,得到最终输出

  class PoswiseFeedForwardNet(nn.Module):
  
      def __init__(self, d_model, d_ff):
          super(PoswiseFeedForwardNet, self).__init__()
          self.l1 = nn.Linear(d_model, d_ff)
          self.l2 = nn.Linear(d_ff, d_model)
  
          self.relu = GELU()
          self.layer_norm = nn.LayerNorm(d_model)
 
     def forward(self, inputs):
         residual = inputs
         output = self.l1(inputs)  # 一层线性卷积
         output = self.relu(output)
         output = self.l2(output)  # 一层线性卷积
         return self.layer_norm(output + residual)

重复以上顺序编码,即将得到经过FFN变化的输出x,维度为[4,6,128],将其重复步骤③-④,因其编码为6个,可重复5个便是完成相应的编码模块。

三、Decode模块代码解读

1、编码数据

解码输入数据介绍,包含以下数据输入dec_input、enc_input的输入与解码后输出的数据,维度为[4,6,128],而dec_input输入如下:

dec_input = [
[1, 0, 0, 0, 0, 0],
[1, 3, 0, 0, 0, 0],
[1, 3, 4, 0, 0, 0],
[1, 3, 4, 1, 0, 0]]

2、文本Embedding与位置编码

dec_input的Embedding与位置编码,因其与encode的实现方法一致,只需将enc_input使用dec_input取代,得到dec_outputs,因此这里将不在介绍。

3、mask编码

整体编码,代码如下:

  def get_attn_pad_mask(seq_q, seq_k, pad_index):
     batch_size, len_q = seq_q.size()
     batch_size, len_k = seq_k.size()
     pad_attn_mask = seq_k.data.eq(pad_index).unsqueeze(1)
     pad_attn_mask = torch.as_tensor(pad_attn_mask, dtype=torch.int)
     return pad_attn_mask.expand(batch_size, len_q, len_k)

以上代码实际是将dec_input进行处理,实际变成以下数据:

[[0, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 1, 1]]

将其增添维度为[4,1,6],并将其扩张为[4,6,6]

局部代码编写,实际为上三角矩阵:

[[0. 1. 1. 1. 1. 1.]
[0. 0. 1. 1. 1. 1.]
[0. 0. 0. 1. 1. 1.]
[0. 0. 0. 0. 1. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 0.]]
将以上数据添加维度为[1,6,6],在将扩展变成[4,6,6]
关于整体mask与局部mask编码,我的理解是整体信息为语句4个词6个,根据解码输入编码整体信息,而局部编码是基于一个语句6*6编码信息,将其扩张重复到4个语句,
使其mask获得整体信息与局部信息。

         dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.pad_index)  # 整体编码的mask
         dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs)
         dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)  # torch.gt(a,b) a>b 则为1否则为0
         dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs, self.pad_index)

最终将mask整合,获取dec_self_attn_mask信息,同理dec_enc_attn_mask(维度为解码编码词维度)采用dec_self_attn_mask的第一步便可获取。

4、Attention编码

编码输入self-Attention,包含2部分,self Attention与cross Attention。

self attention

解码输入dec_outputs进行self.Attention:
实际使用以上Q K V公式,具体实现和编码实现方法一致,唯一不同是在Q*K^T会使用解码maskdec_self_attn_mask,其重要代码为scores.masked_fill_(attn_mask, -1e9),代码如下:

  class ScaledDotProductAttention(nn.Module):
  
      def __init__(self, d_k, device):
          super(ScaledDotProductAttention, self).__init__()
          self.device = device
          self.d_k = d_k
  
      def forward(self, Q, K, V, attn_mask):
          scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
          attn_mask = torch.as_tensor(attn_mask, dtype=torch.bool)
          attn_mask = attn_mask.to(self.device)
          scores.masked_fill_(attn_mask, -1e9)  # it is true give -1e9
          attn = nn.Softmax(dim=-1)(scores)
          context = torch.matmul(attn, V)
          return context, attn

以上代码将执行以下代码:

context, attn = ScaledDotProductAttention(d_k=self.d_k, device=self.device)(Q=q_s, K=k_s, V=v_s,
                                                                            attn_mask=attn_mask)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)  # 将其还原
output = self.linear(context)  # 通过线性又将其变成原来模样维度
dec_outputs = self.layer_norm(output + Q)  # 这里加Q 实际是对Q寻找

到此为止已经完成了解码输入的self-attention模块,输出为dec_outputs实际除了增加mask编码调整Q*K^T以外,其它完全相同。

cross attention

编码输出dec_outputs进行Cross Attention:

dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask) 

重点说明enc_outputs来源编码结果,是一直不变的,以上为Cross Attention 过程,以上代码除了Q来源dec_outputs,K V 来源编码输出enc_outputs以外,即论文所说X与Y不等得到的Q K V称为Cross Attention。
实际以上代码与执行解码self-Attention方法完全一致,仅仅mask更改上文提供的方法,得到输出结果为dec_outputs,因此这里将不在解释了。

5、FFN编码

该部分编码与encode的FFN一样,我将不在解释。

重复步骤上面4与5为n次,便实现解码过程。

四、源码附件(源码有注释)

最后,我给出attention is all you need的所有代码,只需简单环境便可使用,整体实现代码如下:

import json
import math
import torch
import torchvision
import torch.nn as nn
import numpy as np
from pdb import set_trace

from torch.autograd import Variable


def get_attn_pad_mask(seq_q, seq_k, pad_index):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(pad_index).unsqueeze(1)
    pad_attn_mask = torch.as_tensor(pad_attn_mask, dtype=torch.int)
    return pad_attn_mask.expand(batch_size, len_q, len_k)


def get_attn_subsequent_mask(seq):
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequent_mask = np.triu(np.ones(attn_shape), k=1)
    subsequent_mask = torch.from_numpy(subsequent_mask).int()
    return subsequent_mask


class GELU(nn.Module):

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, dropout, max_len=5000):  #
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0., max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))  # 偶数列
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)  # 将变量pe保存到内存中,不计算梯度

    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)  # torch.autograd.Variable 表示有梯度的张量变量
        return self.dropout(x)


class ScaledDotProductAttention(nn.Module):

    def __init__(self, d_k, device):
        super(ScaledDotProductAttention, self).__init__()
        self.device = device
        self.d_k = d_k

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
        attn_mask = torch.as_tensor(attn_mask, dtype=torch.bool)
        attn_mask = attn_mask.to(self.device)
        scores.masked_fill_(attn_mask, -1e9)  # it is true give -1e9
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn


class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, d_k, d_v, n_heads, device):
        super(MultiHeadAttention, self).__init__()
        self.WQ = nn.Linear(d_model, d_k * n_heads)  # 利用线性卷积
        self.WK = nn.Linear(d_model, d_k * n_heads)
        self.WV = nn.Linear(d_model, d_v * n_heads)

        self.linear = nn.Linear(n_heads * d_v, d_model)

        self.layer_norm = nn.LayerNorm(d_model)
        self.device = device

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.n_heads = n_heads

    def forward(self, Q, K, V, attn_mask):
        batch_size = Q.shape[0]
        q_s = self.WQ(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)  # 线性卷积后再分组实现head功能
        k_s = self.WK(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v_s = self.WV(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)

        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)  # 编导对应的头
        context, attn = ScaledDotProductAttention(d_k=self.d_k, device=self.device)(Q=q_s, K=k_s, V=v_s,
                                                                                    attn_mask=attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)  # 将其还原
        output = self.linear(context)  # 通过线性又将其变成原来模样维度
        return self.layer_norm(output + Q), attn  # 这里加Q 实际是对Q寻找


class PoswiseFeedForwardNet(nn.Module):

    def __init__(self, d_model, d_ff):
        super(PoswiseFeedForwardNet, self).__init__()
        self.l1 = nn.Linear(d_model, d_ff)
        self.l2 = nn.Linear(d_ff, d_model)

        self.relu = GELU()
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, inputs):
        residual = inputs
        output = self.l1(inputs)  # 一层线性卷积
        output = self.relu(output)
        output = self.l2(output)  # 一层线性卷积
        return self.layer_norm(output + residual)


class EncoderLayer(nn.Module):

    def __init__(self, d_model, d_ff, d_k, d_v, n_heads, device):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
        self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, d_ff=d_ff)

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(Q=enc_inputs, K=enc_inputs, V=enc_inputs, attn_mask=enc_self_attn_mask)
        # X=Y 因此Q K V相等
        enc_outputs = self.pos_ffn(enc_outputs)  #
        return enc_outputs, attn


class Encoder(nn.Module):

    def __init__(self, vocab_size, d_model, d_ff, d_k, d_v, n_heads, n_layers, pad_index, device):
        #                   4        128     256   64   64     8        4          0
        super(Encoder, self).__init__()
        self.device = device
        self.pad_index = pad_index
        self.src_emb = nn.Embedding(vocab_size, d_model)
        # vocab_size:词典的大小尺寸,比如总共出现5000个词,那就输入5000。此时index为(0-4999) d_model:嵌入向量的维度,即用多少维来表示一个符号
        self.pos_emb = PositionalEncoding(d_model=d_model, dropout=0)

        self.layers = []
        for _ in range(n_layers):
            encoder_layer = EncoderLayer(d_model=d_model, d_ff=d_ff, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
            self.layers.append(encoder_layer)
        self.layers = nn.ModuleList(self.layers)

    def forward(self, x):
        enc_outputs = self.src_emb(x)  # 词嵌入
        enc_outputs = self.pos_emb(enc_outputs)  # pos+matx
        enc_self_attn_mask = get_attn_pad_mask(x, x, self.pad_index)

        enc_self_attns = []
        for layer in self.layers:
            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
            enc_self_attns.append(enc_self_attn)

        enc_self_attns = torch.stack(enc_self_attns)
        enc_self_attns = enc_self_attns.permute([1, 0, 2, 3, 4])
        return enc_outputs, enc_self_attns


class DecoderLayer(nn.Module):

    def __init__(self, d_model, d_ff, d_k, d_v, n_heads, device):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
        self.dec_enc_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
        self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, d_ff=d_ff)

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
        dec_outputs = self.pos_ffn(dec_outputs)
        return dec_outputs, dec_self_attn, dec_enc_attn


class Decoder(nn.Module):

    def __init__(self, vocab_size, d_model, d_ff, d_k, d_v, n_heads, n_layers, pad_index, device):
        super(Decoder, self).__init__()
        self.pad_index = pad_index
        self.device = device
        self.tgt_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model=d_model, dropout=0)
        self.layers = []
        for _ in range(n_layers):
            decoder_layer = DecoderLayer(d_model=d_model, d_ff=d_ff, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
            self.layers.append(decoder_layer)
        self.layers = nn.ModuleList(self.layers)

    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        dec_outputs = self.tgt_emb(dec_inputs)
        dec_outputs = self.pos_emb(dec_outputs)

        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.pad_index)
        dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs)
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs, self.pad_index)

        dec_self_attns, dec_enc_attns = [], []
        for layer in self.layers:
            dec_outputs, dec_self_attn, dec_enc_attn = layer(
                dec_inputs=dec_outputs,
                enc_outputs=enc_outputs,
                dec_self_attn_mask=dec_self_attn_mask,
                dec_enc_attn_mask=dec_enc_attn_mask)
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        dec_self_attns = torch.stack(dec_self_attns)
        dec_enc_attns = torch.stack(dec_enc_attns)

        dec_self_attns = dec_self_attns.permute([1, 0, 2, 3, 4])
        dec_enc_attns = dec_enc_attns.permute([1, 0, 2, 3, 4])

        return dec_outputs, dec_self_attns, dec_enc_attns


class MaskedDecoderLayer(nn.Module):

    def __init__(self, d_model, d_ff, d_k, d_v, n_heads, device):
        super(MaskedDecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
        self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, d_ff=d_ff)

    def forward(self, dec_inputs, dec_self_attn_mask):
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
        dec_outputs = self.pos_ffn(dec_outputs)
        return dec_outputs, dec_self_attn


class MaskedDecoder(nn.Module):

    def __init__(self, vocab_size, d_model, d_ff, d_k,
                 d_v, n_heads, n_layers, pad_index, device):
        super(MaskedDecoder, self).__init__()
        self.pad_index = pad_index
        self.tgt_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model=d_model, dropout=0)

        self.layers = []
        for _ in range(n_layers):
            decoder_layer = MaskedDecoderLayer(
                d_model=d_model, d_ff=d_ff,
                d_k=d_k, d_v=d_v, n_heads=n_heads,
                device=device)
            self.layers.append(decoder_layer)
        self.layers = nn.ModuleList(self.layers)

    def forward(self, dec_inputs):
        dec_outputs = self.tgt_emb(dec_inputs)
        dec_outputs = self.pos_emb(dec_outputs)

        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.pad_index)
        dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs)
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
        dec_self_attns = []
        for layer in self.layers:
            dec_outputs, dec_self_attn = layer(
                dec_inputs=dec_outputs,
                dec_self_attn_mask=dec_self_attn_mask)
            dec_self_attns.append(dec_self_attn)
        dec_self_attns = torch.stack(dec_self_attns)
        dec_self_attns = dec_self_attns.permute([1, 0, 2, 3, 4])
        return dec_outputs, dec_self_attns


class BertModel(nn.Module):

    def __init__(self, vocab_size, d_model, d_ff, d_k, d_v, n_heads, n_layers, pad_index, device):
        super(BertModel, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = PositionalEncoding(d_model=d_model, dropout=0)
        self.seg_embed = nn.Embedding(2, d_model)

        self.layers = []
        for _ in range(n_layers):
            encoder_layer = EncoderLayer(
                d_model=d_model, d_ff=d_ff,
                d_k=d_k, d_v=d_v, n_heads=n_heads,
                device=device)
            self.layers.append(encoder_layer)
        self.layers = nn.ModuleList(self.layers)

        self.pad_index = pad_index

        self.fc = nn.Linear(d_model, d_model)
        self.active1 = nn.Tanh()
        self.classifier = nn.Linear(d_model, 2)

        self.linear = nn.Linear(d_model, d_model)
        self.active2 = GELU()
        self.norm = nn.LayerNorm(d_model)

        self.decoder = nn.Linear(d_model, vocab_size, bias=False)
        self.decoder.weight = self.tok_embed.weight
        self.decoder_bias = nn.Parameter(torch.zeros(vocab_size))

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.tok_embed(input_ids) + self.seg_embed(segment_ids)
        output = self.pos_embed(output)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.pad_index)

        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)

        h_pooled = self.active1(self.fc(output[:, 0]))
        logits_clsf = self.classifier(h_pooled)

        masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1))
        h_masked = torch.gather(output, 1, masked_pos)
        h_masked = self.norm(self.active2(self.linear(h_masked)))
        logits_lm = self.decoder(h_masked) + self.decoder_bias

        return logits_lm, logits_clsf, output


class GPTModel(nn.Module):

    def __init__(self, vocab_size, d_model, d_ff,
                 d_k, d_v, n_heads, n_layers, pad_index,
                 device):
        super(GPTModel, self).__init__()
        self.decoder = MaskedDecoder(
            vocab_size=vocab_size,
            d_model=d_model, d_ff=d_ff,
            d_k=d_k, d_v=d_v, n_heads=n_heads,
            n_layers=n_layers, pad_index=pad_index,
            device=device)
        self.projection = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, dec_inputs):
        dec_outputs, dec_self_attns = self.decoder(dec_inputs)
        dec_logits = self.projection(dec_outputs)
        return dec_logits, dec_self_attns


class Classifier(nn.Module):

    def __init__(self, vocab_size, d_model, d_ff,
                 d_k, d_v, n_heads, n_layers,
                 pad_index, device, num_classes):
        super(Classifier, self).__init__()
        self.encoder = Encoder(
            vocab_size=vocab_size,
            d_model=d_model, d_ff=d_ff,
            d_k=d_k, d_v=d_v, n_heads=n_heads,
            n_layers=n_layers, pad_index=pad_index,
            device=device)
        self.projection = nn.Linear(d_model, num_classes)

    def forward(self, enc_inputs):
        enc_outputs, enc_self_attns = self.encoder(enc_inputs)
        mean_enc_outputs = torch.mean(enc_outputs, dim=1)
        logits = self.projection(mean_enc_outputs)
        return logits, enc_self_attns


class Translation(nn.Module):

    def __init__(self, src_vocab_size, tgt_vocab_size, d_model,
                 d_ff, d_k, d_v, n_heads, n_layers, src_pad_index,
                 tgt_pad_index, device):
        super(Translation, self).__init__()
        self.encoder = Encoder(
            vocab_size=src_vocab_size,  # 5
            d_model=d_model, d_ff=d_ff,  # 128  256
            d_k=d_k, d_v=d_v, n_heads=n_heads,  # 64 64  8
            n_layers=n_layers, pad_index=src_pad_index,  # 4  0
            device=device)
        self.decoder = Decoder(
            vocab_size=tgt_vocab_size,  # 5
            d_model=d_model, d_ff=d_ff,  # 128  256
            d_k=d_k, d_v=d_v, n_heads=n_heads,  # 64 64  8
            n_layers=n_layers, pad_index=tgt_pad_index,  # 4  0
            device=device)
        self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)

    # def forward(self, enc_inputs, dec_inputs, decode_lengths):
    #     enc_outputs, enc_self_attns = self.encoder(enc_inputs)
    #     dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
    #     dec_logits = self.projection(dec_outputs)
    #     return dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns, decode_lengths

    def forward(self, enc_inputs, dec_inputs):
        enc_outputs, enc_self_attns = self.encoder(enc_inputs)
        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
        dec_logits = self.projection(dec_outputs)
        return dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns


if __name__ == '__main__':
    enc_input = [
        [1, 3, 4, 1, 2, 3],
        [1, 3, 4, 1, 2, 3],
        [1, 3, 4, 1, 2, 3],
        [1, 3, 4, 1, 2, 3]]
    dec_input = [
        [1, 0, 0, 0, 0, 0],
        [1, 3, 0, 0, 0, 0],
        [1, 3, 4, 0, 0, 0],
        [1, 3, 4, 1, 0, 0]]
    enc_input = torch.as_tensor(enc_input, dtype=torch.long).to(torch.device('cpu'))
    dec_input = torch.as_tensor(dec_input, dtype=torch.long).to(torch.device('cpu'))
    model = Translation(
        src_vocab_size=5, tgt_vocab_size=5, d_model=128,
        d_ff=256, d_k=64, d_v=64, n_heads=8, n_layers=4, src_pad_index=0,
        tgt_pad_index=0, device=torch.device('cpu'))

    logits, _, _, _ = model(enc_input, dec_input)
    print(logits)


总结

本文已全部介绍完transformer结构原理及代码,但我个人有以下几点说明:
编码传递K V 解码传递Q;
self-attention 和 cross attention本质是X与Y值不同,即得到Q 和 K V 数据来源不同,但实现方法一致;
transformer重点模块为attention(一般是mutil-head attention)、FFN、位置编码、mask编码;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1094384.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

centos6/7 SOCKS5 堆溢出漏洞修复(RPM方式)curl 8.4 CVE-2023-38545 CVE-2023-38546

引用 https://darkdark.top/update-curl.html centos6 rpm 升级包下载:https://download.csdn.net/download/sinat_24092079/88425840 yum update libcurl-8.4.0-1.el6.1.x86_64.rpm curl-8.4.0-1.el6.1.x86_64.rpmcentos7 rpm 升级包下载:https://down…

ELF和静态链接:为什么程序无法同时在Linux和Windows下运行?

目录 疑问 编译、链接和装载:拆解程序执行 ELF 格式和链接:理解链接过程 小结 疑问 既然我们的程序最终都被变成了一条条机器码去执行,那为什么同一个程序,在同一台计算机上,在 Linux 下可以运行,而在…

Linux | 关于入门Linux你有必要了解的指令

目录 前言 1、ls指令 2、pwd指令 3、cd指令 4、touch指令 5、stat指令 6、mkdir指令 7、rmdir 与 rm指令 8、man指令 9、cp指令 10、mv指令 11、cat指令 (1)输入重定向 (2)输出重定向与追加重定向 12、less指令 1…

多模态模型文本预处理方式

句子级别 句子级别的表征编码一整个句子到一个特征中。如果一个句子有多个短语,提取这些短语丢弃其他的单词。 缺点:这种方式会丢失句子中细粒度的信息。 单词级别 将句子中的类别提取出来,结合成一个句子。 缺点:会在类别之…

【数据结构】线性表的抽象数据类型

🦄个人主页:修修修也 🎏所属专栏:数据结构 ⚙️操作环境:Visual Studio 2022 线性表抽象数据类型(LinearListAbstractDataType,简称 ADT)是一种非常重要的抽象数据类型,它是一种使用抽象的方式表示和实现一组数据元素的集合以及与…

宝塔面板服务器内存使用率高的三招解决方法

卸载多余PHP版本。假若安装了多个PHP版本,甚至把 php 5.3、5.4、7.0、7.3 全都安装上了,就会严重增加系统负载和内存使用率。 安装memcached 缓存组件,建议在宝塔面板后台直接安装。 卸载不常用软件。如:宝塔运维、宝塔一键安装…

php如何查找地图距离

要在PHP中使用高德地图、百度地图或腾讯地图获取位置信息,您可以使用它们的相应API服务。以下是获取位置信息的一般步骤: 思路: 获取API密钥:首先,您需要注册并获取相应地图服务提供商的API密钥。这将允许您访问他们的API以获取位…

CSS的美化(文字、背景) Day02

一、文字控制属性 分为:字体样式属性 、文本样式属性 1.1 CSS字体样式属性 1.color定义元素内文字颜色2.font-size 字号大小3 font-family 字体4 font-weight 字体粗细5.font-style 字体风格6.font 字体综合属性 1.1.1 > 文字颜色 color 属性名: color color …

Yakit工具篇:简介和安装使用

简介(来自官方文档) 基于安全融合的理念,Yaklang.io 团队研发出了安全领域垂直语言Yaklang,对于一些无法原生集成在Yak平台中的产品/工具,利用Yaklang可以重新编写 他们的“高质量替代”。对于一些生态完整且认可度较高的产品,Y…

C# CodeFormer 图像修复

效果 项目 代码 using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; using OpenCvSharp; using System; using System.Collections.Generic; using System.Drawing; using System.Drawing.Imaging; using System.Windows.Forms;namespace 图像修复 {p…

高校教务系统登录页面JS分析——南京邮电大学

高校教务系统密码加密逻辑及JS逆向 本文将介绍南京邮电大学教务系统的密码加密逻辑以及使用JavaScript进行逆向分析的过程。通过本文,你将了解到密码加密的基本概念、常用加密算法以及如何通过逆向分析来破解密码。 本文仅供交流学习,勿用于非法用途。 一…

C++标准模板(STL)- 类型支持 (数值极限,min_exponent10,max_exponent,max_exponent10)

数值极限 std::numeric_limits 定义于头文件 <limits> 定义于头文件 <limits> template< class T > class numeric_limits; numeric_limits 类模板提供查询各种算术类型属性的标准化方式&#xff08;例如 int 类型的最大可能值是 std::numeric_limits&l…

多个Python包懒得import,那就一包搞定!

使用Python时&#xff0c;有的代码需要依赖多个框架或库者来完成&#xff0c;代码开头需要import多次&#xff0c;比如&#xff0c; import pandas as pd from pyspark import SparkContext from openpyxl import load_workbook import matplotlib.pyplot as plt import seabo…

Java Day2(Java基础语法)

Java基础 Java基础语法1. 注释、关键字、标识符1.1 Java中的注释1.2 关键字1.3 标识符 2. 数据类型&#xff08;1&#xff09;基本类型&#xff08;primitive type&#xff09;a.字节b.进制c. 浮点数拓展d. 字符拓展 &#xff08;2&#xff09; 引用类型(Reference type ) 3. 类…

【软件测试】总结

文章目录 一. 测试用例1. 常见设计测试用例(1)非软件题型(2)软件题型(3)代码型题(4)关于个人项目设计测试用例 2. 万能公式和具体的方法如何理解(1)万能公式(2)Fiddler实现弱网模式(3)针对公交卡设计测试用例 3. 进阶设计测试用例 二. 自动化1. 什么是自动化以及为什么要做自动…

杀死僵尸进程ZooKeeperMain

关闭Hadoop后jps发现还有个进程ZooKeeperMain没有关闭&#xff0c;使用kill -9 <>也没有用&#xff0c;这种就是僵尸进程&#xff0c;需要用父进程ID来杀死 解决方法 话不多说&#xff0c;直接上解决方案&#xff0c; 1. 第一步 清楚需要关闭的进程ID&#xff0c;我…

CentOS-7下安装及配置vsftpd详细步骤(可匿名访问)

第一步安装vsftpd&#xff1a; yum -y install vsftpd 第二步修改ftp主目录所属用户为用户ftp&#xff1a; chown ftp /var/ftp/pub 第三步备份及配置ftp&#xff1a; cp /etc/vsftpd/vsftpd.conf ~/vsftpd.conf.bakvim /etc/vsftpd/vsftpd.conf 配置如下图&#xff1a;…

《AWD特训营》CTF/AWD竞赛的速胜指南!全面提升安全技术

&#x1f3ac; 鸽芷咕&#xff1a;个人主页 &#x1f525; 个人专栏:《粉丝福利》 《C语言进阶篇》 ⛺️生活的理想&#xff0c;就是为了理想的生活! 文章目录 前言一、AWD竞赛的由来《AWD特训营&#xff1a;技术解析、赛题实战与竞赛技巧》1.1介绍&#xff1a; 《AWD特训营》…

Qt拖拽文件到窗口、快捷方式打开

大部分客户端都支持拖拽文件的功能&#xff0c;本篇博客介绍Qt如何实现文件拖拽到窗口、快捷方式打开&#xff0c;以我的开源视频播放器项目为例&#xff0c;介绍拖拽视频到播放器窗口打开。   需要注意的是&#xff0c;Qt拖拽文件的功能&#xff0c;不支持以管理员权限启动的…

《PyTorch深度学习实践》第三讲 反向传播

《PyTorch深度学习实践》第三讲 反向传播 问题描述问题分析编程实现代码实现效果 参考文献 问题描述 问题分析 编程实现 代码 import torch # 数据集 x_data [1.0, 2.0, 3.0] y_data [2.0, 4.0, 6.0] # w权重 w torch.tensor([1.0]) w.requires_grad True # 需要计算梯度…