AlphaFold3 data_modules 模块的 OpenFoldDataLoader
类继承自 PyTorch 的 torch.utils.data.DataLoader
。该类主要对原始 DataLoader 做了批数据增强与控制循环迭代次数(recycling)相关的处理。
源代码:
class OpenFoldDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.stage = stage
self.generator = generator
self._prep_batch_properties_probs()
def _prep_batch_properties_probs(self):
keyed_probs = []
stage_cfg = self.config[self.stage]
max_iters = self.config.common.max_recycling_iters
if stage_cfg.uniform_recycling:
recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1)
]
else:
recycling_probs = [
0. for _ in range(max_iters + 1)
]
recycling_probs[-1] = 1.
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
keys, probs = zip(*keyed_probs)
max_len = max([len(p) for p in probs])
padding = [[0.] * (max_len - len(p)) for p in probs]
self.prop_keys = keys
self.prop_probs_tensor = torch.tensor(
[p + pad for p, pad in zip(probs, padding)],
dtype=torch.float32,
)
def _add_batch_properties(self, batch):
# TODO: gt_features might change
gt_features = batch.pop('gt_features', None)
samples = torch.multinomial(
self.prop_probs_tensor,
num_samples=1, # 1 per row
replacement=True,
generator=self.generator