解决Vision Transformer在任意尺寸图像上微调的问题:使用timm库

news2025/1/20 10:46:56

解决Vision Transformer在任意尺寸图像上微调的问题:使用timm库

文章目录

          • 一、ViT的微调问题的本质
          • 二、Positional Embedding如何处理
            • 1,绝对位置编码
            • 2,相对位置编码
            • 3,对位置编码进行插值
          • 三、Patch Embedding Layer如何处理
          • 四、使用timm库来对任意尺寸进行微调

一、ViT的微调问题的本质

自从ViT被提出以来,在CV领域引起了新的研究热潮。理论上来说,Transformer的输入是一个序列,并且其参数主要来自于Transformer Block中的Linear层,因此Transformer可以处理任意长度的输入序列。但是在Vision Transformer中,由于需要将二维的图像通过Patch Embedding Layer映射为一个一维的序列,并且需要添加pos_embedding来保留位置信息。因此当patch_size和img_size发生改变时,会造成pos_embbeding的长度和Patch Embedding Layer的参数发生改变,从而导致预训练权重无法直接加载。更多有关ViT的实现细节和原理,可以参考Vision Transformer , 通用 Vision Backbone 超详细解读 (原理分析+代码解读)。

二、Positional Embedding如何处理

在Vision Transformer中有两种主流的编码方式:相对位置编码和绝对位置编码。

1,绝对位置编码

绝对位置编码依据token每个的绝对位置分配一个固定的值,其本质上是一组一维向量,有两种实现方式:

# 可学习的位置编码,ViT中使用, +1是因为有cls_token
	self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))


# 根据正余弦获取位置编码, Transformer中使用
def get_positional_embeddings(sequence_length,dim):
    result = torch.ones(sequence_length,dim)
    for i in range(sequence_length):
        for j in range(dim):
            result[i][j] = np.sin(i/(10000**(j/dim))) if j %2==0 else np.cos(i/(10000**((j-1)/dim)))
    return result

在forward过程中,绝对位置编码会在最开始直接和token相加:

	tokens += self.pos_embedding[:, :(n + 1)]
2,相对位置编码

相对位置编码,依据每个token的query相对于key的位置来分配位置编码,典型例子就是swin transformer,其本质是构建一个可学习的二维table,然后依据相对位置索引(x,y)来从table中取值,具体可以参考:有关swin transformer相对位置编码的理解

不过,在swin transformer中,query和key都是来自于同一个window,因此query和key的数量相同,构建位置编码的方式相对来说比较简单。如果query和key的数量不同,例如Focal Transformer中多层次的self-attention,其位置编码的方式可以参考:Focal Transformer。

对于相对位置编码的构造,还有一种方式是CrossFormer中提出的Dynamic Position Bias。其核心思想为构建一个MLP,其输入是二维的相对位置索引,输出是指定dim的位置偏置。这个和根据正余弦获取位置编码有点类似,只不过一个是依据一维的绝对坐标来生成位置编码,一个是依据二维的相对坐标来生成位置编码。

image-20231122172434141

在forward过程中,相对位置编码不会在一开始与token相加,而是在Attention Layer中以Bias的形式参与self-attention计算,核心代码如下:

        attn = (q @ k.transpose(-2, -1))
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)
3,对位置编码进行插值

综上,我们可以依据实现方式将位置编码分为两大类:可学习的位置编码(例如,ViT、Swin Transformer、Focal Transformer等)和生成式的位置编码(例如,正余弦位置编码和CrossFormer中的DPB)。更多有关位置编码的内容,可以参考论文:Rethinking and Improving Relative Position Encoding for Vision Transformer。对于生成式的位置编码而言,其编码方式与序列长度无关,因此当patch_size和img_size改变而造成num_patches改变时,仍然可以加载与位置编码有关的预训练权重。

但是,对于可学习的位置编码而言,num_patches改变时,无法直接加载与位置编码的预训练权重。以ViT为例,其参数一般是一个shape为[N+1, C]的tensor。与cls_token有关的位置编码不用改变,我们只需要关心与img patch相关的位置编码即可,其shape为[N, C]。当num_patches变为n时,所需要位置编码shape为[n, C]。这显然无法直接加载预训练权重。

Pytorch官方提供了一种思路,通过插值算法,来获取新的权重。我们不妨将原始的位置编码想象为一个shape为[ N , N , C \sqrt{N}, \sqrt{N}, C N ,N ,C]的tensor,将所需要的位置编码想象为一个shape为[ n , n , C \sqrt{n}, \sqrt{n}, C n ,n ,C]。这样我们就可以通过插值算法,将原始的权重映射到所需要的权重上。核心代码如下:

# (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
        pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)
        new_seq_length_1d = image_size // patch_size

        # Perform interpolation.
        # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
        new_pos_embedding_img = nn.functional.interpolate(
            pos_embedding_img,
            size=new_seq_length_1d,
            mode=interpolation_mode,
            align_corners=True,
        )

        # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
        new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)

        # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
        new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)

不过,Pytorch官方的这个代码,只能适配当num_patches是一个完全平方数的情况,因为需要开根号操作。实际上,num_patches一般是通过如下方式计算获得,理论上来说通过插值算法是可以适配到任意尺寸的num_patches的。

n u m _ p a t c h e s = i m g _ s i z e h p a t c h _ s i z e h i m g _ s i z e w p a t c h _ s i z e w (1) num\_patches=\frac{img\_size_h}{patch\_size_h}\frac{img\_size_w}{patch\_size_w} \tag{1} num_patches=patch_sizehimg_sizehpatch_sizewimg_sizew(1)

从上式可以看出,pos_embedding主要与img_size/patch_size有关,因此当把img_size和patch_size等比例缩放时,是不需要调整pos_embedding的。

在timm库中,提供了resample_abs_pos_embed函数,并将其集成到了VisionTransformer类中,所以我们在使用时无需自己考虑对位置编码进行插值处理。

三、Patch Embedding Layer如何处理

Patch Embedding Layer用于将二维的图像转为一维的输入序列,其实现方式通常有两种,如下所示:

### 基于MLP的实现方式
	patch_dim = in_channels * patch_height * patch_width
	self.patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), # 使用einops库
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

### 基于Conv2d的实现方式
	self.patch_embedding = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)

从这两种实现可以看出,Patch Embedding Layer的参数主要与patch_size和in_channels有关,而与img_size无关。Pytorch官方和Timm库都采用基于Conv2d的方式来实现,当patch_size和in_channels改变时,无法直接加载预训练权重。

Pytorch官方并未给出解决方案,timm库通过resample_patch_embed来解决这一问题,并且也集成到了VisionTransformer类中。在使用时,我们也不需要考虑手动对Patch Embedding Layer的权重进行调整。

四、使用timm库来对任意尺寸进行微调

首先需要安装timm库

pip install timm
# 如果安装的Pytorch2.0及以上版本,无需考虑一下步骤
# 如果是其他版本的Pytorch,需要下载functorch库
pip install functorch==版本号
# 具体版本号,需要依据自己环境中的pytorch版本来
# 例如:0.20.0对应Pytorch1.12.0,0.2.1对应Pytorch1.12.1
# 对应关系可以去github上查看:https://github.com/pytorch/functorch/releases

代码示例如下:

import timm
from timm.models.registry import register_model

@register_model # 注册模型
def vit_tiny_patch4_64(pretrained: bool = False, **kwargs) -> VisionTransformer:
    """ ViT-Tiny (Vit-Ti/16)
    """
    # 在model_args中对需要部分参数进行修改,此处调整了img_size, patch_size和in_chans
    model_args = dict(img_size = 64, patch_size=4, in_chans=1, embed_dim=192, depth=12, num_heads=3) 
    # vit_tiny_patch16_224是想要加载的预训练权重对应的模型
    model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) 
    return model

# 注册模型之后,就可以通过create_model来创建模型了
vit = timm.create_model('vit_tiny_patch4_64', pretrained = True) 

不过,由于预训练权重在线下载一般比较慢,可以通过pretrained_cfg来实现加载本地模型,代码如下:

    vit = timm.create_model('vit_tiny_patch4_64')
    cfg = vit.default_cfg
    print(cfg['url']) # 查看下载的url来手动下载
    cfg['file'] = 'vit-tiny.npz' # 这里修改为你下载的模型
    vit = timm.create_model('vit_tiny_patch4_64', pretrained=True, pretrained_cfg=cfg).cuda()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1238861.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Public Key Retrieval is not allowed客户端连接

使用DBeavear或navicat连接mysql服务器时,报错Public Key Retrieval is not allowed 原因: 客户端默认禁用 SSL/TLS 协议,客户端会使用服务器的公钥进行传输,默认情况下客户端不会主动去找服务器拿公钥,进而会出现…

汇编-CALL和RET指令

CALL指令调用一个过程, 使处理器从新的内存位置开始执行。过程使用RET(从过程返回) 指令将处理器转回到该过程被调用的程序点上。 CALL指令的动作: 1.将CALL指令的下一条指令地址压栈(作为子过程返回的地址) 2.将被调过程的地址复制到指令指针寄存器E…

编译 CUDA加速的 OpenCV-4.8.0 版本

文章目录 前言一、编译环境二、前期准备三、CMake编译四、VS编译OpenCV.sln五、问题 前言 由于项目需要用上CUDA加速的OpenCV,编译时也踩了不少坑,所以这里记录一下。 一、编译环境 我的编译环境是: Win10 RTX4050 CUDA-12.0 CUDNN 8.9.…

关于测试用例,你知道多少?

前言 在此之前我搜集一些关于测试用例的知识,后来在我们的群里专门定了一期讨论,来探讨测试用例,毕竟这是一个很大的话题,很难做到面面俱到,但我会尽量全面,用通俗的语言来说测试用例。 1、测试用例(test …

【React-Router】路由快速上手

1. 创建路由开发环境 # 使用CRA创建项目 npm create-react-app react-router-pro# 安装最新的ReactRouter包 npm i react-router-dom2. 快速开始 // index.jsimport React from react; import ReactDOM from react-dom/client; import ./index.css; import App from ./App; i…

传统企业如何实现数字化转型?如何加快企业数字化转型?

科技的发展给社会带来了各种变革,技术日新月异,很多传统的东西都被大众抛之脑后,在这个以技术和数据运营为导向的数字化时代,传统企业想要保持足够的核心竞争力,就必须跟上时代的步伐,进行企业数字化转型&a…

git clone慢的解决办法

在网站 https://www.ipaddress.com/ 分别搜索: github.global.ssl.fastly.net github.com 得到ip: 打开hosts文件 sudo vim /etc/hosts 在hosts文件末尾添加 140.82.114.3 github.com 151.101.1.194 github.global-ssl.fastly.net 151.101.65.194 g…

【RocketMq系列-02】RocketMq的架构解析和高性能设计

RocketMq系列整体栏目 内容链接地址【一】RocketMq安装和基本概念https://zhenghuisheng.blog.csdn.net/article/details/134486709【二】RocketMq的架构解析和高性能设计/font>https://zhenghuisheng.blog.csdn.net/article/details/134559514 RocketMq的架构解析和高性能设…

[点云分割] 基于最小切割的分割

效果&#xff1a; 代码&#xff1a; #include <iostream> #include <vector>#include <pcl/point_types.h> #include <pcl/io/pcd_io.h> #include <pcl/visualization/cloud_viewer.h> #include <pcl/filters/filter_indices.h> #include…

B站已经部分上线前台实名,如不同意实名,后期账号流量将收影响!

B站部分百万粉丝博主的主页显示账号运营人名字的政策是从10月31日开始的。当天&#xff0c;B站官方发布了《哔哩哔哩关于头部“自媒体”账号前台实名的公告》&#xff0c;表明了其前台实名制的实施计划。 B站部分上线前台实名的过程可以追溯到2021年。当时&#xff0c;中国政府…

报表系统是什么?如何快速帮助企业数字化转型?

在信息洪流中&#xff0c;企业需要应对日益增长的数据量和复杂业务环境&#xff0c;这需要借助科技手段来驾驭数据管理和决策分析。报表系统&#xff0c;作为企业决策的重要工具&#xff0c;就如同航海的罗盘&#xff0c;帮助企业在数据的海洋中快速定位&#xff0c;从而提高管…

最常用的5款报表系统

在这个信息化飞速发展的时代&#xff0c;报表系统已经成为了企业管理和决策的重要工具。随着市场的需求不断增长&#xff0c;报表系统也在不断地更新和完善。如今&#xff0c;市面上有数不尽的报表系统&#xff0c;但是哪款才是最常用的呢&#xff1f;接下来&#xff0c;我们将…

慕尼黑电子展Samtec Demo | 回环测试带来Samtec产品组合优异表现

【摘要/前言】 大家好&#xff01;Electronica虎家展台Demo系列回来咯。 实践出真知&#xff0c;再好的纸面数据都不如来一场实际的测试和演示。Samtec团队始终在努力为客户带来卓越的产品和优质服务。而这其中&#xff0c;Demo演示的存在至关重要。演示过程可以为大家带来了…

AppLink结合金蝶云星空作订单信息同步流程

此次通过AppLink&#xff0c;根据请求数据金蝶云星空做销售订单信息同步拉取 在获取订单信息前需要得到金蝶云星空授权&#xff0c;详细授权步骤可查看&#xff1a;金蝶云星空授权指南 根据请求数据在金蝶云星空保存销售订单 当webhook接收到数据时触发流程 步骤1&#xff…

维护工程师面经

文章目录 前言技能要求数据结构定义分类常用的数据结构 数据库原理数据的三级模式结构事务查询方式视图数据库范式 Java相关知识点总结 前言 本博客仅做学习笔记&#xff0c;如有侵权&#xff0c;联系后即刻更改 科普&#xff1a; 参考网址 技能要求 数据结构 参考网址 定…

虚拟机centos设置网络模式(桥接|NAT)

前言 桥接模式是通过物理网卡直接与外部网络建立联系的&#xff0c;而NAT模式则是通过虚拟网卡VMnet1或VMnet8通过宿主机共享IP与外部建立网络关系当需要将虚拟机资源共享给局域网用户使用时&#xff0c;宜采用桥接模式&#xff1b;当需要保护虚拟机资源&#xff0c;确保只能由…

数据结构【DS】特殊二叉树

完全二叉树 叶子结点只能出现在最下层和次下层, 最下层的叶子结点集中在树的左部完全二叉树中, 度为1的节点数 0个或者1个【计算时可以用这个快速计算, 配合&#x1d45b;0&#x1d45b;21】若n为奇数&#xff0c;则分支节点每个都有左右孩子&#xff1b;若n为偶数&#xff0…

Jmeter 压测实战保姆级入门教程

1、Jmeter本地安装 1.1、下载安装 软件下载地址&#xff1a; https://mirrors.tuna.tsinghua.edu.cn/apache/jmeter/binaries/ 选择一个压缩包下载即可 然后解压缩后进入bin目录直接执行命令jmeter即可启动 1.2 修改语言 默认是英文的&#xff0c;修改中文&#xff0c;点击…

探秘互联网医院系统的技术内幕:代码解析与创新

随着科技的飞速发展&#xff0c;互联网医院系统正日益改变着传统医疗服务的面貌。这些系统的背后&#xff0c;隐藏着精密而创新的技术。本文将深入研究互联网医院系统的技术内幕&#xff0c;透过代码解析&#xff0c;揭示这些系统如何实现医疗服务数字化的伟大使命。 1. 实时…

优质猫罐头有哪些品牌?分享5款宠物店自用值得推荐的猫罐头!

不知不觉已经开宠物店7年啦&#xff0c;店里的猫猫大大小小也算是尝试过很多品牌的猫罐头了。优质猫罐头有哪些品牌&#xff1f;在猫罐头的选购上一开始我也是踩了很多坑&#xff0c;各种踩雷。我深知猫罐头的各种门道&#xff0c;新手一不小心就会着道了。 优质猫罐头有哪些品…