LLM - Make Causal Mask 构造因果关系掩码

news2024/11/19 17:26:23

目录

一.引言

二.make_causal_mask

1.完整代码

2.Torch.full

3.torch.view

4.torch.masked_fill_

5.past_key_values_length

6.Test Main

三.总结


一.引言

Causal Mask 主要用于限定模型的可视范围,防止模型看到未来的数据。在具体应用中,Causal Mask 可将所有未来的 token 设置为零,从注意力机制中屏蔽掉这些令牌,使得模型在进行预测时只能关注过去和当前的 token,并确保模型仅基于每个时间步骤可用的信息进行预测。

在 Transformer 模型中,Multihead Attention 中的 Causal Mask 就是用于解决这个问题,以实现模型对于输入序列的正确处理。下面是 Causal 的可视化示例,在实践中其呈现倒三角形状:

全文为 'I love eating lunch.' ,对于 'love' 而言其只能看到 'I',不能看到未来的 'eating'、'lunch'。

二.make_causal_mask

为了方便后续示例的展示,这里选择较小的参数,batch_size = 2,target_length = 4。

1.完整代码

# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
    input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
    """
    Make causal mask used for bi-directional self-attention.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

代码的 Input 主要就是一个二维的 input_ids_shape,分别为 batch_size 和 target_length,dtype 和 device 在这里比较好理解,还有就是最后的 past_key_values_length,用于补齐,这个也比较简单。Output 则是 (batch_size, 1, target_length, target_length) 的 Causal Mask,其中 Msak 的矩阵 target_length x target_length 就是上面所示的倒三角形状。

2.Torch.full

函数介绍

该函数用于创建一个具有指定填充值的新张量。该函数的语法如下:

torch.full(size, fill_value, *, dtype=None, device=None, requires_grad=False)

参数说明:

  • size:张量的形状,可以是一个整数或者一个元组,例如:(3, 3) 或 3。
  • fill_value:张量的填充值。
  • dtype:张量的数据类型,默认为None,即根据输入的数据类型推断。
  • device:张量所在的设备,默认为None,即根据输入的设备推断。
  • requires_grad:是否需要计算梯度,默认为False。

该函数返回一个与指定形状相同且所有元素都被设置为指定填充值的新张量。

函数使用

mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))

这里 torch.finfo(dtype).min 为对应 torch.dtype 类型的最小值,以 bfloat16 为例:

print(torch.tensor(torch.finfo(dtype).min))
=> tensor(-3.3895e+38)

而这一步 mask 的操作就是生成一个 tgt_len x tgt_len 的充满 min 元素的方阵:

3.torch.view

函数介绍

在 PyTorch 库中,view 函数用于改变一个张量(Tensor)的形状(shape)。它返回一个新的张量,其元素与原始张量相同,但形状(shape)已被改变。view 函数的行为非常类似于 NumPy的 reshape 函数。它会返回一个与原始张量共享数据但具有不同形状的新的张量。如果给定的形状与原始张量的元素总数不匹配,则会引发错误。

import torch  
  
x = torch.randn(4, 5)  # 创建一个4x5的随机张量  
y = x.view(20)  # 改变形状为20的一维张量  
z = x.view(-1, 10)  # 改变形状为10的一维张量,第一维度由其他维度决定

函数使用

mask_cond = torch.arange(mask.size(-1))

mask_cond 是一个1维向量:

(mask_cond + 1).view(mask.size(-1), 1)

这一步相当于在 mask_cond 基础上先加常量再 reshape:

4.torch.masked_fill_

函数介绍

在 PyTorch 库中,masked_fill_() 函数是一个张量(Tensor)方法,用于将张量中的指定区域填充为特定值。此函数需要一个掩码(mask)作为输入,该掩码应与原张量具有相同的形状。掩码中的 True 值表示需要填充的区域,False 值表示需要保留的原始值。

torch.Tensor.masked_fill_(mask, value)

参数说明: 

  • mask (Bool tensor) - 掩码张量,用于指定需要填充的区域。
  • value (float) - 填充的值。

示例

假设我们有一个 3x3 的张量,我们想要将所有大于 5 的元素替换为 -1。我们可以使用该函数来实现这个目标。

import torch  
  
# 创建一个3x3的张量  
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  
  
# 创建一个掩码,其中大于5的元素为True,其余为False  
mask = x > 5  
  
# 使用masked_fill_函数将大于5的元素替换为-1  
x.masked_fill_(mask, -1)  
  
print(x)

输出:

tensor([[ 1,  2,  3],  
        [ 4,  5,  6],  
        [-1, -1, -1]])

函数使用

mask_cond < (mask_cond + 1).view(mask.size(-1), 1)

传入函数的 mask 如下,呈倒三角形态,其中 True 的部分填充新值,False 部分保持不变: 

mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)

 根据 mask 对 target x target 的方阵进行填充 0 得到我们上面提到的倒三角:

5.past_key_values_length

在 PyTorch 中,past_key_values_length 是一个参数,用于指定在使用 Transformer 模型时,过去键值缓存(past key-value cache)的长度。该参数通常与 Transformer 模型中的自注意力机制(self-attention mechanism)一起使用。在过去键值缓存中,模型保存了过去的键和值向量,以便在生成序列时重复使用它们。这些过去的键和值向量可以用于计算自注意力分数,从而提高生成序列的效率。较大的past_key_values_length可以增加模型的表现力,但也会增加计算量和内存消耗。因此,需要根据具体任务和资源限制来选择合适的值。

if past_key_values_length > 0:
    mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)

这里定义 past_key_values_length = 1,代码逻辑就是在原有的 tgt x tgt 方阵前补 past_key_values_length 个 0:

6.Test Main

mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
if __name__ == '__main__':
    batch_size = 2
    target_length = 4
    input_shape = (batch_size, target_length)
    data_type = torch.bfloat16

    causal_mask = _make_causal_mask(input_shape, data_type, 1)
    print(causal_mask)
    print(causal_mask.shape)

=>        pask_key_length = 1填充后,tgt x tgt 变为 tgt x (1 + tgt)

=>        通过 None + expand 的组合,tgt x (1 + tgt) 变为 bsz x 1 x tgt x (1 + tgt)

三.总结

新系列博文一方面是阅读 HF 上 LLM 模型实现的源码,了解对应知识的实现过程。另一方面是之前很多同学主要接触 TF 1.x、TF 2.x 以及 Estimator 和 Keras 这一类深度学习工具,趁此机会也能熟悉 Torch 的使用方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1039363.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

中国核动力研究设计院使用 DolphinDB 替换 MySQL 实时监控仪表

随着仪表测点的大幅增多和采样频率的增加&#xff0c;中国核动力研究设计院仪控团队原本基于 MySQL 搭建的旧系统已经无法满足大量数据并发写入、实时查询和聚合计算的需求。他们在研究 DB-Engines 时序数据库榜单时了解到国内排名第一的 DolphinDB。经过测试&#xff0c;发现其…

瞄准核心因素:Boruta特征选择算法助力精准决策

一、引言 特征选择在机器学习和数据挖掘中扮演着重要的角色。通过选择最相关和最有信息量的特征&#xff0c;可以降低维度&#xff0c;减少数据复杂性&#xff0c;并提高模型的预测性能和解释能力。在实际应用中&#xff0c;特征选择有助于识别最具影响力的因素&#xff0c;提供…

【内网穿透】在Ubuntu搭建Web小游戏网站,并将其发布到公网访问

目录 前言 1. 本地环境服务搭建 2. 局域网测试访问 3. 内网穿透 3.1 ubuntu本地安装cpolar 3.2 创建隧道 3.3 测试公网访问 4. 配置固定二级子域名 4.1 保留一个二级子域名 4.2 配置二级子域名 4.3 测试访问公网固定二级子域名 前言 网&#xff1a;我们通常说的是互…

复亚智能落地江苏化工,安防巡逻无人机守牢“安全线”

化工业是国民经济的重要组成部分&#xff0c;但其生产环境和条件充满了挑战。大部分化学反应发生在高温、高压且有毒的环境中&#xff0c;而近70%的原料、中间体和终产品都带有易燃、易爆、有毒、有害以及腐蚀性的特性。在这样的情境下&#xff0c;安全生产不仅仅是一项日常任务…

leetCode 968.监控二叉树(利用状态转移+贪心)

968. 监控二叉树 - 力扣&#xff08;LeetCode&#xff09; 给定一个二叉树&#xff0c;我们在树的节点上安装摄像头。节点上的每个摄影头都可以监视其父对象、自身及其直接子对象。计算监控树的所有节点所需的最小摄像头数量。 >>解题思路: 重要线索->题目示例中的摄…

AP5126 降压恒流 PIN对PIN替LN2566 LED汽车大灯车灯驱动芯片

产品描述 AP5126 是一款 PWM 工作模式,高效率、外 围简单、内置功率管&#xff0c;适用于 12-80V 输入的高 精度降压 LED 恒流驱动芯片。输出最大功率可达 15W&#xff0c;最大电流 1.2A。 AP5126 可实现全亮/半亮功能切换&#xff0c;通过 MODE 切换&#xff1a;全亮/半…

YZ09: VBA_Excel之读心术

【分享成果&#xff0c;随喜正能量】多要求自己&#xff0c;你会更加独立&#xff0c;少要求别人&#xff0c;你会减少失望&#xff0c;宁愿花时间去修炼 不完美的自己&#xff0c;也不要浪费时间去期待完美的别人&#xff01;。 我给VBA下的定义&#xff1a;VBA是个人小型自动…

SolidJs节点级响应性

前言 随着组件化、响应式、虚拟DOM等技术思想引领着前端开发的潮流&#xff0c;相关的技术框架大行其道&#xff0c;就以目前主流的Vue、React框架来说&#xff0c;它们都基于组件化、响应式、虚拟DOM等技术思想的实现&#xff0c;但是具有不同开发使用方式以及实现原理&#…

Elasticsearch:与多个 PDF 聊天 | LangChain Python 应用教程(免费 LLMs 和嵌入)

在本博客中&#xff0c;你将学习创建一个 LangChain 应用程序&#xff0c;以使用 ChatGPT API 和 Huggingface 语言模型与多个 PDF 文件聊天。 如上所示&#xff0c;我们在最最左边摄入 PDF 文件&#xff0c;并它们连成一起&#xff0c;并分为不同的 chunks。我们可以通过使用 …

在线商城项目EShop【ListView、adapter】

要求如下&#xff1a; 1、创建在线商城项目EShop&#xff1b; 2、修改布局文件activity_main.xml&#xff0c;使用LineaLayout和ListView创建商品列表UI&#xff1b; 3、创建列表项布局list_item.xml&#xff0c;设计UI用于显示商品图标、名称和价格信息&#xff1b; 4、创…

IT监控制度,IT监控体系如何分层

IT监控系统是指监控和管理it服务管理的软件。它涵盖了监控管理、服务台管理、问题管理和变更管理&#xff0c;旨在帮助组织更有效的运营信息系统&#xff0c;提高运营事故的响应速度。  通过建立集中监控平台&#xff0c;IT监控系统与信息系统的融合可以完成统一的展示和管理…

ICMP差错包

ICMP报文分类 Type Code 描述 查询/差错 0-Echo响应 0 Echo响应报文 查询 3-目的不可达 0 目标网络不可达报文 差错 1 目标主机不可达报文 差错 2 目标协议不可达报文 差错 3 目标端口不可达报文 差错 4 要求分段并设置DF flag标志报文 差错 5 源路由…

【从0学习Solidity】 50. 多签钱包

【从0学习Solidity】50. 多签钱包 博主简介&#xff1a;不写代码没饭吃&#xff0c;一名全栈领域的创作者&#xff0c;专注于研究互联网产品的解决方案和技术。熟悉云原生、微服务架构&#xff0c;分享一些项目实战经验以及前沿技术的见解。关注我们的主页&#xff0c;探索全栈…

Mac磁盘空间满了怎么办?Mac如何清理磁盘空间

你是不是发现你的Mac电脑存储越来越满&#xff0c;甚至操作系统本身就占了100多G的空间&#xff1f;这不仅影响了电脑的性能&#xff0c;而且也让你无法存储更多的重要文件和软件。别担心&#xff0c;今天这篇文章将告诉你如何清除多余的文件&#xff0c;让你的Mac重获新生。 一…

【kafka实战】01 3分钟在Linux上安装kafka

本节采用docker安装Kafka。采用的是bitnami的镜像。Bitnami是一个提供各种流行应用的Docker镜像和软件包的公司。采用docker的方式3分钟就可以把我们想安装的程序运行起来&#xff0c;不得不说真的很方便啊&#xff0c;好了&#xff0c;开搞。使用前提&#xff1a;Linux虚拟机&…

找不到msvcp110dll,无法继续执行代码,msvcp110dll丢失是什么意思

MSVCP110.dll是一个动态链接库文件&#xff0c;它是Microsoft Visual C 2012 Redistributable package的一部分。这个文件通常用于支持许多Microsoft Visual Studio 2012开发的应用程序。当您在运行某些程序时遇到“找不到msvcp110.dll”的错误时&#xff0c;这意味着您的计算机…

PHY6230低成本遥控灯控芯片国产蓝牙BLE5.2 2.4G SoC

高性价比的低功耗高性能蓝牙5.2系统级芯片&#xff0c;适用多种PC/手机外设连接场景。 高性能多模射频收发机&#xff1a; 通过硬件模块的充分复用实现高性能多模数字收发机。发射机&#xff0c;最大发射功率10dBm&#xff1b;BLE 1Mbps速率接收机灵敏度达到-96dBm&#xff1…

查准率(precision,也叫精确率)和查全率(recall,也叫召回率)

精确率和召回率是广泛用于信息检索和统计学分类领域的两个度量值&#xff0c;用来评价结果的质量。其中精确率是检索出相关文档数与检索出的文档总数的比率&#xff0c;衡量的是检索系统的查准率&#xff1b;召回率是指检索出的相关文档数和文档库中所有的相关文档数的比率&…

【算法小课堂】滑动窗口

滑动窗口 基本概念&#xff1a; 滑动窗口本质是双指针算法的一种演变 本质上就是同向双指针&#xff0c;窗口的范围就是[left,right&#xff09; 滑动窗口大致可以分为两类 窗口大小不变的窗口大小变化的 滑动窗口遇到一些验证重复性的问题的时候可以用哈希表来优化 核心思想…

​旅行季《乡村振兴战略下传统村落文化旅游设计》许少辉八一书作想象和世界一样宽广

​旅行季《乡村振兴战略下传统村落文化旅游设计》许少辉八一书作想象和世界一样宽广