Update __init__.py

pull/802/head
Felix 2021-06-08 11:51:03 +08:00 committed by GitHub
parent 3e1e25afca
commit 0f98709701
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 2 deletions

View File

@ -16,8 +16,7 @@ from paddle import nn
import copy
from collections import OrderedDict
from .metrics import TopkAcc, mAP, mINP, Recallk
from .metrics import TopkAcc, mAP, mINP, Recallk, RetriMetric
class CombinedMetrics(nn.Layer):
def __init__(self, config_list):
@ -25,13 +24,21 @@ class CombinedMetrics(nn.Layer):
self.metric_func_list = []
assert isinstance(config_list, list), (
'operator config should be a list')
self.retri_config = dict() # retrieval metrics config
for config in config_list:
assert isinstance(config,
dict) and len(config) == 1, "yaml format error"
metric_name = list(config)[0]
if metric_name in ["Recallk", "mAP", "mINP"]:
self.retri_config[metric_name] = config[metric_name]
continue
metric_params = config[metric_name]
self.metric_func_list.append(eval(metric_name)(**metric_params))
if self.retri_config:
self.metric_func_list.append(RetriMetric(self.retri_config))
def __call__(self, *args, **kwargs):
metric_dict = OrderedDict()
for idx, metric_func in enumerate(self.metric_func_list):