TPAMI 2024 带Token迁移的整体预训练Transformer金字塔网络——Fast-iTPN

news2025/1/13 15:58:55

Fast-iTPN: Integrally Pre-Trained Transformer Pyramid Network with Token Migration

https://github.com/sunsmarterjie/iTPN/blob/main

https://arxiv.org/pdf/2211.12735

Introduction

背景

近年来,视觉模型取得了两大进展,一是将Vision Transformer(ViT)作为网络主干,二是使用Masked Image Modeling(MIM)方法进行模型预训练。这两者的结合在多种下游任务中取得了先进性能,包括图像分类、目标检测和实例/语义分割。

挑战

然而,预训练与下游微调之间的迁移差距仍然存在。具体来说,下游任务(尤其是细粒度识别任务如检测和分割)需要层次化特征,但大多数预训练任务(如BEiT和MAE)都是基于简单的ViT,缺乏层次化设计。即使使用层次化ViT,预训练也仅影响主干网络,而特征金字塔(neck部分)未经过训练,这增加了下游任务微调的风险。

iTPN的提出

目的:为了缓解这一问题,本文提出了积分预训练的金字塔Transformer网络(iTPN),旨在联合优化网络主干和特征金字塔,从而最小化表示模型与下游任务之间的迁移差距。

方法:iTPN基于HiViT(一种MIM友好的层次Transformer),并为其配备了特征金字塔。通过两个关键技术贡献来联合优化主干和特征金字塔:1) 在预训练阶段插入特征金字塔进行重构,并在微调阶段重用这些权重;2) 提出掩码特征建模(MFM)以更好地预训练特征金字塔。

MFM的优势

MFM通过两个步骤来预训练特征金字塔:首先,将原始图像输入到移动平均主干中以计算中间目标;然后,使用金字塔各阶段的输出来重建这些中间目标。MFM与MIM互补,提高了重建和识别的准确性,并可以吸收来自预训练教师(如CLIP)的知识,进一步提高性能。

Fast-iTPN的改进

改进背景:使用层次化架构和特征金字塔时,全局自注意力导致的计算成本会累积。为了缓解这一问题,本文对iTPN进行了升级,提出了Fast-iTPN。

改进方法:Fast-iTPN通过两个灵活的设计来加速推理并减少内存开销:1) 令牌迁移(Token Migration),即根据相似性度量从主干中丢弃冗余令牌,并在没有自注意力操作的特征金字塔中补充这些令牌;2) 令牌收集(Token Gathering),通过引入少量收集令牌来聚合来自所有窗口的全局信息,从而用窗口注意力替换全局注意力,显著加速推理过程且性能损失可忽略不计。

ITPN

动机与背景

在视觉模型的发展中,尽管Vision Transformer(ViT)架构和Masked Image Modeling(MIM)方法结合取得了显著进展,但在上游预训练与下游微调之间仍存在较大的迁移差距。特别是对于需要层次特征的下游任务(如检测和分割),这一差距尤为明显。传统的预训练方法往往只针对骨干网络(backbone)进行优化,而忽略了特征金字塔(neck,如特征金字塔网络FPN)的预训练。

为了缓解这一问题,本文提出了积分预训练的金字塔Transformer网络(iTPN),旨在同时优化骨干网络和特征金字塔,从而最小化表示模型与下游任务之间的迁移差距。

技术贡献

首个预训练特征金字塔

插入特征金字塔:在预训练阶段,iTPN在骨干网络(如HiViT)后插入了一个特征金字塔。这样,在预训练过程中,特征金字塔就能够被优化,并在下游任务中复用其权重。

统一上下游的脖子:通过将特征金字塔整合到预训练阶段(用于重建)并在微调阶段复用其权重(用于识别),实现了上下游脖子的一致性。

掩码特征建模(MFM)

计算中间目标:MFM通过将一个移动平均的骨干网络应用于原始图像来计算中间目标。

多阶段监督:使用特征金字塔的每一阶段的输出来重建这些中间目标,从而实现对特征金字塔的多阶段监督。

适应教师模型:MFM还可以吸收来自预训练教师模型(如CLIP)的知识,进一步提高性能。

网络架构

iTPN的整体架构由两部分组成:骨干网络(如HiViT)和特征金字塔。骨干网络用于提取初步的视觉特征,而特征金字塔则对这些特征进行进一步的处理和聚合,以适应不同的下游任务。

骨干网络

本文采用HiViT作为骨干网络,HiViT通过引入两个基于MLP的阶段来构建层次特征,避免了在全局注意力阶段使用卷积操作或窗口注意力,从而保证了计算效率和与MIM的兼容性。

特征金字塔

特征金字塔通过逐步上采样和融合来自骨干网络的不同层次特征来构建多尺度特征表示。在预训练阶段,这些特征被用于重建由移动平均骨干网络计算的中间目标。

训练过程

iTPN的训练过程分为两个阶段:预训练和微调。

预训练

输入:原始图像被划分为一系列的图像块(tokens),其中一部分被随机掩码。

前向传播:掩码后的图像块通过骨干网络进行处理,生成初步的特征表示。这些特征随后被传递到特征金字塔中进行进一步的处理。

重建目标:特征金字塔的每一阶段都试图重建由移动平均骨干网络计算的中间目标。

损失函数:重建损失(例如,均方误差)用于监督特征金字塔的训练。

微调

在微调阶段,预训练的特征金字塔和骨干网络的权重被冻结或微调,以适应特定的下游任务(如图像分类、目标检测和语义分割)。

技术优势

积分预训练:同时优化骨干网络和特征金字塔,减少了迁移差距。

多阶段监督:MFM通过对特征金字塔的每一阶段进行监督,提高了特征的泛化能力。

适应性强:iTPN可以灵活地适应不同的预训练目标和下游任务。

Fast-iTPN

Fast-iTPN(快速整体预训练Transformer金字塔网络)是在iTPN基础上进行的改进,旨在通过两种灵活的设计来减少计算内存开销并加速推理过程。这两种设计分别是:Token迁移和Token聚合。这些设计不仅保持了模型在下游任务中的性能,还显著提升了模型的推理速度。

Fast-iTPN的主要设计

Token迁移(Token Migration)

Token迁移机制通过两个步骤来实现:

丢弃冗余Token:根据一个相似性度量(如余弦相似度),从主干网络(如HiViT)中丢弃一部分冗余的Token。这些被丢弃的Token通常是信息重复的或者对全局表示贡献较小的。

补充到特征金字塔:将被丢弃的Token补充到特征金字塔中,但不进行自注意力操作。这一步骤有效地利用了这些被丢弃的Token,并且因为特征金字塔中没有自注意力操作,所以补充Token的计算成本相对较低。

Token聚合(Token Gathering)

Token聚合通过引入少量的聚合Token来进一步减少全局自注意力操作的计算成本。这些聚合Token的作用是从所有窗口聚合全局信息,从而使得全局注意力可以被窗口注意力替代。具体来说:

聚合Token的作用:聚合Token在每个窗口中接收来自其他Token的信息,并将这些信息进行聚合,然后传递给下一层。通过这种方式,每个聚合Token都能够捕获到来自整个输入图像的全局信息。

减少计算成本:由于聚合Token只需要在窗口内进行自注意力操作,因此相比于全局自注意力,这种机制可以显著降低计算成本。同时,由于聚合Token的数量远少于输入Token的总数,因此总体计算量也大大减少。

技术细节

特征金字塔的设计

Fast-iTPN在HiViT的基础上构建了一个特征金字塔,该金字塔通过上采样和特征融合操作将不同层的特征图整合到一起。这种设计使得模型在不同尺度上都能够学习到丰富的特征表示,从而更好地适应下游任务的需求。

掩码特征建模(Masked Feature Modeling, MFM)

MFM是Fast-iTPN中用于预训练特征金字塔的方法。它通过以下两个步骤来实现:

计算中间目标:将原始图像输入到一个移动平均主干网络中,以计算得到中间目标。这些中间目标包含了丰富的特征信息,可以作为预训练过程中的监督信号。

使用金字塔各阶段输出进行重建:利用特征金字塔每个阶段的输出来重建这些中间目标。这一步骤不仅预训练了特征金字塔本身,还通过重建任务促进了金字塔内部不同层之间的特征交互。

模型训练

Fast-iTPN在训练过程中同时优化了主干网络和特征金字塔。由于特征金字塔在预训练阶段就被引入了,因此它在后续的微调阶段能够直接使用预训练得到的权重,从而减少了迁移差距并提升了下游任务的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2045048.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SAP LE学习笔记03 - 在IM(在库管理)中收货到仓库的流程,关联 WM移动Type与IM移动Type

上一章讲了 LE-WM的 WM和IM之间的关系。 SAP LE学习笔记02 - WM和库存管理(IM)之间的关系,保管Lot(Quant)-CSDN博客 本章继续将LE-WM的知识。 目录 1,在IM(在库管理)中收货到仓库的流程 a)&…

Golang | Leetcode Golang题解之第337题打家劫舍III

题目: 题解: func rob(root *TreeNode) int {val : dfs(root)return max(val[0], val[1]) }func dfs(node *TreeNode) []int {if node nil {return []int{0, 0}}l, r : dfs(node.Left), dfs(node.Right)selected : node.Val l[1] r[1]notSelected : …

EMC学习笔记5——辐射骚扰发射

辐射骚扰发射是基本的实验项目,目的是检验设备在工作时有没有产生意外的过强电磁辐射。 例如发电机,在工作时会产生意外的电磁波辐射,因为电子设备中隐藏了一些天线,这些隐藏的天线在辐射电磁波。 一、两种基本的天线结构 如前面…

智能小家电能否利用亚马逊VC搭上跨境快车?——WAYLI威利跨境助力商家

智能小家电行业在全球化背景下,正迎来前所未有的发展机遇。亚马逊为品牌商和制造商提供的一站式服务平台,为智能小家电企业提供了搭乘跨境快车、拓展国际市场的绝佳机会。 首先,亚马逊VC平台能够帮助智能小家电企业简化与亚马逊的合作流程&am…

Hive3:三种常用的复杂数据类型

一、Array类型 1、数据示例 2、实操 元数据 zhangsan beijing,shanghai,tianjin,hangzhou wangwu changchun,chengdu,wuhan,beijin创建表 CREATE TABLE myhive.test_array(name string, work_locations array<string>) ROW FORMAT DELIMITED FIELDS TERMINATED BY \t…

远程调用-OpenFeign

目录 1.RestTemplate存在问题 2.OpenFeign介绍 一、主要特点 二、应用场景 3.OpenFeign快速上手 3.1引入依赖 3.2添加注解 3.3编写OpenFeign的客户端 3.4远程调用 ​编辑3.5测试 4.OpenFeign参数传递 4.1传递单个参数 4.2传递多个参数 4.3传递对象 4.4传递JSO…

搬瓦工美国西海岸CN2 GIA VPS测评

很多人想知道搬瓦工美国CN2 GIA VPS系列怎么样&#xff1f;实际情况是&#xff1a;搬瓦工在美国西海岸的sanjose和losangeles运作着2.5Gbps-10Gbps CN2GIA带宽的VPS&#xff0c;底层虚拟为KVM&#xff0c;纯SSD阵列&#xff0c;支持在多机房之间切换。由于三网强制走同样的路由…

每日OJ_牛客_QQ2 微信红包

目录 牛客_QQ2 微信红包 解析代码 牛客_QQ2 微信红包 微信红包_牛客题霸_牛客网 解析代码 本题很多思路&#xff0c;第一种排序思路&#xff0c;如果一个数出现次数超过一半了&#xff0c;排序过后&#xff0c;必然排在中间&#xff0c;则最后遍历整个数组查看是否符合即可。…

在java环境下判断某个元素是否存在

1、在做web功能自动化时&#xff0c;有时需要通过判断某个元素是否存在来决定下一步的操作&#xff0c;但如果直接通过如下命令来进行查找时&#xff0c;如果不存在程序会报错&#xff0c;无法达到想要的效果&#xff0c;而java中也没有可以直接调用的工具类&#xff0c;因此就…

【java工具类】计算两个经纬度点之间的距离

计算两个经纬度点之间的距离 1、计算两个经纬度点之间的距离2、代码如下 1、计算两个经纬度点之间的距离 2、代码如下 public class DistanceCalculatorUtils {// 地球半径&#xff0c;单位为米private static final double EARTH_RADIUS 6371000;/*** 计算两个经纬度点之间的…

使用 HTTPS 代理在本地测试 AWS Lambdas

​ 欢迎来到雲闪世界。AWS Lambda 通常是在云中部署和执行代码的最简单方法之一&#xff0c;尤其是在使用sam CLI部署代码时。无服务器资源定义的简单性加上在本地打包资源并确保它们在 AWS 上运行的能力&#xff0c;提供了美妙的开发体验。 但有时&#xff0c;当构建和…

torch.roll()函数使用方法

官方文档在这里&#xff0c;说的比较清楚&#xff0c;但是举的例子不是很直观。我们再详细解释一下&#xff1a; torch.roll(input, shifts, dimsNone) → Tensor input&#xff1a;输入的tensorshifts&#xff1a;滚动的方向和长度&#xff0c;若为正&#xff0c;则向索引大…

Web 服务基础介绍

目录 1.1 互联网发展历程回顾 1.2 Web 服务介绍 1.2.1 Apache 经典的 Web 服务端 1.2.1.1 Apache prefork 模型 1.2.1.2 Apache worker 模型 1.2.1.3 Apache event模型 1.2.2 Nginx-高性能的 Web 服务端 1.2.3 用户访问体验和性能 1.2.3.1 用户访问体验统计 1.2.3.2 …

数字孪生技术框架:从数据到决策的桥梁

随着科技的飞速发展&#xff0c;数字孪生技术作为一种创新的信息化手段&#xff0c;正逐步渗透到各个行业领域&#xff0c;成为推动数字化转型的重要力量。数字孪生技术框架&#xff0c;作为支撑这一技术体系的核心架构&#xff0c;以其独特的层级结构&#xff0c;实现了从数据…

Matlab进阶绘图第66期—特征渲染的滑珠气泡图

特征渲染的滑珠气泡图是在滑珠散点图的基础上&#xff0c;添加散点大小与颜色参数&#xff0c;通过散点的尺寸与颜色表示两个额外的特征。 由于Matlab中没有现成的函数绘制特征渲染的滑珠气泡图&#xff0c;因此需要大家自行解决。 本文利用自己制作的BubbleScatter工具&…

奥威BI数据可视化展示:如何充分发挥数据价值

奥威BI数据可视化展示&#xff1a;如何充分发挥数据价值 在大数据时代&#xff0c;数据已成为企业最宝贵的资产之一。然而&#xff0c;仅仅拥有海量数据并不足以带来竞争优势&#xff0c;关键在于如何有效地挖掘、分析和展示这些数据&#xff0c;从而转化为有价值的洞察和决策…

更改Docker默认存储位置

Docker镜像和容器等数据默认保存在目录/var/lib/docker目录下&#xff0c;我们可以更改Docker 的默认存储位置&#xff0c;比如改到数据盘。需注决&#xff0c;变更存储位置时&#xff0c;原来的镜像和容器有可能丢失。 1、确认docker默认存放目录 [rootkfk12 ~]# docker inf…

算法的学习笔记—链表中环的入口结点(牛客JZ23)

&#x1f600;前言 在链表的操作中&#xff0c;环形链表是一个常见且需要特别处理的结构。当我们遇到一个包含环的链表时&#xff0c;如何找到环的入口结点是一个经典的问题。本文将详细介绍使用双指针技术来解决这一问题&#xff0c;并提供一个基于 Java 的实现代码。 &#x…

搭建内网开发环境(五)|基于nexus搭建npm私服

引言 在前面一篇教程中&#xff0c;通过 nexus 搭建了 maven 的私服&#xff0c;并通过脚本将本地的依赖文件批量上传到私服中&#xff0c;本文介绍通过 nexus 搭建 npm 私服&#xff0c;同样也通过脚本将本地依赖文件同步到私服中。 搭建内网开发环境&#xff08;一&#xff…

目标检测 | yolov6 原理和介绍

前言&#xff1a;目标检测 | yolov5 原理和介绍 后续&#xff1a; 1.简介 YOLOv6是由美团视觉智能部研发的一款目标检测框架&#xff0c;专注于工业应用&#xff0c;致力于提供极致的检测精度和推理效率。相较于YOLOv4和YOLOv5&#xff0c;YOLOv6在网络结构方面进行了深入优化…