Update __init__.py
parent
3e1e25afca
commit
0f98709701
|
@ -16,8 +16,7 @@ from paddle import nn
|
|||
import copy
|
||||
from collections import OrderedDict
|
||||
|
||||
from .metrics import TopkAcc, mAP, mINP, Recallk
|
||||
|
||||
from .metrics import TopkAcc, mAP, mINP, Recallk, RetriMetric
|
||||
|
||||
class CombinedMetrics(nn.Layer):
|
||||
def __init__(self, config_list):
|
||||
|
@ -25,13 +24,21 @@ class CombinedMetrics(nn.Layer):
|
|||
self.metric_func_list = []
|
||||
assert isinstance(config_list, list), (
|
||||
'operator config should be a list')
|
||||
|
||||
self.retri_config = dict() # retrieval metrics config
|
||||
for config in config_list:
|
||||
assert isinstance(config,
|
||||
dict) and len(config) == 1, "yaml format error"
|
||||
metric_name = list(config)[0]
|
||||
if metric_name in ["Recallk", "mAP", "mINP"]:
|
||||
self.retri_config[metric_name] = config[metric_name]
|
||||
continue
|
||||
metric_params = config[metric_name]
|
||||
self.metric_func_list.append(eval(metric_name)(**metric_params))
|
||||
|
||||
if self.retri_config:
|
||||
self.metric_func_list.append(RetriMetric(self.retri_config))
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
metric_dict = OrderedDict()
|
||||
for idx, metric_func in enumerate(self.metric_func_list):
|
||||
|
|
Loading…
Reference in New Issue