NLP实战8:图解 Transformer笔记

news2025/1/5 15:19:03

目录

1.Transformer宏观结构

2.Transformer结构细节

2.1输入

2.2编码部分

2.3解码部分

2.4多头注意力机制

2.5线性层和softmax

2.6 损失函数

3.参考代码


🍨 本文为[🔗365天深度学习训练营]内部限免文章(版权归 *K同学啊* 所有)
🍖 作者:[K同学啊]

Transformer整体结构图,与seq2seq模型类似,Transformer模型结构中的左半部分为编码器(encoder),右半部分为解码器(decoder),接下来拆解Transformer。

1.Transformer宏观结构

Transformer模型类似于seq2seq结构,包含编码部分和解码部分。不同之处在于它能够并行计算整个序列输入,无需按时间步进行逐步处理。

其宏观结构如下:

6层编码和6层解码器

其中,每层encoder由两部分组成:

  • Self-Attention Layer
  • Feed Forward Neural Network(前馈神经网络,FFNN)

decoder在encoder的Self-Attention和FFNN中间多加了一个Encoder-Decoder Attention层。该层的作用是帮助解码器集中注意力于输入序列中最相关的部分。

单层encoder和decoder

2.Transformer结构细节

2.1输入

Transformer的数据输入与seq2seq不同。除了词向量,Transformer还需要输入位置向量,用于确定每个单词的位置特征和句子中不同单词之间的距离特征。

2.2编码部分

编码部分的输入文本序列经过处理后得到向量序列,送入第一层编码器。每层编码器输出一个向量序列,作为下一层编码器的输入。第一层编码器的输入是融合位置向量的词向量,后续每层编码器的输入则是前一层编码器的输出。

2.3解码部分

最后一个编码器输出一组序列向量,作为解码器的K、V输入。

解码阶段的每个时间步输出一个翻译后的单词。当前时间步的解码器输出作为下一个时间步解码器的输入Q,与编码器的输出K、V共同组成下一步的输入。重复此过程直到输出一个结束符。

解码器中的 Self-Attention 层,和编码器中的 Self-Attention 层的区别:

  • 在解码器里,Self-Attention 层只允许关注到输出序列中早于当前位置之前的单词。具体做法是:在 Self-Attention 分数经过 Softmax 层之前,屏蔽当前位置之后的那些位置(将Attention Score设置成-inf)。
  • 解码器 Attention层是使用前一层的输出来构造Query 矩阵,而Key矩阵和Value矩阵来自于编码器最终的输出。

2.4多头注意力机制

Transformer论文引入了多头注意力机制(多个注意力头组成),以进一步完善Self-Attention。

  • 它扩展了模型关注不同位置的能力
  • 多头注意力机制赋予Attention层多个“子表示空间”。

残差链接&Normalize: 编码器和解码器的每个子层(Self-Attention 层和 FFNN)都有一个残差连接和层标准化(layer-normalization),细节如下图

2.5线性层和softmax

Decoder最终输出一个浮点数向量。通过线性层和Softmax,将该向量转换为一个包含模型输出词汇表中每个单词分数的logits向量(假设有10000个英语单词)。Softmax将这些分数转换为概率,使其总和为1。然后选择具有最高概率的数字对应的词作为该时间步的输出单词。

2.6 损失函数

在Transformer训练过程中,解码器的输出和标签一起输入损失函数,以计算损失(loss)。最终,模型通过方向传播(backpropagation)来优化损失。

3.参考代码

class SelfAttention(nn.Module):
	def __init__(self, embed_size, heads):
		super(SelfAttention, self).__init__()
		self.embed_size = embed_size
		self.heads = heads
		self.head_dim = embed_size // heads

		assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"

		self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
		self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
		self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
		self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

	def forward(self, values, keys, query, mask):
		N =query.shape[0]
		value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]

		# split embedding into self.heads pieces
		values = values.reshape(N, value_len, self.heads, self.head_dim)
		keys = keys.reshape(N, key_len, self.heads, self.head_dim)
		queries = query.reshape(N, query_len, self.heads, self.head_dim)
		
		values = self.values(values)
		keys = self.keys(keys)
		queries = self.queries(queries)

		energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)
		# queries shape: (N, query_len, heads, heads_dim)
		# keys shape : (N, key_len, heads, heads_dim)
		# energy shape: (N, heads, query_len, key_len)

		if mask is not None:
			energy = energy.masked_fill(mask == 0, float("-1e20"))

		attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)

		out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
		# attention shape: (N, heads, query_len, key_len)
		# values shape: (N, value_len, heads, heads_dim)
		# (N, query_len, heads, head_dim)

		out = self.fc_out(out)
		return out


class TransformerBlock(nn.Module):
	def __init__(self, embed_size, heads, dropout, forward_expansion):
		super(TransformerBlock, self).__init__()
		self.attention = SelfAttention(embed_size, heads)
		self.norm1 = nn.LayerNorm(embed_size)
		self.norm2 = nn.LayerNorm(embed_size)

		self.feed_forward = nn.Sequential(
			nn.Linear(embed_size, forward_expansion*embed_size),
			nn.ReLU(),
			nn.Linear(forward_expansion*embed_size, embed_size)
		)
		self.dropout = nn.Dropout(dropout)

	def forward(self, value, key, query, mask):
		attention = self.attention(value, key, query, mask)

		x = self.dropout(self.norm1(attention + query))
		forward = self.feed_forward(x)
		out = self.dropout(self.norm2(forward + x))
		return out


class Encoder(nn.Module):
	def __init__(
			self,
			src_vocab_size,
			embed_size,
			num_layers,
			heads,
			device,
			forward_expansion,
			dropout,
			max_length,
		):
		super(Encoder, self).__init__()
		self.embed_size = embed_size
		self.device = device
		self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
		self.position_embedding = nn.Embedding(max_length, embed_size)

		self.layers = nn.ModuleList(
			[
				TransformerBlock(
					embed_size,
					heads,
					dropout=dropout,
					forward_expansion=forward_expansion,
					)
				for _ in range(num_layers)]
		)
		self.dropout = nn.Dropout(dropout)


	def forward(self, x, mask):
		N, seq_length = x.shape
		positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
		out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
		for layer in self.layers:
			out = layer(out, out, out, mask)

		return out


class DecoderBlock(nn.Module):
	def __init__(self, embed_size, heads, forward_expansion, dropout, device):
		super(DecoderBlock, self).__init__()
		self.attention = SelfAttention(embed_size, heads)
		self.norm = nn.LayerNorm(embed_size)
		self.transformer_block = TransformerBlock(
			embed_size, heads, dropout, forward_expansion
		)

		self.dropout = nn.Dropout(dropout)

	def forward(self, x, value, key, src_mask, trg_mask):
		attention = self.attention(x, x, x, trg_mask)
		query = self.dropout(self.norm(attention + x))
		out = self.transformer_block(value, key, query, src_mask)
		return out


class Decoder(nn.Module):
	def __init__(
			self,
			trg_vocab_size,
			embed_size,
			num_layers,
			heads,
			forward_expansion,
			dropout,
			device,
			max_length,
	):
		super(Decoder, self).__init__()
		self.device = device
		self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
		self.position_embedding = nn.Embedding(max_length, embed_size)
		self.layers = nn.ModuleList(
			[DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
			for _ in range(num_layers)]
			)
		self.fc_out = nn.Linear(embed_size, trg_vocab_size)
		self.dropout = nn.Dropout(dropout)

	def forward(self, x ,enc_out , src_mask, trg_mask):
		N, seq_length = x.shape
		positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
		x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

		for layer in self.layers:
			x = layer(x, enc_out, enc_out, src_mask, trg_mask)

		out =self.fc_out(x)
		return out


class Transformer(nn.Module):
	def __init__(
			self,
			src_vocab_size,
			trg_vocab_size,
			src_pad_idx,
			trg_pad_idx,
			embed_size = 256,
			num_layers = 6,
			forward_expansion = 4,
			heads = 8,
			dropout = 0,
			device="cuda",
			max_length=100
		):
		super(Transformer, self).__init__()
		self.encoder = Encoder(
			src_vocab_size,
			embed_size,
			num_layers,
			heads,
			device,
			forward_expansion,
			dropout,
			max_length
			)
		self.decoder = Decoder(
			trg_vocab_size,
			embed_size,
			num_layers,
			heads,
			forward_expansion,
			dropout,
			device,
			max_length
			)


		self.src_pad_idx = src_pad_idx
		self.trg_pad_idx = trg_pad_idx
		self.device = device


	def make_src_mask(self, src):
		src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
		# (N, 1, 1, src_len)
		return src_mask.to(self.device)

	def make_trg_mask(self, trg):
		N, trg_len = trg.shape
		trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
			N, 1, trg_len, trg_len
		)
		return trg_mask.to(self.device)

	def forward(self, src, trg):
		src_mask = self.make_src_mask(src)
		trg_mask = self.make_trg_mask(trg)
		enc_src = self.encoder(src, src_mask)
		out = self.decoder(trg, enc_src, src_mask, trg_mask)
		return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/781750.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Okhttp-LoggingInterceptor的简单使用

概述 Okhttp除了提供强大的get,post网络请求外,还包含请求日志的拦截器,可以监视,重写,重试调用请求。 简单使用 我们在构造OkHttpClient时,通过addInterceptor()方法添加我们需要的过滤器。 object OkhttpUtils{……

SpringBoot知识范围-学习步骤【JSB系列之000】

语言视频选择收录专辑链接C张雪峰推荐选择了计算机专业之后-在大学期间卷起来-【大学生活篇】JAVA黑马B站视频JAVA部分的知识范围、学习步骤详解JAVAWEB黑马B站视频JAVAWEB部分的知识范围、学习步骤详解SpringBootSpringBoot知识范围-学习步骤【JSB系列之000】微信小程序详细解…

【stable diffusion】保姆级入门课程04-Stable diffusion(SD)图生图-局部重绘的用法

目录 0.本章素材 1.什么是局部重绘 2.局部重绘和涂鸦有什么不同 3.操作界面讲解 3.1.蒙版模糊 3.2.蒙版模式 3.3.蒙版蒙住的内容 3.4.重绘区域 4.局部重绘的应用(面部修复) 5.课后训练 0.本章素材 chilloutmix模型(真人模型)百度地址&#xf…

数据结构—树状数组

树状数组 单点修改、区间查询区间修改、单点查询区间修改、区间查询 单点修改、区间查询 这里讲解树状数组的最基本操作单点修改、区间查询,当然能做到单点修改、区间查询,肯定就能做到单点修改、单点查询了。树状数组是用来快速求前缀和的,…

MGRE之OSPF实验

目录 题目: 步骤二:拓扑设计与地址规划​编辑 步骤三:IP地址配置 步骤四:缺省路由配置 步骤五:NAT的配置 步骤六:MGRE配置 中心站点R1配置 分支站点配置 中心站点R5 R1配置 分支站点配置 检测&…

UE 材质学习补充

Add Name Reroute Node ...(本地变量) 该节点可以整理节点,优化界面 Texture Texture(纹理图像),一般由RGB三个通道混合构成,RGB三个通道的值代表亮度,RGB三个通道分别都是0-1(0-255&#xff09…

征服FarmerJohn(二) Naptime【USACO05JAN】

题解目录 前言题目内容题目描述输入输出样例题目思路示例代码AC图片 后记往期精彩 前言 在上一期征服FarmerJohn(一)三角形【USACO2020FEB-B】结束之后,我们来看一道难度有所提升的DP问题,也就是常说的动态规划,今天我…

Please set the ROCKETMQ_HOME variable in your environment!

原因 启动ROCKETMQ执行命令start mqnamesrv.cmd时报错 翻译意思是请在您的环境中设置ROCKETMQ_HOME变量! 查看mqnamesrv.cmd可以看到如果"%ROCKETMQ_HOME%\bin\runserver.cmd"不存在会报此错误 配置上环境变量ROCKETMQ_HOME即可

《深入理解计算机系统》(美)布赖恩特(Bryant,R.E.) 等

适合对象:对计算机感兴趣的朋友。 需要相关资料的可私信我。 持续更新中: 第一章:计算机系统漫游 主要知识点:解读全书结构框架,解释OS的原理和相关硬件软件。计算机系统是由硬件和系统软件组成,共同协作…

kafka消费者api和分区分配和offset消费

kafka消费者 消费者的消费方式为主动从broker拉取消息,由于消费者的消费速度不同,由broker决定消息发送速度难以适应所有消费者的能力 拉取数据的问题在于,消费者可能会获得空数据 消费者组工作流程 Consumer Group(CG&#x…

如何在 SwiftUI 中使用 Touch ID 和 Face ID?

1. 需要通过指纹,面容认证后才能打开 App 2. 添加配置 需要向 Info.plist 文件中添加一个配置,向用户说明为什么要访问 添加 Privacy - Face ID Usage Description 并为其赋予值 $(PRODUCT_NAME) need Touch Id or Face ID permission for app lock 3. …

RTC在不同业务场景下的最佳音质实践

背景介绍 WebRTC是目前实时音视频领域最流行的开源框架。2010年Google收购GIPS引擎后,将其纳入Chrome体系且开源后, 命名为“WebRTC”。WebRTC获得各大浏览器厂商的支持并纳入W3C标准,促进了实时音视频在移动互联网应用中的 普及。2021年1月&…

算法练习——力扣随笔【LeetCode】【C++】

文章目录 LeetCode 练习随笔力扣上的题目和 OJ题目相比不同之处?定义问题排序问题统计问题其他 LeetCode 练习随笔 做题环境 C 中等题很值,收获挺多的 不会的题看题解,一道题卡1 h ,多来几道,时间上耗不起。 力扣上的题…

Pytorch个人学习记录总结 06

目录 神经网络-卷积层 torch.nn.Conv2d 神经网络-最大池化的使用 torch.nn.MaxPool2d 神经网络-卷积层 torch.nn.Conv2d torch.nn.Conv2d的官方文档地址 CLASS torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride1, padding0, dilation1, groups1, biasTrue,…

TMS FNC Dashboard Pack Crack

TMS FNC Dashboard Pack Crack TTMSFNCWidgetProgress:循环进度指示器 TTMSFNCWidget设定值:带范围和设定值的值指示器 TTMSFNCWidgetMultiProgress:多个值的基于同心圆的进度指示器 TTMSFNCWidgetDistributionIndicator:各种模式…

【Kubernetes部署篇】ingress-nginx高可用架构实施部署

文章目录 一、环境说明二、实施过程1、部署Ingress Controller2、安装并配置Nginx3、安装并配置Keepalived3、测试keepalived主备切换 三、创建Ingress规则,测试七层转发 一、环境说明 1、环境说明: IP地址主机名称备注16.32.15.201node-1K8S节点16.32…

AMS358i和施耐德TM241 EtherNet 通信

产品、配件及工具型号 设备名称 型号 数量 激光测距 AMS358i 1 直流电源24VDC 1 连接电缆 KD U-M12-5A-V1-050 1 交换机 1 施耐德PLC TM241 1 AMS358i通信网线 KSS ET-M12-4A-RJ45-A-P7-020 1 网线 双向水晶头 2 电气连接图及说明 点击桌面的Somachi…

【NLP】使用 Keras 保存和加载深度学习模型

一、说明 训练深度学习模型是一个耗时的过程。您可以在训练期间和训练后保存模型进度。因此,您可以从上次中断的地方继续训练模型,并克服漫长的训练挑战。 在这篇博文中,我们将介绍如何保存模型并使用 Keras 逐步加载它。我们还将探索模型检查…

虹科活动 | 虹科ADAS自动驾驶研讨会

​​虹科ADAS/自动驾驶研讨会将于8月7日在上海闵行展开——加快ADAS/AD开发步伐! 期待您的参与!

Day45: 300.最长递增子序列,674. 最长连续递增序列,718. 最长重复子数组

目录 300.最长递增子序列 思路 674. 最长连续递增序列 思路 718. 最长重复子数组 思路 300.最长递增子序列 300. 最长递增子序列 - 力扣(LeetCode) 思路 1. 确定dp数组及其下标含义 dp[i]表示i之前包括i的以nums[i]结尾的最长递增子序列…