mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
fix mmcls get classes (#215)
* fix mmcls get classes * resolve comment * resolve comment
This commit is contained in:
parent
937985e8f0
commit
120f4ac89f
@ -4,13 +4,12 @@ from typing import List, Sequence, Union
|
|||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mmcls.datasets import DATASETS
|
|
||||||
from mmcls.models.classifiers.base import BaseClassifier
|
from mmcls.models.classifiers.base import BaseClassifier
|
||||||
from mmcv.utils import Registry
|
from mmcv.utils import Registry
|
||||||
|
|
||||||
from mmdeploy.codebase.base import BaseBackendModel
|
from mmdeploy.codebase.base import BaseBackendModel
|
||||||
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
|
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
|
||||||
load_config)
|
get_root_logger, load_config)
|
||||||
|
|
||||||
|
|
||||||
def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs):
|
def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs):
|
||||||
@ -150,20 +149,35 @@ def get_classes_from_config(model_cfg: Union[str, mmcv.Config]):
|
|||||||
Returns:
|
Returns:
|
||||||
list[str]: A list of string specifying names of different class.
|
list[str]: A list of string specifying names of different class.
|
||||||
"""
|
"""
|
||||||
model_cfg = load_config(model_cfg)[0]
|
from mmcls.datasets import DATASETS
|
||||||
|
|
||||||
module_dict = DATASETS.module_dict
|
module_dict = DATASETS.module_dict
|
||||||
|
model_cfg = load_config(model_cfg)[0]
|
||||||
data_cfg = model_cfg.data
|
data_cfg = model_cfg.data
|
||||||
|
|
||||||
if 'train' in data_cfg:
|
def _get_class_names(dataset_type: str):
|
||||||
module = module_dict[data_cfg.train.type]
|
dataset = data_cfg.get(dataset_type, None)
|
||||||
elif 'val' in data_cfg:
|
if (not dataset) or (dataset.type not in module_dict):
|
||||||
module = module_dict[data_cfg.val.type]
|
return None
|
||||||
elif 'test' in data_cfg:
|
|
||||||
module = module_dict[data_cfg.test.type]
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f'No dataset config found in: {model_cfg}')
|
|
||||||
|
|
||||||
|
module = module_dict[dataset.type]
|
||||||
|
if module.CLASSES is not None:
|
||||||
return module.CLASSES
|
return module.CLASSES
|
||||||
|
return module.get_classes(dataset.get('classes', None))
|
||||||
|
|
||||||
|
class_names = None
|
||||||
|
for dataset_type in ['val', 'test', 'train']:
|
||||||
|
class_names = _get_class_names(dataset_type)
|
||||||
|
if class_names is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
if class_names is None:
|
||||||
|
logger = get_root_logger()
|
||||||
|
logger.warning(f'Use generated class names, because \
|
||||||
|
it failed to parse CLASSES from config: {data_cfg}')
|
||||||
|
num_classes = model_cfg.model.head.num_classes
|
||||||
|
class_names = [str(i) for i in range(num_classes)]
|
||||||
|
return class_names
|
||||||
|
|
||||||
|
|
||||||
def build_classification_model(model_files: Sequence[str],
|
def build_classification_model(model_files: Sequence[str],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user