fix mmcls get classes (#215)
* fix mmcls get classes * resolve comment * resolve commentpull/224/head
parent
937985e8f0
commit
120f4ac89f
|
@ -4,13 +4,12 @@ from typing import List, Sequence, Union
|
|||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcls.datasets import DATASETS
|
||||
from mmcls.models.classifiers.base import BaseClassifier
|
||||
from mmcv.utils import Registry
|
||||
|
||||
from mmdeploy.codebase.base import BaseBackendModel
|
||||
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
|
||||
load_config)
|
||||
get_root_logger, load_config)
|
||||
|
||||
|
||||
def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs):
|
||||
|
@ -150,20 +149,35 @@ def get_classes_from_config(model_cfg: Union[str, mmcv.Config]):
|
|||
Returns:
|
||||
list[str]: A list of string specifying names of different class.
|
||||
"""
|
||||
model_cfg = load_config(model_cfg)[0]
|
||||
from mmcls.datasets import DATASETS
|
||||
|
||||
module_dict = DATASETS.module_dict
|
||||
model_cfg = load_config(model_cfg)[0]
|
||||
data_cfg = model_cfg.data
|
||||
|
||||
if 'train' in data_cfg:
|
||||
module = module_dict[data_cfg.train.type]
|
||||
elif 'val' in data_cfg:
|
||||
module = module_dict[data_cfg.val.type]
|
||||
elif 'test' in data_cfg:
|
||||
module = module_dict[data_cfg.test.type]
|
||||
else:
|
||||
raise RuntimeError(f'No dataset config found in: {model_cfg}')
|
||||
def _get_class_names(dataset_type: str):
|
||||
dataset = data_cfg.get(dataset_type, None)
|
||||
if (not dataset) or (dataset.type not in module_dict):
|
||||
return None
|
||||
|
||||
return module.CLASSES
|
||||
module = module_dict[dataset.type]
|
||||
if module.CLASSES is not None:
|
||||
return module.CLASSES
|
||||
return module.get_classes(dataset.get('classes', None))
|
||||
|
||||
class_names = None
|
||||
for dataset_type in ['val', 'test', 'train']:
|
||||
class_names = _get_class_names(dataset_type)
|
||||
if class_names is not None:
|
||||
break
|
||||
|
||||
if class_names is None:
|
||||
logger = get_root_logger()
|
||||
logger.warning(f'Use generated class names, because \
|
||||
it failed to parse CLASSES from config: {data_cfg}')
|
||||
num_classes = model_cfg.model.head.num_classes
|
||||
class_names = [str(i) for i in range(num_classes)]
|
||||
return class_names
|
||||
|
||||
|
||||
def build_classification_model(model_files: Sequence[str],
|
||||
|
|
Loading…
Reference in New Issue