random_masking 函数测试

news2025/3/21 21:21:28

文章目录

  • 1. description
  • 2. excel
  • 3. pytorch code

1. description

  • 功能:按一定比例的随机部分样本,简单来说就是按照一定的比例将行向量从小到大的顺序提取出来。
  • 思考1: 用了均匀分布,并且按照一定比例,取前prob概率来表示
  • 思考2:用了torch.argsort 来生成idx_shuffle 来的到快速从小到大排序的
  • 思考3:用了torch.gather 来配合idx_shuffle 来找到最小的部分数据
  • 思考4:用了torch.gather+idx_restore+ones_like matrix 组合的方式生成mask矩阵
  • 小结:主要配合torch.argsort+torch.gather的方式生成相关的mask矩阵和最小矩阵和恢复矩阵,代码的巧妙运用很具备参考意义。

2. excel

在这里插入图片描述

3. pytorch code

  • pytorch
import torch
import torch.nn as nn

torch.set_printoptions(precision=3, sci_mode=False)
torch.manual_seed(2324)


def random_masking(x, mask_ratio):
    """
    Perform per-sample random masking by per-sample shuffling.
    Per-sample shuffling is done by argsort random noise.
    x: [N, L, D], sequence
    """
    N, L, D = x.shape  # batch, length, dim
    len_keep = int(L * (1 - mask_ratio))

    noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

    # sort noise for each sample
    ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=x.device)
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    return x_masked, mask, ids_restore


if __name__ == "__main__":
    run_code = 0
    bs = 2
    seq_len = 8
    seq_dim = 10
    mx_total = bs * seq_dim * seq_len
    a_matrix = torch.arange(mx_total).reshape((bs, seq_len, seq_dim))
    a_x_masked, a_mask, a_ids_restore = random_masking(a_matrix, mask_ratio=0.4)
    print(f"a_matrix=\n{a_matrix}")
    print(f"a_x_masked=\n{a_x_masked}")
    print(f"a_mask=\n{a_mask}")
    print(f"a_ids_restore=\n{a_ids_restore}")
  • result
a_matrix=
tensor([[[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9],
         [ 10,  11,  12,  13,  14,  15,  16,  17,  18,  19],
         [ 20,  21,  22,  23,  24,  25,  26,  27,  28,  29],
         [ 30,  31,  32,  33,  34,  35,  36,  37,  38,  39],
         [ 40,  41,  42,  43,  44,  45,  46,  47,  48,  49],
         [ 50,  51,  52,  53,  54,  55,  56,  57,  58,  59],
         [ 60,  61,  62,  63,  64,  65,  66,  67,  68,  69],
         [ 70,  71,  72,  73,  74,  75,  76,  77,  78,  79]],

        [[ 80,  81,  82,  83,  84,  85,  86,  87,  88,  89],
         [ 90,  91,  92,  93,  94,  95,  96,  97,  98,  99],
         [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
         [110, 111, 112, 113, 114, 115, 116, 117, 118, 119],
         [120, 121, 122, 123, 124, 125, 126, 127, 128, 129],
         [130, 131, 132, 133, 134, 135, 136, 137, 138, 139],
         [140, 141, 142, 143, 144, 145, 146, 147, 148, 149],
         [150, 151, 152, 153, 154, 155, 156, 157, 158, 159]]])
a_x_masked=
tensor([[[ 10,  11,  12,  13,  14,  15,  16,  17,  18,  19],
         [  0,   1,   2,   3,   4,   5,   6,   7,   8,   9],
         [ 20,  21,  22,  23,  24,  25,  26,  27,  28,  29],
         [ 60,  61,  62,  63,  64,  65,  66,  67,  68,  69]],

        [[ 90,  91,  92,  93,  94,  95,  96,  97,  98,  99],
         [120, 121, 122, 123, 124, 125, 126, 127, 128, 129],
         [140, 141, 142, 143, 144, 145, 146, 147, 148, 149],
         [ 80,  81,  82,  83,  84,  85,  86,  87,  88,  89]]])
a_mask=
tensor([[0., 0., 0., 1., 1., 1., 0., 1.],
        [0., 0., 1., 1., 0., 1., 0., 1.]])
a_ids_restore=
tensor([[1, 0, 2, 4, 7, 6, 3, 5],
        [3, 0, 6, 4, 1, 5, 2, 7]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2318563.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

TDengine 中的流式计算

简介 TDengine 中的流计算,功能相当于简化版的 FLINK , 具有实时计算,计算结果可以输出到超级表中存储,同时也可用于窗口预计算,加快查询速度。 创建流式计算 CREATE STREAM [IF NOT EXISTS] stream_name [stream_o…

Java 大视界 -- Java 大数据在智慧交通自动驾驶仿真与测试数据处理中的应用(136)

💖亲爱的朋友们,热烈欢迎来到 青云交的博客!能与诸位在此相逢,我倍感荣幸。在这飞速更迭的时代,我们都渴望一方心灵净土,而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识,也…

JVM的一些知识

JVM简介 JVM 是 Java Virtual Machine 的简称,意为 Java 虚拟机。 虚拟机是指通过软件模拟的具有完整硬件功能的、运行在一个完全隔离的环境中的完整计算机系统。常见的虚拟机:JVM、VMwave、Virtual Box。 JVM 和其他两个虚拟机的区别: VMw…

C语言每日一练——day_7

引言 针对初学者,每日练习几个题,快速上手C语言。第七天。(连续更新中) 采用在线OJ的形式 什么是在线OJ? 在线判题系统(英语:Online Judge,缩写OJ)是一种在编程竞赛中用…

Java使用FFmpegFrameGrabber进行视频拆帧,结合Thumbnails压缩图片保存到文件夹

引入依赖 <dependency><groupId>net.coobird</groupId><artifactId>thumbnailator</artifactId><version>0.4.17</version></dependency><dependency><groupId>org.bytedeco</groupId><artifactId>ja…

用hexo初始化博客执行hexo init时碰到的问题

用hexo初始化博客执行hexo init时碰到的问题 $ hexo init myblog INFO Cloning hexo-starter https://github.com/hexojs/hexo-starter.git fatal: unable to access https://github.com/hexojs/hexo-starter.git/: SSL certificate problem: unable to get local issuer cer…

4.1--入门知识扫盲,ISO知识体系介绍(看一遍,协议啥的全部记住)

OSI七层模型&#xff1a;网络世界的"七重天"生存指南&#xff08;附快递小哥版图解&#xff09; “如果你觉得网络分层很抽象&#xff0c;那就想象自己在寄快递” —— 来自一个被三次握手逼疯的程序员 开场白&#xff1a;网络通信就像送外卖 假设你要给隔壁妹子送奶…

AI训练如何获取海量数据,论平台的重要性

引言&#xff1a;数据——AI时代的“新石油” 在人工智能和大模型技术飞速发展的今天&#xff0c;数据已成为驱动技术进步的 “ 燃料 ”。无论是训练聊天机器人、优化推荐算法&#xff0c;还是开发自动驾驶系统&#xff0c;都需要海量、多样化的数据支持。 然而&#xff0c;获…

Git 使用SSH登陆

一、SSH介绍 SSH连接相比于HTTP连接会简单一点&#xff0c;因为SSH连接通过了私钥与公钥进行身份认证&#xff0c;这样就不需要像HTTP一样&#xff0c;每次clone或者操作仓库都需要输入密码 其中私钥和密钥是需要在自己电脑上生成的&#xff0c;通过命令即可生成一个私钥和一个…

织梦DedeCMS修改文章【标题、短标题、关键词】长度限制

在后台虽然可以设置标题的长度&#xff0c;但是数据库的字段固定是60个字符&#xff0c;短标题是36字符&#xff0c;关键词30字符&#xff0c;所以这里教大家修改一下织梦DedeCMS修改【标题】【短标题】【关键词】长度限制 一、后台配置 1、进入dede后台管理 -> 系统 ->…

Powershell WSL部署ubuntu22.04.5子系统

前提条件WSL 安装 wsl 安装参考1wsl 安装csdn参考2wsl 百度网盘离线下载 本地目录安装ubuntu22.04.5 子系统 powershell 管理员打开执行(实现,下载安装ubuntu子系统,用户创建,远程ssh登录设置,防火墙端口开放)子系统IP 查看方法wsl

umi自带的tailwindcss修改为手动安装

1》为什么改为手动&#xff1f; 主要是为了解决这个报错问题&#xff0c;虽然重新运行也可解决&#xff0c;但是总是要运行2-3次&#xff0c;比较麻烦 2》如何手动 1&#xff0c;先在devDependencies下安装这两个包 pnpm install postcss8.5.1 -D "autoprefixer"…

麒麟V10 arm cpu aarch64 下编译 RocketMQ-Client-CPP 2.2.0

国产自主可控服务器需要访问RocketMQ消息队列&#xff0c;最新的CSDK是2020年发布的 rocketmq-client-cpp-2.2.0 这个版本支持TLS模式。 用默认的版本安装遇到一些问题&#xff0c;记录一下。 下载Releases apache/rocketmq-client-cpp GitHubhttps://github.com/apache/roc…

使用码云搭建CocoaPods远程私有库

一、创建远程私有索引库 用来存放私有框架的详细描述信息.podspec文件 1. 创建私有库 假设码云上创建的私有库为repo-spec 2. 查看本地已存在的索引库 pod repo list 3. 将远程私有索引库添加到本地 pod repo add [https://gitee.com/jingluoguo/repo-spec.git](https://gi…

专访LayaAir引擎最有价值专家-施杨

在 LayaAir 引擎的资源商店中&#xff0c;许多开发者都会注意到一个熟悉的名字——“射手座”。他不仅贡献了大量高质量的 Shader 资源&#xff0c;让一些开发者通过他的作品了解到 LayaAir 引擎在 3D 视觉效果上的更多可能&#xff0c;也让大家能够以低成本直接学习并应用这些…

自然语言处理:文本聚类

介绍 大家好&#xff0c;博主又来和大家分享自然语言处理领域的知识了。今天给大家分享的内容是自然语言处理中的文本聚类。 文本聚类在自然语言处理领域占据着重要地位&#xff0c;它能将大量无序的文本按照内容的相似性自动划分成不同的类别&#xff0c;极大地提高了文本处…

RabbitMQ 集群降配

这里写自定义目录标题 摘要检查状态1. 检查 RabbitMQ 服务状态2. 检查 RabbitMQ 端口监听3. 检查 RabbitMQ 管理插件是否启用4. 检查开机自启状态5. 确认集群高可用性6. 检查使用该集群的服务是否做了断开重连 实操1. 负载均衡配置2. 逐个节点降配&#xff08;滚动操作&#xf…

数据结构:二叉树(一)·(重点)

前言 什么树&#xff1f;what&#xff1f; 树的概念与结构 概念&#xff1a; 树是⼀种⾮线性的数据结构&#xff0c;它是由 n &#xff08; n>0 &#xff09; 个有限结点组成⼀个具有层次关系的集合。 结构&#xff1a; 有⼀个特殊的结点&#xff0c;称为根结点&#…

DevEco Studio的使用

目录 1.创建ArkTS工程 2.ArkTS工程目录结构&#xff08;Stage模型&#xff09; 构建第一个页面 构建第二个页面 实现页面间的跳转 1.创建ArkTS工程 若首次打开DevEco Studio&#xff0c;请点击Create Project创建工程。如果已经打开了一个工程&#xff0c;请在菜单栏选择…

数据开发岗笔试题>>sql(hive) ,excel [2025]

sql SELECT user_id, AVG(loan_amount) AS avg_loan_amount FROM loan GROUP BY user_id HAVING AVG(loan_amount) > 20000; 授信表&#xff1a;credit 字段包含user_id(用户id)&#xff0c;credit_id(授信id)&#xff0c;credit_time(授信时间yyyy-MM-dd HH:mm:ss)&#x…