欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/133378772
在模型训练的过程中,加载数据部分,极其容易出现异常,以及不可控的因素,需要通过异常捕获的方式,及时处理,常用方式就是使用 collate_fn
,除此之外,还可以直接跳过错误样本,运行下一个样本进行补充。
PyTorch Dataset 类是一个抽象类,用于表示一个数据集,可以将数据和标签封装成一个可迭代的对象。要使用 Dataset 类,我们需要继承它,并实现两个方法:
__getitem__(self, index)
:根据给定的索引,返回数据集中的一个样本和对应的标签。__len__(self)
:返回数据集中的样本数量。
即:
- 将数据获取封装成单独函数。
- 使用
while True
持续监控,如果运行正确,即break
跳过。 - 如果运行失败,则打印日志,选择下一个样本运行,即
idx += 1
。 - 注意,索引不要溢出。
源码如下:
def __getitem__(self, idx):
# TODO: 解决数据异常问题,KeyError,尽量保持数据干净
while True:
try:
feats = self.getitem_wrapper(idx)
break
except Exception as e:
name = self.idx_to_chain_id(idx)
logger.error(f"err sample: {name} !!!")
idx += 1
idx = idx % len(self._chain_ids) # 避免溢出
return feats