diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index c370710b3..a331a8c46 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -153,7 +153,7 @@ class Engine(object): # build metric if self.mode == 'train' and "Metric" in self.config and "Train" in self.config[ - "Metric"]: + "Metric"] and self.config["Metric"]["Train"]: metric_config = self.config["Metric"]["Train"] if hasattr(self.train_dataloader, "collate_fn" ) and self.train_dataloader.collate_fn is not None: diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 7928ecb0e..0c803ccfd 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -61,7 +61,7 @@ class TopkAcc(AvgMetrics): self.avg_meters[f"top{k}"].update(metric_dict[f"top{k}"], x.shape[0]) - self.topk = filter(lambda k: k <= output_dims, self.topk) + self.topk = list(filter(lambda k: k <= output_dims, self.topk)) return metric_dict