[Fix] Fix for `inference_model` cannot get classes information in checkpoint. (#1093)

* Fix for MMCLS1.x not being able to get classes information in checkpoint during inference

Let MMCLS1.x get classes information from checkpoint during inference instead of using imagenet classes initialization

* Update inference.py
pull/1123/head
kitecats 2022-10-14 08:27:01 +08:00 committed by GitHub
parent 31c67ffed4
commit 06c919efc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 1 deletions

View File

@ -37,7 +37,11 @@ def init_model(config, checkpoint=None, device='cuda:0', options=None):
# Mapping the weights to GPU may cause unexpected video memory leak
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if 'CLASSES' in checkpoint.get('meta', {}):
if 'dataset_meta' in checkpoint.get('meta', {}):
# mmcls 1.x
model.CLASSES = checkpoint['meta']['dataset_meta']['classes']
elif 'CLASSES' in checkpoint.get('meta', {}):
# mmcls < 1.x
model.CLASSES = checkpoint['meta']['CLASSES']
else:
from mmcls.datasets.categories import IMAGENET_CATEGORIES