mmclassification/mmpretrain/apis/base.py

389 lines
14 KiB
Python
Raw Normal View History

[Feature] Support multiple multi-modal algorithms and inferencers. (#1561) * [Feat] Migrate blip caption to mmpretrain. (#50) * Migrate blip caption to mmpretrain * minor fix * support train * [Feature] Support OFA caption task. (#51) * [Feature] Support OFA caption task. * Remove duplicated files. * [Feature] Support OFA vqa task. (#58) * [Feature] Support OFA vqa task. * Fix lint. * [Feat] Add BLIP retrieval to mmpretrain. (#55) * init * minor fix for train * fix according to comments * refactor * Update Blip retrieval. (#62) * [Feature] Support OFA visual grounding task. (#59) * [Feature] Support OFA visual grounding task. * minor add TODO --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feat] Add flamingos coco caption and vqa. (#60) * first init * init flamingo coco * add vqa * minor fix * remove unnecessary modules * Update config * Use `ApplyToList`. --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature]: BLIP2 coco retrieval (#53) * [Feature]: Add blip2 retriever * [Feature]: Add blip2 all modules * [Feature]: Refine model * [Feature]: x1 * [Feature]: Runnable coco ret * [Feature]: Runnable version * [Feature]: Fix lint * [Fix]: Fix lint * [Feature]: Use 364 img size * [Feature]: Refactor blip2 * [Fix]: Fix lint * refactor files * minor fix * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * Remove * fix blip caption inputs (#68) * [Feat] Add BLIP NLVR support. (#67) * first init * init flamingo coco * add vqa * add nlvr * refactor nlvr * minor fix * minor fix * Update dataset --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature]: BLIP2 Caption (#70) * [Feature]: Add language model * [Feature]: blip2 caption forward * [Feature]: Reproduce the results * [Feature]: Refactor caption * refine config --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feat] Migrate BLIP VQA to mmpretrain (#69) * reformat * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * refactor code --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * Update RefCOCO dataset * [Fix] fix lint * [Feature] Implement inference APIs for multi-modal tasks. (#65) * [Feature] Implement inference APIs for multi-modal tasks. * [Project] Add gradio demo. * [Improve] Update requirements * Update flamingo * Update blip * Add NLVR inferencer * Update flamingo * Update hugging face model register * Update ofa vqa * Update BLIP-vqa (#71) * Update blip-vqa docstring (#72) * Refine flamingo docstring (#73) * [Feature]: BLIP2 VQA (#61) * [Feature]: VQA forward * [Feature]: Reproduce accuracy * [Fix]: Fix lint * [Fix]: Add blank line * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feature]: BLIP2 docstring (#74) * [Feature]: Add caption docstring * [Feature]: Add docstring to blip2 vqa * [Feature]: Add docstring to retrieval * Update BLIP-2 metafile and README (#75) * [Feature]: Add readme and docstring * Update blip2 results --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature] BLIP Visual Grounding on MMPretrain Branch (#66) * blip grounding merge with mmpretrain * remove commit * blip grounding test and inference api * refcoco dataset * refcoco dataset refine config * rebasing * gitignore * rebasing * minor edit * minor edit * Update blip-vqa docstring (#72) * rebasing * Revert "minor edit" This reverts commit 639cec757c215e654625ed0979319e60f0be9044. * blip grounding final * precommit * refine config * refine config * Update blip visual grounding --------- Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com> Co-authored-by: mzr1996 <mzr1996@163.com> * Update visual grounding metric * Update OFA docstring, README and metafiles. (#76) * [Docs] Update installation docs and gradio demo docs. (#77) * Update OFA name * Update Visual Grounding Visualizer * Integrate accelerate support * Fix imports. * Fix timm backbone * Update imports * Update README * Update circle ci * Update flamingo config * Add gradio demo README * [Feature]: Add scienceqa (#1571) * [Feature]: Add scienceqa * [Feature]: Change param name * Update docs * Update video --------- Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com> Co-authored-by: yingfhu <yingfhu@gmail.com> Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com> Co-authored-by: Rongjie Li <limo97@163.com>
2023-05-19 16:50:04 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from typing import Callable, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from mmengine.config import Config
from mmengine.dataset import default_collate
from mmengine.fileio import get_file_backend
from mmengine.model import BaseModel
from mmengine.runner import load_checkpoint
from mmpretrain.structures import DataSample
from mmpretrain.utils import track
from .model import get_model, list_models
ModelType = Union[BaseModel, str, Config]
InputType = Union[str, np.ndarray, list]
class BaseInferencer:
"""Base inferencer for various tasks.
The BaseInferencer provides the standard workflow for inference as follows:
1. Preprocess the input data by :meth:`preprocess`.
2. Forward the data to the model by :meth:`forward`. ``BaseInferencer``
assumes the model inherits from :class:`mmengine.models.BaseModel` and
will call `model.test_step` in :meth:`forward` by default.
3. Visualize the results by :meth:`visualize`.
4. Postprocess and return the results by :meth:`postprocess`.
When we call the subclasses inherited from BaseInferencer (not overriding
``__call__``), the workflow will be executed in order.
All subclasses of BaseInferencer could define the following class
attributes for customization:
- ``preprocess_kwargs``: The keys of the kwargs that will be passed to
:meth:`preprocess`.
- ``forward_kwargs``: The keys of the kwargs that will be passed to
:meth:`forward`
- ``visualize_kwargs``: The keys of the kwargs that will be passed to
:meth:`visualize`
- ``postprocess_kwargs``: The keys of the kwargs that will be passed to
:meth:`postprocess`
All attributes mentioned above should be a ``set`` of keys (strings),
and each key should not be duplicated. Actually, :meth:`__call__` will
dispatch all the arguments to the corresponding methods according to the
``xxx_kwargs`` mentioned above.
Subclasses inherited from ``BaseInferencer`` should implement
:meth:`_init_pipeline`, :meth:`visualize` and :meth:`postprocess`:
- _init_pipeline: Return a callable object to preprocess the input data.
- visualize: Visualize the results returned by :meth:`forward`.
- postprocess: Postprocess the results returned by :meth:`forward` and
:meth:`visualize`.
Args:
model (BaseModel | str | Config): A model name or a path to the config
file, or a :obj:`BaseModel` object. The model name can be found
by ``cls.list_models()`` and you can also query it in
:doc:`/modelzoo_statistics`.
pretrained (str, optional): Path to the checkpoint. If None, it will
try to find a pre-defined weight from the model you specified
(only work if the ``model`` is a model name). Defaults to None.
device (str | torch.device | None): Transfer the model to the target
device. Defaults to None.
device_map (str | dict | None): A map that specifies where each
submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every
submodule of it will be sent to the same device. You can use
`device_map="auto"` to automatically generate the device map.
Defaults to None.
offload_folder (str | None): If the `device_map` contains any value
`"disk"`, the folder where we will offload weights.
**kwargs: Other keyword arguments to initialize the model (only work if
the ``model`` is a model name).
"""
preprocess_kwargs: set = set()
forward_kwargs: set = set()
visualize_kwargs: set = set()
postprocess_kwargs: set = set()
def __init__(self,
model: ModelType,
pretrained: Union[bool, str] = True,
device: Union[str, torch.device, None] = None,
device_map=None,
offload_folder=None,
**kwargs) -> None:
if isinstance(model, BaseModel):
if isinstance(pretrained, str):
load_checkpoint(model, pretrained, map_location='cpu')
if device_map is not None:
from .utils import dispatch_model
model = dispatch_model(
model,
device_map=device_map,
offload_folder=offload_folder)
elif device is not None:
model.to(device)
else:
model = get_model(
model,
pretrained,
device=device,
device_map=device_map,
offload_folder=offload_folder,
**kwargs)
model.eval()
self.config = model._config
self.model = model
self.pipeline = self._init_pipeline(self.config)
self.visualizer = None
def __call__(
self,
inputs,
return_datasamples: bool = False,
batch_size: int = 1,
**kwargs,
) -> dict:
"""Call the inferencer.
Args:
inputs (InputsType): Inputs for the inferencer.
return_datasamples (bool): Whether to return results as
:obj:`BaseDataElement`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
**kwargs: Key words arguments passed to :meth:`preprocess`,
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
Each key in kwargs should be in the corresponding set of
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
and ``postprocess_kwargs``.
Returns:
dict: Inference and visualization results.
"""
(
preprocess_kwargs,
forward_kwargs,
visualize_kwargs,
postprocess_kwargs,
) = self._dispatch_kwargs(**kwargs)
ori_inputs = self._inputs_to_list(inputs)
inputs = self.preprocess(
ori_inputs, batch_size=batch_size, **preprocess_kwargs)
preds = []
for data in track(inputs, 'Inference'):
preds.extend(self.forward(data, **forward_kwargs))
visualization = self.visualize(ori_inputs, preds, **visualize_kwargs)
results = self.postprocess(preds, visualization, return_datasamples,
**postprocess_kwargs)
return results
def _inputs_to_list(self, inputs: InputType) -> list:
"""Preprocess the inputs to a list.
Cast the input data to a list of data.
- list or tuple: return inputs
- str:
- Directory path: return all files in the directory
- other cases: return a list containing the string. The string
could be a path to file, a url or other types of string according
to the task.
- other: return a list with one item.
Args:
inputs (str | array | list): Inputs for the inferencer.
Returns:
list: List of input for the :meth:`preprocess`.
"""
if isinstance(inputs, str):
backend = get_file_backend(inputs)
if hasattr(backend, 'isdir') and backend.isdir(inputs):
# Backends like HttpsBackend do not implement `isdir`, so only
# those backends that implement `isdir` could accept the inputs
# as a directory
file_list = backend.list_dir_or_file(inputs, list_dir=False)
inputs = [
backend.join_path(inputs, file) for file in file_list
]
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
return list(inputs)
def preprocess(self, inputs: InputType, batch_size: int = 1, **kwargs):
"""Process the inputs into a model-feedable format.
Customize your preprocess by overriding this method. Preprocess should
return an iterable object, of which each item will be used as the
input of ``model.test_step``.
``BaseInferencer.preprocess`` will return an iterable chunked data,
which will be used in __call__ like this:
.. code-block:: python
def __call__(self, inputs, batch_size=1, **kwargs):
chunked_data = self.preprocess(inputs, batch_size, **kwargs)
for batch in chunked_data:
preds = self.forward(batch, **kwargs)
Args:
inputs (InputsType): Inputs given by user.
batch_size (int): batch size. Defaults to 1.
Yields:
Any: Data processed by the ``pipeline`` and ``default_collate``.
"""
chunked_data = self._get_chunk_data(
map(self.pipeline, inputs), batch_size)
yield from map(default_collate, chunked_data)
@torch.no_grad()
def forward(self, inputs: Union[dict, tuple], **kwargs):
"""Feed the inputs to the model."""
return self.model.test_step(inputs)
def visualize(self,
inputs: list,
preds: List[DataSample],
show: bool = False,
**kwargs) -> List[np.ndarray]:
"""Visualize predictions.
Customize your visualization by overriding this method. visualize
should return visualization results, which could be np.ndarray or any
other objects.
Args:
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
preds (Any): Predictions of the model.
show (bool): Whether to display the image in a popup window.
Defaults to False.
Returns:
List[np.ndarray]: Visualization results.
"""
if show:
raise NotImplementedError(
f'The `visualize` method of {self.__class__.__name__} '
'is not implemented.')
@abstractmethod
def postprocess(
self,
preds: List[DataSample],
visualization: List[np.ndarray],
return_datasample=False,
**kwargs,
) -> dict:
"""Process the predictions and visualization results from ``forward``
and ``visualize``.
This method should be responsible for the following tasks:
1. Convert datasamples into a json-serializable dict if needed.
2. Pack the predictions and visualization results and return them.
3. Dump or log the predictions.
Customize your postprocess by overriding this method. Make sure
``postprocess`` will return a dict with visualization results and
inference results.
Args:
preds (List[Dict]): Predictions of the model.
visualization (np.ndarray): Visualized predictions.
return_datasample (bool): Whether to return results as datasamples.
Defaults to False.
Returns:
dict: Inference and visualization results with key ``predictions``
and ``visualization``
- ``visualization (Any)``: Returned by :meth:`visualize`
- ``predictions`` (dict or DataSample): Returned by
:meth:`forward` and processed in :meth:`postprocess`.
If ``return_datasample=False``, it usually should be a
json-serializable dict containing only basic data elements such
as strings and numbers.
"""
@abstractmethod
def _init_pipeline(self, cfg: Config) -> Callable:
"""Initialize the test pipeline.
Return a pipeline to handle various input data, such as ``str``,
``np.ndarray``. It is an abstract method in BaseInferencer, and should
be implemented in subclasses.
The returned pipeline will be used to process a single data.
It will be used in :meth:`preprocess` like this:
.. code-block:: python
def preprocess(self, inputs, batch_size, **kwargs):
...
dataset = map(self.pipeline, dataset)
...
"""
def _get_chunk_data(self, inputs: Iterable, chunk_size: int):
"""Get batch data from dataset.
Args:
inputs (Iterable): An iterable dataset.
chunk_size (int): Equivalent to batch size.
Yields:
list: batch data.
"""
inputs_iter = iter(inputs)
while True:
try:
chunk_data = []
for _ in range(chunk_size):
processed_data = next(inputs_iter)
chunk_data.append(processed_data)
yield chunk_data
except StopIteration:
if chunk_data:
yield chunk_data
break
def _dispatch_kwargs(self, **kwargs) -> Tuple[dict, dict, dict, dict]:
"""Dispatch kwargs to preprocess(), forward(), visualize() and
postprocess() according to the actual demands.
Returns:
Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess,
forward, visualize and postprocess respectively.
"""
# Ensure each argument only matches one function
method_kwargs = self.preprocess_kwargs | self.forward_kwargs | \
self.visualize_kwargs | self.postprocess_kwargs
union_kwargs = method_kwargs | set(kwargs.keys())
if union_kwargs != method_kwargs:
unknown_kwargs = union_kwargs - method_kwargs
raise ValueError(
f'unknown argument {unknown_kwargs} for `preprocess`, '
'`forward`, `visualize` and `postprocess`')
preprocess_kwargs = {}
forward_kwargs = {}
visualize_kwargs = {}
postprocess_kwargs = {}
for key, value in kwargs.items():
if key in self.preprocess_kwargs:
preprocess_kwargs[key] = value
if key in self.forward_kwargs:
forward_kwargs[key] = value
if key in self.visualize_kwargs:
visualize_kwargs[key] = value
if key in self.postprocess_kwargs:
postprocess_kwargs[key] = value
return (
preprocess_kwargs,
forward_kwargs,
visualize_kwargs,
postprocess_kwargs,
)
@staticmethod
def list_models(pattern: Optional[str] = None):
"""List models defined in metafile of corresponding packages.
Args:
pattern (str | None): A wildcard pattern to match model names.
Returns:
List[str]: a list of model names.
"""
return list_models(pattern=pattern)