diff --git a/demo/image_demo_with_inferencer.py b/demo/image_demo_with_inferencer.py new file mode 100644 index 000000000..ce40f2224 --- /dev/null +++ b/demo/image_demo_with_inferencer.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser + +from mmseg.apis import MMSegInferencer + + +def main(): + parser = ArgumentParser() + parser.add_argument('img', help='Image file') + parser.add_argument('model', help='Config file') + parser.add_argument('--checkpoint', default=None, help='Checkpoint file') + parser.add_argument( + '--out-dir', default='', help='Path to save result file') + parser.add_argument( + '--show', + action='store_true', + default=False, + help='Whether to display the drawn image.') + parser.add_argument( + '--save-mask', + action='store_true', + default=False, + help='Enable save the mask file') + parser.add_argument( + '--dataset-name', + default='cityscapes', + help='Color palette used for segmentation map') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + parser.add_argument( + '--opacity', + type=float, + default=0.5, + help='Opacity of painted segmentation map. In (0, 1] range.') + args = parser.parse_args() + + # build the model from a config file and a checkpoint file + mmseg_inferencer = MMSegInferencer( + args.model, + args.checkpoint, + dataset_name=args.dataset_name, + device=args.device) + + # test a single image + mmseg_inferencer( + args.img, + show=args.show, + out_dir=args.out_dir, + save_mask=args.save_mask, + opacity=args.opacity) + + +if __name__ == '__main__': + main() diff --git a/mmseg/__init__.py b/mmseg/__init__.py index 1a7627af5..9f171ccb0 100644 --- a/mmseg/__init__.py +++ b/mmseg/__init__.py @@ -9,7 +9,7 @@ from .version import __version__, version_info MMCV_MIN = '2.0.0rc4' MMCV_MAX = '2.1.0' -MMENGINE_MIN = '0.4.0' +MMENGINE_MIN = '0.5.0' MMENGINE_MAX = '1.0.0' diff --git a/mmseg/apis/__init__.py b/mmseg/apis/__init__.py index 9933b99b3..d22dc3f0a 100644 --- a/mmseg/apis/__init__.py +++ b/mmseg/apis/__init__.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .inference import inference_model, init_model, show_result_pyplot +from .mmseg_inferencer import MMSegInferencer -__all__ = ['init_model', 'inference_model', 'show_result_pyplot'] +__all__ = [ + 'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer' +] diff --git a/mmseg/apis/mmseg_inferencer.py b/mmseg/apis/mmseg_inferencer.py new file mode 100644 index 000000000..deb57b9b8 --- /dev/null +++ b/mmseg/apis/mmseg_inferencer.py @@ -0,0 +1,279 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List, Optional, Sequence, Union + +import mmcv +import mmengine +import numpy as np +from mmcv.transforms import Compose +from mmengine.infer.infer import BaseInferencer, ModelType + +from mmseg.structures import SegDataSample +from mmseg.utils import ConfigType, SampleList, register_all_modules +from mmseg.visualization import SegLocalVisualizer + +InputType = Union[str, np.ndarray] +InputsType = Union[InputType, Sequence[InputType]] +PredType = Union[SegDataSample, SampleList] + + +class MMSegInferencer(BaseInferencer): + """Semantic segmentation inferencer, provides inference and visualization + interfaces. Note: MMEngine >= 0.5.0 is required. + + Args: + model (str, optional): Path to the config file or the model name + defined in metafile. For example, it could be + "fcn_r50-d8_4xb2-40k_cityscapes-512x1024" or + "configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py" + weights (str, optional): Path to the checkpoint. If it is not specified + and model is a model name of metafile, the weights will be loaded + from metafile. Defaults to None. + palette (List[List[int]], optional): The palette of + segmentation map. + classes (Tuple[str], optional): Category information. + dataset_name (str, optional): Name of the datasets supported in mmseg. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + scope (str, optional): The scope of the model. Defaults to None. + """ + + preprocess_kwargs: set = set() + forward_kwargs: set = {'mode', 'out_dir'} + visualize_kwargs: set = { + 'show', 'wait_time', 'draw_pred', 'img_out_dir', 'opacity' + } + postprocess_kwargs: set = { + 'pred_out_dir', 'return_datasample', 'save_mask', 'mask_dir' + } + + def __init__(self, + model: Union[ModelType, str], + weights: Optional[str] = None, + palette: Optional[Union[str, List]] = None, + classes: Optional[Union[str, List]] = None, + dataset_name: Optional[str] = None, + device: Optional[str] = None, + scope: Optional[str] = 'mmseg') -> None: + # A global counter tracking the number of images processes, for + # naming of the output images + self.num_visualized_imgs = 0 + register_all_modules() + super().__init__( + model=model, weights=weights, device=device, scope=scope) + + assert isinstance(self.visualizer, SegLocalVisualizer) + self.visualizer.set_dataset_meta(palette, classes, dataset_name) + + def __call__(self, + inputs: InputsType, + return_datasamples: bool = False, + batch_size: int = 1, + show: bool = False, + wait_time: int = 0, + draw_pred: bool = True, + out_dir: str = '', + save_mask: bool = False, + mask_dir: str = 'mask', + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (Union[str, np.ndarray]): Inputs for the inferencer. + return_datasamples (bool): Whether to return results as + :obj:`SegDataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + show (bool): Whether to display the image in a popup window. + Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + draw_pred (bool): Whether to draw Prediction SegDataSample. + Defaults to True. + out_dir (str): Output directory of inference results. Defaults: ''. + save_mask (bool): Whether save pred mask as a file. + mask_dir (str): Sub directory of `pred_out_dir`, used to save pred + mask file. + + Returns: + dict: Inference and visualization results. + """ + return super().__call__( + inputs=inputs, + return_datasamples=return_datasamples, + batch_size=batch_size, + show=show, + wait_time=wait_time, + draw_pred=draw_pred, + img_out_dir=out_dir, + pred_out_dir=out_dir, + save_mask=save_mask, + mask_dir=mask_dir, + **kwargs) + + def visualize(self, + inputs: list, + preds: List[dict], + show: bool = False, + wait_time: int = 0, + draw_pred: bool = True, + img_out_dir: str = '', + opacity: float = 0.8) -> List[np.ndarray]: + """Visualize predictions. + + Args: + inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. + preds (Any): Predictions of the model. + show (bool): Whether to display the image in a popup window. + Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + draw_pred (bool): Whether to draw Prediction SegDataSample. + Defaults to True. + img_out_dir (str): Output directory of drawn images. Defaults: '' + opacity (int, float): The transparency of segmentation mask. + Defaults to 0.8. + + Returns: + List[np.ndarray]: Visualization results. + """ + if self.visualizer is None or (not show and img_out_dir == ''): + return None + + if getattr(self, 'visualizer') is None: + raise ValueError('Visualization needs the "visualizer" term' + 'defined in the config, but got None') + + self.visualizer.alpha = opacity + + results = [] + + for single_input, pred in zip(inputs, preds): + if isinstance(single_input, str): + img_bytes = mmengine.fileio.get(single_input) + img = mmcv.imfrombytes(img_bytes) + img = img[:, :, ::-1] + img_name = osp.basename(single_input) + elif isinstance(single_input, np.ndarray): + img = single_input.copy() + img_num = str(self.num_visualized_imgs).zfill(8) + img_name = f'{img_num}.jpg' + else: + raise ValueError('Unsupported input type:' + f'{type(single_input)}') + + out_file = osp.join(img_out_dir, img_name) if img_out_dir != ''\ + else None + + self.visualizer.add_datasample( + img_name, + img, + pred, + show=show, + wait_time=wait_time, + draw_gt=False, + draw_pred=draw_pred, + out_file=out_file) + results.append(self.visualizer.get_image()) + self.num_visualized_imgs += 1 + + return results + + def postprocess(self, + preds: PredType, + visualization: List[np.ndarray], + return_datasample: bool = False, + mask_dir: str = 'mask', + save_mask: bool = True, + pred_out_dir: str = '') -> dict: + """Process the predictions and visualization results from ``forward`` + and ``visualize``. + + This method should be responsible for the following tasks: + + 1. Convert datasamples into a json-serializable dict if needed. + 2. Pack the predictions and visualization results and return them. + 3. Dump or log the predictions. + + Args: + preds (List[Dict]): Predictions of the model. + visualization (np.ndarray): Visualized predictions. + return_datasample (bool): Whether to return results as datasamples. + Defaults to False. + pred_out_dir: File to save the inference results w/o + visualization. If left as empty, no file will be saved. + Defaults to ''. + mask_dir (str): Sub directory of `pred_out_dir`, used to save pred + mask file. + save_mask (bool): Whether save pred mask as a file. + + Returns: + dict: Inference and visualization results with key ``predictions`` + and ``visualization`` + + - ``visualization (Any)``: Returned by :meth:`visualize` + - ``predictions`` (dict or DataSample): Returned by + :meth:`forward` and processed in :meth:`postprocess`. + If ``return_datasample=False``, it usually should be a + json-serializable dict containing only basic data elements such + as strings and numbers. + """ + results_dict = {} + + results_dict['predictions'] = preds + results_dict['visualization'] = visualization + + if pred_out_dir != '': + mmengine.mkdir_or_exist(pred_out_dir) + if save_mask: + preds = [preds] if isinstance(preds, SegDataSample) else preds + for pred in preds: + mmcv.imwrite( + pred.pred_sem_seg.numpy().data[0], + osp.join(pred_out_dir, mask_dir, + osp.basename(pred.metainfo['img_path']))) + else: + mmengine.dump(results_dict, + osp.join(pred_out_dir, 'results.pkl')) + + if return_datasample: + return preds + + return results_dict + + def _init_pipeline(self, cfg: ConfigType) -> Compose: + """Initialize the test pipeline. + + Return a pipeline to handle various input data, such as ``str``, + ``np.ndarray``. It is an abstract method in BaseInferencer, and should + be implemented in subclasses. + + The returned pipeline will be used to process a single data. + It will be used in :meth:`preprocess` like this: + + .. code-block:: python + def preprocess(self, inputs, batch_size, **kwargs): + ... + dataset = map(self.pipeline, dataset) + ... + """ + pipeline_cfg = cfg.test_dataloader.dataset.pipeline + # Loading annotations is also not applicable + idx = self._get_transform_idx(pipeline_cfg, 'LoadAnnotations') + if idx != -1: + del pipeline_cfg[idx] + load_img_idx = self._get_transform_idx(pipeline_cfg, + 'LoadImageFromFile') + + if load_img_idx == -1: + raise ValueError( + 'LoadImageFromFile is not found in the test pipeline') + pipeline_cfg[load_img_idx]['type'] = 'InferencerLoader' + return Compose(pipeline_cfg) + + def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int: + """Returns the index of the transform in a pipeline. + + If the transform is not found, returns -1. + """ + for i, transform in enumerate(pipeline_cfg): + if transform['type'] == name: + return i + return -1 diff --git a/mmseg/datasets/transforms/loading.py b/mmseg/datasets/transforms/loading.py index 492f2063e..d2e93b1ab 100644 --- a/mmseg/datasets/transforms/loading.py +++ b/mmseg/datasets/transforms/loading.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings -from typing import Dict, Optional +from typing import Dict, Optional, Union import mmcv import mmengine.fileio as fileio @@ -437,3 +437,59 @@ class LoadBiomedicalData(BaseTransform): f'to_xyz={self.to_xyz}, ' f'backend_args={self.backend_args})') return repr_str + + +@TRANSFORMS.register_module() +class InferencerLoader(BaseTransform): + """Load an image from ``results['img']``. + + Similar with :obj:`LoadImageFromFile`, but the image has been loaded as + :obj:`np.ndarray` in ``results['img']``. Can be used when loading image + from webcam. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_path + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def __init__(self, **kwargs) -> None: + super().__init__() + self.from_file = TRANSFORMS.build( + dict(type='LoadImageFromFile', **kwargs)) + self.from_ndarray = TRANSFORMS.build( + dict(type='LoadImageFromNDArray', **kwargs)) + + def transform(self, single_input: Union[str, np.ndarray, dict]) -> dict: + """Transform function to add image meta information. + + Args: + results (dict): Result dict with Webcam read image in + ``results['img']``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + if isinstance(single_input, str): + inputs = dict(img_path=single_input) + elif isinstance(single_input, np.ndarray): + inputs = dict(img=single_input) + elif isinstance(single_input, dict): + inputs = single_input + else: + raise NotImplementedError + + if 'img' in inputs: + return self.from_ndarray(inputs) + return self.from_file(inputs) diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py index 27443f2c5..f4db83594 100644 --- a/mmseg/visualization/local_visualizer.py +++ b/mmseg/visualization/local_visualizer.py @@ -63,16 +63,7 @@ class SegLocalVisualizer(Visualizer): **kwargs): super().__init__(name, image, vis_backends, save_dir, **kwargs) self.alpha: float = alpha - # Set default value. When calling - # `SegLocalVisualizer().dataset_meta=xxx`, - # it will override the default value. - if dataset_name is None: - dataset_name = 'cityscapes' - classes = classes if classes else get_classes(dataset_name) - palette = palette if palette else get_palette(dataset_name) - assert len(classes) == len( - palette), 'The length of classes should be equal to palette' - self.dataset_meta: dict = {'classes': classes, 'palette': palette} + self.set_dataset_meta(palette, classes, dataset_name) def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData, classes: Optional[Tuple[str]], @@ -109,6 +100,21 @@ class SegLocalVisualizer(Visualizer): return self.get_image() + def set_dataset_meta(self, + palette: Optional[Union[str, List]] = None, + classes: Optional[Union[str, List]] = None, + dataset_name: Optional[str] = None) -> None: + # Set default value. When calling + # `SegLocalVisualizer().dataset_meta=xxx`, + # it will override the default value. + if dataset_name is None: + dataset_name = 'cityscapes' + classes = classes if classes else get_classes(dataset_name) + palette = palette if palette else get_palette(dataset_name) + assert len(classes) == len( + palette), 'The length of classes should be equal to palette' + self.dataset_meta: dict = {'classes': classes, 'palette': palette} + @master_only def add_datasample( self, @@ -186,6 +192,6 @@ class SegLocalVisualizer(Visualizer): self.show(drawn_img, win_name=name, wait_time=wait_time) if out_file is not None: - mmcv.imwrite(drawn_img, out_file) + mmcv.imwrite(mmcv.bgr2rgb(drawn_img), out_file) else: self.add_image(name, drawn_img, step) diff --git a/requirements/mminstall.txt b/requirements/mminstall.txt index 8dd00e5cc..707ba66ff 100644 --- a/requirements/mminstall.txt +++ b/requirements/mminstall.txt @@ -1,4 +1,4 @@ mmcls>=1.0.0rc0 mmcv>=2.0.0rc4 -e git+https://github.com/open-mmlab/mmdetection.git@dev-3.x#egg=mmdet -mmengine>=0.4.0,<1.0.0 +mmengine>=0.5.0,<1.0.0 diff --git a/tests/test_apis/test_inferencer.py b/tests/test_apis/test_inferencer.py new file mode 100644 index 000000000..44eb17157 --- /dev/null +++ b/tests/test_apis/test_inferencer.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import tempfile + +import numpy as np +import torch +import torch.nn as nn +from mmengine import ConfigDict +from torch.utils.data import DataLoader, Dataset + +from mmseg.apis import MMSegInferencer +from mmseg.models import EncoderDecoder +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.registry import MODELS +from mmseg.utils import register_all_modules + + +@MODELS.register_module(name='InferExampleHead') +class ExampleDecodeHead(BaseDecodeHead): + + def __init__(self, num_classes=19, out_channels=None): + super().__init__( + 3, 3, num_classes=num_classes, out_channels=out_channels) + + def forward(self, inputs): + return self.cls_seg(inputs[0]) + + +@MODELS.register_module(name='InferExampleBackbone') +class ExampleBackbone(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 3) + + def init_weights(self, pretrained=None): + pass + + def forward(self, x): + return [self.conv(x)] + + +@MODELS.register_module(name='InferExampleModel') +class ExampleModel(EncoderDecoder): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +class ExampleDataset(Dataset): + + def __init__(self) -> None: + super().__init__() + self.pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') + ] + + def __getitem__(self, idx): + return dict(img=torch.tensor([1]), img_metas=dict()) + + def __len__(self): + return 1 + + +def test_inferencer(): + register_all_modules() + test_dataset = ExampleDataset() + data_loader = DataLoader( + test_dataset, + batch_size=1, + sampler=None, + num_workers=0, + shuffle=False, + ) + + visualizer = dict( + type='SegLocalVisualizer', + vis_backends=[dict(type='LocalVisBackend')], + name='visualizer') + + cfg_dict = dict( + model=dict( + type='InferExampleModel', + data_preprocessor=dict(type='SegDataPreProcessor'), + backbone=dict(type='InferExampleBackbone'), + decode_head=dict(type='InferExampleHead'), + test_cfg=dict(mode='whole')), + visualizer=visualizer, + test_dataloader=data_loader) + cfg = ConfigDict(cfg_dict) + model = MODELS.build(cfg.model) + + ckpt = model.state_dict() + ckpt_filename = tempfile.mktemp() + torch.save(ckpt, ckpt_filename) + + # test initialization + infer = MMSegInferencer(cfg, ckpt_filename) + + # test forward + img = np.random.randint(0, 256, (4, 4, 3)) + infer(img) + + imgs = [img, img] + infer(imgs) + results = infer(imgs, out_dir=tempfile.gettempdir(), draw_pred=True) + + # test results + assert 'predictions' in results + assert 'visualization' in results + assert len(results['predictions']) == 2 + assert results['predictions'][0].seg_logits.data.shape == torch.Size( + (19, 4, 4)) + assert results['predictions'][0].pred_sem_seg.shape == torch.Size((4, 4))