准备工作
环境:
wsl2 Ubuntu 20.04
mindspore 2.0.0
python 3.8
pytorch 2.0.1 cpu
基于自己编写的mindspore TOOD项目和MMDetection实现的pytorch权重来做迁移,
- TOOD论文pytorch mmdetection实现
tood_r50_fpn_1x_coco权重
论文中的代码也是用mmdetection实现的 - TOOD mmdetection实现
观察上面两个实现的配置文件,区别只是分类损失用的不同,我们先对照TOOD mmdetection实现。 - MindSpore TOOD项目链接
该代码基于FCOS mindspore实现的,对网络命名进行了优化,更靠近官方的pytorch风格
基于MindSpore实现TOOD forward 结构
先搭模型,结构就是resnet50+fpn+toodhead。除了模型结构,还要注意head以及fpn部分的权值初始化要与mmdetection中的实现对齐,这个在后续训练时会有影响
- 两种框架下pad的区别需要注意,区别见MindSpore官方的迁移指南 ,我尽量使用显式表达,防止出错
- resent50 backbone在训练时加载预训练权重进行初始化
- mmdetection中FPN部分的初始化为xavier初始化,我在mindspore中采用更好的kaiming初始化
- head部分卷积和一般性的偏置使用normal初始化以及zeros初始化
- head部分的分类分支偏置采用的prob初始化
- 其他部分(BN,GN)的初始化两个框架相同
权重转换
迁移其实就是在做权重的键值映射对齐,有了FCOS的迁移经验,且对网络模型部分做了命名优化,做这个会快很多。
可参考的经验:
- FCOS权重迁移经验
- https://gitee.com/lirongxi4/pt2ms_convert
一个迁移脚本,通用性一般 - MindSpore官方的迁移指南
打印两种框架的权重的名称及shape进行比对,
利用文本对比网站进行对比:
根据shape可以看到顺序完全对齐了,注意scale在pt中是一个浮点数,而在ms中是一个1x1的tensor。FPN实现的运算顺序也在代码中专门调试过,只需完成名称转换即可。
虽然可以根据顺序直接转换,但为了稳定性,还是用字典映射的方法,总结的名称转换方式如下(pytorch的名称改为mindspore的):
def tood_pth2ckpt():
ms_ckpt = ms.load_checkpoint('tood_ms.ckpt') # mindspore FCOS保存的随机权重
pth = torch.load("/mnt/f/pretrain_weight/tood_r50_fpn_1x_coco.pth", map_location=torch.device('cpu')) # pytorch FCOS权重
match_pt_kv = {} # 匹配到的pt权重的name及value的字典
match_pt_kv_mslist = [] # 匹配到的pt权重的name及value的字典, mindspore加载权重需求的格式
not_match_pt_kv = {} # 未匹配到的pt权重的name及value
matched_ms_k = [] # 被匹配到的ms权重名称
'''一般性的转换规则'''
pt2ms = {'backbone': 'tood_body.backbone', # backbone部分
'neck': 'tood_body.fpn',
'bbox_head': 'tood_body.head',
'downsample': 'down_sample_layer',
}
'''conv层的转换规则, 一致,可忽略'''
pt2ms_conv = {
"weight": "weight",
"bias": "bias",
}
'''downsample层的转换规则, 有卷积层和bn层, 分别为0,1命名,在torch中weight重复'''
pt2ms_down = {
"0.weight": "0.weight",
"1.weight": "1.gamma",
"1.bias": "1.beta",
"running_mean": "moving_mean",
"running_var": "moving_variance",
}
'''BN层的转换规则'''
pt2ms_bn = {
"running_mean": "moving_mean",
"running_var": "moving_variance",
"weight": "gamma",
"bias": "beta",
}
'''GN层的转换规则'''
pt2ms_gn = {
"weight": "gamma",
"bias": "beta",
}
for i, v in pth['state_dict'].items():
pt_name = copy.deepcopy(i)
pt_value = copy.deepcopy(v)
'''一般性的处理'''
for k, v in pt2ms.items():
if k in pt_name:
pt_name = pt_name.replace(k, v)
'''conv层的转换规则, 一致,可忽略'''
'''FPN部分特别处理'''
if 'fpn' in pt_name:
pt_name = pt_name.replace('.conv', '')
'''下采样层特别处理'''
if 'down' in pt_name:
for k, v in pt2ms_down.items():
if k in pt_name:
pt_name = pt_name.replace(k, v)
'''BN层处理'''
if 'bn' in pt_name:
for k, v in pt2ms_bn.items():
if k in pt_name:
pt_name = pt_name.replace(k, v)
'''GN层处理'''
if 'gn' in pt_name:
for k, v in pt2ms_gn.items():
if k in pt_name:
pt_name = pt_name.replace(k, v)
'''改名成功,匹配到ms中的权重了,记录'''
if pt_name in ms_ckpt.keys():
if 'scale' in pt_name:
pt_value = torch.tensor([pt_value])
assert pt_value.shape == ms_ckpt[pt_name].shape
match_pt_kv[pt_name] = pt_value
match_pt_kv_mslist.append({'name': pt_name, 'data': ms.Tensor(pt_value.numpy(), ms_ckpt[pt_name].dtype)})
matched_ms_k.append(pt_name)
else:
not_match_pt_kv[i + ' ' + pt_name] = pt_value
'''打印未匹配的pt权重名称'''
print('\n\n-----------------------------未匹配的pt权重名称----------------------------')
print('----------原名称-------- ----------转换后名称---------')
for j, v in not_match_pt_kv.items():
print(j, np.array(v.shape))
'''打印未被匹配到的ms权重名称'''
print('\n\n---------------------------未被匹配到的ms权重名称----------------------------')
for j, v in ms_ckpt.items():
if j not in matched_ms_k:
print(j, np.array(v.shape))
print('end')
return match_pt_kv_mslist
输出:
-----------------------------未匹配的pt权重名称----------------------------
----------原名称-------- ----------转换后名称---------
backbone.layer4.1.bn3.num_batches_tracked tood_body.backbone.layer4.1.bn3.num_batches_tracked []
backbone.layer4.2.bn1.num_batches_tracked tood_body.backbone.layer4.2.bn1.num_batches_tracked []
backbone.layer4.2.bn2.num_batches_tracked tood_body.backbone.layer4.2.bn2.num_batches_tracked []
backbone.layer4.2.bn3.num_batches_tracked tood_body.backbone.layer4.2.bn3.num_batches_tracked []
......
---------------------------未被匹配到的ms权重名称----------------------------
end
剩下一些bn层的num_batches_tracked状态,不需要管
接下来进行输出对齐,推理到需要padding的卷积时发现了一些问题,
mindspore中
nn.Conv2d(64, 64, kernel_size=3, stride=1,
padding=1, pad_mode='pad', has_bias=False)
不等价于pytorch的
nn.Conv2d(64, 64, kernel_size=3, stride=1,
padding=1)
查阅资料按道理应该等价的啊,结果不等价
发现是跟ms中这样等价的, 先pad,再valid卷积:
pad1 = ms.nn.Pad(((0,0),(0,0),(1,1),(1,1)))
conv2 = ms.nn.Conv2d(64, 64, kernel_size=3, stride=1,
pad_mode='valid')
不解。。。
未完待续。。。