Transformer学习: Transformer小模块学习--位置编码,多头自注意力,掩码矩阵

news2025/1/16 14:09:46

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

Transformer学习

  • 1 位置编码模块
    • 1.1 PE代码
    • 1.2 测试PE
    • 1.3 原文代码
  • 2 多头自注意力模块
    • 2.1 多头自注意力代码
    • 2.2 测试多头注意力
  • 3 未来序列掩码矩阵
    • 3.1 代码
    • 3.2 测试掩码

在这里插入图片描述

1 位置编码模块

P E ( p o s , 2 i ) = sin ⁡ ( p o s / 1000 0 2 i / d m o d e l ) PE(pos,2i)=\sin(pos/10000^{2i/d_{\mathrm{model}}}) PE(pos,2i)=sin(pos/100002i/dmodel)

P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s / 1000 0 2 i / d m o d e l ) PE(pos,2i+1)=\cos(pos/10000^{2i/d_\mathrm{model}}) PE(pos,2i+1)=cos(pos/100002i/dmodel)

pos 是序列中每个对象的索引, p o s ∈ [ 0 , m a x s e q l e n ] pos\in [0,max_seq_len] pos[0,maxseqlen], i i i 向量维度序号, i ∈ [ 0 , e m b e d d i m / 2 ] i\in [0,embed_dim/2] i[0,embeddim/2], d m o d e l d_{model} dmodel是模型的embedding维度

1.1 PE代码

import numpy as np
import matplotlib.pyplot as plt
import math
import torch
import seaborn as sns

def get_pos_ecoding(max_seq_len,embed_dim):
	# 初始化位置矩阵 [max_seq_len,embed_dim]
	pe = torch.zeros(max_seq_len,embed_dim])
	position = torch.arange(0,max_seq_len).unsqueeze(1) # [max_seq_len,1]
	print("位置:", position,position.shape)
	div_term = torch.exp(torch.arange(0,embed_dim,2)*-(math.log(10000.0)/embed_dim))   # 除项维度为embed_dim的一半,因为对矩阵分奇数和偶数位置进行填充。
	pe[:,0::2] = torch.sin(position/div_term)
	pe[:,1::2] = torch.cos(position/div_term)
	return pe

1.2 测试PE

pe = get_pos_ecoding(8,4)
plt.figure(figsize=(8,8))
sns.heatmap(pe)
plt.title("Sinusoidal Function")
plt.xlabel("hidden dimension")
plt.ylabel("sequence length")

输出:
位置: tensor([[0],
[1],
[2],
[3],
[4],
[5],
[6],
[7]]) torch.Size([8, 1])
除项: tensor([1.0000, 0.0100]) torch.Size([2])
在这里插入图片描述

plt.figure(figsize=(8, 5))
plt.plot(positional_encoding[1:, 1], label="dimension 1")
plt.plot(positional_encoding[1:, 2], label="dimension 2")
plt.plot(positional_encoding[1:, 3], label="dimension 3")
plt.legend()
plt.xlabel("Sequence length")
plt.ylabel("Period of Positional Encoding")

在这里插入图片描述

1.3 原文代码

class PositionalEncoding(nn.Module):
    "Implement the PE function."
    def __init__(self, d_model, dropout, max_len=5000):
    	# max_len 序列最大长度,自定义的,不是真正的最大长度
    	# d_model 模型嵌入维度
        super(PositionalEncoding, self).__init__()
        # 实例化dropout层
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings once in log space.
        # 初始化一个位置编码矩阵, shape: (max_len, d_model)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        #二维张量扩充为三维张量 shape: (1,max_len, d_model)
        pe = pe.unsqueeze(0)
        
        # 将位置编码注册为模型的buffer,注册为buffer之后,不会进行更新
        # 注册为buffer后可以再模型保存后重新加载时候将这个位置编码器和模型参数加载进来
        self.register_buffer('pe', pe)  # 注册的名字pe,变量也是pe
        
    def forward(self, x):
    	# x 序列的嵌入表示
    	# pe是按max_len进行注册的,太长了,将第二个维度(max_len对应的维度)缩小为真正的序列x的最大长度
    	# Variable(self.pe[:, :x.size(1)], requires_grad=False) 即位置编码
        x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)
        return self.dropout(x)

2 多头自注意力模块

2.1 多头自注意力代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

# 复制网络,即使用几层网络就改变N的数量
# 如 4层线性层  clones(nn.Linear(model_dim,model_dim),4)
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

# 计算注意力
def attention(q, k, v, mask=None, dropout=None):
	# q,k,v [bs,-1,head,embed_dim//head], q.size(-1) = embed_dim//head
    d_k = q.size(-1)
    # (head,embed_dim//head)*(embed_dim//head,head)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
    	# 如果使用mask, 0的位置用-1e9填充
        scores = scores.masked_fill(mask == 0, -1e9)
    # 对scores的最后一个维度进行softmax
    p_attn = F.softmax(scores, dim = -1)
	# 判断是否需要进行dorpout处理
    if dropout is not None:
        p_attn = dropout(p_attn)
    # 返回添加注意力后的结果,注意力系数
    return torch.matmul(p_attn, v), p_attn

# 计算多头注意力
class Multi_Head_Self_Att(nn.Module):
    def __init__(self,head,model_dim,dropout=0.1):
        super(Multi_Head_Self_Att,self).__init__()+
        # 判断嵌入维度能否被head整除,不能整除抛出异常
        assert model_dim % head == 0
        self.d_k = model_dim//head
        self.head = head
        self.linears = clones(nn.Linear(model_dim,model_dim),4)
        self.att = None
        self.dropout = nn.Dropout(p=dropout)
    def forward(self,q,k,v,mask=None):
        if mask is not None:
        	# 掩码非空,扩充维度,代表多头中的第n个头
            mask = mask.unsqueeze(1)
        nbatches = q.size(0)
        
        # zip函数 将线性层与q,k,v分别对应(self.linears,q),(self.linears,k),(self.linears,v)
        # q,k,v [bs,-1,head,embed_dim/head]
        # 使用线性层l处理x,把处理后的x形状view成(nbatches,-1,int(self.head),int(self.d_k)).transpose(1,2)) 第二个维度自适应维度大小
        # transpose(1,2) 让代表序列长度的维度与词向量的维度相邻
        q,k,v = [l(x).view(nbatches,-1,int(self.head),int(self.d_k)).transpose(1,2) for l,x in zip(self.linears,(q,k,v))] 
        
        # 返回计算注意力之后的值作为x和注意力分数
        x, self.attn = attention(q, k, v, mask=mask, dropout=self.dropout)
        # 通过多头注意力计算后,得到了每个头计算结果组成的4维张量,需要将其转换成与输入一样的维度
        # 将第2个维度和第三个维度换回来,维度组成(nbatches,-1,model_dim)
        # contiguous()使得转置之后的张量能够运用view方法
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, int(self.head * self.d_k))
        # 之前建立了四个线性层,前面q,k,v用了三个线性层,最后一个现成层对注意力结果进行一次线性变换
        return self.linears[-1](x),self.attn  

2.2 测试多头注意力

# 模型参数
head = 4
model_dim = 128
seq_len = 10
dropout = 0.1

# 生成示例输入
q = torch.randn(seq_len, model_dim)
k = torch.randn(seq_len, model_dim)
v = torch.randn(seq_len, model_dim)

# 创建多头自注意力模块
att = Multi_Head_Self_Att(head, model_dim, dropout=dropout)

# 运行模块
output,att = att(q, k, v)

# 输出形状
print("Output shape:", output.shape)
print(att.shape())
sns.heatmap(att.squeeze().detach().cpu())

输出
Output shape: torch.Size([10, 1, 128])
torch.Size([10, 4, 1, 1])
在这里插入图片描述

3 未来序列掩码矩阵

作用:

  • 解码器中的掩码是防止泄露未来要预测的部分,掩码矩阵是一个除对角线的上三角矩阵
  • 序列填充部分的掩码是判断哪些部位是填充的部位,填充的部位在计算注意力时保证期注意力分数为0【以添加一个负无穷的小数,使得其softmax值为0】

3.1 代码

def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    print("掩码矩阵:",subsequent_mask)
    return torch.from_numpy(subsequent_mask) == 0

测试掩码

plt.figure(figsize=(5,5))
print(subsequent_mask(8),subsequent_mask(8).shape)
plt.imshow(subsequent_mask(8)[0])

掩码矩阵:
[[[0 1 1 1 1 1 1 1]
[0 0 1 1 1 1 1 1]
[0 0 0 1 1 1 1 1]
[0 0 0 0 1 1 1 1]
[0 0 0 0 0 1 1 1]
[0 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 1]
[0 0 0 0 0 0 0 0]]]

tensor([[[ True, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[ True, True, True, True, True, False, False, False],
[ True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True]]]) torch.Size([1, 8, 8])
在这里插入图片描述
紫色部分为添加掩码的部分

3.2 测试掩码

import torch
from torch.autograd import Variable


# 函数接受两个参数 tgt 和 pad,其中 tgt 是目标序列的张量,pad 是表示填充的值
def make_std_mask(tgt, pad):
    "Create a mask to hide padding and future words."
    # 首先,创建一个掩码 tgt_mask,其形状与 tgt 的形状相同,用于指示哪些位置不是填充位置。
    # 这是通过将 tgt 张量中不等于 pad 的位置设置为 True(1),其余位置设置为 False(0)来实现的。
    # unsqueeze(-2) 的作用是在倒数第二个维度上添加一个维度,以便后续的逻辑运算。
    tgt_mask = (tgt != pad).unsqueeze(-2)
    # 调用 subsequent_mask 函数,生成一个用于遮挡未来词的掩码。这个掩码是一个上三角矩阵,
    # 其对角线及其以下的元素为 True(1),其余元素为 False(0)。
    
    # tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)):
    # 将 tgt_mask 和生成的未来词掩码进行逻辑与操作,将未来词位置的掩码设置为 False,即遮挡掉未来词。
    tgt_mask = tgt_mask & Variable(
        subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
    return tgt_mask

def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(*attn_shape), diagonal=1)
    return subsequent_mask == 0

# 示例数据
tgt = torch.tensor([[1, 2, 3, 0, 0], [4, 5, 0, 0, 0], [6, 7, 8, 9, 10]])  # 目标序列,假设填充值为 0
pad = 0  # 填充值

# 创建掩码
tgt_mask = make_std_mask(tgt, pad)

# 打印结果
print("目标序列:")
print(tgt)
print("\n生成的掩码:")
print(tgt_mask)

输出:

目标序列:
tensor([[ 1, 2, 3, 0, 0],
[ 4, 5, 0, 0, 0],
[ 6, 7, 8, 9, 10]])

生成的掩码:
tensor([[[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, False, False],
[ True, True, True, False, False]],
[[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, False, False, False],
[ True, True, False, False, False],
[ True, True, False, False, False]],
[[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, True, False],
[ True, True, True, True, True]]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1569502.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

绿联 安装halo博客,使用MySQL数据库

绿联 安装halo博客,使用MySQL数据库 1、镜像 halohub/halo:2 halo2系列已支持halohub/halo:2直接拉取最新版本,因此更新容器可以使用相同镜像重新编辑的方式进行升级。 安装前准备: 绿联 安装Mysql 5.7版本数据库 绿联 安装phpmyadmin管理M…

基本线段树以及相关例题

1.线段树的概念 线段树是一种二叉树,也就是对于一个线段,我们会用一个二叉树来表示。 这个其实就是一个线段树,我们会将其每次从中间分开,其左孩子就是左边的集合的和,其右孩子就是右边集合的和; 我们可以…

VSCode常用修改默认设置(settings.json)

❓ 问题1 我现在在vscode中鼠标选中某个单词,相同的单词都会自动出现一个高亮背景色,我需要怎么关闭这个功能呢? ⚠️ 注意 selectionHighlight 这个是鼠标双击后的高亮匹配,可以保留默认开启的配置,不用去改它。 …

力扣每日一题:LCR112--矩阵中的最长递增路径

题目 给定一个 m x n 整数矩阵 matrix ,找出其中 最长递增路径 的长度。 对于每个单元格,你可以往上,下,左,右四个方向移动。 不能 在 对角线 方向上移动或移动到 边界外(即不允许环绕)。 示例…

前端入门系列-HTML-HTML结构

&#x1f308;个人主页&#xff1a;羽晨同学 &#x1f4ab;个人格言:“成为自己未来的主人~” HTML基础 HTML结构 认识HTML标签 HTML代码是由“标签”构成的。 形如 <body>hello </body> 标签名&#xff08;body&#xff09;放到<>中大部分标签成对…

见证历史:Quantinuum与微软取得突破性进展,开启了可靠量子计算的新时代!

Quantinuum与微软的合作取得了重大突破&#xff0c;将可靠量子计算带入了新的时代。他们结合了Quantinuum的System Model H2量子计算机和微软创新的量子比特虚拟化系统&#xff0c;在逻辑量子比特领域取得了800倍于物理电路错误率的突破。这一创新不仅影响深远&#xff0c;加速…

C#仿OutLook的特色窗体设计

目录 1. 资源图片准备 2. 设计流程&#xff1a; &#xff08;1&#xff09;用MenuStrip控件设计菜单栏 &#xff08;2&#xff09;用ToolStrip控件设计工具栏 &#xff08;3&#xff09;用StatusStrip控件设计状态栏 &#xff08;4&#xff09;ImageList组件装载树节点图…

【ORB-SLAM3】Ubuntu20.04 使用 RealSense D435i 运行 ORB-SLAM3 时遇到的一些 Bug

【ORB-SLAM3】使用 RealSense D435i 跑 ORB-SLAM3 时遇到的一些 Bug 1 hwmon command 0x80( 5 0 0 0 ) failed (response -7 HW not ready)2 No rule to make target /opt/ros/noetic/lib/x86_64-linux-gnu/librealsense2.so, needed by ../lib/libORB_SLAM3.so 1 hwmon comman…

微信小程序 python+django口腔牙科问诊系统 springboot设计与实现_1171u

口腔助手”小程序主要有管理员&#xff0c;医生和用户三个功能模块。以下将对这三个功能的作用进行详细的剖析。 本文通过采用B/S架构&#xff0c;uniapp框架、MySQL数据库&#xff0c;结合国内“口腔助手”管理现状&#xff0c;开发了一个基于微信小程序的“口腔助手”小程序。…

基于ArrayList实现简单洗牌

前言 在之前的那篇文章中&#xff0c;我们已经认识了顺序表—>http://t.csdnimg.cn/2I3fE 基于此&#xff0c;便好理解ArrayList和后面的洗牌游戏了。 什么是ArrayList? ArrayList底层是一段连续的空间&#xff0c;并且可以动态扩容&#xff0c;是一个动态类型的顺序表&…

二、企业级架构之Nginx

一、Nginx的重装与升级 1、为什么需要重装与升级&#xff1a; 在实际业务场景中&#xff0c;需要使用软件新版本的功能、特性&#xff0c;就需要对原有软件进行升级或者重装操作。 Nginx&#xff1a;1.12版本 → 1.16版本 2、Nginx重装&#xff1a; 第一步&#xff1a;停止…

linux操作系统的进程状态

这个博客只是为了自己复习用的&#xff01;&#xff01;&#xff01; 冯诺依曼体系结构 计算机是由一个一个硬件组成的 输入设备&#xff1a;键盘&#xff0c;鼠标&#xff0c;扫描仪&#xff0c;写板等等 中央处理器&#xff08;CPU&#xff09;:含有运算器和控制器等 输出单…

Vue-05

v-model 应用于其他表单元素 常见的表单元素都可以用v-model绑定关联 → 快速获取或设置表单元素的值 它会根据控件类型自动选取正确的方法来更新元素 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name…

树莓派部署yolov5实现目标检测(ubuntu22.04.3)

最近两天搞了一下树莓派部署yolov5&#xff0c;有点难搞&#xff08;这个东西有点老&#xff0c;版本冲突有些包废弃了等等&#xff09; 最后换到ubuntu系统弄了&#xff0c;下面是我的整体步骤&#xff08;建议先使能一下ssh&#xff08;最下面有&#xff09;&#xff0c;结合…

Macbook文件清理软件 Mac电脑清理垃圾文件怎么清理

为了维护Macbook电脑的系统健康&#xff0c;我们需要定期给电脑进行全面清理&#xff0c;清除系统垃圾文件、软件缓存和系统内存。那么好用的Macbook文件清理软件有哪些呢&#xff1f;今天就给大家介绍几款好用的电脑清理软件并介绍Mac电脑清理垃圾文件怎么清理。 一、Macbook…

RPA自动化小红书自动化写文以及发文!

1、视频演示 RPA自动化小红书自动写作发文 2、核心功能点 采集笔记&#xff1a;采集小红书上点赞量大于1000的爆款笔记 下载素材&#xff1a;下载爆款笔记的主图 爆款改写&#xff1a;根据爆款笔记的标题仿写新的标题以及新的文案 自动发布&#xff1a;将爆款笔记发布到小红…

Django之REST Client插件

一、接口测试工具介绍 在开发前后端分离项目时,无论是开发后端,还是前端,基本都是需要测试API接口的内容,而目前我们需要开发遵循RESTFul规范的项目,也是必然的(自己不开发前端页面)。 在网上有很多这样的工具,常用的postman,但还是需要下载安装。在这我们介绍一个VSCod…

【GEE实践应用】GEE下载遥感数据以及下载后在ArcGIS中的常见显示问题处理(以下载哨兵2号数据为例)

本期内容我们使用GEE进行遥感数据的下载&#xff0c;使用的相关代码如下所示&#xff0c;其中table是我们提前导入的下载遥感数据的研究区域的矢量边界数据。 var district table;var dsize district.size(); print(dsize);var district_geometry district.geometry();Map.…

51入门之LED

目录 1.配置文件 2.点亮一个LED 2.1单个端口操作点亮单个LED 2.2整体操作点亮LED 3.LED闪烁 4.LED实现流水灯 4.1使用for循环和移位实现 4.1.1移位操作符 4.1.2使用移位操作和for循环实现 4.2使用移位函数实现LED流水灯 众所周知&#xff0c;任何一个硬件工程师…

go juc 线程中的子类

1.go test() 主死随从 package mainimport ("fmt""strconv""time" )func test() {for i : 1; i < 10; i {fmt.Println("hello " strconv.Itoa(i))//阻塞time.Sleep(time.Second)} } func main() {//开启协程go test()for i : 1; …