diff --git a/mmdeploy/codebase/mmcls/deploy/classification_model.py b/mmdeploy/codebase/mmcls/deploy/classification_model.py index f7f3bbfe7..260d72a80 100644 --- a/mmdeploy/codebase/mmcls/deploy/classification_model.py +++ b/mmdeploy/codebase/mmcls/deploy/classification_model.py @@ -4,13 +4,12 @@ from typing import List, Sequence, Union import mmcv import numpy as np import torch -from mmcls.datasets import DATASETS from mmcls.models.classifiers.base import BaseClassifier from mmcv.utils import Registry from mmdeploy.codebase.base import BaseBackendModel from mmdeploy.utils import (Backend, get_backend, get_codebase_config, - load_config) + get_root_logger, load_config) def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs): @@ -150,20 +149,35 @@ def get_classes_from_config(model_cfg: Union[str, mmcv.Config]): Returns: list[str]: A list of string specifying names of different class. """ - model_cfg = load_config(model_cfg)[0] + from mmcls.datasets import DATASETS + module_dict = DATASETS.module_dict + model_cfg = load_config(model_cfg)[0] data_cfg = model_cfg.data - if 'train' in data_cfg: - module = module_dict[data_cfg.train.type] - elif 'val' in data_cfg: - module = module_dict[data_cfg.val.type] - elif 'test' in data_cfg: - module = module_dict[data_cfg.test.type] - else: - raise RuntimeError(f'No dataset config found in: {model_cfg}') + def _get_class_names(dataset_type: str): + dataset = data_cfg.get(dataset_type, None) + if (not dataset) or (dataset.type not in module_dict): + return None - return module.CLASSES + module = module_dict[dataset.type] + if module.CLASSES is not None: + return module.CLASSES + return module.get_classes(dataset.get('classes', None)) + + class_names = None + for dataset_type in ['val', 'test', 'train']: + class_names = _get_class_names(dataset_type) + if class_names is not None: + break + + if class_names is None: + logger = get_root_logger() + logger.warning(f'Use generated class names, because \ + it failed to parse CLASSES from config: {data_cfg}') + num_classes = model_cfg.model.head.num_classes + class_names = [str(i) for i in range(num_classes)] + return class_names def build_classification_model(model_files: Sequence[str],