mmcls现在叫mmpretrain,以前叫mmclassification,这里为了统一称为mmcls。在基于MM框架的下游任务,例如检测(mmdetection)中可以使用mmcls中的backbone进行特征提取,但这就需要知道网络的参数以及输出特征的维度。本文简单介绍了在mmdetection中使用mmcls中backbone的方法。mmdetection中需要配置backbone、模型权重及neck的特征维度等信息。
1 查找mmcls预训练模型
查找mmcls支持的网络的方法有多种:
- 在mmpretrain的README中;
- 在modelzoo种查找模型库统计 — MMClassification 1.0.0rc6 文档 (mmpretrain.readthedocs.io)
- 直接看repo的configs目录下的列表
2 获取网络参数(配置)及预训练权重
找到网络后还需要找到网络参数及预训练权重。以replknet为例,获取网络参数可以直接看mmpretrain/configs/replknet中的配置文件,例如replknet-31B_32xb64_in1k.py,但配置文件可能并没有直接写模型配置信息,而是依赖其他配置文件,如下图中的replknet-31B_in1k.py
继续找到上述配置文件,可以看到网络配置:
预训练权重可以在mmpretrain/configs/replknet下的README中找到,例如:
https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_in21k-pre_3rdparty_in1k_20221118-54ed5c46.pth
预训练权重也可以在modelzoo中查找:
3 获取特征输出维度
首先在modelzoo中查到已有模型的名称,然后使用mmcls.get_model获取模型,输出指定层的特征维度。
import torch
from mmcls import get_model, inference_model
inputs = torch.rand(16, 3, 224, 224)
# 构建模型
model_name = 'replknet-31B_in21k-pre_3rdparty_in1k'
model = get_model(model_name, pretrained=False, backbone=dict(out_indices=(0, 1, 2, 3)))
# model = get_model(model_name, pretrained=False, backbone=dict(out_scales=(0, 1, 2, 3))) # mvitv2
feats = model.extract_feat(inputs)
for feat in feats:
print(feat.shape)
可以看到输出为 [128, 256, 512, 1024]:
torch.Size([16, 128])
torch.Size([16, 256])
torch.Size([16, 512])
torch.Size([16, 1024])
4 mmdetection中使用
在mmdetection中修改配置文件中backbone,预训练权重和neck中的in_channels等信息。同时应该注意网络的优化器配置的参数。
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_in21k-pre_3rdparty_in1k-384px_20221118-76c92b24.pth' # noqa
model = dict(
backbone=dict(
_delete_=True,
type='mmcls.RepLKNet',
arch='31B',
out_indices=[0, 1, 2, 3],
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')),
neck=dict(
_delete_=True,
type='mmdet.FPN',
in_channels=[128, 256, 512, 1024],
out_channels=256,
num_outs=5))