GPT - 因果掩码(Causal Mask)

news2025/4/14 12:59:55

本节代码定义了一个函数 causal_mask,用于生成因果掩码(Causal Mask)。因果掩码通常用于自注意力机制中,以确保模型在解码时只能看到当前及之前的位置,而不能看到未来的信息。这种掩码在自然语言处理任务(如语言生成)中非常重要,因为它模拟了人类阅读或写作时的顺序性。

一、因果掩码(Causal Mask)代码实现

def causal_mask(x):
    mask = torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1) == 0
    return mask
1. 输入参数
  • x:输入张量,通常是一个序列,形状为 (seq_len, d_model)(batch_size, seq_len, d_model)。这里的 seq_len 是序列的长度。

2. 生成掩码
mask = torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1) == 0
  • torch.ones(x.shape[0], x.shape[0]):生成一个形状为 (seq_len, seq_len) 的全1矩阵。

  • torch.triu(..., diagonal=1):取该矩阵的上三角部分(包括对角线),其余部分设置为0。diagonal=1 表示从对角线的下一个位置开始取上三角部分。

  • == 0:将上三角部分(包括对角线)的值设置为 False,其余部分设置为 True。这样生成的掩码矩阵中,True 表示需要保留的注意力位置,False 表示需要被忽略的注意力位置。

3. 返回值
  • mask:生成的因果掩码,形状为 (seq_len, seq_len),是一个布尔张量。

示例

假设输入张量 x 的形状为 (5, d_model),即序列长度为5。那么:

x = torch.randn(5, d_model)  # 示例输入
mask = causal_mask(x)
print(mask)

输出的掩码矩阵 mask 将是:

tensor([[ True, False, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True]])

作用

在自注意力机制中,因果掩码用于确保模型在计算注意力分数时,只能看到当前及之前的位置,而不能看到未来的信息。具体来说:

  • True:表示可以计算注意力分数。

  • False:表示需要被忽略,注意力分数会被设置为一个非常小的值(如 -1e9),从而在 softmax 归一化后,其权重趋近于0。

二、因果掩码如何使用?

1. 因果掩码的生成

因果掩码的生成函数如下:

def causal_mask(x):
    mask = torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1) == 0
    return mask
  • 输入x 是一个张量,通常是一个序列的嵌入表示,形状为 (seq_len, d_model)(batch_size, seq_len, d_model)

  • 输出:生成一个布尔张量 mask,形状为 (seq_len, seq_len),其中上三角部分(包括对角线)为 True,其余部分为 False

2. 因果掩码的应用

因果掩码在 Poetry 数据集类中被应用,具体如下:

class Poetry(Dataset):
    def __init__(self, poetries, tokenizer: Tokenizer):
        self.poetries = poetries
        self.tokenizer = tokenizer

        self.pad_id = self.tokenizer.vocab["[PAD]"]
        self.bos_id = self.tokenizer.vocab["[BOS]"]
        self.eos_id = self.tokenizer.vocab["[EOS]"]

    def __len__(self):
        return len(self.poetries)
    
    def __getitem__(self, idx):
        poetry = self.poetries[idx]
        poetry_ids = self.tokenizer.encode(poetry)
        input_ids = torch.tensor([self.bos_id] + poetry_ids)
        input_msk = causal_mask(input_ids)
        label_ids = torch.tensor(poetry_ids + [self.eos_id])
        return {
            "input_ids": input_ids,
            "input_msk": input_msk,
            "label_ids": label_ids
        }
  • __getitem__ 方法

    • 对于每首诗 poetry,将其编码为 poetry_ids

    • 在输入序列的开头添加 [BOS](开始标记符),生成 input_ids

    • 使用 causal_mask 函数生成因果掩码 input_msk

    • 在标签序列的末尾添加 [EOS](结束标记符),生成 label_ids

3. 因果掩码的传递

在训练过程中,因果掩码 input_msk 会被传递给模型的自注意力层。具体如下:

for epoch in range(epochs):
    for batch in tqdm(trainloader, desc="Training"):
        batch_input_ids = batch["input_ids"]
        batch_input_msk = batch["input_msk"]
        batch_label_ids = batch["label_ids"]

        output = model(batch_input_ids, batch_input_msk)
        loss = loss_fn(output.view(-1, len(vocab)), batch_label_ids.view(-1))
        loss.backward()
        optim.step()
        optim.zero_grad()
  • model(batch_input_ids, batch_input_msk)

    • batch_input_ids 是输入序列的嵌入表示。

    • batch_input_msk 是对应的因果掩码。

    • 模型在计算自注意力时,会使用 batch_input_msk 来确保解码器只能看到当前及之前的位置。

4. 因果掩码的作用

MultiHeadAttention 类中,因果掩码被应用到注意力分数矩阵中:

if attn_mask is not None:
    attn_mask = attn_mask.unsqueeze(1)
    atten_scores = atten_scores.masked_fill(attn_mask == 0, -1e9)
  • attn_mask.unsqueeze(1)

    • 将掩码的形状从 (batch_size, seq_len, seq_len) 扩展为 (batch_size, 1, seq_len, seq_len)

  • masked_fill

    • 将掩码中为 False 的位置的注意力分数设置为 -1e9,确保这些位置的注意力权重趋近于0。

5. 生成诗歌时的因果掩码

在生成诗歌时,因果掩码同样被应用:

def generate_poetry(method="greedy", top_k=5):
    model.eval()
    with torch.no_grad():
        input_ids = torch.tensor(vocab["[BOS]"]).view(1, -1)

        while input_ids.shape[1] < seq_len:
            output = model(input_ids, None)
            probabilities = torch.softmax(output[:, -1, :], dim=-1)
            
            if method == "greedy":
                next_token_id = torch.argmax(probabilities, dim=-1)
            elif method == "top_k":
                top_k_probs, top_k_indices = torch.topk(probabilities[0], top_k)
                next_token_id = top_k_indices[torch.multinomial(top_k_probs, 1)]

            if next_token_id == vocab["[EOS]"]:
                break

            input_ids = torch.cat([input_ids, next_token_id.view(1, 1)], dim=1)
    return input_ids.squeeze()
  • model(input_ids, None)

    • 在生成诗歌时,输入序列 input_ids 会逐渐增长,但因果掩码是隐含的,因为模型的自注意力层会自动处理序列的顺序性。

    • 生成过程中,模型只能看到当前及之前的位置,这与训练时使用因果掩码的目的相同。



 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2332192.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

适合工程建筑行业的OA系统有什么推荐?

工程行业具有项目周期长、协作链条复杂等特性&#xff0c;传统管理模式下的 “人治”“纸质化” 弊端日益凸显。OA 系统作为数字化管理的核心载体&#xff0c;通过流程标准化、数据可视化&#xff0c;精准解决工程行业项目管理核心痛点。 泛微 e-office 深度聚焦工程场景&#…

深入解析栈回溯技术:如何通过异常处理精准定位程序崩溃点

一、栈回溯 1.1 栈回溯的原理 调试程序时&#xff0c;经常发生这类错误&#xff1a; 1.读写某个地址&#xff0c;导致程序崩溃 2.调用某个空函数&#xff0c;导致程序崩溃在异常处理函数中&#xff0c;可以打印出”发生错误瞬间”的所有寄存器。 我们调试时&#xff0c;可以…

重构居家养老安全网:从 “被动响应” 到 “主动守护”

随着全球老龄化加剧&#xff0c;居家养老安全成为社会关注的核心议题。 传统养老模式依赖人工巡检或单一传感器&#xff0c;存在响应滞后、隐私泄露、场景覆盖不足等问题。 由此智绅科技应运而生&#xff0c;七彩喜智慧养老系统构筑居家养老安全网。 而物联网&#xff08;Io…

Unity6下架中国区,团结引擎接棒:这是分裂,还是本地化的开始?

就在近日&#xff0c;一则消息在国内游戏开发圈内迅速传播开来&#xff1a;Unity 6 及其后续版本已在中国大陆及港澳地区下架。这意味着&#xff0c;未来中国用户将无法直接使用 Unity 最新的主线版本。而取而代之的&#xff0c;是由 Unity 中国主导推出的本地化产品 —— 团结…

ESP8266水位监测以及温湿度数据采集

上面就是ESP8266的引脚图&#xff0c;水温检测使用的是水位监测传感器&#xff0c;温湿度测量使用的是DHT11&#xff0c;DHT11的反应时间是2秒&#xff0c;这里要注意。开发采用Arduino程序 1. 传感器初始化 功能&#xff1a;初始化DHT11温湿度传感器和串口通信。 代码实现&…

国产信创数据库:PolarDB 分布式版 V2.0,支持集中分布式一体化

阿里云PolarDB数据库管理软件&#xff08;分布式版&#xff09;V2.0 &#xff0c;安全可靠的集中分布式一体化数据库管理软件。点此查看详情https://www.aliyun.com/activity/database/polardbx-v2?spma2c6h.13046898.publish-article.8.44146ffaE0lEWT 立即咨询专家&#xf…

Axure PR 9 中继器 09 删除行

大家好&#xff0c;我是大明同学。 接着上期的内容&#xff0c;这期内容&#xff0c;我们来了解一下Axure中继器数据表删除行交互设计。 预览地址&#xff1a;https://vvlmqu.axshare.com 删除行 1.打开上期RP 文件&#xff0c;设计一个删除弹窗元件&#xff0c; 创建为动态面…

HDCP(五)

HDCP 2.2 测试用例设计详解 基于HDCP 2.2 CTS v1.1规范及协议核心机制&#xff0c;以下从正常流程与异常场景两大方向拆解测试用例设计要点&#xff0c;覆盖认证、密钥管理、拓扑验证等关键环节&#xff1a; 1. 正常流程测试 1.1 单设备认证 • 测试目标&#xff1a;验证源设…

商城APP打包教程

下载 HBuilderX 工具 HBuilderX支持插件拓展功能。App开发版已集成相关插件、开箱即用 根据自身电脑系统选择对应软件下载&#xff0c;建议选择APP开发版 2. 下载好软件安装后打开 建议直接在uniapp插件页面一键导入&#xff0c;正常情况下uniapp插件都是最新的&#xff0c;大家…

Spring 框架的核心基础:IoC 和 AOP

一、IoC&#xff08;Inversion of Control&#xff0c;控制反转&#xff09; 定义&#xff1a; IoC&#xff08;Inversion of Control&#xff0c;控制反转&#xff09;&#xff0c;就是把对象创建和依赖关系的管理交给 Spring 容器&#xff0c;而不是由程序员手动去创建对象…

SpringBoot 基础知识,HTTP 概述

1. 概述 1.1 Spring Spring 提供若干个子项目&#xff0c;每个项目用于完成特定功能 Spring 的若干个子项目都基于一个基础的框架&#xff1a;Spring Framework 框架类似于 房屋的地基 但 Spring Framework 配置繁琐&#xff0c;入门难度大 1.2 Spring Boot 于是&#xf…

《网络管理》实践环节04:SNMP监控数据采集流程及SNMP协议详细分析

兰生幽谷&#xff0c;不为莫服而不芳&#xff1b; 君子行义&#xff0c;不为莫知而止休。 1 实验目标 1. 理解SNMP网络管理原理 2. 掌握SNMP服务器采集SNMP Agent数据的方法 3. 掌握SNMP报文发送和应答流程 4. 掌握典型GetResponsePDU数据结构分析的方法 4. 具备SNMP通信…

《Uniapp-Vue 3-TS 实战开发》构建HTTP请求拦截器

引言 在 UniApp 结合 TypeScript 和 Vue3 的项目开发中&#xff0c;请求拦截器起着至关重要的作用。它能够在请求发送前和响应接收后对数据进行统一处理&#xff0c;极大地提高了代码的可维护性和功能性。本文将详细解析上述代码中请求拦截器的实现及其在 UniApp-Ts-Vue3 项目中…

从PDF中提取表格:以GB/T2260—2007为例

文章目录 先说结论前因后果思路1、PDF2CSV2、PDF2MD → MD2CSV3、针对不同表格的两种思路1&#xff09; 竖形三线表2&#xff09;五元素为一组 还没结束批量处理1、分割markdown文档2、跳过另一种格式的文档 总结一下 先说结论 结论就是&#xff0c;博主用了一天的时间去研究如…

初识MySQL · 复合查询(内外连接)

目录 前言&#xff1a; 基本查询回顾 笛卡尔积和子查询 笛卡尔积 内外连接 子查询 单行子查询 多行子查询 多列子查询 from中使用子查询 合并查询 前言&#xff1a; 在前文我们学习了MySQL的基本查询&#xff0c;就是简单的套用了select语句&#xff0c;最多不过是…

辛格迪客户案例 | 北京舒曼德医药实施电子合约系统(eSign)

01 北京舒曼德医药科技开发有限公司&#xff1a;医药科技的数字化先锋 北京舒曼德医药科技开发有限公司&#xff08;以下简称“舒曼德医药”&#xff09;作为国内医药科技领域的领军企业&#xff0c;致力于创新药物的研发、临床试验和市场推广。公司以“科技兴药、质量为先、服…

Python面向对象-开闭原则(OCP)

1. 什么是开闭原则&#xff1f; 开闭原则(Open-Closed Principle, OCP) 是面向对象设计的五大SOLID原则之一&#xff0c;由Bertrand Meyer提出。其核心定义是&#xff1a; “软件实体(类、模块、函数等)应该对扩展开放&#xff0c;对修改关闭。” 对扩展开放&#xff1a;当需求…

Class 文件和类加载机制

一、Class 文件 与 类加载机制 概述 什么是 Class 文件&#xff1f; Java 源码&#xff08;.java&#xff09;经过 javac 编译器 编译生成的字节码文件&#xff08;.class&#xff09;&#xff1b;由 JVM 识别执行&#xff0c;包含类的完整结构信息&#xff08;如字段、方法、…

Vue3+Vite+TypeScript+Element Plus开发-07.Mockjs引用与Axios封装

系列文档目录 Vue3ViteTypeScript安装 Element Plus安装与配置 主页设计与router配置 静态菜单设计 Pinia引入 Header响应式菜单缩展 Mockjs引用与Axios封装 登录设计 登录成功跳转主页 多用户动态加载菜单 Pinia持久化 动态路由-配置 文章目录 目录 系列文档目…

【Redis】背景知识

一、Redis的特性 Redis是一种基于键值对&#xff08;key-value&#xff09;的NoSQL数据库&#xff0c;与很多键值对数据库不同的是&#xff0c;Redis中的值可以是由string&#xff08;字符串&#xff09;&#xff0c;hash&#xff08;哈希&#xff09;&#xff0c;list&#xf…