[Fix] Fix the requirements and lazy register mmcls models. (#1275)
parent
46af7d3ed2
commit
6ea59bd846
|
@ -1,16 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.dataset import Compose, default_collate
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.registry import DefaultScope
|
||||
|
||||
import mmcls.datasets # noqa: F401
|
||||
if TYPE_CHECKING:
|
||||
from mmengine.model import BaseModel
|
||||
|
||||
|
||||
def inference_model(model: BaseModel, img: Union[str, np.ndarray]):
|
||||
def inference_model(model: 'BaseModel', img: Union[str, np.ndarray]):
|
||||
"""Inference image(s) with the classifier.
|
||||
|
||||
Args:
|
||||
|
@ -21,6 +19,11 @@ def inference_model(model: BaseModel, img: Union[str, np.ndarray]):
|
|||
result (dict): The classification results that contains
|
||||
`class_name`, `pred_label` and `pred_score`.
|
||||
"""
|
||||
from mmengine.dataset import Compose, default_collate
|
||||
from mmengine.registry import DefaultScope
|
||||
|
||||
import mmcls.datasets # noqa: F401
|
||||
|
||||
cfg = model.cfg
|
||||
# build the data pipeline
|
||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
|
|
|
@ -8,18 +8,14 @@ from pathlib import Path
|
|||
from typing import List, Union
|
||||
|
||||
from mmengine.config import Config
|
||||
from mmengine.runner import load_checkpoint
|
||||
from mmengine.utils import get_installed_path
|
||||
from modelindex.load_model_index import load
|
||||
from modelindex.models.Model import Model
|
||||
|
||||
import mmcls.models # noqa: F401
|
||||
from mmcls.registry import MODELS
|
||||
|
||||
|
||||
class ModelHub:
|
||||
"""A hub to host the meta information of all pre-defined models."""
|
||||
_models_dict = {}
|
||||
__mmcls_registered = False
|
||||
|
||||
@classmethod
|
||||
def register_model_index(cls,
|
||||
|
@ -56,6 +52,7 @@ class ModelHub:
|
|||
Returns:
|
||||
modelindex.models.Model: The metainfo of the specified model.
|
||||
"""
|
||||
cls._register_mmcls_models()
|
||||
# lazy load config
|
||||
metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
|
||||
if metainfo is None:
|
||||
|
@ -77,12 +74,16 @@ class ModelHub:
|
|||
|
||||
return config_path
|
||||
|
||||
|
||||
# register models in mmcls
|
||||
mmcls_root = Path(get_installed_path('mmcls'))
|
||||
model_index_path = mmcls_root / '.mim' / 'model-index.yml'
|
||||
ModelHub.register_model_index(
|
||||
model_index_path, config_prefix=mmcls_root / '.mim')
|
||||
@classmethod
|
||||
def _register_mmcls_models(cls):
|
||||
# register models in mmcls
|
||||
if not cls.__mmcls_registered:
|
||||
from mmengine.utils import get_installed_path
|
||||
mmcls_root = Path(get_installed_path('mmcls'))
|
||||
model_index_path = mmcls_root / '.mim' / 'model-index.yml'
|
||||
ModelHub.register_model_index(
|
||||
model_index_path, config_prefix=mmcls_root / '.mim')
|
||||
cls.__mmcls_registered = True
|
||||
|
||||
|
||||
def init_model(config, checkpoint=None, device=None, **kwargs):
|
||||
|
@ -109,10 +110,15 @@ def init_model(config, checkpoint=None, device=None, **kwargs):
|
|||
config.merge_from_dict({'model': kwargs})
|
||||
config.model.setdefault('data_preprocessor',
|
||||
config.get('data_preprocessor', None))
|
||||
|
||||
import mmcls.models # noqa: F401
|
||||
from mmcls.registry import MODELS
|
||||
|
||||
model = MODELS.build(config.model)
|
||||
if checkpoint is not None:
|
||||
# Mapping the weights to GPU may cause unexpected video memory leak
|
||||
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
|
||||
from mmengine.runner import load_checkpoint
|
||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
if not model.with_head:
|
||||
# Don't set CLASSES if the model is headless.
|
||||
|
@ -216,6 +222,7 @@ def list_models(pattern=None) -> List[str]:
|
|||
'resnet50_8xb256-rsb-a2-300e_in1k',
|
||||
'resnet50_8xb256-rsb-a3-100e_in1k']
|
||||
"""
|
||||
ModelHub._register_mmcls_models()
|
||||
if pattern is None:
|
||||
return sorted(list(ModelHub._models_dict.keys()))
|
||||
# Always match keys with any postfix.
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
matplotlib
|
||||
modelindex
|
||||
numpy
|
||||
packaging
|
||||
rich
|
||||
|
|
Loading…
Reference in New Issue