众所周知,使用大量数据预训练后的骨干网络可以提升整个模型的泛化能力,而我们如果将网络的骨干网络替换后则不能直接使用原来的权重。这个项目的作用是在你替换骨干网络后可以将网络预训练权重一并“偷”过来。
下给结论:将DeeplabV3+的骨干网络由Xception替换为mobilenetv3,使用预训练和不适用预训练,第1个epoch和前171个epoch的结果,以及前200轮的损失函数如下:
不适用预训练:
--->1轮miou如下:
mIoU: 16.11; mPA: 16.67; Accuracy: 96.68
使用预训练:
--->1轮miou如下:
mIoU: 40.53; mPA: 56.8; Accuracy: 97.09
不适用预训练:
--->171轮miou如下:
mIoU:64.68 ; mPA: 78.0; Accuracy: 98.82
使用预训练:
--->171轮miou如下:
mIoU:86.08 ; mPA: 92.54; Accuracy: 99.56
可以看到,将预训练权重转换并加载后,无论是对早期的梯度下降还是最终的模型性能都有显著提升。
一、权重文件的结构
pyotrch的权重文件为.pth。其由一个collections.OrderedDict构成,是一个序列化的集合。
其由一个单位由两个元素组成:该层的名称+该层的Tensor权重。通过对其遍历,可以将这两者分别取出:
for k, v in pretrained_dict.items():
而k的数据类型经过验证后为str型,在原版的自适应加载代码中,无法使用其他网络的预训练权重的原因是:权重名称不匹配,既然权重名仅为str型,那么使用replace将不一致的部分修改一致即可将权重文件添加到网络中。
二、修改权重名
首先,这种方法仅针对权重名有一定关联的情况,所以大家在移植骨干网络时尽可能使函数名和构造名与原网络保持一致。由于骨干网络位于模型开始的位置,我们输出第1层即可得知修改后的网络与原版网络的差别,本例中使用bubbliiiing的YoloV4中的mobilenetv3和deeplabV3+作为素材,将mobilenentv3网络移入deeplabv3+。
在python3中直接输出OrderedDict的某层会报错,因为其是一个序列化的数据机构,我们需要使用list(.)将其转换为无序列数据结构再输出,使用下面代码即可完成输出:
print('源模型文件格式为:{}'.format(list(pretrained_dict.keys())[0]))
print('目标模型文件格式为:{}'.format(list(model_dict.keys())[0]))
输出后可知,两个模型之间的差距为:
使用两个变量暂存两者之间的不同,在将这两个变量送入replace()函数即可完成层名的修改。如上图中,源模型和目标模型的差距是源模型比目标模型多一个model.,将其删除即可。
三、转换并保存
本例中使用的函数已经单独封装:
def model_converter(custom_model ,model_path):
model_dict = custom_model.state_dict()
pretrained_dict = torch.load(model_path, map_location = torch.device('cpu'))
load_key, no_load_key, temp_dict = [], [], {}
# 展示骨干网络的第一层
print('源模型文件格式为:{}'.format(list(pretrained_dict.keys())[0]))
print('目标模型文件格式为:{}'.format(list(model_dict.keys())[0]))
print('请将两模型之间不同的部分输入:')
orgStr = input('源模型:')
targetStr = input('目标模型:')
print('--->开始模型转换')
for k, v in pretrained_dict.items():
k = k.replace(orgStr,targetStr)
if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
temp_dict[k] = v
load_key.append(k)
else:
no_load_key.append(k)
# 将权重更新到模型中
model_dict.update(temp_dict)
custom_model.load_state_dict(model_dict)
# 保存模型
torch.save(custom_model.state_dict(), 'converted_weights.pth')
将实例化后的模型和源模型文件输入本函数即可,经过修改后匹配部分会自动转换,不匹配的部分会暂存在no_load_key中,可自行对其进行分析。经过测试,这种方法基本能完成骨干网络的转换。