mask2former利用不确定性采样点选择提高模型性能

news2024/11/26 9:58:16

在机器学习和深度学习的训练过程中,不确定性高的点通常代表模型在这些点上的预测不够可靠或有较高的误差。因此,关注这些不确定性高的点,通过计算这些点的损失并进行梯度更新,可以有效地提高模型的整体性能。确定性高的点预测结果已经比较准确,相应地对模型的训练贡献较小,所以可以减少对这些点的关注或完全忽略它们的损失计算。

代码复现参考仓库:https://github.com/NielsRogge/Transformers-Tutorials

在这篇博客中,我们将详细解释 mask2former 中的一段代码,该代码通过不确定性采样点来选择重要点,并探讨其在模型训练中的重要性。mask2former原文描述比较简单,如下:
在这里插入图片描述

代码源自transformers库中的modeling_mask2former.py,主要讲解如下代码:

    def sample_points_using_uncertainty(
        self,
        logits: torch.Tensor,
        uncertainty_function,
        num_points: int,
        oversample_ratio: int,
        importance_sample_ratio: float,
    ) -> torch.Tensor:
        """
        This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The
        uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit
        prediction as input.

        Args:
            logits (`float`):
                Logit predictions for P points.
            uncertainty_function:
                A function that takes logit predictions for P points and returns their uncertainties.
            num_points (`int`):
                The number of points P to sample.
            oversample_ratio (`int`):
                Oversampling parameter.
            importance_sample_ratio (`float`):
                Ratio of points that are sampled via importance sampling.

        Returns:
            point_coordinates (`torch.Tensor`):
                Coordinates for P sampled points.
        """

        num_boxes = logits.shape[0]
        num_points_sampled = int(num_points * oversample_ratio)

        # Get random point coordinates
        point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
        # Get sampled prediction value for the point coordinates
        point_logits = sample_point(logits, point_coordinates, align_corners=False)
        # Calculate the uncertainties based on the sampled prediction values of the points
        point_uncertainties = uncertainty_function(point_logits)#[n1+n2, 1, 37632],理解为,值越大,不确定性越高

        num_uncertain_points = int(importance_sample_ratio * num_points)#9408
        num_random_points = num_points - num_uncertain_points#3136

        idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]#[n1+n2, 9408]这行代码的作用是从每个 num_boxes 的不确定性值中选择 num_uncertain_points 个最大值的索引。这些索引将用于从原始的点坐标张量 point_coordinates 中选择相应的点,这些点将被认为是基于不确定性的重要性采样点。
        shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)#这两行代码的主要目的是确保在从 point_coordinates 中选择点时,能够正确地访问全局索引,使得每个 box 的采样点能够准确地映射到整个张量中的位置。
        idx += shift[:, None]
        point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)#[n1+n2, 9408, 2]

        if num_random_points > 0:
            point_coordinates = torch.cat(
                [point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],
                dim=1,
            )
        return point_coordinates

以下是 sample_points_using_uncertainty 函数的参数解释:

  • logits (torch.Tensor): P 个点的 logit 预测值。
  • uncertainty_function: 一个函数,接受 P 个点的 logit 预测值并返回它们的不确定性。
  • num_points (int): 需要采样的点的数量 P。
  • oversample_ratio (int): 过采样参数,用于增加采样点的数量,以确保能在不确定性采样中选到合适的点。
  • importance_sample_ratio (float): 使用重要性采样选出的点的比例。

函数步骤解释

  1. 计算总采样点数

    num_boxes = logits.shape[0]
    num_points_sampled = int(num_points * oversample_ratio)
    

    num_boxes 是指预测的盒子数量,num_points_sampled 是经过过采样之后的总采样点数。

  2. 生成随机点的坐标

    point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
    

    在 [0, 1] * [0, 1] 空间内生成随机点的坐标。

  3. 获取这些随机点的预测值

    point_logits = sample_point(logits, point_coordinates, align_corners=False)
    

    对随机点的坐标进行采样,获取它们的预测 logit 值。

  4. 计算这些点的不确定性

    point_uncertainties = uncertainty_function(point_logits)
    

    使用 uncertainty_function 计算这些点的不确定性。

  5. 确定不确定性采样和随机采样的点数

    num_uncertain_points = int(importance_sample_ratio * num_points)
    num_random_points = num_points - num_uncertain_points
    

    根据 importance_sample_ratio 确定通过不确定性采样的点数 num_uncertain_points,以及剩余的随机采样点数 num_random_points

  6. 选择不确定性最高的点

    idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
    shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)
    idx += shift[:, None]
    point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
    

    使用 torch.topk 函数选择每个盒子中不确定性最高的 num_uncertain_points 个点,并获取它们的坐标。

  7. 添加随机点

    if num_random_points > 0:
        point_coordinates = torch.cat(
            [point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],
            dim=1,
        )
    

    如果需要添加随机采样点,将它们与不确定性采样点合并。

  8. 返回采样点的坐标

    return point_coordinates
    

    最终返回所有采样点的坐标。

关键代码解读

1. 偏移量的生成
 shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)

这行代码的目的是为每个 box 生成一个偏移量(shift),用于转换局部索引为全局索引。

  • torch.arange(num_boxes, dtype=torch.long, device=logits.device) 生成一个从 0 到 num_boxes-1 的张量。
  • num_points_sampled 是每个 box 中采样的点的数量。
  • 乘法操作 num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device) 为每个 box 生成一个偏移量。例如,假设 num_points_sampled 为 100,那么生成的偏移量张量为 [0, 100, 200, 300, ...]

这些偏移量将用于将局部索引(即每个 box 内的索引)转换为全局索引(即在整个 point_coordinates 中的索引)。

2. 局部索引转换为全局索引
  idx += shift[:, None]

这行代码将局部索引转换为全局索引。

  • idxtorch.topk 返回的不确定性最高的点的局部索引,形状为 [num_boxes, num_uncertain_points]
  • shift[:, None] 的形状是 [num_boxes, 1],通过这种方式将每个 box 的偏移量广播到与 idx 的形状匹配。

通过将 shift 加到 idx 上,每个 box 的局部索引将变成全局索引。例如,如果第一个 box 的偏移量为 100,那么第一个 box 内的局部索引 [0, 1, 2, ...] 将变为 [100, 101, 102, ...]

总结

通过 sample_points_using_uncertainty 函数,我们可以有效地选择不确定性高的点进行训练,提高模型在这些关键点上的表现,同时减少确定性高的点的计算开销。这种不确定性采样方法结合了重要性采样和随机采样,确保了模型训练的高效性和鲁棒性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1816784.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

互联网全栈开发:产品经理、后端开发、前端开发、运维、测试等

我们都知道互联网公司,有几个较为重要的职业: 产品经理 后端开发 前端开发 运维 测试 这些技术往往相互隔阂,存在技术壁垒,而我开通了抖音号,常在抖音中发送这些视频,我的抖音号:1056668488。请大家麻…

EVA-CLIP实战

摘要 EVA-CLIP,这是一种基于对比语言图像预训练(CLIP)技术改进的模型,通过引入新的表示学习、优化和增强技术,显著提高了CLIP的训练效率和效果。EVA-CLIP系列模型在保持较低训练成本的同时,实现了与先前具有相似参数数量的CLIP模型相比更高的性能。特别地,文中提到的EV…

10 款最佳免费 Google SEO 工具

谷歌提供了免费测试和报告的工具,以帮助网站所有者和 SEO 专业人员分析和提高其网站的搜索性能。这些是最好的免费谷歌搜索引擎优化工具,用于升级您的搜索引擎优化,以及帮助您发现新的关键字机会以及帮助您发现新的关键字机会的工具。 无论您…

Nature最新!浙大王浩华团队:一种创新方法使量子态传输的保真度大大提高

在量子计算的快速发展过程中,量子信息传输技术(量子态传输)的进步至关重要。 然而,当前固态量子系统在实现量子信息传输方面存在一些显著的挑战,例如量子混沌或者系统不完美,其传输的保真度和效率通常难以…

VMware Ubuntu虚拟机上设置SSH连接,win直接用ssh连接虚拟机

要在Ubuntu虚拟机上设置SSH连接,并进行一些特定配置,您可以按照以下步骤进行操作: 步骤 1:安装OpenSSH Server 打开终端。 更新包列表并安装OpenSSH Server: sudo apt update sudo apt install openssh-server安装完…

51单片机实验05 -点阵

目录 一,熟悉矩阵led小灯 1,点亮矩阵的一只led 2,点亮矩阵的一排led 3,点亮矩阵的全部led static 关键字 unsigned 关键字 4,点阵的静态显示 2)心形矩阵显示代码 3)效果 二,课…

跑起来字节跳动音频超分开源项目versatile_audio_super_resolution

已部署在AutoDL上https://www.codewithgpu.com/i/haoheliu/versatile_audio_super_resolution/versatile_audio_super_resolution ipynb: 音乐 By 邓文怡 一个深圳的小姑娘%cd /root/versatile_audio_super_resolution/运行目录# 读取一个mp3音频文件,然后将它转换…

数据安全交换系统 与网闸有什么区别?

数据安全交换系统是指用于安全地传输、共享和交换数据的一种系统。这样的系统通常包括一系列安全性和隐私保护功能,确保数据在传输和存储过程中不会被未经授权的用户访问、泄露或篡改。 数据安全交换系统和网闸在功能和定位上有一些区别: 功能&#xff…

PDU模块中浪涌保护模块与空开模块的应用

由于PDU具体应用的特殊性,其在规划设计时具有应用场景的针对性,同时PDU的高度定制化的特点,是其他电气联接与保护产品所不具备的。 PDU基础的输出输入功能外,其电路的控制与电压保护器同时也极为重要。空气开关和浪涌保护器相关功…

Java课程设计:基于Java+Swing+MySQL的图书管理系统(内附源码)

文章目录 一、项目介绍二、项目展示三、源码展示四、源码获取 一、项目介绍 图书管理系统是一个常见的软件项目,广泛应用于图书馆、学校、企业等需要管理图书资源的场景。该系统通常涵盖图书信息录入、查询、借阅、归还等核心功能,是实现图书资源高效管理的重要工具。 随着信…

coap:使用californium建立coap server和client的简单示例

【pom.xml】 <dependency><groupId>org.eclipse.californium</groupId><artifactId>californium-core</artifactId><version>2.0.0-M7</version> </dependency> <dependency><groupId>org.eclipse.californium&l…

元宇宙3D虚拟代言人凸显企业形象和品牌风格

在虚拟社交的新时代浪潮中&#xff0c;拥有一个个性鲜明的AI数字人形象&#xff0c;无疑能让你在虚拟的海洋中独领风骚。深圳华锐视点作为你的数字形象创造的合作伙伴&#xff0c;为你呈现了一个丰富多彩的素材库与高度灵活的编辑工具。在这里&#xff0c;你可以依据个人喜好和…

爆款AI工具大盘点:最强文本、视频、音乐生成AI,适用岗位全解析!

博主猫头虎的技术世界 &#x1f31f; 欢迎来到猫头虎的博客 — 探索技术的无限可能&#xff01; 专栏链接&#xff1a; &#x1f517; 精选专栏&#xff1a; 《面试题大全》 — 面试准备的宝典&#xff01;《IDEA开发秘籍》 — 提升你的IDEA技能&#xff01;《100天精通鸿蒙》 …

内存管理--3.用幻灯片讲解C++手动内存管理

用幻灯片讲解C手动内存管理 1.栈内存的基本元素 2.栈内存的聚合对象 3.手动分配内存和释放内存 注意&#xff1a;手动分配内存&#xff0c;指的是在堆内存中。 除非实现自己的数据结构&#xff0c;否则永远不要手动分配内存! 即使这样&#xff0c;您也应该通过std::allocator…

Redis 配置及操作整理

本篇文章介绍了Redis在window中如何安装和修改配置及Redis几种数据类型及操作命令。 目录 window环境安装 修改配置 设置密码 设置最大内存大小 其他参数介绍 启动服务 使用客户端 客户端连接 验证密码 Redis数据类型 String 设置 运算 其它 Hash 设置 获取 …

文件操作学不懂,小代老师带你深入理解文件操作(上卷)

文件操作学不懂&#xff0c;小代老师带你深入理解文件操作上卷 1. 为什么使用⽂件&#xff1f;2. 什么是⽂件&#xff1f;2.1 程序⽂件2.2 数据⽂件2.3 文件名 3. 二进制文件和文本文件&#xff1f; 1. 为什么使用⽂件&#xff1f; 如果没有⽂件&#xff0c;我们写的程序的数据…

旋转方块加载动画

效果图: 完整代码: <!DOCTYPE html> <html> <head><meta charset="UTF-8" /><title>旋转方块加载动画</title><style type="text/css">body {background: #ECF0F1;display: flex;justify-content: center;al…

java自学阶段二:JavaWeb开发50(Spring和Springboot学习)

Spring、Springboot基础知识学习 目录 学习目标Spring基础概念IOC控制反转DI依赖注入事务管理AOP面向切面编程Spring案例说明&#xff08;Postman使用、Restful开发规范、lombok、Restful、nginx了解&#xff09; 一&#xff1a;学习目标&#xff1a; 1&#xff09;了解Sprin…

如何基于 Elasticsearch 实现排序沉底或前置

在搜索场景的应用中&#xff0c;存在希望根据某个或某些字段来调整排序评分&#xff0c;从而实现排序沉底或置顶效果的使用需求。以商机管理中的扫街场景为例&#xff0c;当我们在扫街场景中需要寻找一个商户时&#xff0c;希望这个商户离的近、GMV 潜力大、被他人跟进过的次数…

前端计网面试题(二)

一、在浏览器中输入url并且按下回车之后发生了什么&#xff1f; 首先解析url&#xff0c;判断url是否合法&#xff0c;如果合法再判断是否完整。如果不合法&#xff0c;则使用用户默认的搜索引擎进行搜索。DNS域名解析获取URL对应的ip地址。&#xff08;首先看本地是否有缓存&…