diff --git a/ppcls/configs/Cartoonface/ResNet50_icartoon.yaml b/ppcls/configs/Cartoonface/ResNet50_icartoon.yaml index f18f3346b..9640b7be9 100644 --- a/ppcls/configs/Cartoonface/ResNet50_icartoon.yaml +++ b/ppcls/configs/Cartoonface/ResNet50_icartoon.yaml @@ -138,4 +138,4 @@ Metric: topk: [1, 5] Eval: - Recallk: - topk: 1 + topk: [1] diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index d2e66bc54..8ec438ece 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -69,7 +69,7 @@ class mINP(nn.Layer): class Recallk(nn.Layer): def __init__(self, topk=(1, 5)): super().__init__() - assert isinstance(topk, (int, list)) + assert isinstance(topk, (int, list, tuple)) if isinstance(topk, int): topk = [topk] self.topk = topk @@ -97,6 +97,9 @@ class RetriMetric(nn.Layer): gallery_img_id, self.max_rank) if "Recallk" in self.config.keys(): topk = self.config['Recallk']['topk'] + assert isinstance(topk, (int, list, tuple)) + if isinstance(topk, int): + topk = [topk] for k in topk: metric_dict["recall{}".format(k)] = all_cmc[k - 1] if "mAP" in self.config.keys():