update metrics
parent
93e9970ede
commit
87a0ba6f31
|
@ -39,6 +39,7 @@ class CombinedMetrics(AvgMetrics):
|
|||
eval(metric_name)(**metric_params))
|
||||
else:
|
||||
self.metric_func_list.append(eval(metric_name)())
|
||||
self.reset()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
metric_dict = OrderedDict()
|
||||
|
@ -54,6 +55,10 @@ class CombinedMetrics(AvgMetrics):
|
|||
def avg(self):
|
||||
return self.metric_func_list[0].avg
|
||||
|
||||
def reset(self):
|
||||
for metric in self.metric_func_list:
|
||||
if hasattr(metric, "reset"):
|
||||
metric.reset()
|
||||
|
||||
def build_metrics(config):
|
||||
metrics_list = CombinedMetrics(copy.deepcopy(config))
|
||||
|
|
|
@ -33,6 +33,9 @@ class TopkAcc(AvgMetrics):
|
|||
if isinstance(topk, int):
|
||||
topk = [topk]
|
||||
self.topk = topk
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.avg_meters = {"top{}".format(k): AverageMeter("top{}".format(k)) for k in self.topk}
|
||||
|
||||
def forward(self, x, label):
|
||||
|
@ -316,6 +319,9 @@ class HammingDistance(MultiLabelMetric):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.avg_meters = {"HammingDistance": AverageMeter("HammingDistance")}
|
||||
|
||||
def forward(self, output, target):
|
||||
|
@ -343,6 +349,10 @@ class AccuracyScore(MultiLabelMetric):
|
|||
assert base in ["sample", "label"
|
||||
], 'must be one of ["sample", "label"]'
|
||||
self.base = base
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.avg_meters = {"AccuracyScore": AverageMeter("AccuracyScore")}
|
||||
|
||||
def forward(self, output, target):
|
||||
preds = super()._multi_hot_encode(output)
|
||||
|
|
Loading…
Reference in New Issue