这篇文章对mmdetection训练得到的模型权重,或者说checkpoints文件进行分析,一般模型保存在work-dir文件夹下,具体路径要参考训练用到的config,即配置文件。保存的模型一般是.pt的文件。
.pt模型文件读取方法
这种模型文件可以用torch.load()函数进行解析
import torch
pth_path = 'work-dir/your_check_point.pt'
model = torch.load(pth_path)
这里我们就可以看到这个model实际上不是什么复杂的东西,就是一个很大的dict
这个model一般包括三个key、value。
meta
第一个:meta,包含一些基本信息。就是告诉你这个模型是在什么背景下被训练得到的,用的mmdet是什么版本,随机种子seed是多少,config是什么,方便你复现复刻出来这个model
state_dict
这个是模型关键。一般网上下载的预训练权重只有这个,其是一个大的OrderedDict里面包含了这个模型按顺序得到的各层参数,看下图就明白个大概了。
一般要利用一个checkpoint(.pt的模型权重文件) ,也就是主要读取这里面的信息,来进行refine或者infer。
optimizer
里面存放的是优化器的状态,方便用这个.pt文件进行resume,即意外中断实验的时候进行继续实验,结合mmdet的train.py里的resume_from命令理解。