[Feature] Support MMSegInferencer (#2413)

## Motivation

Support `MMSegInferencer` for providing an easy and clean interface for
single or multiple images inferencing.

Ref: https://github.com/open-mmlab/mmengine/pull/773
https://github.com/open-mmlab/mmocr/pull/1608

## Modification

- mmseg/apis/mmseg_inferencer.py
- mmseg/visualization/local_visualizer.py
- demo/image_demo_with_inferencer.py

## Use cases (Optional)

Based on https://github.com/open-mmlab/mmengine/tree/inference

Add a new image inference demo with `MMSegInferencer`

- demo/image_demo_with_inferencer.py

```shell
python demo/image_demo_with_inferencer.py demo/demo.png fcn_r50-d8_4xb2-40k_cityscapes-512x1024
```

---------

Co-authored-by: MeowZheng <meowzheng@outlook.com>
This commit is contained in:
谢昕辰 2023-02-23 21:16:19 +08:00 committed by GitHub
parent 039ba5d4ca
commit 53fe1ccf39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 528 additions and 15 deletions

View File

@ -0,0 +1,54 @@
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser
from mmseg.apis import MMSegInferencer
def main():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('model', help='Config file')
parser.add_argument('--checkpoint', default=None, help='Checkpoint file')
parser.add_argument(
'--out-dir', default='', help='Path to save result file')
parser.add_argument(
'--show',
action='store_true',
default=False,
help='Whether to display the drawn image.')
parser.add_argument(
'--save-mask',
action='store_true',
default=False,
help='Enable save the mask file')
parser.add_argument(
'--dataset-name',
default='cityscapes',
help='Color palette used for segmentation map')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--opacity',
type=float,
default=0.5,
help='Opacity of painted segmentation map. In (0, 1] range.')
args = parser.parse_args()
# build the model from a config file and a checkpoint file
mmseg_inferencer = MMSegInferencer(
args.model,
args.checkpoint,
dataset_name=args.dataset_name,
device=args.device)
# test a single image
mmseg_inferencer(
args.img,
show=args.show,
out_dir=args.out_dir,
save_mask=args.save_mask,
opacity=args.opacity)
if __name__ == '__main__':
main()

View File

@ -9,7 +9,7 @@ from .version import __version__, version_info
MMCV_MIN = '2.0.0rc4'
MMCV_MAX = '2.1.0'
MMENGINE_MIN = '0.4.0'
MMENGINE_MIN = '0.5.0'
MMENGINE_MAX = '1.0.0'

View File

@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_model, init_model, show_result_pyplot
from .mmseg_inferencer import MMSegInferencer
__all__ = ['init_model', 'inference_model', 'show_result_pyplot']
__all__ = [
'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer'
]

View File

@ -0,0 +1,279 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List, Optional, Sequence, Union
import mmcv
import mmengine
import numpy as np
from mmcv.transforms import Compose
from mmengine.infer.infer import BaseInferencer, ModelType
from mmseg.structures import SegDataSample
from mmseg.utils import ConfigType, SampleList, register_all_modules
from mmseg.visualization import SegLocalVisualizer
InputType = Union[str, np.ndarray]
InputsType = Union[InputType, Sequence[InputType]]
PredType = Union[SegDataSample, SampleList]
class MMSegInferencer(BaseInferencer):
"""Semantic segmentation inferencer, provides inference and visualization
interfaces. Note: MMEngine >= 0.5.0 is required.
Args:
model (str, optional): Path to the config file or the model name
defined in metafile. For example, it could be
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024" or
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py"
weights (str, optional): Path to the checkpoint. If it is not specified
and model is a model name of metafile, the weights will be loaded
from metafile. Defaults to None.
palette (List[List[int]], optional): The palette of
segmentation map.
classes (Tuple[str], optional): Category information.
dataset_name (str, optional): Name of the datasets supported in mmseg.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
scope (str, optional): The scope of the model. Defaults to None.
"""
preprocess_kwargs: set = set()
forward_kwargs: set = {'mode', 'out_dir'}
visualize_kwargs: set = {
'show', 'wait_time', 'draw_pred', 'img_out_dir', 'opacity'
}
postprocess_kwargs: set = {
'pred_out_dir', 'return_datasample', 'save_mask', 'mask_dir'
}
def __init__(self,
model: Union[ModelType, str],
weights: Optional[str] = None,
palette: Optional[Union[str, List]] = None,
classes: Optional[Union[str, List]] = None,
dataset_name: Optional[str] = None,
device: Optional[str] = None,
scope: Optional[str] = 'mmseg') -> None:
# A global counter tracking the number of images processes, for
# naming of the output images
self.num_visualized_imgs = 0
register_all_modules()
super().__init__(
model=model, weights=weights, device=device, scope=scope)
assert isinstance(self.visualizer, SegLocalVisualizer)
self.visualizer.set_dataset_meta(palette, classes, dataset_name)
def __call__(self,
inputs: InputsType,
return_datasamples: bool = False,
batch_size: int = 1,
show: bool = False,
wait_time: int = 0,
draw_pred: bool = True,
out_dir: str = '',
save_mask: bool = False,
mask_dir: str = 'mask',
**kwargs) -> dict:
"""Call the inferencer.
Args:
inputs (Union[str, np.ndarray]): Inputs for the inferencer.
return_datasamples (bool): Whether to return results as
:obj:`SegDataSample`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw Prediction SegDataSample.
Defaults to True.
out_dir (str): Output directory of inference results. Defaults: ''.
save_mask (bool): Whether save pred mask as a file.
mask_dir (str): Sub directory of `pred_out_dir`, used to save pred
mask file.
Returns:
dict: Inference and visualization results.
"""
return super().__call__(
inputs=inputs,
return_datasamples=return_datasamples,
batch_size=batch_size,
show=show,
wait_time=wait_time,
draw_pred=draw_pred,
img_out_dir=out_dir,
pred_out_dir=out_dir,
save_mask=save_mask,
mask_dir=mask_dir,
**kwargs)
def visualize(self,
inputs: list,
preds: List[dict],
show: bool = False,
wait_time: int = 0,
draw_pred: bool = True,
img_out_dir: str = '',
opacity: float = 0.8) -> List[np.ndarray]:
"""Visualize predictions.
Args:
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
preds (Any): Predictions of the model.
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw Prediction SegDataSample.
Defaults to True.
img_out_dir (str): Output directory of drawn images. Defaults: ''
opacity (int, float): The transparency of segmentation mask.
Defaults to 0.8.
Returns:
List[np.ndarray]: Visualization results.
"""
if self.visualizer is None or (not show and img_out_dir == ''):
return None
if getattr(self, 'visualizer') is None:
raise ValueError('Visualization needs the "visualizer" term'
'defined in the config, but got None')
self.visualizer.alpha = opacity
results = []
for single_input, pred in zip(inputs, preds):
if isinstance(single_input, str):
img_bytes = mmengine.fileio.get(single_input)
img = mmcv.imfrombytes(img_bytes)
img = img[:, :, ::-1]
img_name = osp.basename(single_input)
elif isinstance(single_input, np.ndarray):
img = single_input.copy()
img_num = str(self.num_visualized_imgs).zfill(8)
img_name = f'{img_num}.jpg'
else:
raise ValueError('Unsupported input type:'
f'{type(single_input)}')
out_file = osp.join(img_out_dir, img_name) if img_out_dir != ''\
else None
self.visualizer.add_datasample(
img_name,
img,
pred,
show=show,
wait_time=wait_time,
draw_gt=False,
draw_pred=draw_pred,
out_file=out_file)
results.append(self.visualizer.get_image())
self.num_visualized_imgs += 1
return results
def postprocess(self,
preds: PredType,
visualization: List[np.ndarray],
return_datasample: bool = False,
mask_dir: str = 'mask',
save_mask: bool = True,
pred_out_dir: str = '') -> dict:
"""Process the predictions and visualization results from ``forward``
and ``visualize``.
This method should be responsible for the following tasks:
1. Convert datasamples into a json-serializable dict if needed.
2. Pack the predictions and visualization results and return them.
3. Dump or log the predictions.
Args:
preds (List[Dict]): Predictions of the model.
visualization (np.ndarray): Visualized predictions.
return_datasample (bool): Whether to return results as datasamples.
Defaults to False.
pred_out_dir: File to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.
mask_dir (str): Sub directory of `pred_out_dir`, used to save pred
mask file.
save_mask (bool): Whether save pred mask as a file.
Returns:
dict: Inference and visualization results with key ``predictions``
and ``visualization``
- ``visualization (Any)``: Returned by :meth:`visualize`
- ``predictions`` (dict or DataSample): Returned by
:meth:`forward` and processed in :meth:`postprocess`.
If ``return_datasample=False``, it usually should be a
json-serializable dict containing only basic data elements such
as strings and numbers.
"""
results_dict = {}
results_dict['predictions'] = preds
results_dict['visualization'] = visualization
if pred_out_dir != '':
mmengine.mkdir_or_exist(pred_out_dir)
if save_mask:
preds = [preds] if isinstance(preds, SegDataSample) else preds
for pred in preds:
mmcv.imwrite(
pred.pred_sem_seg.numpy().data[0],
osp.join(pred_out_dir, mask_dir,
osp.basename(pred.metainfo['img_path'])))
else:
mmengine.dump(results_dict,
osp.join(pred_out_dir, 'results.pkl'))
if return_datasample:
return preds
return results_dict
def _init_pipeline(self, cfg: ConfigType) -> Compose:
"""Initialize the test pipeline.
Return a pipeline to handle various input data, such as ``str``,
``np.ndarray``. It is an abstract method in BaseInferencer, and should
be implemented in subclasses.
The returned pipeline will be used to process a single data.
It will be used in :meth:`preprocess` like this:
.. code-block:: python
def preprocess(self, inputs, batch_size, **kwargs):
...
dataset = map(self.pipeline, dataset)
...
"""
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
# Loading annotations is also not applicable
idx = self._get_transform_idx(pipeline_cfg, 'LoadAnnotations')
if idx != -1:
del pipeline_cfg[idx]
load_img_idx = self._get_transform_idx(pipeline_cfg,
'LoadImageFromFile')
if load_img_idx == -1:
raise ValueError(
'LoadImageFromFile is not found in the test pipeline')
pipeline_cfg[load_img_idx]['type'] = 'InferencerLoader'
return Compose(pipeline_cfg)
def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
"""Returns the index of the transform in a pipeline.
If the transform is not found, returns -1.
"""
for i, transform in enumerate(pipeline_cfg):
if transform['type'] == name:
return i
return -1

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, Optional
from typing import Dict, Optional, Union
import mmcv
import mmengine.fileio as fileio
@ -437,3 +437,59 @@ class LoadBiomedicalData(BaseTransform):
f'to_xyz={self.to_xyz}, '
f'backend_args={self.backend_args})')
return repr_str
@TRANSFORMS.register_module()
class InferencerLoader(BaseTransform):
"""Load an image from ``results['img']``.
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
from webcam.
Required Keys:
- img
Modified Keys:
- img
- img_path
- img_shape
- ori_shape
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
"""
def __init__(self, **kwargs) -> None:
super().__init__()
self.from_file = TRANSFORMS.build(
dict(type='LoadImageFromFile', **kwargs))
self.from_ndarray = TRANSFORMS.build(
dict(type='LoadImageFromNDArray', **kwargs))
def transform(self, single_input: Union[str, np.ndarray, dict]) -> dict:
"""Transform function to add image meta information.
Args:
results (dict): Result dict with Webcam read image in
``results['img']``.
Returns:
dict: The dict contains loaded image and meta information.
"""
if isinstance(single_input, str):
inputs = dict(img_path=single_input)
elif isinstance(single_input, np.ndarray):
inputs = dict(img=single_input)
elif isinstance(single_input, dict):
inputs = single_input
else:
raise NotImplementedError
if 'img' in inputs:
return self.from_ndarray(inputs)
return self.from_file(inputs)

View File

@ -63,16 +63,7 @@ class SegLocalVisualizer(Visualizer):
**kwargs):
super().__init__(name, image, vis_backends, save_dir, **kwargs)
self.alpha: float = alpha
# Set default value. When calling
# `SegLocalVisualizer().dataset_meta=xxx`,
# it will override the default value.
if dataset_name is None:
dataset_name = 'cityscapes'
classes = classes if classes else get_classes(dataset_name)
palette = palette if palette else get_palette(dataset_name)
assert len(classes) == len(
palette), 'The length of classes should be equal to palette'
self.dataset_meta: dict = {'classes': classes, 'palette': palette}
self.set_dataset_meta(palette, classes, dataset_name)
def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
classes: Optional[Tuple[str]],
@ -109,6 +100,21 @@ class SegLocalVisualizer(Visualizer):
return self.get_image()
def set_dataset_meta(self,
palette: Optional[Union[str, List]] = None,
classes: Optional[Union[str, List]] = None,
dataset_name: Optional[str] = None) -> None:
# Set default value. When calling
# `SegLocalVisualizer().dataset_meta=xxx`,
# it will override the default value.
if dataset_name is None:
dataset_name = 'cityscapes'
classes = classes if classes else get_classes(dataset_name)
palette = palette if palette else get_palette(dataset_name)
assert len(classes) == len(
palette), 'The length of classes should be equal to palette'
self.dataset_meta: dict = {'classes': classes, 'palette': palette}
@master_only
def add_datasample(
self,
@ -186,6 +192,6 @@ class SegLocalVisualizer(Visualizer):
self.show(drawn_img, win_name=name, wait_time=wait_time)
if out_file is not None:
mmcv.imwrite(drawn_img, out_file)
mmcv.imwrite(mmcv.bgr2rgb(drawn_img), out_file)
else:
self.add_image(name, drawn_img, step)

View File

@ -1,4 +1,4 @@
mmcls>=1.0.0rc0
mmcv>=2.0.0rc4
-e git+https://github.com/open-mmlab/mmdetection.git@dev-3.x#egg=mmdet
mmengine>=0.4.0,<1.0.0
mmengine>=0.5.0,<1.0.0

View File

@ -0,0 +1,115 @@
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
import numpy as np
import torch
import torch.nn as nn
from mmengine import ConfigDict
from torch.utils.data import DataLoader, Dataset
from mmseg.apis import MMSegInferencer
from mmseg.models import EncoderDecoder
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.registry import MODELS
from mmseg.utils import register_all_modules
@MODELS.register_module(name='InferExampleHead')
class ExampleDecodeHead(BaseDecodeHead):
def __init__(self, num_classes=19, out_channels=None):
super().__init__(
3, 3, num_classes=num_classes, out_channels=out_channels)
def forward(self, inputs):
return self.cls_seg(inputs[0])
@MODELS.register_module(name='InferExampleBackbone')
class ExampleBackbone(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 3)
def init_weights(self, pretrained=None):
pass
def forward(self, x):
return [self.conv(x)]
@MODELS.register_module(name='InferExampleModel')
class ExampleModel(EncoderDecoder):
def __init__(self, **kwargs):
super().__init__(**kwargs)
class ExampleDataset(Dataset):
def __init__(self) -> None:
super().__init__()
self.pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
def __getitem__(self, idx):
return dict(img=torch.tensor([1]), img_metas=dict())
def __len__(self):
return 1
def test_inferencer():
register_all_modules()
test_dataset = ExampleDataset()
data_loader = DataLoader(
test_dataset,
batch_size=1,
sampler=None,
num_workers=0,
shuffle=False,
)
visualizer = dict(
type='SegLocalVisualizer',
vis_backends=[dict(type='LocalVisBackend')],
name='visualizer')
cfg_dict = dict(
model=dict(
type='InferExampleModel',
data_preprocessor=dict(type='SegDataPreProcessor'),
backbone=dict(type='InferExampleBackbone'),
decode_head=dict(type='InferExampleHead'),
test_cfg=dict(mode='whole')),
visualizer=visualizer,
test_dataloader=data_loader)
cfg = ConfigDict(cfg_dict)
model = MODELS.build(cfg.model)
ckpt = model.state_dict()
ckpt_filename = tempfile.mktemp()
torch.save(ckpt, ckpt_filename)
# test initialization
infer = MMSegInferencer(cfg, ckpt_filename)
# test forward
img = np.random.randint(0, 256, (4, 4, 3))
infer(img)
imgs = [img, img]
infer(imgs)
results = infer(imgs, out_dir=tempfile.gettempdir(), draw_pred=True)
# test results
assert 'predictions' in results
assert 'visualization' in results
assert len(results['predictions']) == 2
assert results['predictions'][0].seg_logits.data.shape == torch.Size(
(19, 4, 4))
assert results['predictions'][0].pred_sem_seg.shape == torch.Size((4, 4))