update metrics

pull/1925/head
cuicheng01 2022-05-16 04:31:26 +00:00
parent 93e9970ede
commit 87a0ba6f31
2 changed files with 15 additions and 0 deletions

View File

@ -39,6 +39,7 @@ class CombinedMetrics(AvgMetrics):
eval(metric_name)(**metric_params))
else:
self.metric_func_list.append(eval(metric_name)())
self.reset()
def forward(self, *args, **kwargs):
metric_dict = OrderedDict()
@ -54,6 +55,10 @@ class CombinedMetrics(AvgMetrics):
def avg(self):
return self.metric_func_list[0].avg
def reset(self):
for metric in self.metric_func_list:
if hasattr(metric, "reset"):
metric.reset()
def build_metrics(config):
metrics_list = CombinedMetrics(copy.deepcopy(config))

View File

@ -33,6 +33,9 @@ class TopkAcc(AvgMetrics):
if isinstance(topk, int):
topk = [topk]
self.topk = topk
self.reset()
def reset(self):
self.avg_meters = {"top{}".format(k): AverageMeter("top{}".format(k)) for k in self.topk}
def forward(self, x, label):
@ -316,6 +319,9 @@ class HammingDistance(MultiLabelMetric):
def __init__(self):
super().__init__()
self.reset()
def reset(self):
self.avg_meters = {"HammingDistance": AverageMeter("HammingDistance")}
def forward(self, output, target):
@ -343,6 +349,10 @@ class AccuracyScore(MultiLabelMetric):
assert base in ["sample", "label"
], 'must be one of ["sample", "label"]'
self.base = base
self.reset()
def reset(self):
self.avg_meters = {"AccuracyScore": AverageMeter("AccuracyScore")}
def forward(self, output, target):
preds = super()._multi_hot_encode(output)