问题描述:多机多卡训练保存了optimizer.pt文件,但是该文件在被读取时显示已经损坏。
原来的报错:
Traceback (most recent call last):
File "/mnt/petrelfs/tongjingqi/train-moe/smoe/entrypoint/cpt_fpt.py", line 280, in <module>
main()
File "/mnt/petrelfs/tongjingqi/train-moe/smoe/utils/notification.py", line 146, in wrapper_sender
raise ex
File "/mnt/petrelfs/tongjingqi/train-moe/smoe/utils/notification.py", line 95, in wrapper_sender
value = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/mnt/petrelfs/tongjingqi/train-moe/smoe/entrypoint/cpt_fpt.py", line 265, in main
train_result = trainer.train(resume_from_checkpoint=checkpoint)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/petrelfs/tongjingqi/anaconda3/envs/moeenv/lib/python3.11/site-packages/transformers/trainer.py", line 1539, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/mnt/petrelfs/tongjingqi/anaconda3/envs/moeenv/lib/python3.11/site-packages/transformers/trainer.py", line 1752, in _inner_training_loop
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/petrelfs/tongjingqi/anaconda3/envs/moeenv/lib/python3.11/site-packages/transformers/trainer_callback.py", line 353, in on_train_begin
return self.call_event("on_train_begin", args, state, control)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/petrelfs/tongjingqi/anaconda3/envs/moeenv/lib/python3.11/site-packages/transformers/trainer_callback.py", line 397, in call_event
result = getattr(callback, event)(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/petrelfs/tongjingqi/train-moe/smoe/callbacks/save_model.py", line 134, in on_train_begin
File "/mnt/petrelfs/tongjingqi/anaconda3/envs/moeenv/lib/python3.11/site-packages/torch/serialization.py", line 798, in load
with _open_zipfile_reader(opened_file) as opened_zipfile:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/petrelfs/tongjingqi/anaconda3/envs/moeenv/lib/python3.11/site-packages/torch/serialization.py", line 283, in __init__
super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: PytorchStreamReader failed reading zip archive: invalid header or archive is corrupted
对bug进行最小化的复现:
>>> import torch
>>> torch.load("outputs/cpt-moe-fpt-test_lr_change-1811148/optimizer.pt")
outputs/cpt-moe-fpt-test_lr_change-1811148/optimizer.pt _is_zipfile: True
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/mnt/petrelfs/tongjingqi/anaconda3/envs/moeenv/lib/python3.11/site-packages/torch/serialization.py", line 798, in load
with _open_zipfile_reader(opened_file) as opened_zipfile:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/petrelfs/tongjingqi/anaconda3/envs/moeenv/lib/python3.11/site-packages/torch/serialization.py", line 283, in __init__
super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: PytorchStreamReader failed reading zip archive: invalid header or archive is corrupted
造成错误的保存代码:
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
optimizer = kwargs['optimizer']
optimizer_state = optimizer.state_dict()
save_path = os.path.join(args.output_dir, OPTIMIZER_NAME)
torch.save(optimizer_state, save_path)
原因:在多机多卡训练时,每个卡的进程都会保存这个文件,如果多个卡同时写文件,就会造成文件损坏。
找出bug的方法:首先尝试用最简单的方法复现出bug,当时报错是在自定义文件读取函数中发生的,为了排除是自定义文件读取函数的问题,尝试直接使用命令行使用torch.load库函数读取文件,果然复现了这个bug。之后为了排除是文件保存函数torch.save()的问题,做了torch.save()的最小化实现,发现可以正常保存和读取,说明不是库函数版本等环境问题。于是将问题范围缩小到自定义文件保存函数上面。
最终发现在保存文件中插入的print函数执行了16次,说明文件被反复的写入,推测可能是多进程写入的冲突导致文件损坏。
修改方法:
在函数开头加入判断,只有全局rank为0的进程才执行保存,这样只会保存一份,就不会出现多进程同时写入冲突导致文件损坏的问题了。
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if os.environ.get("RANK", "0") == "0":
optimizer = kwargs['optimizer']
optimizer_state = optimizer.state_dict()
save_path = os.path.join(args.output_dir, OPTIMIZER_NAME)
torch.save(optimizer_state, save_path)
解决问题后,文件果然可以正常读: