在做深度学习项目时,从头训练一个模型是需要大量时间和算力的,我们通常采用加载预训练权重的方法,而我们往往面临以下几种情况:
未修改网络,A与B一致
很简单,直接.load_state_dict()
net = ANet(num_classses = 5,init_weights=True)
net.to(device)
net.load_state_dict(torch.load('weight/B_weight.pth'))
修改了网络,A与B不一致
[pytorch官方文档](Search — PyTorch master documentation):
load_state_dict(state_dict, strict=True)
将 state_dict 中的参数和缓冲区复制到此模块及其后代中。如果 strict 为 True,则 state_dict 的键必须与该模块的 state_dict() 函数返回的键完全匹配。
state_dict是包含参数和持久缓冲区的字典,可以看出 strict默认为True,所以默认状态下是严格要求state_dict中的key与torch.nn.Module.state_dict返回的key完全一致的
load_state_dict()函数有两个返回值:
missing_keys 是包含缺失键的 str 列表
unexpected_keys 是包含意外键的 str 列表
方法一:
将strict改为false,加载键值相同的部分。
model = NET2()
state_dict = model.state_dict()
weights = torch.load(weights_path)['model_state_dict'] #读取预训练模型权重
model.load_state_dict(weights, strict=False) #strict
但是此时还存在一种情况:键值相同但shape不同,故应进行if…in…的判断:
ANet = torch.load('ANet.pt') # 加载预训练权重模型(.pt文件)参数
#现成的模型的话,如resnet50 = models.resnet50(pretrained=True)
#采用:pretrained_dict = resnet50().state_dict()
model = Model() # 创建模型
model_dict = model.state_dict() # 得到模型的参数字典
# 判断预训练模型中网络的模块是否修改后的网络中也存在,并且shape相同,如果相同则取出
pretrained_dict = {k: v for k, v in ANet.items() if k in model_dict and (v.shape == model_dict[k].shape)}
# 更新修改之后的 model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的 state_dict
model.load_state_dict(model_dict, strict=False)
方法二:
1.将权重导入原模型,之后在加载后的原模型基础上进行修改。
2.修改权重文件参数,再进行导入
适用于改动不大的模型