pull/790/head
weishengyu 2021-06-04 22:39:10 +08:00
parent b23e72b17c
commit 5dc93ac013
1 changed files with 3 additions and 12 deletions

View File

@ -33,20 +33,11 @@ class CombinedMetrics(nn.Layer):
metric_params = config[metric_name]
self.metric_func_list.append(eval(metric_name)(**metric_params))
def __call__(self,
similarities_matrix,
query_img_id,
gallery_img_id,
x=None,
label=None):
def __call__(self, **kwargs):
metric_dict = OrderedDict()
for idx, metric_func in enumerate(self.metric_func_list):
if x is None:
metric_dict.update(metric_func(x, label))
else:
metric_dict.update(
metric_func(similarities_matrix, query_img_id,
gallery_img_id))
metric_dict.update(metric_func(**kwargs))
return metric_dict