mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add BaseInferencer (#874)
* [Feature] Add BaseInferencer (#773) * Update BaseInferencer * Fix ci * Fix CI and rename iferencer to infer * Fix CI * Add renamed file * Add test file * Adjust interface sequence * refine preprocess * Update unit test Update unit test * Update unit test * Fix unit test * Fix as comment * Minor refine * Fix docstring and support load image from different backend * Support load collate_fn from downstream repos, refine dispatch * Minor refine * Fix lint * refine grammar * Remove FileClient * Refine docstring * add rich * Add list_models * Add list_models * Remove backend args * Minor refine * Fix typos in docs and type hints (#787) * [Fix] Add _inputs_to_list (#795) * Add preprocess inputs * Add type hint * update api/infer in index.rst * rename preprocess_inputs to _inputs_to_list * Fix doc format * Update infer.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * [Fix] Fix alias type (#801) * [Enhance] Support loading model config from checkpoint (#864) * first commit * [Enhance] Support build model from weight * minor refine * Fix type hint * refine comments * Update docstring * refine as comment * Add method * Refine docstring * Fix as comment * refine comments * Refine warning message * Fix unit test and refine comments * replace MODULE2PACKAGE to MODULE2PAKCAGE * Fix typo and syntax error in docstring Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Tong Gao <gaotongxiao@gmail.com>
This commit is contained in:
parent
ad590e45a2
commit
2d8f2be375
14
docs/en/api/infer.rst
Normal file
14
docs/en/api/infer.rst
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
.. role:: hidden
|
||||||
|
:class: hidden-section
|
||||||
|
|
||||||
|
mmengine.infer
|
||||||
|
===================================
|
||||||
|
|
||||||
|
.. currentmodule:: mmengine.infer
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
BaseInferencer
|
@ -82,6 +82,7 @@ You can switch between Chinese and English documents in the lower-left corner of
|
|||||||
mmengine.evaluator <api/evaluator>
|
mmengine.evaluator <api/evaluator>
|
||||||
mmengine.structures <api/structures>
|
mmengine.structures <api/structures>
|
||||||
mmengine.dataset <api/dataset>
|
mmengine.dataset <api/dataset>
|
||||||
|
mmengine.infer <api/infer>
|
||||||
mmengine.device <api/device>
|
mmengine.device <api/device>
|
||||||
mmengine.hub <api/hub>
|
mmengine.hub <api/hub>
|
||||||
mmengine.logging <api/logging>
|
mmengine.logging <api/logging>
|
||||||
|
14
docs/zh_cn/api/infer.rst
Normal file
14
docs/zh_cn/api/infer.rst
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
.. role:: hidden
|
||||||
|
:class: hidden-section
|
||||||
|
|
||||||
|
mmengine.infer
|
||||||
|
===================================
|
||||||
|
|
||||||
|
.. currentmodule:: mmengine.infer
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
BaseInferencer
|
@ -82,6 +82,7 @@
|
|||||||
mmengine.evaluator <api/evaluator>
|
mmengine.evaluator <api/evaluator>
|
||||||
mmengine.structures <api/structures>
|
mmengine.structures <api/structures>
|
||||||
mmengine.dataset <api/dataset>
|
mmengine.dataset <api/dataset>
|
||||||
|
mmengine.infer <api/infer>
|
||||||
mmengine.device <api/device>
|
mmengine.device <api/device>
|
||||||
mmengine.hub <api/hub>
|
mmengine.hub <api/hub>
|
||||||
mmengine.logging <api/logging>
|
mmengine.logging <api/logging>
|
||||||
|
4
mmengine/infer/__init__.py
Normal file
4
mmengine/infer/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .infer import BaseInferencer
|
||||||
|
|
||||||
|
__all__ = ['BaseInferencer']
|
648
mmengine/infer/infer.py
Normal file
648
mmengine/infer/infer.py
Normal file
@ -0,0 +1,648 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import copy
|
||||||
|
import os.path as osp
|
||||||
|
import re
|
||||||
|
import warnings
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import (Any, Callable, Dict, Iterable, List, Optional, Sequence,
|
||||||
|
Tuple, Union)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from rich.progress import track
|
||||||
|
|
||||||
|
from mmengine.config import Config, ConfigDict
|
||||||
|
from mmengine.config.utils import MODULE2PACKAGE
|
||||||
|
from mmengine.dataset import COLLATE_FUNCTIONS, pseudo_collate
|
||||||
|
from mmengine.device import get_device
|
||||||
|
from mmengine.fileio import (get_file_backend, isdir, join_path,
|
||||||
|
list_dir_or_file, load)
|
||||||
|
from mmengine.logging import print_log
|
||||||
|
from mmengine.registry import MODELS, VISUALIZERS, DefaultScope
|
||||||
|
from mmengine.runner.checkpoint import (_load_checkpoint,
|
||||||
|
_load_checkpoint_to_model)
|
||||||
|
from mmengine.structures import InstanceData
|
||||||
|
from mmengine.utils import get_installed_path, is_installed
|
||||||
|
from mmengine.visualization import Visualizer
|
||||||
|
|
||||||
|
InstanceList = List[InstanceData]
|
||||||
|
InputType = Union[str, np.ndarray, torch.Tensor]
|
||||||
|
InputsType = Union[InputType, Sequence[InputType]]
|
||||||
|
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
|
||||||
|
ResType = Union[Dict, List[Dict]]
|
||||||
|
ConfigType = Union[Config, ConfigDict]
|
||||||
|
ModelType = Union[dict, ConfigType, str]
|
||||||
|
|
||||||
|
|
||||||
|
class InferencerMeta(ABCMeta):
|
||||||
|
"""Check the legality of the inferencer.
|
||||||
|
|
||||||
|
All Inferencers should not define duplicated keys for
|
||||||
|
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` and
|
||||||
|
``postprocess_kwargs``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
assert isinstance(self.preprocess_kwargs, set)
|
||||||
|
assert isinstance(self.forward_kwargs, set)
|
||||||
|
assert isinstance(self.visualize_kwargs, set)
|
||||||
|
assert isinstance(self.postprocess_kwargs, set)
|
||||||
|
|
||||||
|
all_kwargs = (
|
||||||
|
self.preprocess_kwargs | self.forward_kwargs
|
||||||
|
| self.visualize_kwargs | self.postprocess_kwargs)
|
||||||
|
|
||||||
|
assert len(all_kwargs) == (
|
||||||
|
len(self.preprocess_kwargs) + len(self.forward_kwargs) +
|
||||||
|
len(self.visualize_kwargs) + len(self.postprocess_kwargs)), (
|
||||||
|
f'Class define error! {self.__name__} should not '
|
||||||
|
'define duplicated keys for `preprocess_kwargs`, '
|
||||||
|
'`forward_kwargs`, `visualize_kwargs` and '
|
||||||
|
'`postprocess_kwargs` are not allowed.')
|
||||||
|
|
||||||
|
|
||||||
|
class BaseInferencer(metaclass=InferencerMeta):
|
||||||
|
"""Base inferencer for downstream tasks.
|
||||||
|
|
||||||
|
The BaseInferencer provides the standard workflow for inference as follows:
|
||||||
|
|
||||||
|
1. Preprocess the input data by :meth:`preprocess`.
|
||||||
|
2. Forward the data to the model by :meth:`forward`. ``BaseInferencer``
|
||||||
|
assumes the model inherits from :class:`mmengine.models.BaseModel` and
|
||||||
|
will call `model.test_step` in :meth:`forward` by default.
|
||||||
|
3. Visualize the results by :meth:`visualize`.
|
||||||
|
4. Postprocess and return the results by :meth:`postprocess`.
|
||||||
|
|
||||||
|
When we call the subclasses inherited from BaseInferencer (not overriding
|
||||||
|
``__call__``), the workflow will be executed in order.
|
||||||
|
|
||||||
|
All subclasses of BaseInferencer could define the following class
|
||||||
|
attributes for customization:
|
||||||
|
|
||||||
|
- ``preprocess_kwargs``: The keys of the kwargs that will be passed to
|
||||||
|
:meth:`preprocess`.
|
||||||
|
- ``forward_kwargs``: The keys of the kwargs that will be passed to
|
||||||
|
:meth:`forward`
|
||||||
|
- ``visualize_kwargs``: The keys of the kwargs that will be passed to
|
||||||
|
:meth:`visualize`
|
||||||
|
- ``postprocess_kwargs``: The keys of the kwargs that will be passed to
|
||||||
|
:meth:`postprocess`
|
||||||
|
|
||||||
|
All attributes mentioned above should be a ``set`` of keys (strings),
|
||||||
|
and each key should not be duplicated. Actually, :meth:`__call__` will
|
||||||
|
dispatch all the arguments to the corresponding methods according to the
|
||||||
|
``xxx_kwargs`` mentioned above, therefore, the key in sets should
|
||||||
|
be unique to avoid ambiguous dispatching.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
If subclasses defined the class attributes mentioned above with
|
||||||
|
duplicated keys, an ``AssertionError`` will be raised during import
|
||||||
|
process.
|
||||||
|
|
||||||
|
Subclasses inherited from ``BaseInferencer`` should implement
|
||||||
|
:meth:`_init_pipeline`, :meth:`visualize` and :meth:`postprocess`:
|
||||||
|
|
||||||
|
- _init_pipeline: Return a callable object to preprocess the input data.
|
||||||
|
- visualize: Visualize the results returned by :meth:`forward`.
|
||||||
|
- postprocess: Postprocess the results returned by :meth:`forward` and
|
||||||
|
:meth:`visualize`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str, optional): Path to the config file or the model name
|
||||||
|
defined in metafile. Take the `mmdet metafile <https://github.com/open-mmlab/mmdetection/blob/master/configs/retinanet/metafile.yml>`_
|
||||||
|
as an example, the `model` could be `retinanet_r18_fpn_1x_coco` or
|
||||||
|
its alias. If model is not specified, user must provide the
|
||||||
|
`weights` saved by MMEngine which contains the config string.
|
||||||
|
Defaults to None.
|
||||||
|
weights (str, optional): Path to the checkpoint. If it is not specified
|
||||||
|
and model is a model name of metafile, the weights will be loaded
|
||||||
|
from metafile. Defaults to None.
|
||||||
|
device (str, optional): Device to run inference. If None, the available
|
||||||
|
device will be automatically used. Defaults to None.
|
||||||
|
scope (str, optional): The scope of the model. Defaults to None.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Since ``Inferencer`` could be used to infer batch data,
|
||||||
|
`collate_fn` should be defined. If `collate_fn` is not defined in config
|
||||||
|
file, the `collate_fn` will be `pseudo_collate` by default.
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
preprocess_kwargs: set = set()
|
||||||
|
forward_kwargs: set = set()
|
||||||
|
visualize_kwargs: set = set()
|
||||||
|
postprocess_kwargs: set = set()
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model: Union[ModelType, str, None] = None,
|
||||||
|
weights: Optional[str] = None,
|
||||||
|
device: Optional[str] = None,
|
||||||
|
scope: Optional[str] = None) -> None:
|
||||||
|
if scope is None:
|
||||||
|
default_scope = DefaultScope.get_current_instance()
|
||||||
|
if default_scope is not None:
|
||||||
|
scope = default_scope.scope_name
|
||||||
|
self.scope = scope
|
||||||
|
# Load config to cfg
|
||||||
|
cfg: ConfigType
|
||||||
|
if isinstance(model, str):
|
||||||
|
if osp.isfile(model):
|
||||||
|
cfg = Config.fromfile(model)
|
||||||
|
else:
|
||||||
|
# Load config and weights from metafile. If `weights` is
|
||||||
|
# assigned, the weights defined in metafile will be ignored.
|
||||||
|
cfg, _weights = self._load_model_from_metafile(model)
|
||||||
|
if weights is None:
|
||||||
|
weights = _weights
|
||||||
|
elif isinstance(model, (Config, ConfigDict)):
|
||||||
|
cfg = copy.deepcopy(model)
|
||||||
|
elif isinstance(model, dict):
|
||||||
|
cfg = copy.deepcopy(ConfigDict(model))
|
||||||
|
elif model is None:
|
||||||
|
if weights is None:
|
||||||
|
raise ValueError(
|
||||||
|
'If model is None, the weights must be specified since '
|
||||||
|
'the config needs to be loaded from the weights')
|
||||||
|
cfg = ConfigDict()
|
||||||
|
else:
|
||||||
|
raise TypeError('model must be a filepath or any ConfigType'
|
||||||
|
f'object, but got {type(model)}')
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
device = get_device()
|
||||||
|
|
||||||
|
self.model = self._init_model(cfg, weights, device) # type: ignore
|
||||||
|
self.pipeline = self._init_pipeline(cfg)
|
||||||
|
self.collate_fn = self._init_collate(cfg)
|
||||||
|
self.visualizer = self._init_visualizer(cfg)
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: InputsType,
|
||||||
|
return_datasamples: bool = False,
|
||||||
|
batch_size: int = 1,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict:
|
||||||
|
"""Call the inferencer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (InputsType): Inputs for the inferencer.
|
||||||
|
return_datasamples (bool): Whether to return results as
|
||||||
|
:obj:`BaseDataElement`. Defaults to False.
|
||||||
|
batch_size (int): Batch size. Defaults to 1.
|
||||||
|
**kwargs: Key words arguments passed to :meth:`preprocess`,
|
||||||
|
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
|
||||||
|
Each key in kwargs should be in the corresponding set of
|
||||||
|
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
|
||||||
|
and ``postprocess_kwargs``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Inference and visualization results.
|
||||||
|
"""
|
||||||
|
(
|
||||||
|
preprocess_kwargs,
|
||||||
|
forward_kwargs,
|
||||||
|
visualize_kwargs,
|
||||||
|
postprocess_kwargs,
|
||||||
|
) = self._dispatch_kwargs(**kwargs)
|
||||||
|
|
||||||
|
ori_inputs = self._inputs_to_list(inputs)
|
||||||
|
inputs = self.preprocess(
|
||||||
|
ori_inputs, batch_size=batch_size, **preprocess_kwargs)
|
||||||
|
preds = []
|
||||||
|
for data in track(inputs, description='Inference'):
|
||||||
|
preds.extend(self.forward(data, **forward_kwargs))
|
||||||
|
visualization = self.visualize(
|
||||||
|
ori_inputs, preds,
|
||||||
|
**visualize_kwargs) # type: ignore # noqa: E501
|
||||||
|
results = self.postprocess(preds, visualization, return_datasamples,
|
||||||
|
**postprocess_kwargs)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _inputs_to_list(self, inputs: InputsType) -> list:
|
||||||
|
"""Preprocess the inputs to a list.
|
||||||
|
|
||||||
|
Preprocess inputs to a list according to its type:
|
||||||
|
|
||||||
|
- list or tuple: return inputs
|
||||||
|
- str:
|
||||||
|
- Directory path: return all files in the directory
|
||||||
|
- other cases: return a list containing the string. The string
|
||||||
|
could be a path to file, a url or other types of string according
|
||||||
|
to the task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (InputsType): Inputs for the inferencer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of input for the :meth:`preprocess`.
|
||||||
|
"""
|
||||||
|
if isinstance(inputs, str):
|
||||||
|
backend = get_file_backend(inputs)
|
||||||
|
if hasattr(backend, 'isdir') and isdir(inputs):
|
||||||
|
# Backends like HttpsBackend do not implement `isdir`, so only
|
||||||
|
# those backends that implement `isdir` could accept the inputs
|
||||||
|
# as a directory
|
||||||
|
filename_list = list_dir_or_file(inputs, list_dir=False)
|
||||||
|
inputs = [
|
||||||
|
join_path(inputs, filename) for filename in filename_list
|
||||||
|
]
|
||||||
|
|
||||||
|
if not isinstance(inputs, (list, tuple)):
|
||||||
|
inputs = [inputs]
|
||||||
|
|
||||||
|
return list(inputs)
|
||||||
|
|
||||||
|
def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs):
|
||||||
|
"""Process the inputs into a model-feedable format.
|
||||||
|
|
||||||
|
Customize your preprocess by overriding this method. Preprocess should
|
||||||
|
return an iterable object, of which each item will be used as the
|
||||||
|
input of ``model.test_step``.
|
||||||
|
|
||||||
|
``BaseInferencer.preprocess`` will return an iterable chunked data,
|
||||||
|
which will be used in __call__ like this:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def __call__(self, inputs, batch_size=1, **kwargs):
|
||||||
|
chunked_data = self.preprocess(inputs, batch_size, **kwargs)
|
||||||
|
for batch in chunked_data:
|
||||||
|
preds = self.forward(batch, **kwargs)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (InputsType): Inputs given by user.
|
||||||
|
batch_size (int): batch size. Defaults to 1.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Any: Data processed by the ``pipeline`` and ``collate_fn``.
|
||||||
|
"""
|
||||||
|
chunked_data = self._get_chunk_data(
|
||||||
|
map(self.pipeline, inputs), batch_size)
|
||||||
|
yield from map(self.collate_fn, chunked_data)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, inputs: Union[dict, tuple], **kwargs) -> Any:
|
||||||
|
"""Feed the inputs to the model."""
|
||||||
|
return self.model.test_step(inputs)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def visualize(self,
|
||||||
|
inputs: list,
|
||||||
|
preds: Any,
|
||||||
|
show: bool = False,
|
||||||
|
**kwargs) -> List[np.ndarray]:
|
||||||
|
"""Visualize predictions.
|
||||||
|
|
||||||
|
Customize your visualization by overriding this method. visualize
|
||||||
|
should return visualization results, which could be np.ndarray or any
|
||||||
|
other objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
|
||||||
|
preds (Any): Predictions of the model.
|
||||||
|
show (bool): Whether to display the image in a popup window.
|
||||||
|
Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[np.ndarray]: Visualization results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def postprocess(
|
||||||
|
self,
|
||||||
|
preds: Any,
|
||||||
|
visualization: List[np.ndarray],
|
||||||
|
return_datasample=False,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict:
|
||||||
|
"""Process the predictions and visualization results from ``forward``
|
||||||
|
and ``visualize``.
|
||||||
|
|
||||||
|
This method should be responsible for the following tasks:
|
||||||
|
|
||||||
|
1. Convert datasamples into a json-serializable dict if needed.
|
||||||
|
2. Pack the predictions and visualization results and return them.
|
||||||
|
3. Dump or log the predictions.
|
||||||
|
|
||||||
|
Customize your postprocess by overriding this method. Make sure
|
||||||
|
``postprocess`` will return a dict with visualization results and
|
||||||
|
inference results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds (List[Dict]): Predictions of the model.
|
||||||
|
visualization (np.ndarray): Visualized predictions.
|
||||||
|
return_datasample (bool): Whether to return results as datasamples.
|
||||||
|
Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Inference and visualization results with key ``predictions``
|
||||||
|
and ``visualization``
|
||||||
|
|
||||||
|
- ``visualization (Any)``: Returned by :meth:`visualize`
|
||||||
|
- ``predictions`` (dict or DataSample): Returned by
|
||||||
|
:meth:`forward` and processed in :meth:`postprocess`.
|
||||||
|
If ``return_datasample=False``, it usually should be a
|
||||||
|
json-serializable dict containing only basic data elements such
|
||||||
|
as strings and numbers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _load_model_from_metafile(self, model: str) -> Tuple[Config, str]:
|
||||||
|
"""Load config and weights from metafile.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): model name defined in metafile.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Config, str]: Loaded Config and weights path defined in
|
||||||
|
metafile.
|
||||||
|
"""
|
||||||
|
model = model.lower()
|
||||||
|
|
||||||
|
assert self.scope is not None, (
|
||||||
|
'scope should be initialized if you want '
|
||||||
|
'to load config from metafile.')
|
||||||
|
assert self.scope in MODULE2PACKAGE, (
|
||||||
|
f'{self.scope} not in {MODULE2PACKAGE}!,'
|
||||||
|
'please pass a valid scope.')
|
||||||
|
project = MODULE2PACKAGE[self.scope]
|
||||||
|
assert is_installed(project), f'Please install {project}'
|
||||||
|
package_path = get_installed_path(project)
|
||||||
|
for model_cfg in BaseInferencer._get_models_from_package(package_path):
|
||||||
|
model_name = model_cfg['Name'].lower()
|
||||||
|
model_aliases = model_cfg.get('Alias', [])
|
||||||
|
if isinstance(model_aliases, str):
|
||||||
|
model_aliases = [model_aliases.lower()]
|
||||||
|
else:
|
||||||
|
model_aliases = [alias.lower() for alias in model_aliases]
|
||||||
|
if (model_name == model or model in model_aliases):
|
||||||
|
cfg = Config.fromfile(
|
||||||
|
osp.join(package_path, '.mim', model_cfg['Config']))
|
||||||
|
weights = model_cfg['Weights']
|
||||||
|
weights = weights[0] if isinstance(weights, list) else weights
|
||||||
|
return cfg, weights
|
||||||
|
raise ValueError(f'Cannot find model: {model} in {project}')
|
||||||
|
|
||||||
|
def _init_model(
|
||||||
|
self,
|
||||||
|
cfg: ConfigType,
|
||||||
|
weights: Optional[str],
|
||||||
|
device: str = 'cpu',
|
||||||
|
) -> nn.Module:
|
||||||
|
"""Initialize the model with the given config and checkpoint on the
|
||||||
|
specific device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (ConfigType): Config containing the model information.
|
||||||
|
weights (str, optional): Path to the checkpoint.
|
||||||
|
device (str, optional): Device to run inference. Defaults to 'cpu'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: Model loaded with checkpoint.
|
||||||
|
"""
|
||||||
|
checkpoint: Optional[dict] = None
|
||||||
|
if weights is not None:
|
||||||
|
checkpoint = _load_checkpoint(weights, map_location='cpu')
|
||||||
|
|
||||||
|
if not cfg:
|
||||||
|
assert checkpoint is not None
|
||||||
|
try:
|
||||||
|
# Prefer to get config from `message_hub` since `message_hub`
|
||||||
|
# is a more stable module to store all runtime information.
|
||||||
|
# However, the early version of MMEngine will not save config
|
||||||
|
# in `message_hub`, so we will try to load config from `meta`.
|
||||||
|
cfg_string = checkpoint['message_hub']['runtime_info']['cfg']
|
||||||
|
except KeyError:
|
||||||
|
assert 'meta' in checkpoint, (
|
||||||
|
'If model(config) is not provided, the checkpoint must'
|
||||||
|
'contain the config string in `meta` or `message_hub`, '
|
||||||
|
'but both `meta` and `message_hub` are not found in the '
|
||||||
|
'checkpoint.')
|
||||||
|
meta = checkpoint['meta']
|
||||||
|
if 'cfg' in meta:
|
||||||
|
cfg_string = meta['cfg']
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Cannot find the config in the checkpoint.')
|
||||||
|
cfg.update(
|
||||||
|
Config.fromstring(cfg_string, file_format='.py')._cfg_dict)
|
||||||
|
|
||||||
|
# Delete the `pretrained` field to prevent model from loading the
|
||||||
|
# the pretrained weights unnecessarily.
|
||||||
|
if cfg.model.get('pretrained') is not None:
|
||||||
|
del cfg.model.pretrained
|
||||||
|
|
||||||
|
model = MODELS.build(cfg.model)
|
||||||
|
model.cfg = cfg
|
||||||
|
self._load_weights_to_model(model, checkpoint, cfg)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _load_weights_to_model(self, model: nn.Module,
|
||||||
|
checkpoint: Optional[dict],
|
||||||
|
cfg: Optional[ConfigType]) -> None:
|
||||||
|
"""Loading model weights and meta information from cfg and checkpoint.
|
||||||
|
|
||||||
|
Subclasses could override this method to load extra meta information
|
||||||
|
from ``checkpoint`` and ``cfg`` to model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): Model to load weights and meta information.
|
||||||
|
checkpoint (dict, optional): The loaded checkpoint.
|
||||||
|
cfg (Config or ConfigDict, optional): The loaded config.
|
||||||
|
"""
|
||||||
|
if checkpoint is not None:
|
||||||
|
_load_checkpoint_to_model(model, checkpoint)
|
||||||
|
else:
|
||||||
|
warnings.warn('Checkpoint is not loaded, and the inference '
|
||||||
|
'result is calculated by the randomly initialized '
|
||||||
|
'model!')
|
||||||
|
|
||||||
|
def _init_collate(self, cfg: ConfigType) -> Callable:
|
||||||
|
"""Initialize the ``collate_fn`` with the given config.
|
||||||
|
|
||||||
|
The returned ``collate_fn`` will be used to collate the batch data.
|
||||||
|
If will be used in :meth:`preprocess` like this
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
def preprocess(self, inputs, batch_size, **kwargs):
|
||||||
|
...
|
||||||
|
dataloader = map(self.collate_fn, dataloader)
|
||||||
|
yield from dataloader
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (ConfigType): Config which could contained the `collate_fn`
|
||||||
|
information. If `collate_fn` is not defined in config, it will
|
||||||
|
be :func:`pseudo_collate`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: Collate function.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with COLLATE_FUNCTIONS.switch_scope_and_registry(
|
||||||
|
self.scope) as registry:
|
||||||
|
collate_fn = registry.get(cfg.test_dataloader.collate_fn)
|
||||||
|
except AttributeError:
|
||||||
|
collate_fn = pseudo_collate
|
||||||
|
return collate_fn # type: ignore
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _init_pipeline(self, cfg: ConfigType) -> Callable:
|
||||||
|
"""Initialize the test pipeline.
|
||||||
|
|
||||||
|
Return a pipeline to handle various input data, such as ``str``,
|
||||||
|
``np.ndarray``. It is an abstract method in BaseInferencer, and should
|
||||||
|
be implemented in subclasses.
|
||||||
|
|
||||||
|
The returned pipeline will be used to process a single data.
|
||||||
|
It will be used in :meth:`preprocess` like this:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
def preprocess(self, inputs, batch_size, **kwargs):
|
||||||
|
...
|
||||||
|
dataset = map(self.pipeline, dataset)
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _init_visualizer(self, cfg: ConfigType) -> Optional[Visualizer]:
|
||||||
|
"""Initialize visualizers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (ConfigType): Config containing the visualizer information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Visualizer or None: Visualizer initialized with config.
|
||||||
|
"""
|
||||||
|
if 'visualizer' not in cfg:
|
||||||
|
return None
|
||||||
|
timestamp = str(datetime.timestamp(datetime.now()))
|
||||||
|
name = cfg.visualizer.get('name', timestamp)
|
||||||
|
if Visualizer.check_instance_created(name):
|
||||||
|
name = f'{name}-{timestamp}'
|
||||||
|
cfg.visualizer.name = name
|
||||||
|
return VISUALIZERS.build(cfg.visualizer)
|
||||||
|
|
||||||
|
def _get_chunk_data(self, inputs: Iterable, chunk_size: int):
|
||||||
|
"""Get batch data from dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (Iterable): An iterable dataset.
|
||||||
|
chunk_size (int): Equivalent to batch size.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
list: batch data.
|
||||||
|
"""
|
||||||
|
inputs_iter = iter(inputs)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
chunk_data = []
|
||||||
|
for _ in range(chunk_size):
|
||||||
|
processed_data = next(inputs_iter)
|
||||||
|
chunk_data.append(processed_data)
|
||||||
|
yield chunk_data
|
||||||
|
except StopIteration:
|
||||||
|
if chunk_data:
|
||||||
|
yield chunk_data
|
||||||
|
break
|
||||||
|
|
||||||
|
def _dispatch_kwargs(self, **kwargs) -> Tuple[Dict, Dict, Dict, Dict]:
|
||||||
|
"""Dispatch kwargs to preprocess(), forward(), visualize() and
|
||||||
|
postprocess() according to the actual demands.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess,
|
||||||
|
forward, visualize and postprocess respectively.
|
||||||
|
"""
|
||||||
|
# Ensure each argument only matches one function
|
||||||
|
method_kwargs = self.preprocess_kwargs | self.forward_kwargs | \
|
||||||
|
self.visualize_kwargs | self.postprocess_kwargs
|
||||||
|
|
||||||
|
union_kwargs = method_kwargs | set(kwargs.keys())
|
||||||
|
if union_kwargs != method_kwargs:
|
||||||
|
unknown_kwargs = union_kwargs - method_kwargs
|
||||||
|
raise ValueError(
|
||||||
|
f'unknown argument {unknown_kwargs} for `preprocess`, '
|
||||||
|
'`forward`, `visualize` and `postprocess`')
|
||||||
|
|
||||||
|
preprocess_kwargs = {}
|
||||||
|
forward_kwargs = {}
|
||||||
|
visualize_kwargs = {}
|
||||||
|
postprocess_kwargs = {}
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key in self.preprocess_kwargs:
|
||||||
|
preprocess_kwargs[key] = value
|
||||||
|
elif key in self.forward_kwargs:
|
||||||
|
forward_kwargs[key] = value
|
||||||
|
elif key in self.visualize_kwargs:
|
||||||
|
visualize_kwargs[key] = value
|
||||||
|
else:
|
||||||
|
postprocess_kwargs[key] = value
|
||||||
|
|
||||||
|
return (
|
||||||
|
preprocess_kwargs,
|
||||||
|
forward_kwargs,
|
||||||
|
visualize_kwargs,
|
||||||
|
postprocess_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_models_from_package(package_path: str):
|
||||||
|
"""Load model config defined in metafile from package path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
package_path (str): Path to the package.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
dict: Model config defined in metafile.
|
||||||
|
"""
|
||||||
|
meta_indexes = load(osp.join(package_path, '.mim', 'model-index.yml'))
|
||||||
|
for meta_path in meta_indexes['Import']:
|
||||||
|
# meta_path example: mmcls/.mim/configs/conformer/metafile.yml
|
||||||
|
meta_path = osp.join(package_path, '.mim', meta_path)
|
||||||
|
metainfo = load(meta_path)
|
||||||
|
yield from metainfo['Models']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def list_models(scope: Optional[str] = None, patterns: str = r'.*'):
|
||||||
|
"""List models defined in metafile of corresponding packages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope (str, optional): The scope to which the model belongs.
|
||||||
|
Defaults to None.
|
||||||
|
patterns (str, optional): Regular expressions for the searched
|
||||||
|
models. Once matched with ``Alias`` or ``Name`` filed in
|
||||||
|
metafile, corresponding model will be added to the return list.
|
||||||
|
Defaults to '.*'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Model dict with model name and its alias.
|
||||||
|
"""
|
||||||
|
matched_models = []
|
||||||
|
if scope is None:
|
||||||
|
default_scope = DefaultScope.get_current_instance()
|
||||||
|
assert default_scope is not None, (
|
||||||
|
'scope should be initialized if you want '
|
||||||
|
'to load config from metafile.')
|
||||||
|
assert scope in MODULE2PACKAGE, (
|
||||||
|
f'{scope} not in {MODULE2PACKAGE}!, please make pass a valid '
|
||||||
|
'scope.')
|
||||||
|
project = MODULE2PACKAGE[scope]
|
||||||
|
assert is_installed(project), (f'Please install {project}')
|
||||||
|
package_path = get_installed_path(project)
|
||||||
|
|
||||||
|
for model_cfg in BaseInferencer._get_models_from_package(package_path):
|
||||||
|
model_name = [model_cfg['Name']]
|
||||||
|
model_name.extend(model_cfg.get('Alias', []))
|
||||||
|
for name in model_name:
|
||||||
|
if re.match(patterns, name) is not None:
|
||||||
|
matched_models.append(name)
|
||||||
|
output_str = ''
|
||||||
|
for name in matched_models:
|
||||||
|
output_str += f'model_name: {name}\n'
|
||||||
|
print_log(output_str, logger='current')
|
||||||
|
return matched_models
|
@ -195,7 +195,7 @@ class Registry:
|
|||||||
return self._get_root_registry()
|
return self._get_root_registry()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def switch_scope_and_registry(self, scope: str) -> Generator:
|
def switch_scope_and_registry(self, scope: Optional[str]) -> Generator:
|
||||||
"""Temporarily switch default scope to the target scope, and get the
|
"""Temporarily switch default scope to the target scope, and get the
|
||||||
corresponding registry.
|
corresponding registry.
|
||||||
|
|
||||||
@ -203,7 +203,7 @@ class Registry:
|
|||||||
registry, otherwise yield the current itself.
|
registry, otherwise yield the current itself.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scope (str): The target scope.
|
scope (str, optional): The target scope.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from mmengine.registry import Registry, DefaultScope, MODELS
|
>>> from mmengine.registry import Registry, DefaultScope, MODELS
|
||||||
|
@ -31,7 +31,7 @@ class ToyModel(BaseModel):
|
|||||||
self.linear1 = nn.Linear(2, 2)
|
self.linear1 = nn.Linear(2, 2)
|
||||||
self.linear2 = nn.Linear(2, 1)
|
self.linear2 = nn.Linear(2, 1)
|
||||||
|
|
||||||
def forward(self, inputs, data_samples, mode='tensor'):
|
def forward(self, inputs, data_samples=None, mode='tensor'):
|
||||||
if isinstance(inputs, list):
|
if isinstance(inputs, list):
|
||||||
inputs = torch.stack(inputs)
|
inputs = torch.stack(inputs)
|
||||||
if isinstance(data_samples, list):
|
if isinstance(data_samples, list):
|
||||||
|
@ -3,5 +3,6 @@ matplotlib
|
|||||||
numpy
|
numpy
|
||||||
pyyaml
|
pyyaml
|
||||||
regex;sys_platform=='win32'
|
regex;sys_platform=='win32'
|
||||||
|
rich
|
||||||
termcolor
|
termcolor
|
||||||
yapf
|
yapf
|
||||||
|
221
tests/test_infer/test_infer.py
Normal file
221
tests/test_infer/test_infer.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmengine.infer import BaseInferencer
|
||||||
|
from mmengine.registry import VISUALIZERS, DefaultScope
|
||||||
|
from mmengine.testing import RunnerTestCase
|
||||||
|
from mmengine.utils import is_installed, is_list_of
|
||||||
|
from mmengine.visualization import Visualizer
|
||||||
|
|
||||||
|
|
||||||
|
class ToyInferencer(BaseInferencer):
|
||||||
|
preprocess_kwargs = {'pre_arg'}
|
||||||
|
forward_kwargs = {'for_arg'}
|
||||||
|
visualize_kwargs = {'vis_arg'}
|
||||||
|
postprocess_kwargs = {'pos_arg'}
|
||||||
|
|
||||||
|
def preprocess(self, inputs, batch_size=1, pre_arg=None, **kwargs):
|
||||||
|
return super().preprocess(inputs, batch_size, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, inputs, for_arg=None, **kwargs):
|
||||||
|
return super().forward(inputs, **kwargs)
|
||||||
|
|
||||||
|
def visualize(self, inputs, preds, vis_arg=None, **kwargs):
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def postprocess(self,
|
||||||
|
preds,
|
||||||
|
imgs,
|
||||||
|
return_datasamples,
|
||||||
|
pos_arg=None,
|
||||||
|
**kwargs):
|
||||||
|
return imgs, preds
|
||||||
|
|
||||||
|
def _init_pipeline(self, cfg):
|
||||||
|
|
||||||
|
def pipeline(img):
|
||||||
|
if isinstance(img, str):
|
||||||
|
img = np.load(img, allow_pickle=True)
|
||||||
|
img = torch.from_numpy(img).float()
|
||||||
|
elif isinstance(img, np.ndarray):
|
||||||
|
img = torch.from_numpy(img).float()
|
||||||
|
else:
|
||||||
|
img = torch.tensor(img).float()
|
||||||
|
return img
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
class ToyVisualizer(Visualizer):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseInferencer(RunnerTestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
super().setUp()
|
||||||
|
runner = self.build_runner(copy.deepcopy(self.epoch_based_cfg))
|
||||||
|
runner.train()
|
||||||
|
self.cfg_path = osp.join(runner.work_dir, f'{runner.timestamp}.py')
|
||||||
|
self.ckpt_path = osp.join(runner.work_dir, 'epoch_1.pth')
|
||||||
|
VISUALIZERS.register_module(module=ToyVisualizer, name='ToyVisualizer')
|
||||||
|
|
||||||
|
def test_custom_inferencer(self):
|
||||||
|
# Inferencer should not define ***_kwargs with duplicate keys.
|
||||||
|
with self.assertRaisesRegex(AssertionError, 'Class define error'):
|
||||||
|
|
||||||
|
class CustomInferencer(BaseInferencer):
|
||||||
|
preprocess_kwargs = set('a')
|
||||||
|
forward_kwargs = set('a')
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
VISUALIZERS._module_dict.pop('ToyVisualizer')
|
||||||
|
return super().tearDown()
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
# Pass model as Config
|
||||||
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
ToyInferencer(cfg, self.ckpt_path)
|
||||||
|
# Pass model as ConfigDict
|
||||||
|
ToyInferencer(cfg._cfg_dict, self.ckpt_path)
|
||||||
|
# Pass model as normal dict
|
||||||
|
ToyInferencer(dict(cfg._cfg_dict), self.ckpt_path)
|
||||||
|
# Pass model as string point to path of config
|
||||||
|
ToyInferencer(self.cfg_path, self.ckpt_path)
|
||||||
|
|
||||||
|
cfg.model.pretrained = 'fake_path'
|
||||||
|
inferencer = ToyInferencer(cfg, self.ckpt_path)
|
||||||
|
self.assertNotIn('pretrained', inferencer.cfg.model)
|
||||||
|
|
||||||
|
# Pass invalid model
|
||||||
|
with self.assertRaisesRegex(TypeError, 'model must'):
|
||||||
|
ToyInferencer([self.epoch_based_cfg], self.ckpt_path)
|
||||||
|
|
||||||
|
# Pass model as model name defined in metafile
|
||||||
|
if is_installed('mmdet'):
|
||||||
|
from mmdet.utils import register_all_modules
|
||||||
|
|
||||||
|
register_all_modules()
|
||||||
|
ToyInferencer(
|
||||||
|
'faster-rcnn_s50_fpn_syncbn-backbone+head_ms-range-1x_coco',
|
||||||
|
'https://download.openmmlab.com/mmdetection/v2.0/resnest/faster_rcnn_s50_fpn_syncbn-backbone%2Bhead_mstrain-range_1x_coco/faster_rcnn_s50_fpn_syncbn-backbone%2Bhead_mstrain-range_1x_coco_20200926_125502-20289c16.pth', # noqa: E501
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint = self.ckpt_path
|
||||||
|
ToyInferencer(weights=checkpoint)
|
||||||
|
|
||||||
|
def test_call(self):
|
||||||
|
num_imgs = 12
|
||||||
|
imgs = []
|
||||||
|
img_paths = []
|
||||||
|
for i in range(num_imgs):
|
||||||
|
img = np.random.random((1, 2))
|
||||||
|
img_path = osp.join(self.temp_dir.name, f'{i}.npy')
|
||||||
|
img.dump(img_path)
|
||||||
|
imgs.append(img)
|
||||||
|
img_paths.append(img_path)
|
||||||
|
|
||||||
|
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
||||||
|
inferencer(imgs)
|
||||||
|
inferencer(img_paths)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not is_installed('mmdet'), reason='mmdet is not installed')
|
||||||
|
def test_load_model_from_meta(self):
|
||||||
|
from mmdet.utils import register_all_modules
|
||||||
|
|
||||||
|
register_all_modules()
|
||||||
|
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
||||||
|
inferencer._load_model_from_metafile('retinanet_r18_fpn_1x_coco')
|
||||||
|
with self.assertRaisesRegex(ValueError, 'Cannot find model'):
|
||||||
|
inferencer._load_model_from_metafile('fake_model')
|
||||||
|
# TODO: Test alias
|
||||||
|
|
||||||
|
def test_init_model(self):
|
||||||
|
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
||||||
|
model = inferencer._init_model(self.iter_based_cfg, self.ckpt_path)
|
||||||
|
self.assertFalse(model.training)
|
||||||
|
|
||||||
|
def test_get_chunk_data(self):
|
||||||
|
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
||||||
|
data = list(range(1, 11))
|
||||||
|
chunk_data = inferencer._get_chunk_data(data, 3)
|
||||||
|
self.assertEqual(
|
||||||
|
list(chunk_data), [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]])
|
||||||
|
|
||||||
|
def test_init_visualizer(self):
|
||||||
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
||||||
|
visualizer = inferencer._init_visualizer(cfg)
|
||||||
|
self.assertIsNone(visualizer, None)
|
||||||
|
cfg.visualizer = dict(type='ToyVisualizer')
|
||||||
|
visualizer = inferencer._init_visualizer(cfg)
|
||||||
|
self.assertIsInstance(visualizer, ToyVisualizer)
|
||||||
|
|
||||||
|
# Visualizer could be built with the same name repeatedly.
|
||||||
|
cfg.visualizer = dict(type='ToyVisualizer', name='toy')
|
||||||
|
visualizer = inferencer._init_visualizer(cfg)
|
||||||
|
visualizer = inferencer._init_visualizer(cfg)
|
||||||
|
|
||||||
|
def test_dispatch_kwargs(self):
|
||||||
|
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
||||||
|
kwargs = dict(
|
||||||
|
pre_arg=dict(a=1),
|
||||||
|
for_arg=dict(c=2),
|
||||||
|
vis_arg=dict(b=3),
|
||||||
|
pos_arg=dict(d=4))
|
||||||
|
pre_arg, for_arg, vis_arg, pos_arg = inferencer._dispatch_kwargs(
|
||||||
|
**kwargs)
|
||||||
|
self.assertEqual(pre_arg, dict(pre_arg=dict(a=1)))
|
||||||
|
self.assertEqual(for_arg, dict(for_arg=dict(c=2)))
|
||||||
|
self.assertEqual(vis_arg, dict(vis_arg=dict(b=3)))
|
||||||
|
self.assertEqual(pos_arg, dict(pos_arg=dict(d=4)))
|
||||||
|
# Test unknown arg.
|
||||||
|
kwargs = dict(return_datasample=dict())
|
||||||
|
with self.assertRaisesRegex(ValueError, 'unknown'):
|
||||||
|
inferencer._dispatch_kwargs(**kwargs)
|
||||||
|
|
||||||
|
def test_preprocess(self):
|
||||||
|
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
||||||
|
data = list(range(1, 11))
|
||||||
|
pre_data = inferencer.preprocess(data, batch_size=3)
|
||||||
|
target_data = [
|
||||||
|
[torch.tensor(1),
|
||||||
|
torch.tensor(2),
|
||||||
|
torch.tensor(3)],
|
||||||
|
[torch.tensor(4),
|
||||||
|
torch.tensor(5),
|
||||||
|
torch.tensor(6)],
|
||||||
|
[torch.tensor(7),
|
||||||
|
torch.tensor(8),
|
||||||
|
torch.tensor(9)],
|
||||||
|
[torch.tensor(10)],
|
||||||
|
]
|
||||||
|
self.assertEqual(list(pre_data), target_data)
|
||||||
|
os.mkdir(osp.join(self.temp_dir.name, 'imgs'))
|
||||||
|
for i in range(1, 11):
|
||||||
|
img = np.array(1)
|
||||||
|
img.dump(osp.join(self.temp_dir.name, 'imgs', f'{i}.npy'))
|
||||||
|
# Passing a directory of images.
|
||||||
|
inputs = inferencer._inputs_to_list(
|
||||||
|
osp.join(self.temp_dir.name, 'imgs'))
|
||||||
|
dataloader = inferencer.preprocess(inputs, batch_size=3)
|
||||||
|
for data in dataloader:
|
||||||
|
self.assertTrue(is_list_of(data, torch.Tensor))
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not is_installed('mmdet'), reason='mmdet is not installed')
|
||||||
|
def test_list_models(self):
|
||||||
|
model_list = BaseInferencer.list_models('mmdet')
|
||||||
|
self.assertTrue(len(model_list) > 0)
|
||||||
|
DefaultScope._instance_dict.clear()
|
||||||
|
with self.assertRaisesRegex(AssertionError, 'scope should be'):
|
||||||
|
BaseInferencer.list_models()
|
||||||
|
with self.assertRaisesRegex(AssertionError, 'unknown not in'):
|
||||||
|
BaseInferencer.list_models('unknown')
|
Loading…
x
Reference in New Issue
Block a user