[Feature] Support getting model from the name defined in the model-index file. (#1236)

* [Feature] Support getting model from the name defined in the model-index file.

* Add unit tests.

* Prevent import `timm` if the `TIMMBackbone` is not used.

* Fix Windows CI.

* Move `init_model` to `mmcls.apis.hub`, and support pass nn.Module to all
model components.

* Fix requirements

* Rename `hub.py` to `model.py` and add unit tests.
pull/1243/head
Ma Zerun 2022-12-06 17:00:22 +08:00 committed by GitHub
parent d990982fc0
commit c127c474b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 395 additions and 85 deletions

View File

@ -13,12 +13,17 @@ These are some high-level APIs for classification tasks.
:local:
:backlinks: top
Model
------------------
.. autofunction:: list_models
.. autofunction:: get_model
.. autofunction:: init_model
Inference
------------------
.. autosummary::
:toctree: generated
:nosignatures:
init_model
inference_model
.. autofunction:: inference_model

View File

@ -3,6 +3,7 @@ import mmcv
import mmengine
from mmengine.utils import digit_version
from .apis import * # noqa: F401, F403
from .version import __version__
mmcv_minimum_version = '2.0.0rc1'

View File

@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_model, init_model
from .inference import inference_model
from .model import ModelHub, get_model, init_model, list_models
__all__ = ['init_model', 'inference_model']
__all__ = [
'init_model', 'inference_model', 'list_models', 'get_model', 'ModelHub'
]

View File

@ -1,62 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Union
import numpy as np
import torch
from mmengine.config import Config
from mmengine.dataset import Compose, pseudo_collate
from mmengine.runner import load_checkpoint
from mmengine.dataset import Compose, default_collate
from mmengine.model import BaseModel
from mmengine.registry import DefaultScope
from mmcls.models import build_classifier
from mmcls.utils import register_all_modules
import mmcls.datasets # noqa: F401
def init_model(config, checkpoint=None, device='cuda:0', options=None):
"""Initialize a classifier from config file.
Args:
config (str or :obj:`mmengine.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
options (dict): Options to override some settings in the used config.
Returns:
nn.Module: The constructed classifier.
"""
register_all_modules()
if isinstance(config, str):
config = Config.fromfile(config)
elif not isinstance(config, Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if options is not None:
config.merge_from_dict(options)
config.model.setdefault('data_preprocessor',
config.get('data_preprocessor', None))
model = build_classifier(config.model)
if checkpoint is not None:
# Mapping the weights to GPU may cause unexpected video memory leak
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if 'dataset_meta' in checkpoint.get('meta', {}):
# mmcls 1.x
model.CLASSES = checkpoint['meta']['dataset_meta']['classes']
elif 'CLASSES' in checkpoint.get('meta', {}):
# mmcls < 1.x
model.CLASSES = checkpoint['meta']['CLASSES']
else:
from mmcls.datasets.categories import IMAGENET_CATEGORIES
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use imagenet by default.')
model.CLASSES = IMAGENET_CATEGORIES
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
def inference_model(model, img):
def inference_model(model: BaseModel, img: Union[str, np.ndarray]):
"""Inference image(s) with the classifier.
Args:
@ -67,7 +21,6 @@ def inference_model(model, img):
result (dict): The classification results that contains
`class_name`, `pred_label` and `pred_score`.
"""
register_all_modules()
cfg = model.cfg
# build the data pipeline
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
@ -79,9 +32,10 @@ def inference_model(model, img):
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
test_pipeline_cfg.pop(0)
data = dict(img=img)
test_pipeline = Compose(test_pipeline_cfg)
with DefaultScope.overwrite_default_scope('mmcls'):
test_pipeline = Compose(test_pipeline_cfg)
data = test_pipeline(data)
data = pseudo_collate([data])
data = default_collate([data])
# forward the model
with torch.no_grad():

220
mmcls/apis/model.py 100644
View File

@ -0,0 +1,220 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import fnmatch
import os.path as osp
import warnings
from os import PathLike
from pathlib import Path
from typing import List, Union
from mmengine.config import Config
from mmengine.runner import load_checkpoint
from mmengine.utils import get_installed_path
from modelindex.load_model_index import load
from modelindex.models.Model import Model
import mmcls.models # noqa: F401
from mmcls.registry import MODELS
class ModelHub:
"""A hub to host the meta information of all pre-defined models."""
_models_dict = {}
@classmethod
def register_model_index(cls,
model_index_path: Union[str, PathLike],
config_prefix: Union[str, PathLike, None] = None):
"""Parse the model-index file and register all models.
Args:
model_index_path (str | PathLike): The path of the model-index
file.
config_prefix (str | PathLike | None): The prefix of all config
file paths in the model-index file.
"""
model_index = load(str(model_index_path))
model_index.build_models_with_collections()
for metainfo in model_index.models:
model_name = metainfo.name.lower()
if metainfo.name in cls._models_dict:
raise ValueError(
'The model name {} is conflict in {} and {}.'.format(
model_name, osp.abspath(metainfo.filepath),
osp.abspath(cls._models_dict[model_name].filepath)))
metainfo.config = cls._expand_config_path(metainfo, config_prefix)
cls._models_dict[model_name] = metainfo
@classmethod
def get(cls, model_name):
"""Get the model's metainfo by the model name.
Args:
model_name (str): The name of model.
Returns:
modelindex.models.Model: The metainfo of the specified model.
"""
# lazy load config
metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
if metainfo is None:
raise ValueError(f'Failed to find model {model_name}.')
if isinstance(metainfo.config, str):
metainfo.config = Config.fromfile(metainfo.config)
return metainfo
@staticmethod
def _expand_config_path(metainfo: Model,
config_prefix: Union[str, PathLike] = None):
if config_prefix is None:
config_prefix = osp.dirname(metainfo.filepath)
if metainfo.config is None or osp.isabs(metainfo.config):
config_path: str = metainfo.config
else:
config_path = osp.abspath(osp.join(config_prefix, metainfo.config))
return config_path
# register models in mmcls
mmcls_root = Path(get_installed_path('mmcls'))
model_index_path = mmcls_root / '.mim' / 'model-index.yml'
ModelHub.register_model_index(
model_index_path, config_prefix=mmcls_root / '.mim')
def init_model(config, checkpoint=None, device=None, **kwargs):
"""Initialize a classifier from config file.
Args:
config (str | :obj:`mmengine.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
device (str | torch.device | None): Transfer the model to the target
device. Defaults to None.
**kwargs: Other keyword arguments of the model config.
Returns:
nn.Module: The constructed model.
"""
if isinstance(config, (str, PathLike)):
config = Config.fromfile(config)
elif not isinstance(config, Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if kwargs:
config.merge_from_dict({'model': kwargs})
config.model.setdefault('data_preprocessor',
config.get('data_preprocessor', None))
model = MODELS.build(config.model)
if checkpoint is not None:
# Mapping the weights to GPU may cause unexpected video memory leak
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if 'dataset_meta' in checkpoint.get('meta', {}):
# mmcls 1.x
model.CLASSES = checkpoint['meta']['dataset_meta']['classes']
elif 'CLASSES' in checkpoint.get('meta', {}):
# mmcls < 1.x
model.CLASSES = checkpoint['meta']['CLASSES']
else:
from mmcls.datasets.categories import IMAGENET_CATEGORIES
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use imagenet by default.')
model.CLASSES = IMAGENET_CATEGORIES
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
def get_model(model_name, pretrained=False, device=None, **kwargs):
"""Get a pre-defined model by the name of model.
Args:
model_name (str): The name of model.
pretrained (bool | str): If True, load the pre-defined pretrained
weights. If a string, load the weights from it. Defaults to False.
device (str | torch.device | None): Transfer the model to the target
device. Defaults to None.
**kwargs: Other keyword arguments of the model config.
Returns:
mmengine.model.BaseModel: The result model.
Examples:
Get a ResNet-50 model and extract images feature:
>>> import torch
>>> from mmcls import get_model
>>> inputs = torch.rand(16, 3, 224, 224)
>>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
>>> feats = model.extract_feat(inputs)
>>> for feat in feats:
... print(feat.shape)
torch.Size([16, 256])
torch.Size([16, 512])
torch.Size([16, 1024])
torch.Size([16, 2048])
Get Swin-Transformer model with pre-trained weights and inference:
>>> from mmcls import get_model, inference_model
>>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
>>> result = inference_model(model, 'demo/demo.JPEG')
>>> print(result['pred_class'])
'sea snake'
""" # noqa: E501
metainfo = ModelHub.get(model_name)
if isinstance(pretrained, str):
ckpt = pretrained
elif pretrained:
if metainfo.weights is None:
raise ValueError(
f"The model {model_name} doesn't have pretrained weights.")
ckpt = metainfo.weights
else:
ckpt = None
if metainfo.config is None:
raise ValueError(
f"The model {model_name} doesn't support building by now.")
model = init_model(metainfo.config, ckpt, device=device, **kwargs)
return model
def list_models(pattern=None) -> List[str]:
"""List all models available in MMClassification.
Args:
pattern (str | None): A wildcard pattern to match model names.
Returns:
List[str]: a list of model names.
Examples:
List all models:
>>> from mmcls import list_models
>>> print(list_models())
List ResNet-50 models on ImageNet-1k dataset:
>>> from mmcls import list_models
>>> print(list_models('resnet*in1k'))
['resnet50_8xb32_in1k',
'resnet50_8xb32-fp16_in1k',
'resnet50_8xb256-rsb-a1-600e_in1k',
'resnet50_8xb256-rsb-a2-300e_in1k',
'resnet50_8xb256-rsb-a3-100e_in1k']
"""
if pattern is None:
return sorted(list(ModelHub._models_dict.keys()))
# Always match keys with any postfix.
matches = fnmatch.filter(ModelHub._models_dict.keys(), pattern + '*')
return matches

View File

@ -1,9 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
try:
import timm
except ImportError:
timm = None
import warnings
from mmengine.logging import MMLogger
@ -68,10 +63,13 @@ class TIMMBackbone(BaseBackbone):
in_channels=3,
init_cfg=None,
**kwargs):
if timm is None:
raise RuntimeError(
try:
import timm
except ImportError:
raise ImportError(
'Failed to import timm. Please run "pip install timm". '
'"pip install dataclasses" may also be needed for Python 3.6.')
if not isinstance(pretrained, bool):
raise TypeError('pretrained must be bool, not str for model path')
if features_only and checkpoint_path:

View File

@ -4,6 +4,7 @@ from collections import OrderedDict
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcls.registry import MODELS
@ -96,7 +97,9 @@ class HuggingFaceClassifier(BaseClassifier):
**kwargs)
self.model = AutoModelForImageClassification.from_config(config)
self.loss_module = MODELS.build(loss)
if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
self.loss_module = loss
self.with_cp = with_cp
if self.with_cp:

View File

@ -2,6 +2,7 @@
from typing import List, Optional
import torch
import torch.nn as nn
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
@ -61,13 +62,16 @@ class ImageClassifier(BaseClassifier):
super(ImageClassifier, self).__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
self.backbone = MODELS.build(backbone)
if not isinstance(backbone, nn.Module):
backbone = MODELS.build(backbone)
if neck is not None and not isinstance(neck, nn.Module):
neck = MODELS.build(neck)
if head is not None and not isinstance(head, nn.Module):
head = MODELS.build(head)
if neck is not None:
self.neck = MODELS.build(neck)
if head is not None:
self.head = MODELS.build(head)
self.backbone = backbone
self.neck = neck
self.head = head
def forward(self,
inputs: torch.Tensor,

View File

@ -4,6 +4,7 @@ from collections import OrderedDict
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcls.registry import MODELS
@ -79,7 +80,10 @@ class TimmClassifier(BaseClassifier):
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
from timm.models import create_model
self.model = create_model(*args, **kwargs)
self.loss_module = MODELS.build(loss)
if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
self.loss_module = loss
self.with_cp = with_cp
if self.with_cp:

View File

@ -2,6 +2,7 @@
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcls.evaluation.metrics import Accuracy
@ -34,7 +35,9 @@ class ClsHead(BaseHead):
super(ClsHead, self).__init__(init_cfg=init_cfg)
self.topk = topk
self.loss_module = MODELS.build(loss)
if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
self.loss_module = loss
self.cal_acc = cal_acc
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:

View File

@ -157,7 +157,9 @@ class ArcFaceClsHead(ClsHead):
init_cfg: Optional[dict] = None):
super(ArcFaceClsHead, self).__init__(init_cfg=init_cfg)
self.loss_module = MODELS.build(loss)
if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
self.loss_module = loss
assert num_subcenters >= 1 and num_classes >= 0
self.in_channels = in_channels

View File

@ -2,6 +2,7 @@
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from mmengine.structures import LabelData
from mmcls.registry import MODELS
@ -36,7 +37,9 @@ class MultiLabelClsHead(BaseHead):
init_cfg: Optional[dict] = None):
super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg)
self.loss_module = MODELS.build(loss)
if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
self.loss_module = loss
if thr is None and topk is None:
thr = 0.5

View File

@ -3,6 +3,7 @@ from typing import Callable, List, Optional, Union
import mmengine.dist as dist
import torch
import torch.nn as nn
from mmengine.runner import Runner
from torch.utils.data import DataLoader
@ -75,8 +76,13 @@ class ImageToImageRetriever(BaseRetriever):
super(ImageToImageRetriever, self).__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
self.image_encoder = MODELS.build(image_encoder)
self.head = MODELS.build(head) if head else None
if not isinstance(image_encoder, nn.Module):
image_encoder = MODELS.build(image_encoder)
if head is not None and not isinstance(head, nn.Module):
head = MODELS.build(head)
self.image_encoder = image_encoder
self.head = head
self.similarity = similarity_fn

View File

@ -4,6 +4,6 @@ interrogate
isort==4.3.21
mmdet>=3.0.0rc0
pytest
sklearn
scikit-learn
xdoctest >= 0.10.0
yapf

View File

@ -0,0 +1,13 @@
Models:
- Name: test_model
Metadata:
FLOPs: 319000000
Parameters: 3500000
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 71.86
Top 5 Accuracy: 90.42
Task: Image Classification
Weights: test_weight.pth
Config: test_config.py

View File

@ -0,0 +1,91 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from unittest import TestCase
from unittest.mock import patch
from mmengine import Config
from mmcls.apis import ModelHub, get_model, init_model, list_models
from mmcls.models import ImageClassifier, MobileNetV2
class TestModelHub(TestCase):
def test_mmcls_models(self):
self.assertIn('resnet18_8xb32_in1k', ModelHub._models_dict)
def test_register_model_index(self):
model_index_path = osp.join(osp.dirname(__file__), '../data/meta.yml')
ModelHub.register_model_index(model_index_path)
self.assertIn('test_model', ModelHub._models_dict)
self.assertEqual(
ModelHub._models_dict['test_model'].config,
osp.abspath(
osp.join(osp.dirname(model_index_path), 'test_config.py')))
with self.assertRaisesRegex(ValueError, 'meta.yml'):
# test name conflict
ModelHub.register_model_index(model_index_path)
# test specify config prefix
del ModelHub._models_dict['test_model']
ModelHub.register_model_index(
model_index_path, config_prefix='configs')
self.assertEqual(ModelHub._models_dict['test_model'].config,
osp.abspath(osp.join('configs', 'test_config.py')))
def test_get_model(self):
metainfo = ModelHub.get('resnet18_8xb32_in1k')
self.assertIsInstance(metainfo.weights, str)
self.assertIsInstance(metainfo.config, Config)
class TestHubAPIs(TestCase):
def test_list_models(self):
models_names = list_models()
self.assertIsInstance(models_names, list)
models_names = list_models(pattern='swin*in1k')
for model_name in models_names:
self.assertTrue(
model_name.startswith('swin') and 'in1k' in model_name)
def test_get_model(self):
model = get_model('mobilenet-v2_8xb32_in1k')
self.assertIsInstance(model, ImageClassifier)
self.assertIsInstance(model.backbone, MobileNetV2)
with patch('mmcls.apis.model.init_model') as mock:
model = get_model('mobilenet-v2_8xb32_in1k', pretrained=True)
model = get_model('mobilenet-v2_8xb32_in1k', pretrained='test.pth')
weight = mock.call_args_list[0][0][1]
self.assertIn('https', weight)
weight = mock.call_args_list[1][0][1]
self.assertEqual('test.pth', weight)
with self.assertRaisesRegex(ValueError, 'Failed to find'):
get_model('unknown-model')
with self.assertRaisesRegex(ValueError, "doesn't support building"):
get_model('swinv2-base-w12_3rdparty_in21k-192px')
def test_init_model(self):
# test init from config object
cfg = ModelHub.get('mobilenet-v2_8xb32_in1k').config
model = init_model(cfg)
self.assertIsInstance(model, ImageClassifier)
self.assertIsInstance(model.backbone, MobileNetV2)
# test init from config file
cfg = ModelHub._models_dict['mobilenet-v2_8xb32_in1k'].config
self.assertIsInstance(cfg, str)
model = init_model(cfg)
self.assertIsInstance(model, ImageClassifier)
self.assertIsInstance(model.backbone, MobileNetV2)
# test modify configs of the model
model = init_model(cfg, head=dict(num_classes=10))
self.assertEqual(model.head.num_classes, 10)