pytorch小记(十二):pytorch中 masked_fill_() vs. masked_fill() 详解

news2025/3/28 12:22:45

pytorch小记(十二):pytorch中 masked_fill_() vs. masked_fill()详解

  • PyTorch `masked_fill_()` vs. `masked_fill()` 详解
    • 1️⃣ `masked_fill()` 和 `masked_fill_()` 的作用
    • 2️⃣ `masked_fill()` vs. `masked_fill_()` 示例
    • 3️⃣ 输出结果
    • 4️⃣ `masked_fill()` vs. `masked_fill_()` 区别
    • 5️⃣ `masked_fill()` 和 `masked_fill_()` 的实际应用
    • 6️⃣ `masked_fill()` 在 Transformer 自注意力中的应用
    • 7️⃣ `masked_fill_()` 在梯度计算中的应用
    • 8️⃣ 总结
      • 💡 实际应用


PyTorch masked_fill_() vs. masked_fill() 详解

在 PyTorch 中,masked_fill_()masked_fill() 主要用于 根据掩码(mask)填充张量(tensor)中的特定元素,但它们的关键区别在于 是否修改原张量(in-place 操作)


1️⃣ masked_fill()masked_fill_() 的作用

  • masked_fill(mask, value)
    • 不会修改原张量,而是返回一个新的张量。
  • masked_fill_(mask, value)
    • 会直接修改原张量(in-place 操作),节省内存。

两者的作用

  • mask 是一个 布尔张量True 代表需要填充的元素)。
  • value 是要填充的数值。

2️⃣ masked_fill() vs. masked_fill_() 示例

import torch

# 创建一个 3×3 的张量
tensor = torch.tensor([
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
])

# 创建一个掩码:True 代表要被替换的元素
mask = torch.tensor([
    [False, True, False],
    [True, False, False],
    [False, False, True]
])

print("原张量 tensor:\n", tensor)

# 使用 masked_fill()(不会修改原张量)
new_tensor = tensor.masked_fill(mask, -1)
print("\n新张量 new_tensor(使用 masked_fill()):\n", new_tensor)
print("\n原张量 tensor(未修改):\n", tensor)

# 使用 masked_fill_()(会修改原张量)
tensor.masked_fill_(mask, -1)
print("\n原张量 tensor(被修改):\n", tensor)

3️⃣ 输出结果

原张量 tensor:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

新张量 new_tensor(使用 masked_fill()):
tensor([[ 1, -1,  3],
        [-1,  5,  6],
        [ 7,  8, -1]])

原张量 tensor(未修改):
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

原张量 tensor(被修改):
tensor([[ 1, -1,  3],
        [-1,  5,  6],
        [ 7,  8, -1]])

4️⃣ masked_fill() vs. masked_fill_() 区别

函数是否修改原张量?返回值
masked_fill(mask, value)❌ 不修改返回新的张量
masked_fill_(mask, value)✅ 直接修改修改后的原张量

总结

  • 如果你希望创建一个新张量,而不修改原数据,用 masked_fill()
  • 如果你希望节省内存并直接修改原张量,用 masked_fill_()

5️⃣ masked_fill()masked_fill_() 的实际应用

自然语言处理(NLP)深度学习模型 中,这两个函数经常用于 掩码(masking)操作,例如:

  • 屏蔽填充(Padding Mask):防止模型处理填充的 PAD 词(如 Transformer)。
  • 屏蔽未来信息(Future Mask):用于自回归模型(如 GPT),确保预测不会使用未来的信息。

6️⃣ masked_fill() 在 Transformer 自注意力中的应用

import torch

# 假设有一个 4×4 的注意力得分矩阵
attn_scores = torch.tensor([
    [0.5, 0.7, 0.8, 0.9],
    [0.6, 0.5, 0.4, 0.8],
    [0.2, 0.4, 0.5, 0.7],
    [0.3, 0.5, 0.6, 0.8]
])

# 创建一个掩码(模拟未来时间步的屏蔽)
mask = torch.tensor([
    [False, False, False, True],
    [False, False, True, True],
    [False, True, True, True],
    [True, True, True, True]
])

# 用 -inf 屏蔽掩码位置
masked_scores = attn_scores.masked_fill(mask, float('-inf'))
print("\n注意力得分(masked_fill()):\n", masked_scores)

示例输出:

注意力得分(masked_fill()):
tensor([[ 0.5000,  0.7000,  0.8000,    -inf],
        [ 0.6000,  0.5000,    -inf,    -inf],
        [ 0.2000,    -inf,    -inf,    -inf],
        [   -inf,    -inf,    -inf,    -inf]])

📌 解释

  • False 的位置保留原始数值。
  • True 的位置填充 -inf,在 softmax 计算时会被归零,不影响其他数值。

7️⃣ masked_fill_() 在梯度计算中的应用

在 PyTorch 训练过程中,如果你想直接修改梯度计算中的变量,可以使用 masked_fill_() 进行 in-place 操作

import torch

# 创建一个需要计算梯度的张量
x = torch.tensor([0.1, 0.2, 0.3, 0.4], requires_grad=True)

# 创建掩码
mask = torch.tensor([False, True, False, True])

# 直接修改 x
x.masked_fill_(mask, 0.0)

print("\n被修改后的 x(masked_fill_()):\n", x)

示例输出:

被修改后的 x(masked_fill_()):
tensor([0.1000, 0.0000, 0.3000, 0.0000], requires_grad=True)

📌 解释

  • 通过 masked_fill_() 直接在计算图中修改 x,避免创建新张量。

8️⃣ 总结

  • masked_fill()masked_fill_() 都用于按掩码填充张量中的特定元素
  • 主要区别
    • masked_fill() 不修改原张量,返回新的张量。
    • masked_fill_() 直接修改原张量(in-place 操作)。
  • 适用场景
    • masked_fill():适用于需要 保持原张量不变 的情况,如 Transformer 掩码处理
    • masked_fill_():适用于需要 节省内存直接修改张量 的情况,如 梯度计算

💡 实际应用

场景推荐使用
创建新张量,不修改原数据masked_fill()
直接修改原数据,减少内存占用masked_fill_()
Transformer 自注意力掩码masked_fill()
梯度计算,避免额外的计算图创建masked_fill_()

🚀 合理使用 masked_fill()masked_fill_(),可以优化你的 PyTorch 代码,提高计算效率! 🎯

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2319022.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

STM32U575RIT6单片机(四)

作业: 使用I2C获取SHT20传感器温湿度 使用I2C获取AP3216C三合一传感器: 光照, 接近, 红外 三个功能 合并的传感器 #ifndef SHT20_H #define SHT20_H#include "stdint.h" #include "i2c.h" #include "stdio.h" //1、确定从机的设备地址(代码不…

EMQX安装与配置

EMQX安装与配置 EMQX安装与配置 https://www.emqx.com/zh/downloads-and-install/broker?osUbuntucd /usr/local/srcwget https://www.emqx.com/zh/downloads/broker/4.4.19/emqx-4.4.19-otp24.3.4.2-1-ubuntu16.04-amd64.deb sudo apt install ./emqx-4.4.19-otp24.3.4.2-1…

JVM逃逸分析作用和原理

JVM逃逸分析作用和原理 在JVM的性能优化中,我们通常会关注内存分配、垃圾回收等问题。而逃逸分析(Escape Analysis)是JVM中一种精妙的优化技术,它可以在对象分配时判断该对象是否会在方法或线程之外被访问,从而影响其…

拓展 Coco AI 功能 - 智能检索 Hexo 博客

在之前的文章中,我们成功让 Coco AI 检索 Hugo 博客,这对于博客作者来说是一大福音。然而,从 Hexo 迁移到 Hugo 的成本不容小觑,毕竟大多数开发者对 Node.js 更熟悉,而 Golang 相对陌生。那么,既然 Coco AI…

爬虫基础之爬取猫眼Top100 可视化

网站: TOP100榜 - 猫眼电影 - 一网打尽好电影 本次案例所需用到的模块 requests (发送HTTP请求) pandas(数据处理和分析 保存数据) parsel(解析HTML数据) pyecharts(数据可视化图表) pymysql(连接和操作MySQL数据库) lxml(数据解析模块) 确定爬取的内容: 电影名称 电影主演…

LS-NET-006-思科MDS 9148S 查看内存

LS-NET-006-思科MDS 9148S 查看内存 方法一:使用 show version​ 命令 该命令可显示设备的基本系统信息,包括内存总量。 登录交换机的CLI(通过控制台或SSH连接)。输入命令: show version 在输出中查找类似以下内容…

小程序API —— 54 路由与通信 - 编程式导航

在小程序中实现页面的跳转,有两种方式: 声明式导航:navigator 组件编程式导航:使用小程序提供的 API 编程式导航 API 提供了五个常用的 API 方法: wx.navigateTo():保留当前页面,跳转到应用内…

关于金融开发领域的一些专业知识总结

目录 1. 交易生命周期 1.1 证券交易所 1.1.1 交易前 1) 订单生成(Order Generation) 2) 订单管理(Order Management) 1.1.2 交易执行 3) 交易匹配(Trade Matching) 1.1.3 交易后 4) 交易确认&…

DeepSeek-R1深度解读

deepseek提出了一种通过强化学习(RL)激励大语言模型(LLMs)推理能力的方法,个人认为最让人兴奋的点是:通过RL发现了一个叫“Aha Moment”的现象,这个时刻发生在模型的中间版本中。在这个阶段&…

15-双链表-双链表基本操作

题目 来源 827. 双链表 - AcWing题库 思路 此题我只想说,千万千万别漏了头结点和尾结点,不然根本查不出来是哪里出了问题,因为传入的k会有问题;最左边插入,相当于是在头结点的右边插入(也就是0号节点的右…

【小也的Java之旅系列】01 分布式、集群、微服务的区别

前言 做Java开发多年,一直以来都有想把Java做成一个系列的想法,最近整理自己的笔记发现有很多值得写的内容,但这些内容又往往杂乱不堪。CSDN上有很多高质量的Java博客,但大多不是从一个人成长的角度去写的。而我们——一个技术人…

基于视觉的核桃分级与套膜装置研究(大纲)

基于视觉的核桃分级与套膜装置研究:从设计到实现的完整指南 (SolidWorks、OpenCV、STM32开发实践) 🌟 项目背景与目标 1.1 为什么选择视觉分级与套膜? 产业痛点: 中国核桃年产量全球第一,但…

JimuReport与deepseek结合,颠覆现有BI模式

在数字化转型的浪潮中,企业对数据的依赖程度越来越高,如何高效地分析和利用数据成为关键。JimuReport凭借其强大的报表设计能力和灵活的数据处理功能,已经成为众多企业的首选工具。如今,它即将与DeepSeek深度结合,为企…

11、STL中的set使用方法

一、了解 set 是 C 标准模板库(STL)中提供的有序关联容器之一。基于红黑树(Red-Black Tree)实现,用于存储一组唯一的元素,并按照元素的值进行排序。 set的特性 唯一性 键是唯一的。无重复。 有序性 按升序…

操作系统——(管程、线程、进程通信)

目录 一、管程机制 (1)管程定义 (2)特点: 二、进程通信 (1)概念 (2)高级通信机制 三、线程 (1)概念 (2)与进程比较…

Sqlserver安全篇之_启用和禁用Named Pipes的案列介绍

https://learn.microsoft.com/zh-cn/sql/tools/configuration-manager/named-pipes-properties?viewsql-server-ver16 https://learn.microsoft.com/zh-cn/sql/tools/configuration-manager/client-protocols-named-pipes-properties-protocol-tab?viewsql-server-ver16 默认…

Web开发-JS应用原生代码前端数据加密CryptoJS库jsencrypt库代码混淆

知识点: 1、安全开发-原生JS-数据加密&代码混淆 2、安全开发-原生JS-数据解密安全案例 一、演示案例-WEB开发-原生JS&第三方库-数据加密 前端技术JS实现: 1、非加密数据大致流程: 客户端发送->明文数据传输-服务端接受数据->…

比特币牛市还在不在

在加密货币的风云世界里,比特币的一举一动始终牵动着投资者们的神经。近期比特币的涨幅动作,再次引发了市场对于牛市是否仍在延续的激烈讨论。 在深入探索比特币市场的过程中,获取全面且及时的资讯至关重要。您可以通过访问Techub News&#…

Python、MATLAB和PPT完成数学建模竞赛中的地图绘制

参加数学建模比赛时,很多题目——诸如统计类、数据挖掘类、环保类、建议类的题目总会涉及到地理相关的情景,往往要求我们制作与地图相关的可视化内容。如下图,这是21年亚太赛的那道塞罕坝的题目,期间涉及到温度、降水和森林覆盖率…

跨平台RTSP高性能实时播放器实现思路

跨平台RTSP高性能实时播放器实现思路 目标:局域网100ms以内超低延迟 一、引言 现有播放器(如VLC)在RTSP实时播放场景中面临高延迟(通常数秒)和资源占用大的问题。本文提出一种跨平台解决方案,通过网络层…