[Fix] Fix the requirements and lazy register mmcls models. (#1275)

pull/1240/head
Ma Zerun 2022-12-19 13:01:11 +08:00 committed by GitHub
parent 46af7d3ed2
commit 6ea59bd846
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 17 deletions

View File

@ -1,16 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
from typing import TYPE_CHECKING, Union
import numpy as np
import torch
from mmengine.dataset import Compose, default_collate
from mmengine.model import BaseModel
from mmengine.registry import DefaultScope
import mmcls.datasets # noqa: F401
if TYPE_CHECKING:
from mmengine.model import BaseModel
def inference_model(model: BaseModel, img: Union[str, np.ndarray]):
def inference_model(model: 'BaseModel', img: Union[str, np.ndarray]):
"""Inference image(s) with the classifier.
Args:
@ -21,6 +19,11 @@ def inference_model(model: BaseModel, img: Union[str, np.ndarray]):
result (dict): The classification results that contains
`class_name`, `pred_label` and `pred_score`.
"""
from mmengine.dataset import Compose, default_collate
from mmengine.registry import DefaultScope
import mmcls.datasets # noqa: F401
cfg = model.cfg
# build the data pipeline
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline

View File

@ -8,18 +8,14 @@ from pathlib import Path
from typing import List, Union
from mmengine.config import Config
from mmengine.runner import load_checkpoint
from mmengine.utils import get_installed_path
from modelindex.load_model_index import load
from modelindex.models.Model import Model
import mmcls.models # noqa: F401
from mmcls.registry import MODELS
class ModelHub:
"""A hub to host the meta information of all pre-defined models."""
_models_dict = {}
__mmcls_registered = False
@classmethod
def register_model_index(cls,
@ -56,6 +52,7 @@ class ModelHub:
Returns:
modelindex.models.Model: The metainfo of the specified model.
"""
cls._register_mmcls_models()
# lazy load config
metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
if metainfo is None:
@ -77,12 +74,16 @@ class ModelHub:
return config_path
# register models in mmcls
mmcls_root = Path(get_installed_path('mmcls'))
model_index_path = mmcls_root / '.mim' / 'model-index.yml'
ModelHub.register_model_index(
model_index_path, config_prefix=mmcls_root / '.mim')
@classmethod
def _register_mmcls_models(cls):
# register models in mmcls
if not cls.__mmcls_registered:
from mmengine.utils import get_installed_path
mmcls_root = Path(get_installed_path('mmcls'))
model_index_path = mmcls_root / '.mim' / 'model-index.yml'
ModelHub.register_model_index(
model_index_path, config_prefix=mmcls_root / '.mim')
cls.__mmcls_registered = True
def init_model(config, checkpoint=None, device=None, **kwargs):
@ -109,10 +110,15 @@ def init_model(config, checkpoint=None, device=None, **kwargs):
config.merge_from_dict({'model': kwargs})
config.model.setdefault('data_preprocessor',
config.get('data_preprocessor', None))
import mmcls.models # noqa: F401
from mmcls.registry import MODELS
model = MODELS.build(config.model)
if checkpoint is not None:
# Mapping the weights to GPU may cause unexpected video memory leak
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
from mmengine.runner import load_checkpoint
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if not model.with_head:
# Don't set CLASSES if the model is headless.
@ -216,6 +222,7 @@ def list_models(pattern=None) -> List[str]:
'resnet50_8xb256-rsb-a2-300e_in1k',
'resnet50_8xb256-rsb-a3-100e_in1k']
"""
ModelHub._register_mmcls_models()
if pattern is None:
return sorted(list(ModelHub._models_dict.keys()))
# Always match keys with any postfix.

View File

@ -1,4 +1,5 @@
matplotlib
modelindex
numpy
packaging
rich