【论文复现】Transformer

news2024/9/22 1:02:05

Transformer

  • 前言
  • 网络架构
  • 数据处理
    • 词嵌入向量
    • 位置编码
  • 模型定义
    • 多头注意力机制
    • 编码器Encoder
    • 解码器Decoder

前言

Transformer应用范围非常广泛,涵盖了自然语言处理、时间序列数据预测、图像处理等领域。由于笔者之前都是应用,但是对于原理并没有深刻理解导致想要进行进一步的调试难度比较大,所以学习Transformer的原理以便更加好地运用Transformer。Transformer-vit-tutorial-baseline这一篇文章写的非常好,非常清晰地展示了用代码实现Transformer的具体编程步骤以及原理解释,另外Transformer From Scratch和Build Transformer With Pytorch展示了用Pytorch复现Transformer代码的全过程。

网络架构

在这里插入图片描述

数据处理

词嵌入向量

将单词转化成嵌入向量能够更加好地捕捉单词之间的语法和语义关系,比如在自然语言处理中经常设置512这个超参数的数值,表示可以最多提取出512个表征特征用于做模型训练。

class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        """
        Args:
            vocab_size: size of vocabulary
            embed_dim: dimension of embeddings
        """
        super(Embedding, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
    def forward(self, x):
        """
        Args:
            x: input vector
        Returns:
            out: embedding vector
        """
        out = self.embed(x)
        return out

位置编码

由于一次性是将所有的单词输入,模型并不知道单词之间的先后关系,然而对于自然语言来说语义顺序很重要,比如“我是你爸爸”顺序不对就变成了“你是我爸爸”。

class PositionalEmbedding(nn.Module):
    def __init__(self,max_seq_len,embed_model_dim):
        """
        Args:
            seq_len: length of input sequence
            embed_model_dim: demension of embedding
        """
        super(PositionalEmbedding, self).__init__()
        self.embed_dim = embed_model_dim

        pe = torch.zeros(max_seq_len,self.embed_dim)
        for pos in range(max_seq_len):
            for i in range(0,self.embed_dim,2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/self.embed_dim)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/self.embed_dim)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)


    def forward(self, x):
        """
        Args:
            x: input vector
        Returns:
            x: output
        """
      
        # make embeddings relatively larger
        x = x * math.sqrt(self.embed_dim)
        #add constant to embedding
        seq_len = x.size(1)
        x = x + torch.autograd.Variable(self.pe[:,:seq_len], requires_grad=False)
        return x

模型定义

多头注意力机制

多头注意力机制作用是让模型知道一句话单词与单词之间的关系,比如有一句话“Dog is crossing the street because it saw its master”。多头注意力就可以分析出it代指的就是Dog。多头注意力机制的基础是自注意力机制,也就是当前单词和当前词所在的句子中其他单词和当前单词之间的关系。对多个单词进行自注意力计算并拼接起来就实现了多头注意力。
在这里插入图片描述

1、从输入的文本中得到Q、K、V三个权值向量,Q代表用户的请求,就是输入的对话,K是与Q对应的键值,V是最后的输出。
2、计算[Q*K.t],先进行乘积,这里K用转置矩阵
3、运行Softmax函数并隐藏当前处理单词后面的单词,得到输出得分
4、将输出与Value进行乘积操作
5、通过线性层输出

class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    assert d_model % num_heads == 0 

    # 初始化维度
    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model // num_heads

    # 初始化K, Q, V以及O向量空间
    self.W_q = nn.Linear(d_model, d_model)
    self.W_k = nn.Linear(d_model, d_model)
    self.W_v = nn.Linear(d_model, d_model)
    self.W_o = nn.Linear(d_model, d_model)

  def scaled_dot_product_attention(self, Q, K, V, mask=None):
    # 计算注意力系数
    attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

    # 如果有mask层就应用
    if mask is not None: 
      attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

    # 将注意力系数转化成概率
    attn_probs = torch.softmax(attn_scores, dim=-1)

    # 与V矩阵相乘得到结果
    output = torch.matmul(attn_probs, V)
    return output
  
  def split_heads(self, x):
    # 改变输入内容的形状适应多头注意力
    batch_size, seq_length, d_model = x.size()
    return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
  
  def combine_heads(self, x):
    # 将多头变回到原始大小
    batch_size, _, seq_length, d_k = x.size()
    return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
  
  def forward(self, Q, K, V, mask=None):
    # 计算Q,K,V矩阵
    Q = self.split_heads(self.W_q(Q))
    K = self.split_heads(self.W_k(K))
    V = self.split_heads(self.W_v(V))

    # 计算注意力系数
    attn_output = self.scaled_dot_product_attention(Q, K, V, mask)

    # 将Q和K矩阵计算结果进行合并
    output = self.W_o(self.combine_heads(attn_output))
    return output

编码器Encoder

在这里插入图片描述

编码器将输入的词加载到模型中,值得注意的是编码器中没有Mask层,输入中的句子是整体”可见“的。以输入长度为d的句子为例,编码器的输入就是n个长度为d的向量。输出就是当前单词本身跟其他单词的相似度,是一个d*d的矩阵。输出实际上就是Value的加权和

1、将单词编码成词嵌入向量
2、给单词加上位置编码
3、将输入的n个长为d的向量复制两份并输入到多头注意力层中
4、将结果矩阵输出归一化方便进行统一评价
5、经过Feedforward层(TODO:这里还需要解释)

class EncoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout):
    super(EncoderLayer, self).__init__()
    self.self_attn = MultiHeadAttention(d_model, num_heads)
    self.feed_forward = PositionalEncoding(d_model, d_ff)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def froward(self, x, mask):
    attn_output = self.self_attn(x, x, x, mask)
    x = self.norm1(x + self.dropout(attn_output))
    ff_output = self.feed_forward(x)
    x = self.norm2(x + self.dropout(ff_output))
    return x

解码器Decoder

在这里插入图片描述
解码器的输入是每一次生成的文字序列,作用是在编码器的输出中挑选感兴趣的内容。跟编码器中的多头注意力原理相似,输出就是从Value中选取感兴趣的内容。

1.跟编码器一样,讲输入转化成词嵌入向量并进行位置编码
2.经过多头注意力层并产生输出向量,不同的一点是这里使用了Mask层屏蔽了当前处理单词后面的所有单词
3.将结果相加并归一化输出作为Value向量
4.将Encoder输出的Key和Value向量和Decoder输出的Query向量相结合组成多头注意力层

class DecoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout):
    super(DecoderLayer, self).__init__()
    self.self_attn = MultiHeadAttention(d_model, num_heads)
    self.cross_attn = MultiHeadAttention(d_model, num_heads)
    self.feed_forward = PositionalEncoding(d_model, d_ff)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, enc_output, src_mask, tgt_mask):
    attn_output = self.self_attn(x, x, x, tgt_mask)
    x = self.norm1(x + self.dropout(attn_output))
    attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
    x = self.norm2(x + self.dropout(attn_output))
    ff_output = self.feed_forward(x)
    x = self.norm3(x + self.dropout(ff_output))
    return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1995166.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

树莓派Pico 2,RP2350 现已发售!

https://www.bilibili.com/video/BV1n5YeeMETu/?vd_sourcea637ced2b66f15709d16fcbaceeb47a9 我们很高兴地宣布推出Raspberry Pi Pico 2,我们的第二代微控制器板:采用了由我们自主设计的新款高性能安全型微控制器 RP2350。 Raspberry Pi Pico 2&#…

5 种经过验证的查询翻译技术可提高您的 RAG 性能

如何在用户输入模糊的情况下获得近乎完美的 LLM 性能 欢迎来到雲闪世界。你认为用户会向 LLM 提出完美的问题,这是大错特错。如果我们不直接执行,而是细化用户的问题,结果会怎样?这就是查询转换。 我们开发了一款应用程序&#x…

查看DrawCall流程 Unity工具 Frame Debug

切换帧率 基础面板 可以看到 我们可以通过切换面板 看DrwaCall产生 MainTex 就是材质了 如何优化? 合批 就会一次性直接渲染

双端队列Deque

Deque(双端队列)是一种允许在两端都进行插入和删除操作的线性数据结构。它在 Java Collections Framework 中作为一个重要的接口,具有以下结构特点: 1. 双端操作 两端插入和删除:与传统队列(只能在一端入…

迭代次数顺序的双重性

(A,B)---6*30*2---(0,1)(1,0) 收敛误差为7e-4,收敛199次取迭代次数平均值, 让A是4a1,4a2,…,4a16,B全是0得到迭代次数的顺序就是1,2,…,16. 但是如果让训练集A-B矩阵的高…

kafka-go使用:以及kafka一些基本概念说明

关于kafka 作为开发人员kafka中最常关注的几个概念,是topic,partition和group这几个概念。topic是主题的意思,简单的说topic是数据主题,这样解释好像显得很苍白,只是做了个翻译。一图胜前言,我们还是通过图解来说明。…

PDF密码移除技巧: 五大 PDF 密码移除器

如果您想解密或删除 PDF 密码,该怎么办?PDF 是一种经常用于商业的格式,您可以在培训、教育和商业场合使用它。添加这些 PDF 文件的密码可以保护您的安全和隐私。因此,有很多 PDF 都用密码加密,当您想要查看这些 PDF 时…

PTrade常见问题系列22

反馈定义的上午7点执行run_daily函数,但是每周一上午都没法正常执行? 1、run_daily函数加载在initialize函数中,执行后才会创建定时任务; 2、由于周末会有例行重启操作,在重启以后拉起交易时相当于非交易日启动的交易…

【人工智能训练师】2 集群搭建

题目一、基础配置 core-site.xml参数配置详情 官方文档:http://hadoop.apache.org/docs/current/hadoop-project-dist/hadoop-common/core-default.xml core-default.xml与core-site.xml的功能是一样的,如果在core-site.xml里没有配置的属性&#xff0c…

【C++二分查找 决策包容性】1300. 转变数组后最接近目标值的数组和

本文涉及的基础知识点 C二分查找 决策包容性 LeetCode1300. 转变数组后最接近目标值的数组和 给你一个整数数组 arr 和一个目标值 target ,请你返回一个整数 value ,使得将数组中所有大于 value 的值变成 value 后,数组的和最接近 target …

【开端】JAVA中的切面使用

一、绪论 在不使用过滤器和 拦截器的前提下,如果统一对JAVA的 方法进行 管理。比如对一类方法或者类进行日志监控,前后逻辑处理。这时就可以使用到切面。它的本质还是一个拦截器。只是通过注解的方式来标识所切的方法。 二、JAVA中切面的使用实例 Aspec…

如何看待“低代码”开发平台的兴起

目录 1.概述 1.1.机遇 1.2.挑战 1.3.对开发者工作方式的影响 2.技术概览 2.1.主要特点 2.2.市场现状 2.3.主流低代码平台 2.4.分析 3.效率与质量的权衡 3.1.提高开发效率 3.2.质量与安全隐患 3.3.企业应用开发的利弊分析 4.挑战与机遇 4.1.机遇 4.2.挑战 4.3.…

为什么需要在线实时预览3D模型?如何实现?

在线实时预览3D模型在现代设计、产品开发、市场营销、以及娱乐等领域中变得越来越重要,原因可以归结为以下几个方面: 1、多平台兼容性: 在线实时预览通常不依赖于特定的操作系统或软件平台,只要设备能够访问互联网和浏览器&…

21-原理图的可读性的优化处理

1.自定义原理图尺寸 先将原理图移动到左下角 2.划分模块 3.放置模块字符串

第三期书生大模型实战营——基础岛

1.书生大模型全链路开源体系 【书生浦语大模型全链路开源开放体系】 https://www.bilibili.com/video/BV18142187g5/?share_sourcecopy_web&vd_source711f676eb7f61df7d2ea626f48ae1769 视频里介绍了书生浦语大模型的开源开放体系,包括了其的技术发展、模型架…

ubuntu系统下安装LNMP集成环境的详细步骤(保姆级教程)

php开发中集成环境的安装是必不可少的技能,而LNMP代表的是:Linux系统下Nginx+MySQL+PHP这种网站服务器架构。今天就给大家分享下LNMP的安装步骤。 1 Nginx安装 在安装Nginx前先执行下更新命令: sudo apt-get update 接下来开始安装Nginx, 提示:Could not get lock /v…

【mysql 第二篇章】请求到真正执行 SQL 到底是一个怎么样的过程?

从用户调用到SQL执行的流程中间发生了什么事情 1、网络请求使用 线程 来处理,当数据库连接池中监听到有连接请求,这个时候会分配一个线程来处理。 2、SQL接口 负责接收 SQL 语句,当线程监听到有请求和读取数据的之后,将 SQL 语句…

Android Fragment:详解,结合真实开发场景Navigation

目录 1)Fragment是什么 2)Fragment的应用场景 3)为什么使用Fragment? 4)Fragment如何使用 5)Fragment的生命周期 6)Android开发,建议是多个activity,还是activity结合fragment&…

SparkSQL——AnalyzedLogicalPlan生成

Rule和RuleExecutor SparkSQL中对LogicalPlan的解析、优化、还有物理执行计划生成都是分成一个个Rule进行的。 RuleExecutor是一个规则引擎,它收集Rule,并对plan按照rule进行执行。 每一个Rule的实现类都要实现apply方法,具体逻辑都放在这个…

mysql中的时间相关函数

MySQL服务器中有3种时区设置: 系统时区(保存在system_time_zone系统变量中)服务器时区(保存在全局系统变量time_zone中)每个客户端连接的时区(保存在会话变量time_zone中) 其中,客…