ML-Decoder: Scalable and Versatile Classification Head

news2025/1/18 18:03:38

1、引言

论文链接:https://openaccess.thecvf.com/content/WACV2023/papers/Ridnik_ML-Decoder_Scalable_and_Versatile_Classification_Head_WACV_2023_paper.pdf

        因为 transformer 解码器分类头[1] 在少类别多标签分类数据集上表现得很好,但由于其查询复杂度为 O(n^2),n 为类别数量,故 transformer 解码器分类头对于多类别数据集是不可行的,且 transformer 解码器分类头只适用于多标签分类任务,故 Tal Ridnik 等引入了一种新的基于多头注意力机制的分类头——ML-Decoder[2]。ML-Decoder 可以用于单标签分类、多标签分类和多标签 ZSL(zero shot learning) 任务,它提供更好的精度-速度 trade-off,可以用于上万类别的数据集,可以作为各种分类头的 drop-in 替代品,结合词查询可以用于 ZSL。

2、方法

        ML-Decoder 流如图 1 右所示,相对于  transformer 解码器分类头,ML-Decoder 有一下改变。

图1  transformer-decoder vs. ML-Decoder

2.1  移除自注意力机制

        通过删除自注意力机制将 ML-Decoder 的查询复杂度由 O(n^2) 降至 O(n),并未影响表示能力。

2.2  组解码

        为了使查询数量与类别数量无关,使用固定的 k 组查询,而不是一个类别对应一个查询。在前馈神经网络后,通过组全连接层在将每个组查询扩展到 g=n/k 个输出的同时池化嵌入维度。如图 2 所示。

图2  组全连接方案(g=4)

2.3  固定查询        

        查询总是被输入到一个多头注意力层,该注意力层会先对查询应用一个可学习的投影计算。因此,将查询权重设置为可学习的是多余的——可学习的投影可以将任何固定值查询转换为可学习查询获得的任何值。

3、模块介绍

3.1  Cross-Attention

        Cross-Attention 的核心其实就是多头注意力机制,输入的 q 为固定查询,k 和 v 均为图像嵌入。Cross-Attention 和 Feed-Forward 模块构成所谓的 TransformerDecoder(Layer),python 代码如下所示:

class TransformerDecoder(nn.Module):
    def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm0 = nn.LayerNorm(d_model)

        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)

        # Implementation of Feedforward model
        self.feed_forward = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model)
        )

        self.norm1 = nn.LayerNorm(d_model)

    def forward(self, tgt: Tensor, memory: Tensor) -> Tensor:
        tgt = tgt + self.dropout(tgt)
        tgt = self.norm0(tgt)
        tgt0 = self.multihead_attn(tgt, memory, memory)[0]
        tgt = tgt + self.dropout(tgt0)
        tgt0 = self.feed_forward(tgt)
        tgt = tgt + self.dropout(tgt0)
        return self.norm1(tgt)

3.2  Group Fully Connected Pooling  

        Group Fully Connected Pooling的目的是将每个组查询扩展到 g=n/k 个输出的同时池化嵌入维度。即将每组查询结果与对应的可学习的 (hidde_dim, g) 维矩阵相乘,python 代码如下所示:

class GroupFC(object):
    def __init__(self, groups: int):
        self.groups = groups

    def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):
        """
        计算每组类的 logits 值(未加偏置)
        :param h: shape=(b, groups, hidden_dim)
        :param duplicate_pooling: shape=(groups, hidden_dim, duplicate_factor), duplicate_factor 每组的类别数
        :param out_extrap: shape=(b, groups, duplicate_factor)
        :return:
        """
        for i in range(h.shape[1]):
            h_i = h[:, i, :]
            w_i = duplicate_pooling[i, :, :]
            out_extrap[:, i, :] = torch.matmul(h_i, w_i)

4、总结

        作者开源的 ML-Decoder 的 python 实现代码在:https://github.com/Alibaba-MIIL/ML_Decoder/blob/main/src_files/ml_decoder/ml_decoder.py

        论文[2] 在 paper with code 上的战绩如图 3 所示,表现还是不错的。

图3  来自论文[2] 的结果

        由于当参数 zsl != 0 时 wordvec_proj 的输入 query_embed = None,本人还未学习过 ZSL 领域,且使用该代码时报错(zsl = 0,当然应该是我的原因,但懒得排错了),于是参考作者的代码写了一个 MLDecoder 类(只考虑 zsl = 0),剩下的代码如下所示。

class MLDecoder(nn.Module):
    """
    Args:
        groups: 查询/类别组数
        hidden_dim: Transformer 解码器特征维度
        in_dim: 输入 tensor 特征维度(CNN 编码器输出为通道数,Transformer 编码器输出为最后一个维度)
    """

    def __init__(self, num_classes, groups, in_dim=2048, hidden_dim=768, mlp_dim=2048, nhead=8, dropout=0.1):
        super().__init__()
        self.proj = nn.Linear(in_dim, hidden_dim)

        # non-learnable queries
        self.query_embed = nn.Embedding(groups, hidden_dim)
        self.query_embed.requires_grad_(False)

        self.num_classes = num_classes

        self.decoder = TransformerDecoder(d_model=hidden_dim, nhead=nhead, dim_feedforward=mlp_dim, dropout=dropout)

        # group fully-connected
        self.duplicate_factor = math.ceil(num_classes / groups)  # 每组类别数量,math.ceil: 向上取整
        self.duplicate_pooling = torch.nn.Parameter(torch.zeros((groups, hidden_dim, self.duplicate_factor)))
        self.duplicate_pooling_bias = torch.nn.Parameter(torch.zeros(num_classes))
        torch.nn.init.xavier_normal_(self.duplicate_pooling)
        self.group_fc = GroupFC(groups)

    def forward(self, x):
        # 确保解码器输入 shape 为 [b, h * w, c]
        if len(x.shape) == 4:
            x = x.flatten(2).transpose(1, 2)

        x = F.relu(self.proj(x), True)  # (b, h * w, hidden_dim)

        # Cross-Attention + Feed-Forward
        query_embed = self.query_embed.weight  # (groups, hidden_dim)
        # tensor.expend: 增大一个维度至指定大小, 不增大的维度为-1,例如将 shape 由 (b, n, c)->(b, 2n, c), 参数 size=(-1, 2n,-1)
        tgt = query_embed[None].expand(x.shape[0], -1, -1)  # (b, groups, hidden_dim)
        h = self.decoder(tgt, x)  # (b, groups, hidden_dim)

        # Group Fully Connected Pooling
        out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)
        self.group_fc(h, self.duplicate_pooling, out_extrap)
        h_out = out_extrap.flatten(1)[:, :self.num_classes]  # (b, num_classes)
        return h_out + self.duplicate_pooling_bias

参考文献

[1] Shilong Liu, Lei Zhang, Xiao Yang, Hang Su, and Jun Zhu. Query2label: A simple transformer way to multi-label classification. arXiv preprint arXiv:2107.10834, 2021.

[2] Tal Ridnik, Gilad Sharir, Avi Ben-Cohen, Emanuel Ben Baruch, and Asaf Noy. Ml-decoder: Scalable and versatile classification head. In IEEE/CVF Winter Conference on Applications of Computer Vision, WACV 2023, Waikoloa, HI, USA, January 2-7, 2023, pages 32–41. IEEE, 2023.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1557408.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【应用层协议原理】

文章目录 第二章 应用层2.1 应用层协议原理2.1.1 网络应用的体系结构2.1.2 客户-服务器(C/S)体系结构2.1.3 对等体(P2P)体系结构2.2.4 C/S和P2P体系结构的混合体2.2.5 进程通信问题1:对进程进行编址(addres…

厦门攸信技术亮相新技术研讨会,展现物流自动化解决方案新高度!

今日,厦门攸信信息技术有限公司受邀参加了一场备受行业关注的电子制造高端盛会——一步步新技术研讨会,凭借卓越的智能制造与物流自动化技术在会议中大放异彩。作为一家引领行业发展的企业,厦门攸信技术不仅展示了其深厚的技术底蕴&#xff0…

算法之美:堆排序原理剖析及应用案例分解实现

这段时间持续更新关于“二叉树”的专栏文章,关心的小伙伴们对于二叉树的基本原理已经有了初步的了解。接下来,我将会更深入地探究二叉树的原理,并且展示如何将这些原理应用到更广泛的场景中去。文章将延续前面文章的风格,尽量精炼…

数据结构 - 图

参考链接:数据结构:图(Graph)【详解】_图数据结构-CSDN博客 图的定义 图(Graph)是由顶点的有穷非空集合 V ( G ) 和顶点之间边的集合 E ( G ) 组成,通常表示为: G ( V , E ) ,其中, G 表示个图, V 是图 G…

深入理解 Hadoop 上的 Hive 查询执行流程

在 Hadoop 生态系统中,Hive 是一个重要的分支,它构建在 Hadoop 之上,提供了一个开源的数据仓库系统。它的主要功能是查询和分析存储在 Hadoop 文件中的大型数据集,包括结构化和半结构化数据。Hive 在数据查询、分析和汇总方面发挥…

Linux(CentOS7)安装 MySQL8

目录 下载 上传 解压 创建配置文件 初始化 MySQL 服务 启动 MySQL 服务 连接 MySQL 创建软链接 下载 官方地址: MySQL :: Download MySQL Community Serverhttps://dev.mysql.com/downloads/mysql/选择版本前需先看一下服务器的 glibc 版本 ldd --versio…

计算机视觉之三维重建(5)---双目立体视觉

文章目录 一、平行视图1.1 示意图1.2 平行视图的基础矩阵1.3 平行视图的极几何1.4 平行视图的三角测量 二、图像校正三、对应点问题3.1 相关匹配法3.2 归一化相关匹配法3.3 窗口问题3.4 相关法存在的问题3.5 约束问题 一、平行视图 1.1 示意图 如下图即是一个平行视图。特点&a…

基于Apriori关联规则的电影推荐系统(python实现)

基于Apriori关联规则的电影推荐系统 1、效果图 2、算法原理 Apriori算法是一种用于挖掘关联规则的频繁项集算法,它采用逐层搜索的迭代方法来发现数据库中项集之间的关系并形成规则。 其核心思想是利用Apriori性质来压缩搜索空间,即如果一个项集是非频繁的,那么它的所有父…

结构体类型,结构体变量的创建和初始化 以及结构中存在的内存对齐

一般结构体类型的声明 struct 结构体类型名 { member-list; //成员表列 }variable-list; //变量表列 例如描述⼀个学⽣: struct Stu { char name[20]; //名字 int age; //年龄 char sex[5]; //性别 }; //结构体变量的初始化 int main() { S…

Django详细教程(二) - 部门用户管理案例

文章目录 前言一、新建项目二、新建app三、设计表结构四、新建数据库五、新建静态文件六、部门管理1.部门展示2.部门添加3.部门删除4.部门编辑 七、模板继承八、用户管理1.辨析三种方法方法一:原始方法方法二:Form组件(简便)方法三:ModelForm…

macOS搭建php环境以及调试Symfony

macOS搭建php环境以及调试Symfony macOS搭建php环境以及调试Symfony 古老的传说运行环境快速前置安装环境 php 的安装安装 Xdebug 来调试 php如何找到你的 php.iniXdebug 安装成功 创建并调试的 Hello world 安装 PHP Debug 安装 Symfony 安装 Composer安装 Symfony CLI 创建 …

vue系统——v-html

<!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>v-html指令</title> </head> <body&…

微信小程序备案流程详细操作指南

自2023年9月1日起&#xff0c;所有新上架的微信小程序均需事先完成备案手续&#xff0c;方能成功上线。而对于已经上架的存量小程序&#xff0c;也需要在2024年3月31日前完成备案工作。若在规定时间内未完成备案&#xff0c;平台将依据备案相关规定&#xff0c;自2024年4月1日起…

大语言模型---强化学习

本文章参考&#xff0c;原文链接&#xff1a;https://blog.csdn.net/qq_35812205/article/details/133563158 SFT使用交叉熵损失函数&#xff0c;目标是调整参数使模型输出与标准答案一致&#xff0c;不能从整体把控output质量 RLHF&#xff08;分为奖励模型训练、近端策略优化…

java数组与集合框架(二)-- 集合框架,Iterator迭代器,list

集合框架&#xff1a; 用于存储数据的容器。 Java 集合框架概述 一方面&#xff0c;面向对象语言对事物的体现都是以对象的形式&#xff0c;为了方便对多个对象的操作&#xff0c;就要对对象进行存储。另一方面&#xff0c;使用Array存储对象方面具有一些弊端&#xff0c;而…

小狐狸ChatGPT付费AI创作系统V2.8.0独立版 + H5端 + 小程序前端

狐狸GPT付费体验系统的开发基于国外很火的ChatGPT&#xff0c;这是一种基于人工智能技术的问答系统&#xff0c;可以实现智能回答用户提出的问题。相比传统的问答系统&#xff0c;ChatGPT可以更加准确地理解用户的意图&#xff0c;提供更加精准的答案。同时&#xff0c;小狐狸G…

09_Web组件

文章目录 Web组件Listener监听器ServletContextListener执行过程 Filter过滤器Filter与Servlet的执行 案例&#xff08;登录案例&#xff09; 小结Web组件 Web组件 JavaEE的Web组件&#xff08;三大Web组件&#xff09;&#xff1a; Servlet → 处理请求对应的业务Listener →…

图论做题笔记:dfs

Leetcode - 797&#xff1a;所有可能的路径 题目&#xff1a; 给你一个有 n 个节点的 有向无环图&#xff08;DAG&#xff09;&#xff0c;请你找出所有从节点 0 到节点 n-1 的路径并输出&#xff08;不要求按特定顺序&#xff09; graph[i] 是一个从节点 i 可以访问的所有节…

公司官网怎么才会被百度收录

在互联网时代&#xff0c;公司官网是企业展示自身形象、产品与服务的重要窗口。然而&#xff0c;即使拥有精美的官网&#xff0c;如果不被搜索引擎收录&#xff0c;就无法被用户发现。本文将介绍公司官网如何被百度收录的一些方法和步骤。 1. 创建和提交网站地图 创建网站地图…

el-select的错误提示不生效、el-select验证失灵、el-select的blur规则失灵

发现问题 在使用el-select进行表单验证的时候&#xff0c;发现点击下拉列表没选的情况下&#xff0c;他不会提示没有选择选项的信息&#xff0c;我设置了rule如下 <!--el-select--><el-form-item label"等级" prop"level"><el-select v-m…