diff --git a/mmcls/apis/inference.py b/mmcls/apis/inference.py index 80ddf6f3..8e4125bc 100644 --- a/mmcls/apis/inference.py +++ b/mmcls/apis/inference.py @@ -37,7 +37,11 @@ def init_model(config, checkpoint=None, device='cuda:0', options=None): # Mapping the weights to GPU may cause unexpected video memory leak # which refers to https://github.com/open-mmlab/mmdetection/pull/6405 checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') - if 'CLASSES' in checkpoint.get('meta', {}): + if 'dataset_meta' in checkpoint.get('meta', {}): + # mmcls 1.x + model.CLASSES = checkpoint['meta']['dataset_meta']['classes'] + elif 'CLASSES' in checkpoint.get('meta', {}): + # mmcls < 1.x model.CLASSES = checkpoint['meta']['CLASSES'] else: from mmcls.datasets.categories import IMAGENET_CATEGORIES