409 lines
16 KiB
Python
409 lines
16 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import fnmatch
|
|
import os.path as osp
|
|
import re
|
|
import warnings
|
|
from os import PathLike
|
|
from pathlib import Path
|
|
from typing import List, Tuple, Union
|
|
|
|
from mmengine.config import Config
|
|
from modelindex.load_model_index import load
|
|
from modelindex.models.Model import Model
|
|
|
|
|
|
class ModelHub:
|
|
"""A hub to host the meta information of all pre-defined models."""
|
|
_models_dict = {}
|
|
__mmpretrain_registered = False
|
|
|
|
@classmethod
|
|
def register_model_index(cls,
|
|
model_index_path: Union[str, PathLike],
|
|
config_prefix: Union[str, PathLike, None] = None):
|
|
"""Parse the model-index file and register all models.
|
|
|
|
Args:
|
|
model_index_path (str | PathLike): The path of the model-index
|
|
file.
|
|
config_prefix (str | PathLike | None): The prefix of all config
|
|
file paths in the model-index file.
|
|
"""
|
|
model_index = load(str(model_index_path))
|
|
model_index.build_models_with_collections()
|
|
|
|
for metainfo in model_index.models:
|
|
model_name = metainfo.name.lower()
|
|
if metainfo.name in cls._models_dict:
|
|
raise ValueError(
|
|
'The model name {} is conflict in {} and {}.'.format(
|
|
model_name, osp.abspath(metainfo.filepath),
|
|
osp.abspath(cls._models_dict[model_name].filepath)))
|
|
metainfo.config = cls._expand_config_path(metainfo, config_prefix)
|
|
cls._models_dict[model_name] = metainfo
|
|
|
|
@classmethod
|
|
def get(cls, model_name):
|
|
"""Get the model's metainfo by the model name.
|
|
|
|
Args:
|
|
model_name (str): The name of model.
|
|
|
|
Returns:
|
|
modelindex.models.Model: The metainfo of the specified model.
|
|
"""
|
|
cls._register_mmpretrain_models()
|
|
# lazy load config
|
|
metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
|
|
if metainfo is None:
|
|
raise ValueError(
|
|
f'Failed to find model "{model_name}". please use '
|
|
'`mmpretrain.list_models` to get all available names.')
|
|
if isinstance(metainfo.config, str):
|
|
metainfo.config = Config.fromfile(metainfo.config)
|
|
return metainfo
|
|
|
|
@staticmethod
|
|
def _expand_config_path(metainfo: Model,
|
|
config_prefix: Union[str, PathLike] = None):
|
|
if config_prefix is None:
|
|
config_prefix = osp.dirname(metainfo.filepath)
|
|
|
|
if metainfo.config is None or osp.isabs(metainfo.config):
|
|
config_path: str = metainfo.config
|
|
else:
|
|
config_path = osp.abspath(osp.join(config_prefix, metainfo.config))
|
|
|
|
return config_path
|
|
|
|
@classmethod
|
|
def _register_mmpretrain_models(cls):
|
|
# register models in mmpretrain
|
|
if not cls.__mmpretrain_registered:
|
|
from importlib_metadata import distribution
|
|
root = distribution('mmpretrain').locate_file('mmpretrain')
|
|
model_index_path = root / '.mim' / 'model-index.yml'
|
|
ModelHub.register_model_index(
|
|
model_index_path, config_prefix=root / '.mim')
|
|
cls.__mmpretrain_registered = True
|
|
|
|
@classmethod
|
|
def has(cls, model_name):
|
|
"""Whether a model name is in the ModelHub."""
|
|
return model_name in cls._models_dict
|
|
|
|
|
|
def get_model(model: Union[str, Config],
|
|
pretrained: Union[str, bool] = False,
|
|
device=None,
|
|
device_map=None,
|
|
offload_folder=None,
|
|
url_mapping: Tuple[str, str] = None,
|
|
**kwargs):
|
|
"""Get a pre-defined model or create a model from config.
|
|
|
|
Args:
|
|
model (str | Config): The name of model, the config file path or a
|
|
config instance.
|
|
pretrained (bool | str): When use name to specify model, you can
|
|
use ``True`` to load the pre-defined pretrained weights. And you
|
|
can also use a string to specify the path or link of weights to
|
|
load. Defaults to False.
|
|
device (str | torch.device | None): Transfer the model to the target
|
|
device. Defaults to None.
|
|
device_map (str | dict | None): A map that specifies where each
|
|
submodule should go. It doesn't need to be refined to each
|
|
parameter/buffer name, once a given module name is inside, every
|
|
submodule of it will be sent to the same device. You can use
|
|
`device_map="auto"` to automatically generate the device map.
|
|
Defaults to None.
|
|
offload_folder (str | None): If the `device_map` contains any value
|
|
`"disk"`, the folder where we will offload weights.
|
|
url_mapping (Tuple[str, str], optional): The mapping of pretrained
|
|
checkpoint link. For example, load checkpoint from a local dir
|
|
instead of download by ``('https://.*/', './checkpoint')``.
|
|
Defaults to None.
|
|
**kwargs: Other keyword arguments of the model config.
|
|
|
|
Returns:
|
|
mmengine.model.BaseModel: The result model.
|
|
|
|
Examples:
|
|
Get a ResNet-50 model and extract images feature:
|
|
|
|
>>> import torch
|
|
>>> from mmpretrain import get_model
|
|
>>> inputs = torch.rand(16, 3, 224, 224)
|
|
>>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
|
|
>>> feats = model.extract_feat(inputs)
|
|
>>> for feat in feats:
|
|
... print(feat.shape)
|
|
torch.Size([16, 256])
|
|
torch.Size([16, 512])
|
|
torch.Size([16, 1024])
|
|
torch.Size([16, 2048])
|
|
|
|
Get Swin-Transformer model with pre-trained weights and inference:
|
|
|
|
>>> from mmpretrain import get_model, inference_model
|
|
>>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
|
|
>>> result = inference_model(model, 'demo/demo.JPEG')
|
|
>>> print(result['pred_class'])
|
|
'sea snake'
|
|
""" # noqa: E501
|
|
if device_map is not None:
|
|
from .utils import dispatch_model
|
|
dispatch_model._verify_require()
|
|
|
|
metainfo = None
|
|
if isinstance(model, Config):
|
|
config = copy.deepcopy(model)
|
|
if pretrained is True and 'load_from' in config:
|
|
pretrained = config.load_from
|
|
elif isinstance(model, (str, PathLike)) and Path(model).suffix == '.py':
|
|
config = Config.fromfile(model)
|
|
if pretrained is True and 'load_from' in config:
|
|
pretrained = config.load_from
|
|
elif isinstance(model, str):
|
|
metainfo = ModelHub.get(model)
|
|
config = metainfo.config
|
|
if pretrained is True and metainfo.weights is not None:
|
|
pretrained = metainfo.weights
|
|
else:
|
|
raise TypeError('model must be a name, a path or a Config object, '
|
|
f'but got {type(config)}')
|
|
|
|
if pretrained is True:
|
|
warnings.warn('Unable to find pre-defined checkpoint of the model.')
|
|
pretrained = None
|
|
elif pretrained is False:
|
|
pretrained = None
|
|
|
|
if kwargs:
|
|
config.merge_from_dict({'model': kwargs})
|
|
config.model.setdefault('data_preprocessor',
|
|
config.get('data_preprocessor', None))
|
|
|
|
from mmengine.registry import DefaultScope
|
|
|
|
from mmpretrain.registry import MODELS
|
|
with DefaultScope.overwrite_default_scope('mmpretrain'):
|
|
model = MODELS.build(config.model)
|
|
|
|
dataset_meta = {}
|
|
if pretrained:
|
|
# Mapping the weights to GPU may cause unexpected video memory leak
|
|
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
|
|
from mmengine.runner import load_checkpoint
|
|
if url_mapping is not None:
|
|
pretrained = re.sub(url_mapping[0], url_mapping[1], pretrained)
|
|
checkpoint = load_checkpoint(model, pretrained, map_location='cpu')
|
|
if 'dataset_meta' in checkpoint.get('meta', {}):
|
|
# mmpretrain 1.x
|
|
dataset_meta = checkpoint['meta']['dataset_meta']
|
|
elif 'CLASSES' in checkpoint.get('meta', {}):
|
|
# mmcls 0.x
|
|
dataset_meta = {'classes': checkpoint['meta']['CLASSES']}
|
|
|
|
if len(dataset_meta) == 0 and 'test_dataloader' in config:
|
|
from mmpretrain.registry import DATASETS
|
|
dataset_class = DATASETS.get(config.test_dataloader.dataset.type)
|
|
dataset_meta = getattr(dataset_class, 'METAINFO', {})
|
|
|
|
if device_map is not None:
|
|
model = dispatch_model(
|
|
model, device_map=device_map, offload_folder=offload_folder)
|
|
elif device is not None:
|
|
model.to(device)
|
|
|
|
model._dataset_meta = dataset_meta # save the dataset meta
|
|
model._config = config # save the config in the model
|
|
model._metainfo = metainfo # save the metainfo in the model
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def init_model(config, checkpoint=None, device=None, **kwargs):
|
|
"""Initialize a classifier from config file (deprecated).
|
|
|
|
It's only for compatibility, please use :func:`get_model` instead.
|
|
|
|
Args:
|
|
config (str | :obj:`mmengine.Config`): Config file path or the config
|
|
object.
|
|
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
|
will not load any weights.
|
|
device (str | torch.device | None): Transfer the model to the target
|
|
device. Defaults to None.
|
|
**kwargs: Other keyword arguments of the model config.
|
|
|
|
Returns:
|
|
nn.Module: The constructed model.
|
|
"""
|
|
return get_model(config, checkpoint, device, **kwargs)
|
|
|
|
|
|
def list_models(pattern=None, exclude_patterns=None, task=None) -> List[str]:
|
|
"""List all models available in MMPretrain.
|
|
|
|
Args:
|
|
pattern (str | None): A wildcard pattern to match model names.
|
|
Defaults to None.
|
|
exclude_patterns (list | None): A list of wildcard patterns to
|
|
exclude names from the matched names. Defaults to None.
|
|
task (str | none): The evaluation task of the model.
|
|
|
|
Returns:
|
|
List[str]: a list of model names.
|
|
|
|
Examples:
|
|
List all models:
|
|
|
|
>>> from mmpretrain import list_models
|
|
>>> list_models()
|
|
|
|
List ResNet-50 models on ImageNet-1k dataset:
|
|
|
|
>>> from mmpretrain import list_models
|
|
>>> list_models('resnet*in1k')
|
|
['resnet50_8xb32_in1k',
|
|
'resnet50_8xb32-fp16_in1k',
|
|
'resnet50_8xb256-rsb-a1-600e_in1k',
|
|
'resnet50_8xb256-rsb-a2-300e_in1k',
|
|
'resnet50_8xb256-rsb-a3-100e_in1k']
|
|
|
|
List Swin-Transformer models trained from stratch and exclude
|
|
Swin-Transformer-V2 models:
|
|
|
|
>>> from mmpretrain import list_models
|
|
>>> list_models('swin', exclude_patterns=['swinv2', '*-pre'])
|
|
['swin-base_16xb64_in1k',
|
|
'swin-base_3rdparty_in1k',
|
|
'swin-base_3rdparty_in1k-384',
|
|
'swin-large_8xb8_cub-384px',
|
|
'swin-small_16xb64_in1k',
|
|
'swin-small_3rdparty_in1k',
|
|
'swin-tiny_16xb64_in1k',
|
|
'swin-tiny_3rdparty_in1k']
|
|
|
|
List all EVA models for image classification task.
|
|
|
|
>>> from mmpretrain import list_models
|
|
>>> list_models('eva', task='Image Classification')
|
|
['eva-g-p14_30m-in21k-pre_3rdparty_in1k-336px',
|
|
'eva-g-p14_30m-in21k-pre_3rdparty_in1k-560px',
|
|
'eva-l-p14_mim-in21k-pre_3rdparty_in1k-196px',
|
|
'eva-l-p14_mim-in21k-pre_3rdparty_in1k-336px',
|
|
'eva-l-p14_mim-pre_3rdparty_in1k-196px',
|
|
'eva-l-p14_mim-pre_3rdparty_in1k-336px']
|
|
"""
|
|
ModelHub._register_mmpretrain_models()
|
|
matches = set(ModelHub._models_dict.keys())
|
|
|
|
if pattern is not None:
|
|
# Always match keys with any postfix.
|
|
matches = set(fnmatch.filter(matches, pattern + '*'))
|
|
|
|
exclude_patterns = exclude_patterns or []
|
|
for exclude_pattern in exclude_patterns:
|
|
exclude = set(fnmatch.filter(matches, exclude_pattern + '*'))
|
|
matches = matches - exclude
|
|
|
|
if task is not None:
|
|
task_matches = []
|
|
for key in matches:
|
|
metainfo = ModelHub._models_dict[key]
|
|
if metainfo.results is None and task == 'null':
|
|
task_matches.append(key)
|
|
elif metainfo.results is None:
|
|
continue
|
|
elif task in [result.task for result in metainfo.results]:
|
|
task_matches.append(key)
|
|
matches = task_matches
|
|
|
|
return sorted(list(matches))
|
|
|
|
|
|
def inference_model(model, *args, **kwargs):
|
|
"""Inference an image with the inferencer.
|
|
|
|
Automatically select inferencer to inference according to the type of
|
|
model. It's a shortcut for a quick start, and for advanced usage, please
|
|
use the correspondding inferencer class.
|
|
|
|
Here is the mapping from task to inferencer:
|
|
|
|
- Image Classification: :class:`ImageClassificationInferencer`
|
|
- Image Retrieval: :class:`ImageRetrievalInferencer`
|
|
- Image Caption: :class:`ImageCaptionInferencer`
|
|
- Visual Question Answering: :class:`VisualQuestionAnsweringInferencer`
|
|
- Visual Grounding: :class:`VisualGroundingInferencer`
|
|
- Text-To-Image Retrieval: :class:`TextToImageRetrievalInferencer`
|
|
- Image-To-Text Retrieval: :class:`ImageToTextRetrievalInferencer`
|
|
- NLVR: :class:`NLVRInferencer`
|
|
|
|
Args:
|
|
model (BaseModel | str | Config): The loaded model, the model
|
|
name or the config of the model.
|
|
*args: Positional arguments to call the inferencer.
|
|
**kwargs: Other keyword arguments to initialize and call the
|
|
correspondding inferencer.
|
|
|
|
Returns:
|
|
result (dict): The inference results.
|
|
""" # noqa: E501
|
|
from mmengine.model import BaseModel
|
|
|
|
if isinstance(model, BaseModel):
|
|
metainfo = getattr(model, '_metainfo', None)
|
|
else:
|
|
metainfo = ModelHub.get(model)
|
|
|
|
from inspect import signature
|
|
|
|
from .image_caption import ImageCaptionInferencer
|
|
from .image_classification import ImageClassificationInferencer
|
|
from .image_retrieval import ImageRetrievalInferencer
|
|
from .multimodal_retrieval import (ImageToTextRetrievalInferencer,
|
|
TextToImageRetrievalInferencer)
|
|
from .nlvr import NLVRInferencer
|
|
from .visual_grounding import VisualGroundingInferencer
|
|
from .visual_question_answering import VisualQuestionAnsweringInferencer
|
|
task_mapping = {
|
|
'Image Classification': ImageClassificationInferencer,
|
|
'Image Retrieval': ImageRetrievalInferencer,
|
|
'Image Caption': ImageCaptionInferencer,
|
|
'Visual Question Answering': VisualQuestionAnsweringInferencer,
|
|
'Visual Grounding': VisualGroundingInferencer,
|
|
'Text-To-Image Retrieval': TextToImageRetrievalInferencer,
|
|
'Image-To-Text Retrieval': ImageToTextRetrievalInferencer,
|
|
'NLVR': NLVRInferencer,
|
|
}
|
|
|
|
inferencer_type = None
|
|
|
|
if metainfo is not None and metainfo.results is not None:
|
|
tasks = set(result.task for result in metainfo.results)
|
|
inferencer_type = [
|
|
task_mapping.get(task) for task in tasks if task in task_mapping
|
|
]
|
|
if len(inferencer_type) > 1:
|
|
inferencer_names = [cls.__name__ for cls in inferencer_type]
|
|
warnings.warn('The model supports multiple tasks, auto select '
|
|
f'{inferencer_names[0]}, you can also use other '
|
|
f'inferencer {inferencer_names} directly.')
|
|
inferencer_type = inferencer_type[0]
|
|
|
|
if inferencer_type is None:
|
|
raise NotImplementedError('No available inferencer for the model')
|
|
|
|
init_kwargs = {
|
|
k: kwargs.pop(k)
|
|
for k in list(kwargs)
|
|
if k in signature(inferencer_type).parameters.keys()
|
|
}
|
|
|
|
inferencer = inferencer_type(model, **init_kwargs)
|
|
return inferencer(*args, **kwargs)[0]
|