diff --git a/mmcls/apis/inference.py b/mmcls/apis/inference.py index 4a2160424..e1e8a0be3 100644 --- a/mmcls/apis/inference.py +++ b/mmcls/apis/inference.py @@ -1,16 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Union +from typing import TYPE_CHECKING, Union import numpy as np import torch -from mmengine.dataset import Compose, default_collate -from mmengine.model import BaseModel -from mmengine.registry import DefaultScope -import mmcls.datasets # noqa: F401 +if TYPE_CHECKING: + from mmengine.model import BaseModel -def inference_model(model: BaseModel, img: Union[str, np.ndarray]): +def inference_model(model: 'BaseModel', img: Union[str, np.ndarray]): """Inference image(s) with the classifier. Args: @@ -21,6 +19,11 @@ def inference_model(model: BaseModel, img: Union[str, np.ndarray]): result (dict): The classification results that contains `class_name`, `pred_label` and `pred_score`. """ + from mmengine.dataset import Compose, default_collate + from mmengine.registry import DefaultScope + + import mmcls.datasets # noqa: F401 + cfg = model.cfg # build the data pipeline test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline diff --git a/mmcls/apis/model.py b/mmcls/apis/model.py index 11407cc0c..a4882102b 100644 --- a/mmcls/apis/model.py +++ b/mmcls/apis/model.py @@ -8,18 +8,14 @@ from pathlib import Path from typing import List, Union from mmengine.config import Config -from mmengine.runner import load_checkpoint -from mmengine.utils import get_installed_path from modelindex.load_model_index import load from modelindex.models.Model import Model -import mmcls.models # noqa: F401 -from mmcls.registry import MODELS - class ModelHub: """A hub to host the meta information of all pre-defined models.""" _models_dict = {} + __mmcls_registered = False @classmethod def register_model_index(cls, @@ -56,6 +52,7 @@ class ModelHub: Returns: modelindex.models.Model: The metainfo of the specified model. """ + cls._register_mmcls_models() # lazy load config metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower())) if metainfo is None: @@ -77,12 +74,16 @@ class ModelHub: return config_path - -# register models in mmcls -mmcls_root = Path(get_installed_path('mmcls')) -model_index_path = mmcls_root / '.mim' / 'model-index.yml' -ModelHub.register_model_index( - model_index_path, config_prefix=mmcls_root / '.mim') + @classmethod + def _register_mmcls_models(cls): + # register models in mmcls + if not cls.__mmcls_registered: + from mmengine.utils import get_installed_path + mmcls_root = Path(get_installed_path('mmcls')) + model_index_path = mmcls_root / '.mim' / 'model-index.yml' + ModelHub.register_model_index( + model_index_path, config_prefix=mmcls_root / '.mim') + cls.__mmcls_registered = True def init_model(config, checkpoint=None, device=None, **kwargs): @@ -109,10 +110,15 @@ def init_model(config, checkpoint=None, device=None, **kwargs): config.merge_from_dict({'model': kwargs}) config.model.setdefault('data_preprocessor', config.get('data_preprocessor', None)) + + import mmcls.models # noqa: F401 + from mmcls.registry import MODELS + model = MODELS.build(config.model) if checkpoint is not None: # Mapping the weights to GPU may cause unexpected video memory leak # which refers to https://github.com/open-mmlab/mmdetection/pull/6405 + from mmengine.runner import load_checkpoint checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') if not model.with_head: # Don't set CLASSES if the model is headless. @@ -216,6 +222,7 @@ def list_models(pattern=None) -> List[str]: 'resnet50_8xb256-rsb-a2-300e_in1k', 'resnet50_8xb256-rsb-a3-100e_in1k'] """ + ModelHub._register_mmcls_models() if pattern is None: return sorted(list(ModelHub._models_dict.keys())) # Always match keys with any postfix. diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 9f814b0ea..e0acc881b 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,4 +1,5 @@ matplotlib +modelindex numpy packaging rich