fix mmcls get classes (#215)

* fix mmcls get classes

* resolve comment

* resolve comment
pull/224/head
RunningLeon 2022-03-09 10:16:49 +08:00 committed by GitHub
parent 937985e8f0
commit 120f4ac89f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 12 deletions

View File

@ -4,13 +4,12 @@ from typing import List, Sequence, Union
import mmcv
import numpy as np
import torch
from mmcls.datasets import DATASETS
from mmcls.models.classifiers.base import BaseClassifier
from mmcv.utils import Registry
from mmdeploy.codebase.base import BaseBackendModel
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
load_config)
get_root_logger, load_config)
def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs):
@ -150,20 +149,35 @@ def get_classes_from_config(model_cfg: Union[str, mmcv.Config]):
Returns:
list[str]: A list of string specifying names of different class.
"""
model_cfg = load_config(model_cfg)[0]
from mmcls.datasets import DATASETS
module_dict = DATASETS.module_dict
model_cfg = load_config(model_cfg)[0]
data_cfg = model_cfg.data
if 'train' in data_cfg:
module = module_dict[data_cfg.train.type]
elif 'val' in data_cfg:
module = module_dict[data_cfg.val.type]
elif 'test' in data_cfg:
module = module_dict[data_cfg.test.type]
else:
raise RuntimeError(f'No dataset config found in: {model_cfg}')
def _get_class_names(dataset_type: str):
dataset = data_cfg.get(dataset_type, None)
if (not dataset) or (dataset.type not in module_dict):
return None
return module.CLASSES
module = module_dict[dataset.type]
if module.CLASSES is not None:
return module.CLASSES
return module.get_classes(dataset.get('classes', None))
class_names = None
for dataset_type in ['val', 'test', 'train']:
class_names = _get_class_names(dataset_type)
if class_names is not None:
break
if class_names is None:
logger = get_root_logger()
logger.warning(f'Use generated class names, because \
it failed to parse CLASSES from config: {data_cfg}')
num_classes = model_cfg.model.head.num_classes
class_names = [str(i) for i in range(num_classes)]
return class_names
def build_classification_model(model_files: Sequence[str],