parent
b838c15465
commit
b457c393eb
|
@ -153,7 +153,7 @@ class Engine(object):
|
|||
|
||||
# build metric
|
||||
if self.mode == 'train' and "Metric" in self.config and "Train" in self.config[
|
||||
"Metric"]:
|
||||
"Metric"] and self.config["Metric"]["Train"]:
|
||||
metric_config = self.config["Metric"]["Train"]
|
||||
if hasattr(self.train_dataloader, "collate_fn"
|
||||
) and self.train_dataloader.collate_fn is not None:
|
||||
|
|
|
@ -61,7 +61,7 @@ class TopkAcc(AvgMetrics):
|
|||
self.avg_meters[f"top{k}"].update(metric_dict[f"top{k}"],
|
||||
x.shape[0])
|
||||
|
||||
self.topk = filter(lambda k: k <= output_dims, self.topk)
|
||||
self.topk = list(filter(lambda k: k <= output_dims, self.topk))
|
||||
|
||||
return metric_dict
|
||||
|
||||
|
|
Loading…
Reference in New Issue