deepspeed存在一个bug,即在训练时不保存调度器状态,因此如果训练中断后再重新开始训练,调度器还是会从头开始而不是接着上一个checkpoint的调度器状态来训练。这个bug在deepspeed的github中也有其他人提出:https://github.com/microsoft/DeepSpeed/issues/3875
因此我们需要写一个保存调度器状态的代码,才可以解决这个问题。
具体方法是加一个callback类,专门负责保存调度器的状态以及在训练重新开始时加载调度器的状态:
class SchedulerStateCallback(TrainerCallback):
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if os.environ.get("RANK", "0") == "0":
scheduler = kwargs['lr_scheduler']
scheduler_state = scheduler.state_dict()
save_path = os.path.join(args.output_dir, SCHEDULER_NAME)
torch.save(scheduler_state, save_path)
#优化器状态已经被deepspeed框架保存了,所以这里没必要再保存
# optimizer = kwargs['optimizer']
# optimizer_state = optimizer.state_dict()
# save_path = os.path.join(args.output_dir, OPTIMIZER_NAME)
# torch.save(optimizer_state, save_path)
#torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# 当训练开始时,尝试加载最近的调度器状态
# load_path = os.path.join(args.output_dir, OPTIMIZER_NAME)
# if os.path.exists(load_path):
# optimizer = kwargs['optimizer']
# optimizer_state = torch.load(load_path)
# optimizer.load_state_dict(optimizer_state)
load_path = os.path.join(args.output_dir, SCHEDULER_NAME)
if os.path.exists(load_path):
scheduler = kwargs['lr_scheduler']
scheduler_state = torch.load(load_path)
scheduler.load_state_dict(scheduler_state)
解决效果如下,我们可以看到,在chaeckpoint10重新开始训练的时候,学习率是接着之前的学习率开始的(5.5e-7),而不是从头开始(0.5e-7):