[Fix] Fix the requirements and lazy register mmcls models. (#1275)
parent
46af7d3ed2
commit
6ea59bd846
|
@ -1,16 +1,14 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mmengine.dataset import Compose, default_collate
|
|
||||||
from mmengine.model import BaseModel
|
|
||||||
from mmengine.registry import DefaultScope
|
|
||||||
|
|
||||||
import mmcls.datasets # noqa: F401
|
if TYPE_CHECKING:
|
||||||
|
from mmengine.model import BaseModel
|
||||||
|
|
||||||
|
|
||||||
def inference_model(model: BaseModel, img: Union[str, np.ndarray]):
|
def inference_model(model: 'BaseModel', img: Union[str, np.ndarray]):
|
||||||
"""Inference image(s) with the classifier.
|
"""Inference image(s) with the classifier.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -21,6 +19,11 @@ def inference_model(model: BaseModel, img: Union[str, np.ndarray]):
|
||||||
result (dict): The classification results that contains
|
result (dict): The classification results that contains
|
||||||
`class_name`, `pred_label` and `pred_score`.
|
`class_name`, `pred_label` and `pred_score`.
|
||||||
"""
|
"""
|
||||||
|
from mmengine.dataset import Compose, default_collate
|
||||||
|
from mmengine.registry import DefaultScope
|
||||||
|
|
||||||
|
import mmcls.datasets # noqa: F401
|
||||||
|
|
||||||
cfg = model.cfg
|
cfg = model.cfg
|
||||||
# build the data pipeline
|
# build the data pipeline
|
||||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||||
|
|
|
@ -8,18 +8,14 @@ from pathlib import Path
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from mmengine.config import Config
|
from mmengine.config import Config
|
||||||
from mmengine.runner import load_checkpoint
|
|
||||||
from mmengine.utils import get_installed_path
|
|
||||||
from modelindex.load_model_index import load
|
from modelindex.load_model_index import load
|
||||||
from modelindex.models.Model import Model
|
from modelindex.models.Model import Model
|
||||||
|
|
||||||
import mmcls.models # noqa: F401
|
|
||||||
from mmcls.registry import MODELS
|
|
||||||
|
|
||||||
|
|
||||||
class ModelHub:
|
class ModelHub:
|
||||||
"""A hub to host the meta information of all pre-defined models."""
|
"""A hub to host the meta information of all pre-defined models."""
|
||||||
_models_dict = {}
|
_models_dict = {}
|
||||||
|
__mmcls_registered = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_model_index(cls,
|
def register_model_index(cls,
|
||||||
|
@ -56,6 +52,7 @@ class ModelHub:
|
||||||
Returns:
|
Returns:
|
||||||
modelindex.models.Model: The metainfo of the specified model.
|
modelindex.models.Model: The metainfo of the specified model.
|
||||||
"""
|
"""
|
||||||
|
cls._register_mmcls_models()
|
||||||
# lazy load config
|
# lazy load config
|
||||||
metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
|
metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
|
||||||
if metainfo is None:
|
if metainfo is None:
|
||||||
|
@ -77,12 +74,16 @@ class ModelHub:
|
||||||
|
|
||||||
return config_path
|
return config_path
|
||||||
|
|
||||||
|
@classmethod
|
||||||
# register models in mmcls
|
def _register_mmcls_models(cls):
|
||||||
mmcls_root = Path(get_installed_path('mmcls'))
|
# register models in mmcls
|
||||||
model_index_path = mmcls_root / '.mim' / 'model-index.yml'
|
if not cls.__mmcls_registered:
|
||||||
ModelHub.register_model_index(
|
from mmengine.utils import get_installed_path
|
||||||
model_index_path, config_prefix=mmcls_root / '.mim')
|
mmcls_root = Path(get_installed_path('mmcls'))
|
||||||
|
model_index_path = mmcls_root / '.mim' / 'model-index.yml'
|
||||||
|
ModelHub.register_model_index(
|
||||||
|
model_index_path, config_prefix=mmcls_root / '.mim')
|
||||||
|
cls.__mmcls_registered = True
|
||||||
|
|
||||||
|
|
||||||
def init_model(config, checkpoint=None, device=None, **kwargs):
|
def init_model(config, checkpoint=None, device=None, **kwargs):
|
||||||
|
@ -109,10 +110,15 @@ def init_model(config, checkpoint=None, device=None, **kwargs):
|
||||||
config.merge_from_dict({'model': kwargs})
|
config.merge_from_dict({'model': kwargs})
|
||||||
config.model.setdefault('data_preprocessor',
|
config.model.setdefault('data_preprocessor',
|
||||||
config.get('data_preprocessor', None))
|
config.get('data_preprocessor', None))
|
||||||
|
|
||||||
|
import mmcls.models # noqa: F401
|
||||||
|
from mmcls.registry import MODELS
|
||||||
|
|
||||||
model = MODELS.build(config.model)
|
model = MODELS.build(config.model)
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
# Mapping the weights to GPU may cause unexpected video memory leak
|
# Mapping the weights to GPU may cause unexpected video memory leak
|
||||||
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
|
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
|
||||||
|
from mmengine.runner import load_checkpoint
|
||||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||||
if not model.with_head:
|
if not model.with_head:
|
||||||
# Don't set CLASSES if the model is headless.
|
# Don't set CLASSES if the model is headless.
|
||||||
|
@ -216,6 +222,7 @@ def list_models(pattern=None) -> List[str]:
|
||||||
'resnet50_8xb256-rsb-a2-300e_in1k',
|
'resnet50_8xb256-rsb-a2-300e_in1k',
|
||||||
'resnet50_8xb256-rsb-a3-100e_in1k']
|
'resnet50_8xb256-rsb-a3-100e_in1k']
|
||||||
"""
|
"""
|
||||||
|
ModelHub._register_mmcls_models()
|
||||||
if pattern is None:
|
if pattern is None:
|
||||||
return sorted(list(ModelHub._models_dict.keys()))
|
return sorted(list(ModelHub._models_dict.keys()))
|
||||||
# Always match keys with any postfix.
|
# Always match keys with any postfix.
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
matplotlib
|
matplotlib
|
||||||
|
modelindex
|
||||||
numpy
|
numpy
|
||||||
packaging
|
packaging
|
||||||
rich
|
rich
|
||||||
|
|
Loading…
Reference in New Issue