pytorch实现transformer(1): 模型介绍

news2024/11/25 10:42:59

文章目录

    • 1. transformer 介绍
    • 2 Position Encoding
      • 2.1 位置编码原理
      • 2.2 代码实现
    • 3 Self-attention
    • 4 前馈层FFN
    • 5 残差连接与层归一化
    • 6 编码器和解码器结构

1. transformer 介绍

Transformer 模型是由谷歌在 2017 年提出并首先应用于机器翻译的神经网络模型结构。机器翻译的目标是从源语言(Source Language)转换到目标语言(Target Language)。Transformer 结构完全通过注意力机制完成对源语言序列和目标语言序列全局依赖的建模。当前几乎全部大语言模型都是基于 Transformer 结构,本节以应用于机器翻译的基于 Transformer 的编码器和解码器介绍该模型。

Transformer它的提出最开始是针对NLP领域的,在次之前大家主要用的是RNN,LSTM这类时序网络。像RNN这类网络其实它是有些问题的,首先它的记忆的长度是有限的,特别像RNN它的记忆长度就比较短,所以后面就有提出LSTM。但是他们还有另外一个问题就是无法并行化,也就是说我们必须先计算 t 0 t_0 t0时刻的输出,计算完之后我们才能进一步计算 t 1 t_1 t1时刻的数据。由于无法并行化,训练效率就比较低.

针对这一问题,Google提出了Transformer来替代之前的时序网络,

  • Transformer不受硬件限制的情况下,理论上记忆是可以无限长的。
  • 其次,它是可以做并行化的,这是一个非常大的优点

基于 Transformer 结构主要分两部分:编码器Encoder和解码器Decoder,它们均由若干个基本的 Transformer 块(Block)组成(对应着图中的灰色框)。每个 Transformer 块都接收一个向量序列 { x i } i = 1 t \{x_i\}_{i=1}^t { xi}i=1t作为输入,并输出一个等长的向量序列作为输出 { y i } i = 1 t \{y_i\}_{i=1}^t { yi}i=1t, 而 y i y_i yi是当前 Transformer 块对输入 x i x_i xi进一步整合其上下文语义后对应的输出。

在这里插入图片描述

图1:基于 Transformer 的编码器和解码器结构

主要涉及到如下几个模块:

  • Position Encoding: 使用位置编码来理解文本的顺序

  • 注意力层:使用多头注意力(Multi-Head Attention)机制 整合上下文语义 ,它使得序列中任意两个token之间的依赖关系可以直接被建模而不基于传统的循环结构,从而更好地解决文本的长程依赖

  • 位置感知前馈层(Position-wise FFN):通过全连接层对输入文本序列中的每个单词表示进行更复杂的变换

  • 残差连接:对应图中的 Add 部分。它是一条分别作用在上述两个子层当中的直连通路,被用于连接它们的输入与输出。从而使得信息流动更加高效,有利于模型的优化

  • 层归一化:对应图中的 Norm 部分。作用于上述两个子层的输出表示序列中,对表示序列进行层归一化操作,同样起到稳定优化的作用。

2 Position Encoding

2.1 位置编码原理

由于Transformer模型没有循环神经网络的迭代操作,所以我们必须提供每个token位置信息给Transformer, 这样它才能识别出语言中的顺序关系。

Positional Encoding(位置嵌入), 它的维度为[num_token,embedding_dimension], 位置嵌入的维度与词向量token的维度是相同的,都是embedding_dimension。其中max_sequence_length属于超参数,指的是限定每个句子最长由多少个词组成。

注意,我们一般以为单位训练Transformer模型,首先初始化字编码大小为[vocab_size,embedding_dimension], vocab_size 为字库中所有字的数量,embedding_dimension为字向量的维度,对应pytorch中,其实就是 nn.Embedding(vocab_size,embedding_dimension)

论文中使用了 s i n sin sin c o s cos cos函数的线性变换来提供给模型位置信息:

在这里插入图片描述
其中pos指的是一句话中某个字的位置,取值范围为[0,max_sequence_length], i i i指的是字向量的序号,取值范围是[0,embedding_dimension/2], d m o d e l d_{model} dmodel指的是embedding_dimension的值。

通过sincos处理,从而使得位置编码产生不同对的周期性变化,使得每个位置在embedding_dimension维度上都会得到不同周期的sincos函数的取值组合,从而产生独一的位置信息,最终使得模型学到位置之间的依赖关系和自然语言的时序特效

2.2 代码实现

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns 
import math


def get_positional_encoding(max_seq_len,embed_dim):
    # 初始化一个positional encoding
    # embed_dim: 字嵌入的维度
    # max_seq_len: 最大的序列长度
    
    positional_encoding = np.array([
        [pos/np.power(1000,2*i/ embed_dim) for i in range(embed_dim)] if pos !=0 else \
            np.zeros(embed_dim) for pos in range(max_seq_len)
            ])
    
    positional_encoding[1:,0::2] = np.sin(positional_encoding[1:,0::2]) #dim 2i 偶数
    positional_encoding[1:,1::2] = np.cos(positional_encoding[1:,1::2]) #dim 2i+1 奇数
    
    return positional_encoding

positional_encoding = get_positional_encoding(max_seq_len=100,embed_dim=16)
plt.figure(figsize=(10,10))
sns.heatmap(positional_encoding)
plt.title('Sinusoidal Function')
plt.xlabel('hidden dimension')

在这里插入图片描述

plt.figure(figsize=(8,5))
plt.plot(positional_encoding[1:,1],label="dimension 1")
plt.plot(positional_encoding[1:,2],label="dimension 2")
plt.plot(positional_encoding[1:,3],label="dimension 3")
plt.legend()
plt.xlabel("Sequence length")
plt.ylabel("Period of Positional Encoding")

在这里插入图片描述

3 Self-attention

自注意力(Self-Attention)操作是基于 Transformer 的机器翻译模型的基本操作,在源语言的编码和目标语言的生成中频繁地被使用以建模源语言、目标语言任意两个单词之间的依赖关系。给定由单词语义嵌入及其位置编码叠加得到的输入表示 { x i ∈ R d } i t \{x_i \in R^d\}_i^t { xiRd}it, 为了实现对上下文语义依赖的建模,进一步引入在自注意力机制中涉及到的三个元素:查询 q i q_i qi

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1652544.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

生信新包|LINGER·从单细胞多组学数据推断基因调控网络

题目:Inferring gene regulatory networks from single-cell multiome data using atlas-scale external data 原理 LINGER 是一个计算框架,旨在从单细胞多组学数据推断基因调控网络。 使用基因表达和染色质可及性的计数矩阵以及细胞类型注释作为输入&…

添砖Java之路其一——Java跨平台原理,JRE与JDK(为什么要安装)。

目录 前言: Java跨平台工作原理简单的理解: JRE与JDK: 前言: 最近又开始学Java了,所以又开一个板块来记录我Java的笔记。 Java跨平台工作原理简单的理解: 简单概括:简单来说Java跨平台原理…

【数据结构与算法】力扣 226. 翻转二叉树

题目描述 给你一棵二叉树的根节点 root ,翻转这棵二叉树,并返回其根节点。 示例 1: 输入: root [4,2,7,1,3,6,9] 输出: [4,7,2,9,6,3,1]示例 2: 输入: root [2,1,3] 输出: [2,3,1…

Ubuntu16.04 离线安装CDH6.2.1

1. 离线包工作 下载Cloudera Manager安装包,地址:https://archive.cloudera.com/cm6/6.2.1/repo-as-tarball/ cm6.2.1-ubuntu1604.tar.gz下载CDH6.2.1安装包,地址:https://archive.cloudera.com/cdh6/6.2.1/parcels/ CDH-6.2.1-1.…

分布式架构的演技进过程

最近看了一篇文章,觉得讲的挺不错,就借机给大家分享一下。 早期应用:早期的应用比较简单,访问人数有限,大部分的开发单机就能完成。 分离模型:在业务发展后,用户数量逐步上升,服务器的性能出现瓶颈;就需要将应用和数据分开存储,避免相互抢占资源。 缓存模式:随着系…

LeetCode746:使用最小花费爬楼梯

题目描述 给你一个整数数组 cost ,其中 cost[i] 是从楼梯第 i 个台阶向上爬需要支付的费用。一旦你支付此费用,即可选择向上爬一个或者两个台阶。 你可以选择从下标为 0 或下标为 1 的台阶开始爬楼梯。 请你计算并返回达到楼梯顶部的最低花费。 代码 …

Qt 6.7 正式发布!

本文翻译自:Qt 6.7 Released! 原文作者:Qt Group研发总监Volker Hilsheimer 在最新发布的Qt 6.7版本中,我们大大小小作出了许多改善,以便您在构建现代应用程序和用户体验时能够享受更多乐趣。 部分新增功能已推出了技术预览版&a…

sql 注入 1

当前在email表 security库 查到user表 1、第一步,知道对方goods表有几列(email 2 列 good 三列,查的时候列必须得一样才可以查,所以创建个临时表,select 123 ) 但是你无法知道对方goods表有多少列 用order …

操作系统之管程

目录 一. 为什么要引入管程二. 管程的定义与基本特征三. 扩展1:用管程来解决生产者和消费者问题四. 扩展2: Java中类似于管程的机制 \quad 一. 为什么要引入管程 \quad \quad 二. 管程的定义与基本特征 \quad \quad 三. 扩展1:用管程来解决生产者和消费者问题 \quad 很智能 \qu…

如何绘制厂区地图?厂区地图路线规划图怎么做的?

随着工业化的快速发展,工厂规模越来越大,厂内货车往往因路线不明兜转,造成物流效率低,甚至路线拥堵;其他也存在基于安全管理的人员定位,访客指引,厂区设备可视化管理等需求。这些需求都与空间位…

基于STM32的智能垃圾桶设计(论文+源码)_kaic

基于STM32的智能垃圾桶设计 摘 要 随着社会科学技术的迅猛进展,人们的生活质量和速度也在不断提高。然而,大多数传统的家庭垃圾桶已经过时且缺乏创新,缺乏人性化设计。它们使用起来不方便、不卫生,所有的生活和废物垃圾都被混合…

根据不同权限,显示不同的菜单界面

本节:根据不同权限,显示不同的菜单界面 1.写几个角色不同的路由路径配置,有的角色有页面的配置,有的角色就没那几个页面的配置。 根据提供的token来判断 2.然后进行路由比对

力扣刷题Day5——内涵动态规划讲解

题目1: 先来一道很简单的题目: 2697. 字典序最小回文串 - 力扣(LeetCode) 思路: 为了得到字典序最小的回文字符串,对于回文串,就是需要左右的字符相等,而要最小的回文串&#xff…

【intro】GraphSAGE

论文 https://arxiv.org/pdf/1706.02216 abstract 大图中节点的低维embedding已经被证明在各种预测任务中非常有用,然而,大多数现有的方法要求在embedding训练期间图中的所有节点都存在;这些先前的方法属于直推式(transductive&#xff09…

中仕公考:非应届生能考军队文职吗?

军队文职考试的招生对象主要针对普通高等学校的毕业生以及社会人才,报考条件中并没有限制考生必须是应届毕业生。所以,往届毕业生也是具备报考资格的,只需其满足相关的申请条件即可。 报考人员可大致分为三类:普通高校毕业生、社…

LabelImg下载及目标检测数据标注

为什么这一部分内容这么少会单独拎出来呢,因为后期会接着介绍YOLOv8中的其他任务,会使用其他软件进行标注,所以就单独区分开来每一个任务的标注方式了。 这一部分就介绍目标检测任务的标注,数据集是我从COCO2017Val中抽出来两类&a…

Edge的使用心得和深度探索-Sider: ChatGPT 侧边栏

作为一款备受欢迎的网络浏览器,Microsoft Edge在用户体验和功能方面都有着诸多优势。在长期的使用中,我总结出了三条使用心得,同时也发现了三个能够极大提高效率的功能。让我们一起深度探索Edge的潜力吧! 使用心得: 界…

如果你已经掌握了C语言和C++,想要学习QT

在开始前我有一些资料,是我根据网友给的问题精心整理了一份「Qt的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!我认为这并不是难事。对于我来说,我…

4000定制网站,因为没有案例,客户走了

接到一个要做企业站点的客户,属于定制开发,预算4000看起来是不是还行的一个订单? 接单第一步:筛客户 从客户询盘的那一刻开始就要围绕核心要素:预算和工期,凡是不符合预期的一律放掉就好了,没必…