1 DeepMove
1.1 构造函数
1.2 初始化权重
1.3 forward
1.4 predict
def predict(self, batch):
score = self.forward(batch)
if self.evaluate_method == 'sample':
# build pos_neg_inedx
pos_neg_index = torch.cat((batch['target'].unsqueeze(1), batch['neg_loc']), dim=1)
score = torch.gather(score, 1, pos_neg_index)
return score
- 如果评估方法是
'sample'
,则执行以下步骤:- 构建正负样本索引 (
pos_neg_index
): 使用torch.cat
函数将批次中的目标位置 (batch['target']
) 与负样本位置 (batch['neg_loc']
) 结合。这里,目标位置通过unsqueeze(1)
方法添加一个维度以匹配负样本位置的维度,使其成为batch_size x (1 + num_negatives)
的形状。 - 选择得分: 使用
torch.gather
方法根据pos_neg_index
从得分张量中选择相关的得分。这一步骤的目的是从模型输出的所有可能位置的得分中,仅提取出与正样本和负样本对应的得分。
- 构建正负样本索引 (
1.5 calculate_loss
def calculate_loss(self, batch):
criterion = nn.NLLLoss().to(self.device)
scores = self.forward(batch)
return criterion(scores, batch['target'])
调用 criterion(scores, batch['target'])
来计算模型输出得分和批次中的目标标签 (batch['target']
) 之间的损失。