mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add BaseInferencer (#874)
* [Feature] Add BaseInferencer (#773) * Update BaseInferencer * Fix ci * Fix CI and rename iferencer to infer * Fix CI * Add renamed file * Add test file * Adjust interface sequence * refine preprocess * Update unit test Update unit test * Update unit test * Fix unit test * Fix as comment * Minor refine * Fix docstring and support load image from different backend * Support load collate_fn from downstream repos, refine dispatch * Minor refine * Fix lint * refine grammar * Remove FileClient * Refine docstring * add rich * Add list_models * Add list_models * Remove backend args * Minor refine * Fix typos in docs and type hints (#787) * [Fix] Add _inputs_to_list (#795) * Add preprocess inputs * Add type hint * update api/infer in index.rst * rename preprocess_inputs to _inputs_to_list * Fix doc format * Update infer.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * [Fix] Fix alias type (#801) * [Enhance] Support loading model config from checkpoint (#864) * first commit * [Enhance] Support build model from weight * minor refine * Fix type hint * refine comments * Update docstring * refine as comment * Add method * Refine docstring * Fix as comment * refine comments * Refine warning message * Fix unit test and refine comments * replace MODULE2PACKAGE to MODULE2PAKCAGE * Fix typo and syntax error in docstring Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Tong Gao <gaotongxiao@gmail.com>
This commit is contained in:
parent
ad590e45a2
commit
2d8f2be375
14
docs/en/api/infer.rst
Normal file
14
docs/en/api/infer.rst
Normal file
@ -0,0 +1,14 @@
|
||||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
|
||||
mmengine.infer
|
||||
===================================
|
||||
|
||||
.. currentmodule:: mmengine.infer
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
BaseInferencer
|
@ -82,6 +82,7 @@ You can switch between Chinese and English documents in the lower-left corner of
|
||||
mmengine.evaluator <api/evaluator>
|
||||
mmengine.structures <api/structures>
|
||||
mmengine.dataset <api/dataset>
|
||||
mmengine.infer <api/infer>
|
||||
mmengine.device <api/device>
|
||||
mmengine.hub <api/hub>
|
||||
mmengine.logging <api/logging>
|
||||
|
14
docs/zh_cn/api/infer.rst
Normal file
14
docs/zh_cn/api/infer.rst
Normal file
@ -0,0 +1,14 @@
|
||||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
|
||||
mmengine.infer
|
||||
===================================
|
||||
|
||||
.. currentmodule:: mmengine.infer
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
BaseInferencer
|
@ -82,6 +82,7 @@
|
||||
mmengine.evaluator <api/evaluator>
|
||||
mmengine.structures <api/structures>
|
||||
mmengine.dataset <api/dataset>
|
||||
mmengine.infer <api/infer>
|
||||
mmengine.device <api/device>
|
||||
mmengine.hub <api/hub>
|
||||
mmengine.logging <api/logging>
|
||||
|
4
mmengine/infer/__init__.py
Normal file
4
mmengine/infer/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .infer import BaseInferencer
|
||||
|
||||
__all__ = ['BaseInferencer']
|
648
mmengine/infer/infer.py
Normal file
648
mmengine/infer/infer.py
Normal file
@ -0,0 +1,648 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
import re
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import (Any, Callable, Dict, Iterable, List, Optional, Sequence,
|
||||
Tuple, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from rich.progress import track
|
||||
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.config.utils import MODULE2PACKAGE
|
||||
from mmengine.dataset import COLLATE_FUNCTIONS, pseudo_collate
|
||||
from mmengine.device import get_device
|
||||
from mmengine.fileio import (get_file_backend, isdir, join_path,
|
||||
list_dir_or_file, load)
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import MODELS, VISUALIZERS, DefaultScope
|
||||
from mmengine.runner.checkpoint import (_load_checkpoint,
|
||||
_load_checkpoint_to_model)
|
||||
from mmengine.structures import InstanceData
|
||||
from mmengine.utils import get_installed_path, is_installed
|
||||
from mmengine.visualization import Visualizer
|
||||
|
||||
InstanceList = List[InstanceData]
|
||||
InputType = Union[str, np.ndarray, torch.Tensor]
|
||||
InputsType = Union[InputType, Sequence[InputType]]
|
||||
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
|
||||
ResType = Union[Dict, List[Dict]]
|
||||
ConfigType = Union[Config, ConfigDict]
|
||||
ModelType = Union[dict, ConfigType, str]
|
||||
|
||||
|
||||
class InferencerMeta(ABCMeta):
|
||||
"""Check the legality of the inferencer.
|
||||
|
||||
All Inferencers should not define duplicated keys for
|
||||
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` and
|
||||
``postprocess_kwargs``.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert isinstance(self.preprocess_kwargs, set)
|
||||
assert isinstance(self.forward_kwargs, set)
|
||||
assert isinstance(self.visualize_kwargs, set)
|
||||
assert isinstance(self.postprocess_kwargs, set)
|
||||
|
||||
all_kwargs = (
|
||||
self.preprocess_kwargs | self.forward_kwargs
|
||||
| self.visualize_kwargs | self.postprocess_kwargs)
|
||||
|
||||
assert len(all_kwargs) == (
|
||||
len(self.preprocess_kwargs) + len(self.forward_kwargs) +
|
||||
len(self.visualize_kwargs) + len(self.postprocess_kwargs)), (
|
||||
f'Class define error! {self.__name__} should not '
|
||||
'define duplicated keys for `preprocess_kwargs`, '
|
||||
'`forward_kwargs`, `visualize_kwargs` and '
|
||||
'`postprocess_kwargs` are not allowed.')
|
||||
|
||||
|
||||
class BaseInferencer(metaclass=InferencerMeta):
|
||||
"""Base inferencer for downstream tasks.
|
||||
|
||||
The BaseInferencer provides the standard workflow for inference as follows:
|
||||
|
||||
1. Preprocess the input data by :meth:`preprocess`.
|
||||
2. Forward the data to the model by :meth:`forward`. ``BaseInferencer``
|
||||
assumes the model inherits from :class:`mmengine.models.BaseModel` and
|
||||
will call `model.test_step` in :meth:`forward` by default.
|
||||
3. Visualize the results by :meth:`visualize`.
|
||||
4. Postprocess and return the results by :meth:`postprocess`.
|
||||
|
||||
When we call the subclasses inherited from BaseInferencer (not overriding
|
||||
``__call__``), the workflow will be executed in order.
|
||||
|
||||
All subclasses of BaseInferencer could define the following class
|
||||
attributes for customization:
|
||||
|
||||
- ``preprocess_kwargs``: The keys of the kwargs that will be passed to
|
||||
:meth:`preprocess`.
|
||||
- ``forward_kwargs``: The keys of the kwargs that will be passed to
|
||||
:meth:`forward`
|
||||
- ``visualize_kwargs``: The keys of the kwargs that will be passed to
|
||||
:meth:`visualize`
|
||||
- ``postprocess_kwargs``: The keys of the kwargs that will be passed to
|
||||
:meth:`postprocess`
|
||||
|
||||
All attributes mentioned above should be a ``set`` of keys (strings),
|
||||
and each key should not be duplicated. Actually, :meth:`__call__` will
|
||||
dispatch all the arguments to the corresponding methods according to the
|
||||
``xxx_kwargs`` mentioned above, therefore, the key in sets should
|
||||
be unique to avoid ambiguous dispatching.
|
||||
|
||||
Warning:
|
||||
If subclasses defined the class attributes mentioned above with
|
||||
duplicated keys, an ``AssertionError`` will be raised during import
|
||||
process.
|
||||
|
||||
Subclasses inherited from ``BaseInferencer`` should implement
|
||||
:meth:`_init_pipeline`, :meth:`visualize` and :meth:`postprocess`:
|
||||
|
||||
- _init_pipeline: Return a callable object to preprocess the input data.
|
||||
- visualize: Visualize the results returned by :meth:`forward`.
|
||||
- postprocess: Postprocess the results returned by :meth:`forward` and
|
||||
:meth:`visualize`.
|
||||
|
||||
Args:
|
||||
model (str, optional): Path to the config file or the model name
|
||||
defined in metafile. Take the `mmdet metafile <https://github.com/open-mmlab/mmdetection/blob/master/configs/retinanet/metafile.yml>`_
|
||||
as an example, the `model` could be `retinanet_r18_fpn_1x_coco` or
|
||||
its alias. If model is not specified, user must provide the
|
||||
`weights` saved by MMEngine which contains the config string.
|
||||
Defaults to None.
|
||||
weights (str, optional): Path to the checkpoint. If it is not specified
|
||||
and model is a model name of metafile, the weights will be loaded
|
||||
from metafile. Defaults to None.
|
||||
device (str, optional): Device to run inference. If None, the available
|
||||
device will be automatically used. Defaults to None.
|
||||
scope (str, optional): The scope of the model. Defaults to None.
|
||||
|
||||
Note:
|
||||
Since ``Inferencer`` could be used to infer batch data,
|
||||
`collate_fn` should be defined. If `collate_fn` is not defined in config
|
||||
file, the `collate_fn` will be `pseudo_collate` by default.
|
||||
""" # noqa: E501
|
||||
|
||||
preprocess_kwargs: set = set()
|
||||
forward_kwargs: set = set()
|
||||
visualize_kwargs: set = set()
|
||||
postprocess_kwargs: set = set()
|
||||
|
||||
def __init__(self,
|
||||
model: Union[ModelType, str, None] = None,
|
||||
weights: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
scope: Optional[str] = None) -> None:
|
||||
if scope is None:
|
||||
default_scope = DefaultScope.get_current_instance()
|
||||
if default_scope is not None:
|
||||
scope = default_scope.scope_name
|
||||
self.scope = scope
|
||||
# Load config to cfg
|
||||
cfg: ConfigType
|
||||
if isinstance(model, str):
|
||||
if osp.isfile(model):
|
||||
cfg = Config.fromfile(model)
|
||||
else:
|
||||
# Load config and weights from metafile. If `weights` is
|
||||
# assigned, the weights defined in metafile will be ignored.
|
||||
cfg, _weights = self._load_model_from_metafile(model)
|
||||
if weights is None:
|
||||
weights = _weights
|
||||
elif isinstance(model, (Config, ConfigDict)):
|
||||
cfg = copy.deepcopy(model)
|
||||
elif isinstance(model, dict):
|
||||
cfg = copy.deepcopy(ConfigDict(model))
|
||||
elif model is None:
|
||||
if weights is None:
|
||||
raise ValueError(
|
||||
'If model is None, the weights must be specified since '
|
||||
'the config needs to be loaded from the weights')
|
||||
cfg = ConfigDict()
|
||||
else:
|
||||
raise TypeError('model must be a filepath or any ConfigType'
|
||||
f'object, but got {type(model)}')
|
||||
|
||||
if device is None:
|
||||
device = get_device()
|
||||
|
||||
self.model = self._init_model(cfg, weights, device) # type: ignore
|
||||
self.pipeline = self._init_pipeline(cfg)
|
||||
self.collate_fn = self._init_collate(cfg)
|
||||
self.visualizer = self._init_visualizer(cfg)
|
||||
self.cfg = cfg
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: InputsType,
|
||||
return_datasamples: bool = False,
|
||||
batch_size: int = 1,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""Call the inferencer.
|
||||
|
||||
Args:
|
||||
inputs (InputsType): Inputs for the inferencer.
|
||||
return_datasamples (bool): Whether to return results as
|
||||
:obj:`BaseDataElement`. Defaults to False.
|
||||
batch_size (int): Batch size. Defaults to 1.
|
||||
**kwargs: Key words arguments passed to :meth:`preprocess`,
|
||||
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
|
||||
Each key in kwargs should be in the corresponding set of
|
||||
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
|
||||
and ``postprocess_kwargs``.
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results.
|
||||
"""
|
||||
(
|
||||
preprocess_kwargs,
|
||||
forward_kwargs,
|
||||
visualize_kwargs,
|
||||
postprocess_kwargs,
|
||||
) = self._dispatch_kwargs(**kwargs)
|
||||
|
||||
ori_inputs = self._inputs_to_list(inputs)
|
||||
inputs = self.preprocess(
|
||||
ori_inputs, batch_size=batch_size, **preprocess_kwargs)
|
||||
preds = []
|
||||
for data in track(inputs, description='Inference'):
|
||||
preds.extend(self.forward(data, **forward_kwargs))
|
||||
visualization = self.visualize(
|
||||
ori_inputs, preds,
|
||||
**visualize_kwargs) # type: ignore # noqa: E501
|
||||
results = self.postprocess(preds, visualization, return_datasamples,
|
||||
**postprocess_kwargs)
|
||||
return results
|
||||
|
||||
def _inputs_to_list(self, inputs: InputsType) -> list:
|
||||
"""Preprocess the inputs to a list.
|
||||
|
||||
Preprocess inputs to a list according to its type:
|
||||
|
||||
- list or tuple: return inputs
|
||||
- str:
|
||||
- Directory path: return all files in the directory
|
||||
- other cases: return a list containing the string. The string
|
||||
could be a path to file, a url or other types of string according
|
||||
to the task.
|
||||
|
||||
Args:
|
||||
inputs (InputsType): Inputs for the inferencer.
|
||||
|
||||
Returns:
|
||||
list: List of input for the :meth:`preprocess`.
|
||||
"""
|
||||
if isinstance(inputs, str):
|
||||
backend = get_file_backend(inputs)
|
||||
if hasattr(backend, 'isdir') and isdir(inputs):
|
||||
# Backends like HttpsBackend do not implement `isdir`, so only
|
||||
# those backends that implement `isdir` could accept the inputs
|
||||
# as a directory
|
||||
filename_list = list_dir_or_file(inputs, list_dir=False)
|
||||
inputs = [
|
||||
join_path(inputs, filename) for filename in filename_list
|
||||
]
|
||||
|
||||
if not isinstance(inputs, (list, tuple)):
|
||||
inputs = [inputs]
|
||||
|
||||
return list(inputs)
|
||||
|
||||
def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs):
|
||||
"""Process the inputs into a model-feedable format.
|
||||
|
||||
Customize your preprocess by overriding this method. Preprocess should
|
||||
return an iterable object, of which each item will be used as the
|
||||
input of ``model.test_step``.
|
||||
|
||||
``BaseInferencer.preprocess`` will return an iterable chunked data,
|
||||
which will be used in __call__ like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def __call__(self, inputs, batch_size=1, **kwargs):
|
||||
chunked_data = self.preprocess(inputs, batch_size, **kwargs)
|
||||
for batch in chunked_data:
|
||||
preds = self.forward(batch, **kwargs)
|
||||
|
||||
Args:
|
||||
inputs (InputsType): Inputs given by user.
|
||||
batch_size (int): batch size. Defaults to 1.
|
||||
|
||||
Yields:
|
||||
Any: Data processed by the ``pipeline`` and ``collate_fn``.
|
||||
"""
|
||||
chunked_data = self._get_chunk_data(
|
||||
map(self.pipeline, inputs), batch_size)
|
||||
yield from map(self.collate_fn, chunked_data)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, inputs: Union[dict, tuple], **kwargs) -> Any:
|
||||
"""Feed the inputs to the model."""
|
||||
return self.model.test_step(inputs)
|
||||
|
||||
@abstractmethod
|
||||
def visualize(self,
|
||||
inputs: list,
|
||||
preds: Any,
|
||||
show: bool = False,
|
||||
**kwargs) -> List[np.ndarray]:
|
||||
"""Visualize predictions.
|
||||
|
||||
Customize your visualization by overriding this method. visualize
|
||||
should return visualization results, which could be np.ndarray or any
|
||||
other objects.
|
||||
|
||||
Args:
|
||||
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
|
||||
preds (Any): Predictions of the model.
|
||||
show (bool): Whether to display the image in a popup window.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
List[np.ndarray]: Visualization results.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def postprocess(
|
||||
self,
|
||||
preds: Any,
|
||||
visualization: List[np.ndarray],
|
||||
return_datasample=False,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""Process the predictions and visualization results from ``forward``
|
||||
and ``visualize``.
|
||||
|
||||
This method should be responsible for the following tasks:
|
||||
|
||||
1. Convert datasamples into a json-serializable dict if needed.
|
||||
2. Pack the predictions and visualization results and return them.
|
||||
3. Dump or log the predictions.
|
||||
|
||||
Customize your postprocess by overriding this method. Make sure
|
||||
``postprocess`` will return a dict with visualization results and
|
||||
inference results.
|
||||
|
||||
Args:
|
||||
preds (List[Dict]): Predictions of the model.
|
||||
visualization (np.ndarray): Visualized predictions.
|
||||
return_datasample (bool): Whether to return results as datasamples.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results with key ``predictions``
|
||||
and ``visualization``
|
||||
|
||||
- ``visualization (Any)``: Returned by :meth:`visualize`
|
||||
- ``predictions`` (dict or DataSample): Returned by
|
||||
:meth:`forward` and processed in :meth:`postprocess`.
|
||||
If ``return_datasample=False``, it usually should be a
|
||||
json-serializable dict containing only basic data elements such
|
||||
as strings and numbers.
|
||||
"""
|
||||
|
||||
def _load_model_from_metafile(self, model: str) -> Tuple[Config, str]:
|
||||
"""Load config and weights from metafile.
|
||||
|
||||
Args:
|
||||
model (str): model name defined in metafile.
|
||||
|
||||
Returns:
|
||||
Tuple[Config, str]: Loaded Config and weights path defined in
|
||||
metafile.
|
||||
"""
|
||||
model = model.lower()
|
||||
|
||||
assert self.scope is not None, (
|
||||
'scope should be initialized if you want '
|
||||
'to load config from metafile.')
|
||||
assert self.scope in MODULE2PACKAGE, (
|
||||
f'{self.scope} not in {MODULE2PACKAGE}!,'
|
||||
'please pass a valid scope.')
|
||||
project = MODULE2PACKAGE[self.scope]
|
||||
assert is_installed(project), f'Please install {project}'
|
||||
package_path = get_installed_path(project)
|
||||
for model_cfg in BaseInferencer._get_models_from_package(package_path):
|
||||
model_name = model_cfg['Name'].lower()
|
||||
model_aliases = model_cfg.get('Alias', [])
|
||||
if isinstance(model_aliases, str):
|
||||
model_aliases = [model_aliases.lower()]
|
||||
else:
|
||||
model_aliases = [alias.lower() for alias in model_aliases]
|
||||
if (model_name == model or model in model_aliases):
|
||||
cfg = Config.fromfile(
|
||||
osp.join(package_path, '.mim', model_cfg['Config']))
|
||||
weights = model_cfg['Weights']
|
||||
weights = weights[0] if isinstance(weights, list) else weights
|
||||
return cfg, weights
|
||||
raise ValueError(f'Cannot find model: {model} in {project}')
|
||||
|
||||
def _init_model(
|
||||
self,
|
||||
cfg: ConfigType,
|
||||
weights: Optional[str],
|
||||
device: str = 'cpu',
|
||||
) -> nn.Module:
|
||||
"""Initialize the model with the given config and checkpoint on the
|
||||
specific device.
|
||||
|
||||
Args:
|
||||
cfg (ConfigType): Config containing the model information.
|
||||
weights (str, optional): Path to the checkpoint.
|
||||
device (str, optional): Device to run inference. Defaults to 'cpu'.
|
||||
|
||||
Returns:
|
||||
nn.Module: Model loaded with checkpoint.
|
||||
"""
|
||||
checkpoint: Optional[dict] = None
|
||||
if weights is not None:
|
||||
checkpoint = _load_checkpoint(weights, map_location='cpu')
|
||||
|
||||
if not cfg:
|
||||
assert checkpoint is not None
|
||||
try:
|
||||
# Prefer to get config from `message_hub` since `message_hub`
|
||||
# is a more stable module to store all runtime information.
|
||||
# However, the early version of MMEngine will not save config
|
||||
# in `message_hub`, so we will try to load config from `meta`.
|
||||
cfg_string = checkpoint['message_hub']['runtime_info']['cfg']
|
||||
except KeyError:
|
||||
assert 'meta' in checkpoint, (
|
||||
'If model(config) is not provided, the checkpoint must'
|
||||
'contain the config string in `meta` or `message_hub`, '
|
||||
'but both `meta` and `message_hub` are not found in the '
|
||||
'checkpoint.')
|
||||
meta = checkpoint['meta']
|
||||
if 'cfg' in meta:
|
||||
cfg_string = meta['cfg']
|
||||
else:
|
||||
raise ValueError(
|
||||
'Cannot find the config in the checkpoint.')
|
||||
cfg.update(
|
||||
Config.fromstring(cfg_string, file_format='.py')._cfg_dict)
|
||||
|
||||
# Delete the `pretrained` field to prevent model from loading the
|
||||
# the pretrained weights unnecessarily.
|
||||
if cfg.model.get('pretrained') is not None:
|
||||
del cfg.model.pretrained
|
||||
|
||||
model = MODELS.build(cfg.model)
|
||||
model.cfg = cfg
|
||||
self._load_weights_to_model(model, checkpoint, cfg)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def _load_weights_to_model(self, model: nn.Module,
|
||||
checkpoint: Optional[dict],
|
||||
cfg: Optional[ConfigType]) -> None:
|
||||
"""Loading model weights and meta information from cfg and checkpoint.
|
||||
|
||||
Subclasses could override this method to load extra meta information
|
||||
from ``checkpoint`` and ``cfg`` to model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model to load weights and meta information.
|
||||
checkpoint (dict, optional): The loaded checkpoint.
|
||||
cfg (Config or ConfigDict, optional): The loaded config.
|
||||
"""
|
||||
if checkpoint is not None:
|
||||
_load_checkpoint_to_model(model, checkpoint)
|
||||
else:
|
||||
warnings.warn('Checkpoint is not loaded, and the inference '
|
||||
'result is calculated by the randomly initialized '
|
||||
'model!')
|
||||
|
||||
def _init_collate(self, cfg: ConfigType) -> Callable:
|
||||
"""Initialize the ``collate_fn`` with the given config.
|
||||
|
||||
The returned ``collate_fn`` will be used to collate the batch data.
|
||||
If will be used in :meth:`preprocess` like this
|
||||
|
||||
.. code-block:: python
|
||||
def preprocess(self, inputs, batch_size, **kwargs):
|
||||
...
|
||||
dataloader = map(self.collate_fn, dataloader)
|
||||
yield from dataloader
|
||||
|
||||
Args:
|
||||
cfg (ConfigType): Config which could contained the `collate_fn`
|
||||
information. If `collate_fn` is not defined in config, it will
|
||||
be :func:`pseudo_collate`.
|
||||
|
||||
Returns:
|
||||
Callable: Collate function.
|
||||
"""
|
||||
try:
|
||||
with COLLATE_FUNCTIONS.switch_scope_and_registry(
|
||||
self.scope) as registry:
|
||||
collate_fn = registry.get(cfg.test_dataloader.collate_fn)
|
||||
except AttributeError:
|
||||
collate_fn = pseudo_collate
|
||||
return collate_fn # type: ignore
|
||||
|
||||
@abstractmethod
|
||||
def _init_pipeline(self, cfg: ConfigType) -> Callable:
|
||||
"""Initialize the test pipeline.
|
||||
|
||||
Return a pipeline to handle various input data, such as ``str``,
|
||||
``np.ndarray``. It is an abstract method in BaseInferencer, and should
|
||||
be implemented in subclasses.
|
||||
|
||||
The returned pipeline will be used to process a single data.
|
||||
It will be used in :meth:`preprocess` like this:
|
||||
|
||||
.. code-block:: python
|
||||
def preprocess(self, inputs, batch_size, **kwargs):
|
||||
...
|
||||
dataset = map(self.pipeline, dataset)
|
||||
...
|
||||
"""
|
||||
|
||||
def _init_visualizer(self, cfg: ConfigType) -> Optional[Visualizer]:
|
||||
"""Initialize visualizers.
|
||||
|
||||
Args:
|
||||
cfg (ConfigType): Config containing the visualizer information.
|
||||
|
||||
Returns:
|
||||
Visualizer or None: Visualizer initialized with config.
|
||||
"""
|
||||
if 'visualizer' not in cfg:
|
||||
return None
|
||||
timestamp = str(datetime.timestamp(datetime.now()))
|
||||
name = cfg.visualizer.get('name', timestamp)
|
||||
if Visualizer.check_instance_created(name):
|
||||
name = f'{name}-{timestamp}'
|
||||
cfg.visualizer.name = name
|
||||
return VISUALIZERS.build(cfg.visualizer)
|
||||
|
||||
def _get_chunk_data(self, inputs: Iterable, chunk_size: int):
|
||||
"""Get batch data from dataset.
|
||||
|
||||
Args:
|
||||
inputs (Iterable): An iterable dataset.
|
||||
chunk_size (int): Equivalent to batch size.
|
||||
|
||||
Yields:
|
||||
list: batch data.
|
||||
"""
|
||||
inputs_iter = iter(inputs)
|
||||
while True:
|
||||
try:
|
||||
chunk_data = []
|
||||
for _ in range(chunk_size):
|
||||
processed_data = next(inputs_iter)
|
||||
chunk_data.append(processed_data)
|
||||
yield chunk_data
|
||||
except StopIteration:
|
||||
if chunk_data:
|
||||
yield chunk_data
|
||||
break
|
||||
|
||||
def _dispatch_kwargs(self, **kwargs) -> Tuple[Dict, Dict, Dict, Dict]:
|
||||
"""Dispatch kwargs to preprocess(), forward(), visualize() and
|
||||
postprocess() according to the actual demands.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess,
|
||||
forward, visualize and postprocess respectively.
|
||||
"""
|
||||
# Ensure each argument only matches one function
|
||||
method_kwargs = self.preprocess_kwargs | self.forward_kwargs | \
|
||||
self.visualize_kwargs | self.postprocess_kwargs
|
||||
|
||||
union_kwargs = method_kwargs | set(kwargs.keys())
|
||||
if union_kwargs != method_kwargs:
|
||||
unknown_kwargs = union_kwargs - method_kwargs
|
||||
raise ValueError(
|
||||
f'unknown argument {unknown_kwargs} for `preprocess`, '
|
||||
'`forward`, `visualize` and `postprocess`')
|
||||
|
||||
preprocess_kwargs = {}
|
||||
forward_kwargs = {}
|
||||
visualize_kwargs = {}
|
||||
postprocess_kwargs = {}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key in self.preprocess_kwargs:
|
||||
preprocess_kwargs[key] = value
|
||||
elif key in self.forward_kwargs:
|
||||
forward_kwargs[key] = value
|
||||
elif key in self.visualize_kwargs:
|
||||
visualize_kwargs[key] = value
|
||||
else:
|
||||
postprocess_kwargs[key] = value
|
||||
|
||||
return (
|
||||
preprocess_kwargs,
|
||||
forward_kwargs,
|
||||
visualize_kwargs,
|
||||
postprocess_kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_models_from_package(package_path: str):
|
||||
"""Load model config defined in metafile from package path.
|
||||
|
||||
Args:
|
||||
package_path (str): Path to the package.
|
||||
|
||||
Yields:
|
||||
dict: Model config defined in metafile.
|
||||
"""
|
||||
meta_indexes = load(osp.join(package_path, '.mim', 'model-index.yml'))
|
||||
for meta_path in meta_indexes['Import']:
|
||||
# meta_path example: mmcls/.mim/configs/conformer/metafile.yml
|
||||
meta_path = osp.join(package_path, '.mim', meta_path)
|
||||
metainfo = load(meta_path)
|
||||
yield from metainfo['Models']
|
||||
|
||||
@staticmethod
|
||||
def list_models(scope: Optional[str] = None, patterns: str = r'.*'):
|
||||
"""List models defined in metafile of corresponding packages.
|
||||
|
||||
Args:
|
||||
scope (str, optional): The scope to which the model belongs.
|
||||
Defaults to None.
|
||||
patterns (str, optional): Regular expressions for the searched
|
||||
models. Once matched with ``Alias`` or ``Name`` filed in
|
||||
metafile, corresponding model will be added to the return list.
|
||||
Defaults to '.*'.
|
||||
|
||||
Returns:
|
||||
dict: Model dict with model name and its alias.
|
||||
"""
|
||||
matched_models = []
|
||||
if scope is None:
|
||||
default_scope = DefaultScope.get_current_instance()
|
||||
assert default_scope is not None, (
|
||||
'scope should be initialized if you want '
|
||||
'to load config from metafile.')
|
||||
assert scope in MODULE2PACKAGE, (
|
||||
f'{scope} not in {MODULE2PACKAGE}!, please make pass a valid '
|
||||
'scope.')
|
||||
project = MODULE2PACKAGE[scope]
|
||||
assert is_installed(project), (f'Please install {project}')
|
||||
package_path = get_installed_path(project)
|
||||
|
||||
for model_cfg in BaseInferencer._get_models_from_package(package_path):
|
||||
model_name = [model_cfg['Name']]
|
||||
model_name.extend(model_cfg.get('Alias', []))
|
||||
for name in model_name:
|
||||
if re.match(patterns, name) is not None:
|
||||
matched_models.append(name)
|
||||
output_str = ''
|
||||
for name in matched_models:
|
||||
output_str += f'model_name: {name}\n'
|
||||
print_log(output_str, logger='current')
|
||||
return matched_models
|
@ -195,7 +195,7 @@ class Registry:
|
||||
return self._get_root_registry()
|
||||
|
||||
@contextmanager
|
||||
def switch_scope_and_registry(self, scope: str) -> Generator:
|
||||
def switch_scope_and_registry(self, scope: Optional[str]) -> Generator:
|
||||
"""Temporarily switch default scope to the target scope, and get the
|
||||
corresponding registry.
|
||||
|
||||
@ -203,7 +203,7 @@ class Registry:
|
||||
registry, otherwise yield the current itself.
|
||||
|
||||
Args:
|
||||
scope (str): The target scope.
|
||||
scope (str, optional): The target scope.
|
||||
|
||||
Examples:
|
||||
>>> from mmengine.registry import Registry, DefaultScope, MODELS
|
||||
|
@ -31,7 +31,7 @@ class ToyModel(BaseModel):
|
||||
self.linear1 = nn.Linear(2, 2)
|
||||
self.linear2 = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, inputs, data_samples, mode='tensor'):
|
||||
def forward(self, inputs, data_samples=None, mode='tensor'):
|
||||
if isinstance(inputs, list):
|
||||
inputs = torch.stack(inputs)
|
||||
if isinstance(data_samples, list):
|
||||
|
@ -3,5 +3,6 @@ matplotlib
|
||||
numpy
|
||||
pyyaml
|
||||
regex;sys_platform=='win32'
|
||||
rich
|
||||
termcolor
|
||||
yapf
|
||||
|
221
tests/test_infer/test_infer.py
Normal file
221
tests/test_infer/test_infer.py
Normal file
@ -0,0 +1,221 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.infer import BaseInferencer
|
||||
from mmengine.registry import VISUALIZERS, DefaultScope
|
||||
from mmengine.testing import RunnerTestCase
|
||||
from mmengine.utils import is_installed, is_list_of
|
||||
from mmengine.visualization import Visualizer
|
||||
|
||||
|
||||
class ToyInferencer(BaseInferencer):
|
||||
preprocess_kwargs = {'pre_arg'}
|
||||
forward_kwargs = {'for_arg'}
|
||||
visualize_kwargs = {'vis_arg'}
|
||||
postprocess_kwargs = {'pos_arg'}
|
||||
|
||||
def preprocess(self, inputs, batch_size=1, pre_arg=None, **kwargs):
|
||||
return super().preprocess(inputs, batch_size, **kwargs)
|
||||
|
||||
def forward(self, inputs, for_arg=None, **kwargs):
|
||||
return super().forward(inputs, **kwargs)
|
||||
|
||||
def visualize(self, inputs, preds, vis_arg=None, **kwargs):
|
||||
return inputs
|
||||
|
||||
def postprocess(self,
|
||||
preds,
|
||||
imgs,
|
||||
return_datasamples,
|
||||
pos_arg=None,
|
||||
**kwargs):
|
||||
return imgs, preds
|
||||
|
||||
def _init_pipeline(self, cfg):
|
||||
|
||||
def pipeline(img):
|
||||
if isinstance(img, str):
|
||||
img = np.load(img, allow_pickle=True)
|
||||
img = torch.from_numpy(img).float()
|
||||
elif isinstance(img, np.ndarray):
|
||||
img = torch.from_numpy(img).float()
|
||||
else:
|
||||
img = torch.tensor(img).float()
|
||||
return img
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
class ToyVisualizer(Visualizer):
|
||||
...
|
||||
|
||||
|
||||
class TestBaseInferencer(RunnerTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
runner = self.build_runner(copy.deepcopy(self.epoch_based_cfg))
|
||||
runner.train()
|
||||
self.cfg_path = osp.join(runner.work_dir, f'{runner.timestamp}.py')
|
||||
self.ckpt_path = osp.join(runner.work_dir, 'epoch_1.pth')
|
||||
VISUALIZERS.register_module(module=ToyVisualizer, name='ToyVisualizer')
|
||||
|
||||
def test_custom_inferencer(self):
|
||||
# Inferencer should not define ***_kwargs with duplicate keys.
|
||||
with self.assertRaisesRegex(AssertionError, 'Class define error'):
|
||||
|
||||
class CustomInferencer(BaseInferencer):
|
||||
preprocess_kwargs = set('a')
|
||||
forward_kwargs = set('a')
|
||||
|
||||
def tearDown(self):
|
||||
VISUALIZERS._module_dict.pop('ToyVisualizer')
|
||||
return super().tearDown()
|
||||
|
||||
def test_init(self):
|
||||
# Pass model as Config
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
ToyInferencer(cfg, self.ckpt_path)
|
||||
# Pass model as ConfigDict
|
||||
ToyInferencer(cfg._cfg_dict, self.ckpt_path)
|
||||
# Pass model as normal dict
|
||||
ToyInferencer(dict(cfg._cfg_dict), self.ckpt_path)
|
||||
# Pass model as string point to path of config
|
||||
ToyInferencer(self.cfg_path, self.ckpt_path)
|
||||
|
||||
cfg.model.pretrained = 'fake_path'
|
||||
inferencer = ToyInferencer(cfg, self.ckpt_path)
|
||||
self.assertNotIn('pretrained', inferencer.cfg.model)
|
||||
|
||||
# Pass invalid model
|
||||
with self.assertRaisesRegex(TypeError, 'model must'):
|
||||
ToyInferencer([self.epoch_based_cfg], self.ckpt_path)
|
||||
|
||||
# Pass model as model name defined in metafile
|
||||
if is_installed('mmdet'):
|
||||
from mmdet.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
ToyInferencer(
|
||||
'faster-rcnn_s50_fpn_syncbn-backbone+head_ms-range-1x_coco',
|
||||
'https://download.openmmlab.com/mmdetection/v2.0/resnest/faster_rcnn_s50_fpn_syncbn-backbone%2Bhead_mstrain-range_1x_coco/faster_rcnn_s50_fpn_syncbn-backbone%2Bhead_mstrain-range_1x_coco_20200926_125502-20289c16.pth', # noqa: E501
|
||||
)
|
||||
|
||||
checkpoint = self.ckpt_path
|
||||
ToyInferencer(weights=checkpoint)
|
||||
|
||||
def test_call(self):
|
||||
num_imgs = 12
|
||||
imgs = []
|
||||
img_paths = []
|
||||
for i in range(num_imgs):
|
||||
img = np.random.random((1, 2))
|
||||
img_path = osp.join(self.temp_dir.name, f'{i}.npy')
|
||||
img.dump(img_path)
|
||||
imgs.append(img)
|
||||
img_paths.append(img_path)
|
||||
|
||||
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
||||
inferencer(imgs)
|
||||
inferencer(img_paths)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_installed('mmdet'), reason='mmdet is not installed')
|
||||
def test_load_model_from_meta(self):
|
||||
from mmdet.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
||||
inferencer._load_model_from_metafile('retinanet_r18_fpn_1x_coco')
|
||||
with self.assertRaisesRegex(ValueError, 'Cannot find model'):
|
||||
inferencer._load_model_from_metafile('fake_model')
|
||||
# TODO: Test alias
|
||||
|
||||
def test_init_model(self):
|
||||
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
||||
model = inferencer._init_model(self.iter_based_cfg, self.ckpt_path)
|
||||
self.assertFalse(model.training)
|
||||
|
||||
def test_get_chunk_data(self):
|
||||
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
||||
data = list(range(1, 11))
|
||||
chunk_data = inferencer._get_chunk_data(data, 3)
|
||||
self.assertEqual(
|
||||
list(chunk_data), [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]])
|
||||
|
||||
def test_init_visualizer(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
||||
visualizer = inferencer._init_visualizer(cfg)
|
||||
self.assertIsNone(visualizer, None)
|
||||
cfg.visualizer = dict(type='ToyVisualizer')
|
||||
visualizer = inferencer._init_visualizer(cfg)
|
||||
self.assertIsInstance(visualizer, ToyVisualizer)
|
||||
|
||||
# Visualizer could be built with the same name repeatedly.
|
||||
cfg.visualizer = dict(type='ToyVisualizer', name='toy')
|
||||
visualizer = inferencer._init_visualizer(cfg)
|
||||
visualizer = inferencer._init_visualizer(cfg)
|
||||
|
||||
def test_dispatch_kwargs(self):
|
||||
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
||||
kwargs = dict(
|
||||
pre_arg=dict(a=1),
|
||||
for_arg=dict(c=2),
|
||||
vis_arg=dict(b=3),
|
||||
pos_arg=dict(d=4))
|
||||
pre_arg, for_arg, vis_arg, pos_arg = inferencer._dispatch_kwargs(
|
||||
**kwargs)
|
||||
self.assertEqual(pre_arg, dict(pre_arg=dict(a=1)))
|
||||
self.assertEqual(for_arg, dict(for_arg=dict(c=2)))
|
||||
self.assertEqual(vis_arg, dict(vis_arg=dict(b=3)))
|
||||
self.assertEqual(pos_arg, dict(pos_arg=dict(d=4)))
|
||||
# Test unknown arg.
|
||||
kwargs = dict(return_datasample=dict())
|
||||
with self.assertRaisesRegex(ValueError, 'unknown'):
|
||||
inferencer._dispatch_kwargs(**kwargs)
|
||||
|
||||
def test_preprocess(self):
|
||||
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
||||
data = list(range(1, 11))
|
||||
pre_data = inferencer.preprocess(data, batch_size=3)
|
||||
target_data = [
|
||||
[torch.tensor(1),
|
||||
torch.tensor(2),
|
||||
torch.tensor(3)],
|
||||
[torch.tensor(4),
|
||||
torch.tensor(5),
|
||||
torch.tensor(6)],
|
||||
[torch.tensor(7),
|
||||
torch.tensor(8),
|
||||
torch.tensor(9)],
|
||||
[torch.tensor(10)],
|
||||
]
|
||||
self.assertEqual(list(pre_data), target_data)
|
||||
os.mkdir(osp.join(self.temp_dir.name, 'imgs'))
|
||||
for i in range(1, 11):
|
||||
img = np.array(1)
|
||||
img.dump(osp.join(self.temp_dir.name, 'imgs', f'{i}.npy'))
|
||||
# Passing a directory of images.
|
||||
inputs = inferencer._inputs_to_list(
|
||||
osp.join(self.temp_dir.name, 'imgs'))
|
||||
dataloader = inferencer.preprocess(inputs, batch_size=3)
|
||||
for data in dataloader:
|
||||
self.assertTrue(is_list_of(data, torch.Tensor))
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_installed('mmdet'), reason='mmdet is not installed')
|
||||
def test_list_models(self):
|
||||
model_list = BaseInferencer.list_models('mmdet')
|
||||
self.assertTrue(len(model_list) > 0)
|
||||
DefaultScope._instance_dict.clear()
|
||||
with self.assertRaisesRegex(AssertionError, 'scope should be'):
|
||||
BaseInferencer.list_models()
|
||||
with self.assertRaisesRegex(AssertionError, 'unknown not in'):
|
||||
BaseInferencer.list_models('unknown')
|
Loading…
x
Reference in New Issue
Block a user