这里写目录标题
- 前言:
- 计算mean recall的详细过程
- 1. **准备数据**:
- 2. **计算每个类别的recall**:
- 具体代码片段
- 准备groundtruth数据
- 准备预测数据
- 计算recall
- 计算mean recall
前言:
计算流程这里参考maskrcnn_benchmark/data/datasets/evaluation/vg/sgg_eval.py这个scene graph generation benchmark的github官网来完成相关的任务。
计算mean recall的详细过程
以下是如何利用预测三元组和groundtruth三元组计算mean recall的详细过程:
1. 准备数据:
使用如下两个变量来分别保存groundtruth三元组和predicate的三元组。
prepare_gt
方法会处理groundtruth三元组数据,并将其存储在一个字典中。prepare_pred
方法会处理预测三元组数据,并将其存储在一个字典中。
2. 计算每个类别的recall:
calculate_recall
方法会遍历所有的groundtruth和predicate数据,计算每个关系类别的recall。- 对于每个类别,计算公式为:
其中,TP是True Positives,FN是False Negatives。(这句的意思看这句代码就理解了,即:float(len(match)) / float(gt_rels.shape[0]
)也就是说,(正确匹配的三元组)/所有groundtruth三元组)
- 计算mean recall:
-
calculate_mean_recall
方法会计算所有类别的平均recall。 -
首先,计算每个类别的recall。
-
然后,计算所有类别recall的平均值:
其中,Recall_i是第i个关系类别的recall,N是类别的总数。
-
具体代码片段
以下是一些关键代码片段的解释:(这些片段是从github文件中专门拿出来的)
准备groundtruth数据
def prepare_gt(self):
for gt in self.gts:
gt_entry = {}
gt_entry['relations'] = gt['relations']
gt_entry['boxes'] = gt['boxes']
gt_entry['labels'] = gt['labels']
self.gt_entries.append(gt_entry)
准备预测数据
def prepare_pred(self):
for pred in self.preds:
pred_entry = {}
pred_entry['relations'] = pred['relations']
pred_entry['boxes'] = pred['boxes']
pred_entry['labels'] = pred['labels']
self.pred_entries.append(pred_entry)
计算recall
注意这里的TP和FN,就是上面公式了的TP和FN。
def calculate_recall(self):
for i, gt_entry in enumerate(self.gt_entries):
pred_entry = self.pred_entries[i]
for rel in gt_entry['relations']:
gt_rel = (rel[0], rel[1], rel[2])
if gt_rel in pred_entry['relations']:
self.tp[rel[2]] += 1
else:
self.fn[rel[2]] += 1
计算mean recall
def calculate_mean_recall(self):
recalls = []
for i in range(self.num_rel_classes):
if self.tp[i] + self.fn[i] > 0:
recalls.append(self.tp[i] / (self.tp[i] + self.fn[i]))
mean_recall = sum(recalls) / len(recalls)
return mean_recall
通过这些方法,你可以计算出每个类别的recall,并进一步计算出mean recall。希望这些解释对你理解这段代码有所帮助!如果有更多问题,请随时问我。