Transformer学习理解

news2024/12/22 0:19:22

1.前言        

        本文介绍当下人工智能领域的基石与核心结构模型——Transformer,为什么说它是基石,因为以ChatGPT为代表的聊 天机器人以及各种有望通向AGI(通用人工智能)的道路上均在采用的Transformer。 Transformer也是当下NLP任务的底座,包括后续的BERT和GPT,都是Transformer架构,BERT主要由Transformer的 encoder构成,GPT则主要是decoder构成。

        本文将会通读Transformer原论文《Attention is all you need》,总结Transformer结构以及文章要点,然后采用pytorch 代码进行机器翻译实验,通过代码进一步了解Transformer在应用过程中的步骤。

2.论文阅读笔记

        Transformer的论文是《Attention is all you need》(https://arxiv.org/abs/1706.03762),由Google团队在2017提出的 一种针对机器翻译的模型结构,后续在各类NLP任务上均获取SOTA。

 Motivation:针对RNN存在的计算复杂、无法串行的问题,提出仅通过attention机制的简洁结构——Transformer,在多个序列任务、机器翻译中获得SOTA,并有运用到其他数据模态的潜力。

 模型结构: 模型由encoder和decoder两部分构成,分别采用了6个block堆叠, encoder的block有两层,multi-head attention和FC层 decoder的block有三层,处理自回归的输入masked multi-head attention,处理encoder的attention,FC层。

注意力机制:scale dot-production attention,采用QK矩阵乘法缩放后softmax充当权重,再与value进行乘法。 多头注意力机制:实验发现多个头效果好,block的输出,把多个头向量进行concat,然后加一个FC层。因此每个头的向量长度是总长度/头数,例:512/8=64,每个头是64维向量。

Transformer的三种注意力层:

(1)encoder:输入来自上一层的全部输出

(2)decoder-输入:为避免模型未卜先知,只允许看见第i步之前的信息,需要做mask操作,确保在生成序列的每个 元素时,只考虑该元素之前的元素。这里通过softmax位置设置负无穷来控制无效的连接。

(3)decoder-第二个attention:q来自上一步输出,k和v来自encoer的输出。这样解码器可以看到输入的所有序列。

FFN层:attention之后接入两个FC层,第一个FC层采用2048,第二个是512,第一个FC层采用max(0, x)作为激活函数。

embedding层的缩放:在embedding层进行尺度缩放,乘以根号d_model。

位置编码:采用正余弦函数构建位置向量,然后采用加法,融入embedding中。

实验:两个任务,450万句子(英-德),3600万句子(英-法),8*16GB显卡,分别训练12小时,84小时。10万step 时,采用4kstep预热;采用标签平滑0.1。

重点句子摘录

1. In the Transformer this is reduced to a constant number of operations, albeit at the cost of reduced effective resolution due to averaging attention-weighted positions, an effect we counteract with Multi-Head Attention as described in section 3.2.

        由于注意力机制最后的加权平均可能会序列中各位置的细粒度捕捉不足,因此引入多头注意力机制。 这里是官方对多头注意力机制引入的解释。

2. We suspect that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients 4. To counteract this effect, we scale the dot products by 根号 dk

        在Q和K做点积时,若向量维度过长,会导致点积结果过大,再经softmax映射后,梯度会变小,不利于模型学习,因此 需要进行缩放。缩放因子为除以根号dk。

多头注意力机制的三种情况:

        i. encoder:输入来自上一层的全部输出。

        ii. decoder-输入:为避免模型未卜先知,只允许看见第i步之前的信息,需要做mask操作,确保在生成序列的每个 元素时,只考虑该元素之前的元素。这里通过softmax位置设置负无穷来控制无效的连接。

        iii. decoder-第二个attention:q来自上一步输出,k和v来自encoer的输出。这样解码器可以看到输入的所有序列。

首次阅读遗留问题

        1. 位置编码是否重新学习? 不重新学习,详见 PositionalEncoding的         self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))

        2. qkv具体实现过程

                i. 通过3个w获得3个特征向量:

                ii. attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) # q.shape == (bs, n_head, len_seq, d_k/n_head) ,每个token用 一个向量表示,总向量长度是头数*每个头的向量长度。

                iii. attn = self.dropout(F.softmax(attn, dim=-1)) iv. output = torch.matmul(attn, v)

        3. decoder输入的处理细节

                 i. 训练阶段,无特殊处理,一个样本可以直接输入,根据下三角mask避免未卜先知。

                ii. 推理阶段,首先手动执行token的输入,然后for循环直至最大长度,期间输出的token拼接到输出列表中,并作为下一步decoder的输入。选择输出token时还采用了beam search来有效地平衡广度和深度。

3.数据集构建

        数据下载

        本案例数据集Tatoeba下载自这里。该项目是帮助不同语言的人学习英语,因此是英语与其它几十种语言的翻译文本。 其中就包括本案例使用的英中文本,共计29668条(Mandarin Chinese - English cmn-eng.zip (29668)) 数据以txt形式存储,一行是一对翻译文本,例如长这样:

        1.That mountain is easy to climb. 那座山很容易爬。

        2.It's too soon. 太早了。

        3.Japan is smaller than Canada. 日本比加拿大小。

        数据集划分

         对于29668条数据进行8:2划分为训练、验证,这里采用配套代码a_data_split.py进行划分,即可在统计目录下获得 train.txt和text.txt。

4.词表构建

        文本任务首要任务是为文本构建词表,这里采用与上节一样的方法,首先对文本进行分词,然后统计语料库中所有的 词,最后根据最大上限、最小词频等约束,构建词表。本部分配套代码是b_gen_vocabulary.py

        词表的构建过程中,涉及两个知识点:中文分词和特殊token。

        1. 中文分词

        对于英文,分词可以直接采用空格。而对于中文,就需要用特定的分词方法,这里采用的是jieba分词工具,以下是英文 和中文的分词代码。

        source.append(parts[0].split(' '))

        target.append(list(jieba.cut(parts[1]))) # 分词

        2. 特殊token

         由于seq2seq任务的特殊性,在解码器部分,通常需要一个token告诉模型,现在是开始,同时还需要有个token让模型 输出,以此告诉人类,模型输出完毕,不要再继续生成了。

         因此相较于文本分类,还多了bos, eos,两个特殊token,有的时候,开始token也会用start表示。

        PAD_TAG = "" # 用PAD补全句子长度

         BOS_TAG = "" # 用BOS表示开始

        EOS_TAG = "" # 用EOS表示结束

        UNK_TAG = "" # 用EOS表示结束

        PAD = 0 # PAD字符对应的数字

        BOS = 1 # BOS字符对应的数字

         EOS = 2 # EOS字符对应的数字

         UNK = 3 # UNK字符对应的数字

        运行代码后,词表字典保存到了result目录下,并得到如下输出,表明英文中有2518个词,中文有3365,但经过最大长 度3000的截断后,只剩下2996,另外4个是特殊token。

        100%|██████████| 23635/23635 [00:00<00:00, 732978.24it/s]

        原始词表长度:2518,截断后长度:2518

        2522

        保存词频统计图:vocab_en.npy_word_freq.jpg

        100%|██████████| 23635/23635 [00:00<00:00, 587040.62it/s]

        保存统计图:vocab_en.npy_length_freq.jpg

        原始词表长度:3365,截断后长度:2996

        3000

5.Dataset编写

        NMTDataset的编写逻辑与上一小节的Dataset类似,首先在类初始化的时候加载原始数据,并进行分词;在getitem迭代 时,再进行token转index操作,这里会涉及增加结束符、填充符、未知符。 核心代码如下:

def __init__(self, path_txt, vocab_path_en, vocab_path_fra, max_len=32):
    self.path_txt = path_txt
    self.vocab_path_en = vocab_path_en
    self.vocab_path_fra = vocab_path_fra
    self.max_len = max_len
    self.word2index = WordToIndex()
    self._init_vocab()
    self._get_file_info()

def __getitem__(self, item):
# 获取切分好的句子list,一个元素是一个词
    sentence_src, sentence_trg = self.source_list[item], self.target_list[item]
# 进行填充, 增加结束符,索引转换
    token_idx_src = self.word2index.encode(sentence_src, self.vocab_en, self.max_len)
    token_idx_trg = self.word2index.encode(sentence_trg, self.vocab_fra, self.max_len)
    str_len, trg_len = len(sentence_src) + 1, len(sentence_trg) + 1 # 有效长度, +1是填充的结束符 <eos>.
    return np.array(token_idx_src, dtype=np.int64), str_len, np.array(token_idx_trg,         dtype=np.int64), trg_len

def _get_file_info(self):
    text_raw = read_data_nmt(self.path_txt)
    text_clean = text_preprocess(text_raw)
    self.source_list, self.target_list = text_split(text_clean)

6.模型构建

        Transformer代码梳理如下图所示,大体可分为三个层级

        1. Transformer的构建,包含encoder、decoder两个模块,以及两个mask构建函数。

        2. 两个coder内部实现,包括位置编码、堆叠的block。

        3. block实现,包含多头注意力、FFN,其中多头注意力将softmax(QK)*V拆分为ScaledDotProductAttention类。

        代码整体与论文中保持一致,总结几个不同之处:

        1. layernorm使用时机前置到attention层和FFN层之前

        2. position embedding 的序列长度默认采用了200,如果需要更长的序列,则要注意配置。 具体代码实现不再赘述,把论文中图2的结果熟悉,并掌握上面的代码结构,可以快速理解各模块、组件的运算和操作步 骤,如有疑问的点,再打开代码观察具体运算过程即可。

7.模型训练

        原论文进行两个数据集的机器翻译任务,采用的数据和超参数列举如下,供参考。

        英语-德语,450万句子对,英语-法语,3600万句子对。均进行base/big两种尺寸训练,分别进行10万step和30万step训 练,耗时12小时/84小时。10万step时,采用4千step进行warmup。正则化方面采用了dropout=0.1的残差连接,0.1的标 签平滑。

        本实验中有2.3万句子对训练,只能作为Transformer的学习,性能指标仅为参考,具体任务后续由BERT、GPT、T5来 实现更为具体的项目。 采用配套代码train_transformer.py,执行训练即可:

        采用配套代码train_transformer.py,执行训练即可:

python train_transformer.py -embs_share_weight -proj_share_weight -label_smoothing -b 256 -warmup 128000 -epoch 400

        训练完成后,在result文件夹下会得到日志与模型,接下来采用配套代码c_train_curve_plot.py

        对日志数据进行绘图可视化,Loss和Accuracy分别如下,可以看出模型拟合能力非常强,性能还在提高,但受限于数据量过少,模型在200个 epoch之后就已经出现了过拟合。

        这里面的评估指标用的acc,具体是直接复用github项目,也能作为模型性能的评估指标,就没有去修改为BLUE。

8.模型推理

        Transformer的推理过程与训练时不同,值得仔细学习。

        Transformer的推理输出是典型的自回归(Auto regressive),并且需要根据多个条件综合判断何时停止,因此推理部分的逻辑值得认真学习,具体步骤如下:

        第一步:输入序列经encoder,获得特征,每个token输出一个向量,这个特征会在decoder的每个step都用到,即 decoder中各block的第二个multi-head attention。需要enc_output去计算k和v,用decoder上一层输出特征向量去计算 q,以此进行decoder的第二个attention。

        enc_output, *_ = self.model.encoder(src_seq, src_mask)

        第二步:手动执行decoder第一步,输入是这个token,输出的是一个概率向量,由分类概率向量再决定第一个输出 token。

        self.register_buffer('init_seq', torch.LongTensor([[trg_bos_idx]]))

        dec_output =  self._model_decode(self.init_seq, enc_output, src_mask)

        第三步:循环max_length次执行decoder剩余步。第i步时,将前序所有步输出的token,组装为一个序列,输入到 decoder。

         在代码中用一个gen_seq维护模型输出的token,输入给模型时,只需要gen_seq[:, :step]即可,很巧妙。 在每一步输出时,都会判断是否需要停止输出字符。

for step in range(2, max_seq_len):
    dec_output = self._model_decode(gen_seq[:, :step], enc_output, src_mask)
略
    if (eos_locs.sum(1) > 0).sum(0).item() == beam_size:
        _, ans_idx = scores.div(seq_lens.float() ** alpha).max(0)
        ans_idx = ans_idx.item()
        break

        借助李宏毅老师2021年课件中一幅图,配合代码,可以很好的理解Transformer在推理时的流程。

         1. 输入序列经encoder,获得特征(绿色、蓝色、蓝色、橙色)

          2. decoder输入第一个序列(序列长度为1,token是),输出第一个概率向量,并通过max得到“机”。

          3. decoder输入第二个序列(序列长度为2,token是[BOS, 机]),输出得到“器”

          4. decoder输入第三个序列(序列长度为3,token是[BOS, 机,器]),输出得到“学”

          5. decoder输入第四个序列(序列长度为4,token是[BOS, 机,器,学]),输出得到“习”

          6. ...以此类推 

        在推理过程中,通常还会使用Beam Search 来最终确定当前步应该输出哪个token,此处不做展开。

        运行配套代码inference_transformer.py可以看到10条训练集的推理结果。

         从结果上看,基本像回事儿了。

        src: tom's door is open .

        trg: 湯姆的門開著 。

        pred: 汤姆的 <unk>了 。

        src:<unk> is a <unk> country .

        trg: <unk>是一個<unk>的國家 。

         pred: <unk>是一個<unk>的城市 。

         src: i can come at three .

         trg: <unk>可以 來 。

        pred: 我 可以 在 那裡

9.总结

        本文通过Transformer论文的学习,了解Transformer基础架构,并通过机器翻译案例,从代码实现的角度深入剖析 Transformer训练和推理的过程。由于Transformer是目前人工智能的核心与基石,因此需要认真、仔细的掌握其中细节。

        本文内容值得注意的知识点:

         1. 多头注意力机制的引入:官方解释为,由于注意力机制最后的加权平均可能会序列中各位置的细粒度捕捉不足,因此引入多头注意力机制。

         2. 注意力计算时的缩放因子:QK乘法后需要缩放,是因为若向量维度过长,会导致点积结果过大,再经softmax映射后,梯度会变小,不利于模型学习,因此需要进行缩放。缩放因子为除以根号dk。

        3. 多头注意力机制的三种情况:

          i. encoder:输入来自上一层的全部输出

         ii. decoder-输入:为避免模型未卜先知,只允许看见第i步之前的信息,需要做mask操作,确保在生成序列的每个元素时,只考虑该元素之前的元素。这里通过softmax位置设置负无穷来控制无效的连接。

         iii. decoder-第二个attention:q来自上一步输出,k和v来自encoer的输出。这样解码器可以看到输入的所有序列。

        4. 输入序列的msk:代码实现时,由于输入数据是通过添加来组batch的,并且为了在运算时做并行运算,因此需要对 src中是pad的token做mask,这一点在论文是不会提及的。

        5. decoder的mask:根据下三角mask避免未卜先知。

        6. 推理时会采用beam search进行搜索,确定输出的token。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1843499.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

API接口测试要注意什么?API接口如何开发?

API接口怎么保证数据安全&#xff1f;如何安全调用邮件API接口&#xff1f; API接口不仅能够确保系统的稳定性和可靠性&#xff0c;还能提高开发效率和用户满意度。然而&#xff0c;要进行有效的API接口测试&#xff0c;需要注意多个方面。AokSend将介绍一些关键点&#xff0c…

【云原生】Kubernetes----Metrics-Server组件与HPA资源

目录 引言 一、概述 &#xff08;一&#xff09;Metrics-Server简介 &#xff08;二&#xff09;Metrics-Server的工作原理 &#xff08;三&#xff09;HPA与Metrics-Server的作用 &#xff08;四&#xff09;HPA与Metrics-Server的关系 &#xff08;五&#xff09;HPA与…

百万级 QPS 接入层网关架构方案演进

文章目录 前言1、单机架构2、DNS 轮询3、Nginx 单机4、Nginx 主备 Keepalived5、LVS 主备 Keepalived Nginx 集群6、LVS 主备 Keepalived Nginx 集群 DNS 轮询 前言 随着PC、移动互联网的快速发展&#xff0c;越来越多的人通过手机、电脑、平板等设备访问各种各样APP、网…

pdf只要前几页,pdf怎么只要前几页

在现代办公和学习环境中&#xff0c;PDF文件已成为我们日常处理信息的重要工具。然而&#xff0c;有时我们并不需要整个PDF文件的内容&#xff0c;而只是其中的几页。那么&#xff0c;如何高效地提取PDF文件中的特定页面呢&#xff1f;本文将为您介绍几种实用的方法。 打开 “ …

阿里CEO个人投资的智驾公司,走了不一样的路

佑驾创新在去年8月和11月完成两轮融资&#xff0c;在今年5月底递表港交所&#xff0c;目前拿到了29家车企88款车型的量产订单。自动驾驶赛道不缺明星&#xff0c;这些因素本不足以凸显它的差异化。但是在招股书中&#xff0c;一条特殊的发展路线&#xff0c;却让佑驾创新显得不…

搭建开发模式下的以太坊私有链【Geth:1.14.5】

一、为什么用到私有链&#xff1f; 在以太坊的公有链上部署智能合约、发起交易需要花费以太币。而通过修改配置&#xff0c;可以在本机搭建一套以太坊私有链&#xff0c;因为与公有链没关系&#xff0c;既不用同步公有链庞大的数据&#xff0c;也不用花钱购买以太币&#xff0…

2024GLEE生活暨教育(上海)博览会,8月20-22日,国家会展中心(上海)

2024GLEE生活暨教育(上海)博览会将于8月20-22日在中国国家会展中心&#xff08;上海&#xff09;举行&#xff0c;博览会总面积近万平方米&#xff0c;设有美好生活和教育产品两大主力展区&#xff0c;全面覆盖婴幼儿、学龄前、小学、初中、高中、大学、中年、老年各个年龄段的…

2024年了,上大学可以不需要用到电脑吗?

前言 在2024年的今天&#xff0c;电脑已经成为了人们工作生活的一大部分。Oh, no&#xff01;好像手机才是。 好像每个人都是这样的&#xff1a;可以没有电脑&#xff0c;但不能没有手机…… 所以2024年的今天&#xff0c;上大学的小伙伴们可以不需要用到电脑吗&#xff1f;…

DDMA信号处理以及数据处理的流程---cfar检测

Hello,大家好,我是Xiaojie,好久不见,欢迎大家能够和Xiaojie一起学习毫米波雷达知识,Xiaojie准备连载一个系列的文章—DDMA信号处理以及数据处理的流程,本系列文章将从目标生成、信号仿真、测距、测速、cfar检测、测角、目标聚类、目标跟踪这几个模块逐步介绍,这个系列的…

SpringBoot 实现全局异常处理

为什么要使用全局异常处理&#xff1f; 减少冗余代码&#xff1a; 在不使用全局异常处理器的情况下&#xff0c;项目中各层可能会出现大量的try {…} catch {…} finally {…}代码块&#xff0c;这些代码块不仅冗余&#xff0c;还影响代码的可读性。全局异常处理器允许我们在一…

第二十五篇——信息加密:韦小宝说谎的秘诀

目录 一、背景介绍二、思路&方案三、过程1.思维导图2.文章中经典的句子理解3.学习之后对于投资市场的理解4.通过这篇文章结合我知道的东西我能想到什么&#xff1f; 四、总结五、升华 一、背景介绍 加密这件事&#xff0c;对于这个时代的我们来说非常重要&#xff0c;那么…

数据分析中的数学:从基础到应用20240617

数据分析中的数学&#xff1a;从基础到应用 数据分析离不开数学的支持&#xff0c;统计学和概率论是其重要组成部分。本文将通过几个具体的实例&#xff0c;详细讲解数据分析中常用的数学知识&#xff0c;并通过Python代码演示如何应用这些知识。 1. 描述性统计 基本概念和用…

java学习--集合(大写二.2)

看尚硅谷视频做的笔记 2.collection接口及方法 jdk8里有一些默认的方法&#xff0c;更多的是体现的是一种规范&#xff0c;规范更多关注的是一些抽象方法。 看接口里面的抽象方法&#xff0c;选一个具体的实现类。 测试collection的方法&#xff0c;存储一个一个数据都有哪些…

ENVI实战—一文搞定监督分类

实验1&#xff1a;利用ROI建立样本训练集和验证集 目的&#xff1a;学会利用ROI建立计算机分类时的样本集 过程&#xff1a; ①导入影像&#xff1a;打开ENVI&#xff0c;选择“打开→打开为→光学传感器→ESA→Sentinel-2”&#xff0c;将Sentinel-2影像导入到ENVI平台中。…

LabVIEW与3D相机开发高精度表面检测系统

使用LabVIEW与3D相机开发一个高精度表面检测系统。该系统能够实时获取三维图像&#xff0c;进行精细的表面分析&#xff0c;广泛应用于工业质量控制、自动化检测和科学研究等领域。通过真实案例&#xff0c;展示开发过程中的关键步骤、挑战及解决方案&#xff0c;确保系统的高性…

MySQL客户端与服务端建立连接抓包分析

文章目录 MySQL客户端与服务端建立连接流程抓包分析1.连接建立流程2.各类数据包介绍2.1挑战数据包2.2认证数据包2.3切换认证插件请求数据包2.4切换认证插件响应数据包2.5成功数据包2.6失败数据包3.注意点4.测试代码MySQL客户端与服务端建立连接流程抓包分析 抓包工具采用的是W…

大厂的 404 页面都长啥样?看到最后一个,我笑了~

每天浏览各大网站&#xff0c;难免会碰到404页面啊。你注意过404页面么&#xff1f;猿妹搜罗来了下面这些知名网站的404页面&#xff0c;以供大家欣赏&#xff0c;看看哪个网站更有创意&#xff1a; 腾讯 网易 淘宝 百度 新浪微博 新浪 京东 优酷 腾讯视频 搜狐 携程 去哪儿 今…

C#——装箱与拆箱详情

装箱与拆箱 装箱: 将值类型转换成引用类型的过程&#xff1b; 拆箱: 把引用类型转为值类型的过程&#xff0c;就是拆箱 装箱 拆箱

usb摄像头应用编程

作者简介&#xff1a; 一个平凡而乐于分享的小比特&#xff0c;中南民族大学通信工程专业研究生在读&#xff0c;研究方向无线联邦学习 擅长领域&#xff1a;驱动开发&#xff0c;嵌入式软件开发&#xff0c;BSP开发 作者主页&#xff1a;一个平凡而乐于分享的小比特的个人主页…

JUC并发编程-第二天:线程高级部分

线程高级部分 线程不安全原子性可见性有序性&#xff08;指令重排&#xff09;更多的解决线程安全 线程不安全 多线程下并发同时对共享数据进行读写&#xff0c;会造成数据混乱线程不安全 当多线程下并发访问临界资源时&#xff0c;如果破坏其原子性、可见性、有序性&#xff…