一、什么是断点续训:
中断的地方,继续训练。与加载预训练权重有什么区别呢?区别在于优化器参数和学习率变了。
二、如何实现“断点续训”
我们需要使用checkpoint
方法保存,模型权重,优化器权重,训练轮数。
保存模型,优化器权重可以理解,保存训练轮数是为了获得中断时的学习率。
由于在中断的时候,我们保存了中断时的模型权重,优化器权重,训练轮数,所以再次训练,加载这些参数,便可以继续训练。
实现流程:
(1)断点训练开关设置
# -------------------#
# 断点续训
# -------------------#
resume = True
resume_weights = os.path.join(save_dir, name_last_weights)
(2)使用checkpoint方式保模型权重,优化器权重,训练轮数
# -----------------------------------------------#
# 保存最后一轮模型权重,优化器权重,训练轮数
# -----------------------------------------------#
last_ckpt = {'epoch': epoch, 'model': save_state_dict, 'optimizer': optimizer.state_dict(), 'loss': val_loss}
torch.save(last_ckpt, os.path.join(save_dir, name_last_weights))
(3)模型权重,训练轮数加载
Init_Eoch = ...
model = YourModel()
# -------------------#
# 断点续训
# -------------------#
if resume:
if args.resume_weights != '':
Init_Epoch = torch.load(args.resume_weights, map_location=device)['epoch']
model.load_state_dict(torch.load(args.resume_weights, map_location=device)['model'])
(4)优化器权重加载
optimizer = optim.AdamW(model.parameters(), lr=0.0001)
# -------------------#
# 断点续训
# -------------------#
if resume:
if args.resume_weights != '':
optimizer.load_state_dict(torch.load(args.resume_weights, map_location=device)['optimizer'])
三、完整“断点续训”框架
# -------------------#
# 断点续训
# -------------------#
resume = True
resume_weights = os.path.join(save_dir, name_last_weights)
Init_Eoch = ...
model = YourModel()
# -------------------#
# 断点续训
# -------------------#
if resume:
if args.resume_weights != '':
Init_Epoch = torch.load(args.resume_weights, map_location=device)['epoch']
model.load_state_dict(torch.load(args.resume_weights, map_location=device)['model'])
optimizer = optim.AdamW(model.parameters(), lr=0.0001)
# -------------------#
# 断点续训
# -------------------#
if resume:
if args.resume_weights != '':
optimizer.load_state_dict(torch.load(args.resume_weights, map_location=device)['optimizer'])
# -----------------------------------------------#
# 保存最后一轮模型权重,优化器权重,训练轮数
# -----------------------------------------------#
last_ckpt = {'epoch': epoch, 'model': save_state_dict, 'optimizer': optimizer.state_dict(), 'loss': val_loss}
torch.save(last_ckpt, os.path.join(save_dir, name_last_weights))
四、实际应用
从第50
轮开始训练,训练到第103
轮,中断训练。
loss
变化:
检测变化:
从第104
轮继续训练,训练到第162
轮,中断训练。
loss
变化:
检测变化:
从第163
轮继续训练,训练到第320
轮,中断训练。
loss
变化:
检测变化:
从第321
轮继续训练,训练到第1000
轮,中断训练。
loss
变化:
检测变化: