From 3affc481c862644c36b72c2f87aa02ebd173cb40 Mon Sep 17 00:00:00 2001 From: agim-a <2293614+agim-a@users.noreply.github.com> Date: Thu, 15 Apr 2021 10:19:23 -0400 Subject: [PATCH] [Fix] check for CLASSES in checkpoint meta (#207) - check for CLASSES in checkpoint meta when key meta does not exists --- mmcls/apis/inference.py | 2 +- tools/test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mmcls/apis/inference.py b/mmcls/apis/inference.py index 61d161e18..5483c86a4 100644 --- a/mmcls/apis/inference.py +++ b/mmcls/apis/inference.py @@ -36,7 +36,7 @@ def init_model(config, checkpoint=None, device='cuda:0', options=None): if checkpoint is not None: map_loc = 'cpu' if device == 'cpu' else None checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc) - if 'CLASSES' in checkpoint['meta']: + if 'CLASSES' in checkpoint.get('meta', {}): model.CLASSES = checkpoint['meta']['CLASSES'] else: from mmcls.datasets import ImageNet diff --git a/tools/test.py b/tools/test.py index f11473af4..8f0e10ea4 100644 --- a/tools/test.py +++ b/tools/test.py @@ -111,7 +111,7 @@ def main(): wrap_fp16_model(model) checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') - if 'CLASSES' in checkpoint['meta']: + if 'CLASSES' in checkpoint.get('meta', {}): CLASSES = checkpoint['meta']['CLASSES'] else: from mmcls.datasets import ImageNet