欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/133673820
在使用 LightningModule 框架训练模型时,因数据导致的训练错误,严重影响训练稳定性,因此需要使用 try-except 及时捕获错误。即 当错误发生时,在 training_step
异常返回 None,同时,on_before_zero_grad
也需要进行异常处理,处理 training_step
的异常返回 None。
同样的,validation_step
也可以这样处理。
源码如下:
class MyObject(pl.LightningModule):
def __init__(self, config, args):
# ...
def training_step_wrapper(self, batch, batch_idx, log_interval=10):
# train key process
def training_step(self, batch, batch_idx, log_interval=10):
"""
typically, each step costs 50 seconds
参考: https://github.com/Lightning-AI/lightning/pull/3566
"""
try:
res = self.training_step_wrapper(batch, batch_idx, log_interval)
return res
except Exception as e:
logger.info(f"[CL] training_step, exception: {e}")
return None
def on_before_zero_grad(self, *args, **kwargs):
try:
self.ema.update(self.model)
except Exception as e:
# 支持 training_step return None
logger.info(f"[CL] on_before_zero_grad, exception: {e}")
return
def validation_step_wrapper(self, batch, batch_idx):
# val key process
def validation_step(self, batch, batch_idx):
try:
self.validation_step_wrapper(batch, batch_idx)
except Exception as e:
logger.info(f"[CL] validation_step, exception: {e}")
return
常见错误如下
数组越界:
index 0 is out of bounds for dimension 0 with size 0
字典错误字段:
num_res = int(np_example["seq_length"])
KeyError: 'seq_length'
计算输入数值为空:
V, _, W = torch.linalg.svd(C)
free()异常:
free(): invalid next size (fast)
munmap_chunk()
空指针:
munmap_chunk(): invalid pointer