diff --git a/mmcls/apis/inference.py b/mmcls/apis/inference.py index 970b9eab..c2a29429 100644 --- a/mmcls/apis/inference.py +++ b/mmcls/apis/inference.py @@ -40,11 +40,11 @@ def init_model(config, checkpoint=None, device='cuda:0', options=None): if 'CLASSES' in checkpoint.get('meta', {}): model.CLASSES = checkpoint['meta']['CLASSES'] else: - from mmcls.datasets import ImageNet + from mmcls.datasets.categories import IMAGENET_CATEGORIES warnings.simplefilter('once') warnings.warn('Class names are not saved in the checkpoint\'s ' 'meta data, use imagenet by default.') - model.CLASSES = ImageNet.CLASSES + model.CLASSES = IMAGENET_CATEGORIES model.cfg = config # save the config in the model for convenience model.to(device) model.eval()