基于dinoV2分类模型修改

news2024/11/15 9:15:53

前言

dinoV2已经发布有一段时间了,faecbook豪言直接说前面的结构我们都不需要进行修改,只需要修改最后的全连接层就可以达到一个很好的效果。我们激动的揣摸了下自己激动的小手已经迫不及待了,这里我使用dinoV2进行了实验,来分享下实验结果。

  • dinoV2官方地址:github链接

一、模型介绍

1、预训练模型介绍

# dinov2_vits14_pretrain.pth 结构 
# s,b,l,g 主要是blocks 模块数量不同,

DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x NestedTensorBlock(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )
  )
  (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
  (head): Identity()
)

2、项目文件介绍

这里可以直接用hubconf.py文件里面进行调用,大家可以根据需求来进行选择。
在这里插入图片描述
在这里插入图片描述导入模型第一次都是从网络进行导入,对于国内用户可能不成功,这里大家可以修改为本地导入,传入已经下载好的预训练模型就行。这里给大家分享一个百度网盘的地址,提取码:mhdq,更多模型大家从官网下载。
导入代码如下:

  • 注意 : dinov2_vitl14 此为L模型大小导入方法,需要和模型大小进行对应。
# hubconf.py文件 中导入
model = dinov2_vitl14(weights={'LVD142M':'/media/wqg/minio/model/dinoV2/dinov2_vitl14_pretrain.pth'})

这里如果直接使用model.eval()
模型输出是(bs,embed_dim)如果是一张图,使用dinov2_vits14模型,则输出是 (1,384)
b,l,g,的embed_dim大家可以通过model.embed_dim进行查看。

3、模型输出

由于我实验的时候发现仅仅只使用x_norm_clstoken效果一直不理想,我这里用到了x_norm_regtokens。
这里可以参考github中的finetune中的导入方法。

# 实例化模型代码
from functools import partial
from dinov2.eval.linear import create_linear_input
from dinov2.eval.linear import LinearClassifier
from dinov2.eval.utils import ModelWithIntermediateLayers

model = dinov2_vits14(weights={'LVD142M':'./model/dinoV2/dinov2_vits14_pretrain.pth'})
autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float16)
self.feature_model = ModelWithIntermediateLayers( model, n_last_blocks=1, autocast_ctx=autocast_ctx).to(device)



# 实例化分类模型全连接层。
self.embed_dim = model.embed_dim
 # 100对应的是你需要分类的类别数量
self.classifier = LinearClassifier( self.embed_dim*2, use_n_blocks=1, use_avgpool=True, num_classes=100).to(device)  

# 冻结骨干网络
for param in model.feature_model.parameters():
    param.requires_grad = False

这里的self.feature_model 输出是有2个维度的,一个是x_norm_regtokens,shape为(bs,pach_h*pach_w,embed_dim),pach_h = input_h/14,pach_w = input_w/14.
另一个是x_norm_clstoken,shape为(bs,embed_dim)。一般情况下x_norm_clstoken用来分类就已经足够了

4、完整代码

from modeling.dinov2.eval.linear import LinearClassifier,create_linear_input
from modeling.dinov2.eval.utils import ModelWithIntermediateLayers
from functools import partial

from modeling.dinov2.hub.backbones import dinov2_vitb14, dinov2_vitg14, dinov2_vitl14, dinov2_vits14
from modeling.dinov2.hub.backbones import dinov2_vitb14_reg, dinov2_vitg14_reg, dinov2_vitl14_reg, dinov2_vits14_reg

class HubConf(nn.Module):
    def __init__(self,cfg,pretrain_choice = 'frozen'):
        super(HubConf, self).__init__()

        model_path = cfg.MODEL.PRETRAIN_PATH
        self.cfg = cfg
        self.base = dinov2_vits14(weights={'LVD142M':'./model/dinoV2/dinov2_vits14_pretrain.pth'})
        self.in_planes = self.base.embed_dim

        autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float16)
        self.feature_model = ModelWithIntermediateLayers(self.base, n_last_blocks=1, autocast_ctx=autocast_ctx)
        if pretrain_choice == 'frozen':
            for param in self.feature_model.parameters():
                param.requires_grad = False

        
        self.classifier = LinearClassifier(self.in_planes*2, use_n_blocks=1, use_avgpool=True, num_classes=cfg.MODEL.nc)


    def forward(self, x):
        global_feat = self.feature_model(x)  # ((b,256, embed_dim ),(b, embed_dim )) ((1,256,384),(1,384))
        out = self.classifier(global_feat)
        return  out

    def load_param(self, trained_path, device='cpu'):
        param_dict = torch.load(trained_path, map_location=device)
        for i in param_dict:
            #if 'classifier' in i:
            if i not in self.state_dict():
                print('not load param ', i)
                continue
            self.state_dict()[i].copy_(param_dict[i])


二、模型修改

这里骨干网络已经完全冻结,没有什么需要修改的,只需要对x_norm_regtokens进行添加卷积操作。

1、添加卷积

# neck结构,在输出后添加卷积的过程。

def autopad(k, p=None):  # kernel, padding
    # Pad to 'same'
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1,
                 act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        self.conv = nn.Conv1d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm1d(c2)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


class neck_dinov2(nn.Module):
    def __init__(self,c0,c1,nc,dropout= 0.5):
        super().__init__()
        self.conv1 = Conv(c0,c0*2)
        self.conv2 = Conv(c0*2,c0)
        self.drop = nn.Dropout(p=dropout, inplace=True)
        self.line = LinearClassifier(c1*2, use_n_blocks=1, use_avgpool=True, num_classes=nc)

    def forward(self,x):
        x1 = copy.copy(x[0][0])
        x1 = self.drop(self.conv2(self.conv1(x1)))
        x = [[x1,copy.copy(x[0][1])]]

        return self.line(x)

2、完整代码

我这里实验的是多头输出,大家单头的可以只实验一次neck结构就行。


class HubConf(nn.Module):
    def __init__(self,cfg,pretrain_choice = 'frozen'):
        super(HubConf, self).__init__()

        model_path = cfg.MODEL.PRETRAIN_PATH
        self.cfg = cfg
        self.base = eval(cfg.MODEL.NAME)(weights={'LVD142M':model_path})
        self.in_planes = self.base.embed_dim

        self.consize = int((cfg.INPUT.SIZE_TRAIN[0]/14)*(cfg.INPUT.SIZE_TRAIN[1]/14))

        autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float16)
        self.feature_model = ModelWithIntermediateLayers(self.base, n_last_blocks=1, autocast_ctx=autocast_ctx)
        if pretrain_choice == 'frozen':
            for param in self.feature_model.parameters():
                param.requires_grad = False

        self.line = LinearClassifier(self.in_planes * 2, use_n_blocks=1, use_avgpool=True, num_classes=100)

        self.country_cls = neck_dinov2(self.consize, self.in_planes, cfg.MODEL.nc1, dropout=cfg.MODEL.DROPOUT)  # 分类头1
        self.cn_cls = neck_dinov2(self.consize,self.in_planes, cfg.MODEL.nc2, dropout=cfg.MODEL.DROPOUT)  # 分类头2
        self.ct_cls = neck_dinov2(self.consize,self.in_planes, cfg.MODEL.nc3, dropout=cfg.MODEL.DROPOUT)  # 分类头3


    def forward(self, x):

        global_feat = self.feature_model(x)  # ((bs, pach_h*pach_w,embed_dim ),(bs, embed_dim ))    ((1,(224/14)*(224/14), 384),(1, 384))

        country_score = self.country_cls(global_feat)
        cn_score = self.cn_cls(global_feat)
        ct_score = self.ct_cls(global_feat)

        return (country_score, cn_score,ct_score)


    def load_param(self, trained_path, device='cuda:0'):
        param_dict = torch.load(trained_path, map_location=device)
        for i in param_dict:
            #if 'classifier' in i:
            if i not in self.state_dict():
                print('not load param ', i)
                continue
            self.state_dict()[i].copy_(param_dict[i])


三、实验自己的数据

1、车辆品牌分类。

  • 车辆品牌为单分类,目前类别有178类,输入图像大小为(126,252),输入图片为车头或者车辆尾部截图。
  • 使用单一的LinearClassifier分类效果不如resnet50的全训练效果,个人分析主要原因是车标太小了,全连接无法准确的学习到,所以我在x_norm_regtokens维度添加了卷积操作。
  • 可视化特征图。使用的骨干为dinov2_vitb14_pretrain,可视化效果如下

在这里插入图片描述

  • 可视化代码
import torch
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from sklearn.decomposition import PCA
import matplotlib
from dinov2.hub.backbones import dinov2_vitb14, dinov2_vitg14, dinov2_vitl14, dinov2_vits14


patch_h = 50
patch_w = 100
feat_dim = 384

transform = T.Compose([
    T.GaussianBlur(9, sigma=(0.1, 2.0)),
    T.Resize((patch_h * 14, patch_w * 14)),
    T.CenterCrop((patch_h * 14, patch_w * 14)),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

# dinov2_vits14 = torch.hub.load('', 'dinov2_vits14', source='local').cuda()
vits14 = torch.hub.load('', 'dinov2_vits14', weights={'LVD142M':'./model/dinoV2/dinov2_vits14_pretrain.pth'},source='local').cuda()

features = torch.zeros(4, patch_h * patch_w, feat_dim)
imgs_tensor = torch.zeros(4, 3, patch_h * 14, patch_w * 14).cuda()

img_path = f'/home/wqg/桌面/car_face_crop/face/face_0003600_111963.jpg'
img = Image.open(img_path).convert('RGB')
imgs_tensor[0] = transform(img)[:3]
with torch.no_grad():
    features_dict = vits14.forward_features(imgs_tensor)
    features = features_dict['x_norm_patchtokens']

features = features.reshape(4 * patch_h * patch_w, feat_dim).cpu()
pca = PCA(n_components=3)
pca.fit(features)
pca_features = pca.transform(features)
pca_features[:, 0] = (pca_features[:, 0] - pca_features[:, 0].min()) / (
            pca_features[:, 0].max() - pca_features[:, 0].min())

pca_features_fg = pca_features[:, 0] > 0.3
pca_features_bg = ~pca_features_fg

b = np.where(pca_features_bg)

pca.fit(features[pca_features_fg])
pca_features_rem = pca.transform(features[pca_features_fg])
for i in range(3):
    # transform using mean and std, I personally found this transformation gives a better visualization
    pca_features_rem[:, i] = (pca_features_rem[:, i] - pca_features_rem[:, i].mean()) / (
                pca_features_rem[:, i].std() ** 2) + 0.5

pca_features_rgb = pca_features.copy()
pca_features_rgb[pca_features_fg] = pca_features_rem
pca_features_rgb[b] = 0

pca_features_rgb = pca_features_rgb.reshape(4, patch_h, patch_w, 3)
plt.imshow(pca_features_rgb[0][..., ::-1])
plt.savefig('features.png')
plt.show()
plt.close()

2、车辆属性分类。

  • 车辆属性分类为多头输出,其中需要输出车辆类型,车辆颜色,车辆朝向等。
  • 只使用LinearClassifier作为每个分类头进行输出既可获得较好的效果。

四、结论

  • 使用dinoV2在大图上做细粒度分类效果不如整体训练效果,需要再通过卷积获得更小区域目标的强化学习。
  • 使用dinoV2在分类整体图像效果时,可以直接得到一个较好的效果,比原有的模型输出效果更好,无须再训练backbone部分,

相关引用链接:

  • dinoV2github: https://github.com/facebookresearch/dinov2
  • dinoV2 finetune:https://github.com/xuwangyin/dinov2-finetune/tree/main
  • dinoV2预训练权重:链接: https://pan.baidu.com/s/1ly7JpCu4Oi5gVBKixafXQg 提取码: mhdq

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1387632.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

c语言学习总结———编译和链接

再次来做一下学习总结,今天我们总结一下关于编译和链接的学习吧! 1. 翻译环境和运⾏环境 在ANSI C的任何⼀种实现中,存在两个不同的环境。 第1种是翻译环境,在这个环境中源代码被转换为可执⾏的机器指令。 第2种是执⾏环境&…

uni-app修改头像和个人信息

效果图 代码&#xff08;总&#xff09; <script setup lang"ts"> import { reqMember, reqMemberProfile } from /services/member/member import type { MemberResult, Gender } from /services/member/type import { onLoad } from dcloudio/uni-app impor…

postgresql迁移到mysql

1.工具方法&#xff1a;Navicat Premium16 2. 手工方法&#xff1a; 迁移流程 下面是将 Postgresql 数据库迁移到 MySQL 的步骤流程&#xff1a; 步骤描述1. 创建MySQL表结构在MySQL中创建与Postgresql中的表结构相同的表2. 导出Postgresql数据将Postgresql中的数据导出为SQ…

python下常用的爬虫模块

目录 一&#xff1a;requests 二&#xff1a;BeautifulSoup 三&#xff1a;Scrapy 四&#xff1a;Selenium 一&#xff1a;requests requests 是一个用于发送 HTTP 请求的 Python 库。它提供了简洁的 API 来发送各种类型的 HTTP 请求&#xff0c;如 GET、POST、PUT、DELETE…

SpringBoot异常处理(Whitelabel Error Page和自定义全局异常处理页面)和整合ajax异常处理

SpringBoot异常处理&#xff08;Whitelabel Error Page和自定义全局异常处理页面&#xff09;和整合ajax异常处理 1、springboot自带的异常处理页面Whitelabel Error Page SpringBoot默认的处理异常的机制&#xff1a;SpringBoot 默认的已经提供了一套处理异常的机制。一旦程…

【python】OpenCV—Histogram(9)

学习参考来自 Python下opencv使用笔记&#xff08;九&#xff09;&#xff08;图像直方图&#xff09; 更多学习笔记可以参考 【python】OpenCV—RGB&#xff08;1&#xff09;【python】OpenCV—Rectangle, Circle, Selective Search&#xff08;1.2&#xff09;【python】…

clickhouse join查询算法

算法对比&#xff1a; 使用方法&#xff1a; SELECT town,max(price) AS max_price,any(population) AS population FROM uk_xxx_paid JOIN uk_xxx_table ON lower(uk_price_paid.town) lower(uk_populations_table.city) GROUP BY town ORDER BY max_price DESC SETTINGS jo…

为什么我建议企业一定要自己的erp管理系统!

在商业世界中&#xff0c;企业就像是一艘船&#xff0c;需要在波涛汹涌的大海中稳稳地航行。然而&#xff0c;如果没有一套有效的管理系统&#xff0c;这艘船就可能迷失方向&#xff0c;甚至触礁沉没。对于那些没有引入ERP系统的企业来说&#xff0c;他们正面临着种种挑战。 信…

搭建储能监控云平台:实现能源管理的智能化

搭建储能监控云平台&#xff1a;实现能源管理的智能化 在全球能源变革的大背景下&#xff0c;储能技术的重要性日益凸显。储能监控云平台作为能源管理的智能解决方案&#xff0c;可以为企业提供全方位的储能系统监控与数据分析&#xff0c;提高能源利用率&#xff0c;降低能源成…

MiniTab的拟合回归模型的分析

拟合回归模型概述 使用拟合回归模型和普通最小二乘法可以描述一组预测变量和一个连续响应之间的关系。可以包括交互作用项和多项式项、执行逐步回归和变换偏斜数据。 例如&#xff0c;房地产评估人员想了解城市公寓与多个预测变量&#xff08;包括建筑面积、可用单元数量、建…

【算法与数据结构】343、LeetCode整数拆分

文章目录 一、题目二、解法三、完整代码 所有的LeetCode题解索引&#xff0c;可以看这篇文章——【算法和数据结构】LeetCode题解。 一、题目 二、解法 思路分析&#xff1a;博主做这道题的时候一直在思考&#xff0c;如何找到 k k k个正整数&#xff0c; k k k究竟为多少合适。…

迈向高效LLM微调:低秩适应(LoRA)技术的原理与实践

在快速发展的人工智能领域中&#xff0c;以高效和有效的方式使用大型语言模型&#xff08;LLM&#xff09;变得越来越重要。在本文中&#xff0c;您将学习如何以计算高效的方式使用低秩适应&#xff08;LoRA&#xff09;对LLM进行调整&#xff01; 为什么需要微调&#xff1f;…

【Java数据结构】03-二叉树,树和森林

4 二叉树、树和森林 重点章节&#xff0c;在选择&#xff0c;填空&#xff0c;综合中都有考察到。 4.1 掌握二叉树、树和森林的定义以及它们之间的异同点 1. 二叉树&#xff08;Binary Tree&#xff09; 定义&#xff1a; 二叉树是一种特殊的树结构&#xff0c;其中每个节点…

vue-ESlint代码规范及修复

1. 介绍 ESLint:是一个代码检查工具&#xff0c;用来检查你的代码是否符合指定的规则(你和你的团队可以自行约定一套规则)。 在创建项目时&#xff0c;我们使用的是 JavaScript Standard Style 代码风格的规则。 规范网址&#xff1a;https://standardjs.com/rules-zhcn.htm…

街机模拟游戏逆向工程(HACKROM)教程:[2]68K汇编的一些规则

指令中的符号(#,$,%) 在指令中&#xff0c;我们最常见到的符号有#和$。 这其中的"#"符号是告诉汇编程序&#xff0c;这个符号后面的数值为一个立即数&#xff0c;而不是一个偏移值或一个地址。立即数可以理解为"单纯的一个数值"。我们会在后面通过一些实…

Dtop环球嘉年华“全球Web 3.0商业场景应用峰会暨2024战略研讨会”曼谷圆满举办

Dtop环球嘉年华“全球Web 3.0商业场景应用峰会暨2024战略研讨会” &#xff08;Global Web 3.0 Business Scene Application Summit And 2024 Strategic Symposium&#xff09;在2024年1月12日于泰国曼谷举办&#xff0c;峰会以“全球Web 3.0商业场景应用生态”为主题&#xff…

vue3中,vue-echarts基本使用(柱状图、饼图、折线图)

注意&#xff1a;vue-echarts在使用前要先安装echarts&#xff0c;不要只安装vue-echarts这一个 echarts官网地址&#xff1a;Apache EChartsApache ECharts&#xff0c;一款基于JavaScript的数据可视化图表库&#xff0c;提供直观&#xff0c;生动&#xff0c;可交互&#xf…

SpringMVC中五种数据提交的方式

单个数据注入&#xff1a;在方法中声明一个和表单提交的参数名称相同的参数&#xff0c;由框架按照名称直接注入。对象封装注入&#xff1a;在方法中声明一个自定义的实体类参数&#xff0c;框架调用实体类中相应的setter方法注入属性值&#xff0c;只要保证实体类中成员变量的…

JAVAEE初阶 多线程进阶(二)

多线程进阶相关知识点 一.CAS1.1 CAS的原子类1.2 实现自旋锁1.3CAS中的ABA问题1.4 ABA问题的解决 二. callable接口三.reentrantLock3.1 reentrantLock与synchronized区别 四.信息量 semaphore五. CountDownLatch六. concurrentHashMap6.1 concurrentHashMap的优点 一.CAS CAS …

【SpringMVC】—— 如何配置使用SpringMVC(详细步骤)

目录 引言 使用 1、新建模块 2、导入坐标 3、创建SpringMVC控制器类 4、初始化SpringMVC环境 5、初始化Servlet容器&#xff0c;加载SpringMVC环境 6、配置运行 引言 SpringMVC是一种基于Java实现MVC模型的轻量级Web框架&#xff0c;SpringMVC是表现层(web层)的框架,也…