[Enhancement] Modify interface of MMSeginferencer and add docs (#2658)

## Motivation

Make MMSeginferencer easier to be used

## Modification

1. Add `_load_weights_to_model` to MMSeginferencer, it is for get
`dataset_meta` from ckpt
2. Modify and remove some parameters of `__call__`, `visualization` and
`postprocess`
3. Add function of save seg mask, remove dump pkl.
4. Refine docstring of MMSeginferencer and SegLocalVisualizer
5. Add the user documentation of MMSeginferencer

## BC-breaking (Optional)

yes, remove some parameters, we need to discuss whether keep them with
deprecated waring or just remove them as the MMSeginferencer just merged
in mmseg a few days.

Co-authored-by: xiexinch <xiexinch@outlook.com>
This commit is contained in:
Miao Zheng 2023-03-03 14:37:54 +08:00 committed by GitHub
parent 8c1d299cb6
commit 310ec4afc7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 328 additions and 104 deletions

View File

@ -16,11 +16,6 @@ def main():
action='store_true', action='store_true',
default=False, default=False,
help='Whether to display the drawn image.') help='Whether to display the drawn image.')
parser.add_argument(
'--save-mask',
action='store_true',
default=False,
help='Enable save the mask file')
parser.add_argument( parser.add_argument(
'--dataset-name', '--dataset-name',
default='cityscapes', default='cityscapes',
@ -43,11 +38,7 @@ def main():
# test a single image # test a single image
mmseg_inferencer( mmseg_inferencer(
args.img, args.img, show=args.show, out_dir=args.out_dir, opacity=args.opacity)
show=args.show,
out_dir=args.out_dir,
save_mask=args.save_mask,
opacity=args.opacity)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -4,13 +4,132 @@ MMSegmentation provides pre-trained models for semantic segmentation in [Model Z
This note will show how to use existing models to inference on given images. This note will show how to use existing models to inference on given images.
As for how to test existing models on standard datasets, please see this [guide](./4_train_test.md) As for how to test existing models on standard datasets, please see this [guide](./4_train_test.md)
## Inference API
MMSegmentation provides several interfaces for users to easily use pre-trained models for inference. MMSegmentation provides several interfaces for users to easily use pre-trained models for inference.
- [mmseg.apis.init_model](#mmsegapisinit_model) - [Tutorial 3: Inference with existing models](#tutorial-3-inference-with-existing-models)
- [mmseg.apis.inference_model](#mmsegapisinference_model) - [Inferencer](#inferencer)
- [mmseg.apis.show_result_pyplot](#mmsegapisshow_result_pyplot) - [Basic Usage](#basic-usage)
- [Initialization](#initialization)
- [Visualize prediction](#visualize-prediction)
- [List model](#list-model)
- [Inference API](#inference-api)
- [mmseg.apis.init_model](#mmsegapisinit_model)
- [mmseg.apis.inference_model](#mmsegapisinference_model)
- [mmseg.apis.show_result_pyplot](#mmsegapisshow_result_pyplot)
## Inferencer
We provides the most **convenient** way to use the model in MMSegmentation `MMSegInferencer`. You can get segmentation mask for an image with only 3 lines of code.
### Basic Usage
The following example shows how to use `MMSegInferencer` to perform inference on a single image.
```
>>> from mmseg.apis import MMSegInferencer
>>> # Load models into memory
>>> inferencer = MMSegInferencer(model='deeplabv3plus_r18-d8_4xb2-80k_cityscapes-512x1024')
>>> # Inference
>>> inferencer('demo/demo.png', show=True)
```
The visualization result should look like:
<div align="center">
https://user-images.githubusercontent.com/76149310/221507927-ae01e3a7-016f-4425-b966-7b19cbbe494e.png
</div>
Moreover, you can use `MMSegInferencer` to process a list of images:
```
# Input a list of images
>>> images = [image1, image2, ...] # image1 can be a file path or a np.ndarray
>>> inferencer(images, show=True, wait_time=0.5) # wait_time is delay time, and 0 means forever.
# Or input image directory
>>> images = $IMAGESDIR
>>> inferencer(images, show=True, wait_time=0.5)
# Save visualized rendering color maps and predicted results
# out_dir is the directory to save the output results, img_out_dir and pred_out_dir are subdirectories of out_dir
# to save visualized rendering color maps and predicted results
>>> inferencer(images, out_dir='outputs', img_out_dir='vis', pred_out_dir='pred')
```
There is a optional parameter of inferencer, `return_datasamples`, whose default value is False, and
return value of inferencer is a `dict` type by default, including 2 keys 'visualization' and 'predictions'.
If `return_datasamples=True` inferencer will return [`SegDataSample`](../advanced_guides/structures.md), or list of it.
```
result = inferencer('demo/demo.png')
# result is a `dict` including 2 keys 'visualization' and 'predictions'.
# 'visualization' includes color segmentation map
print(result['visualization'].shape)
# (512, 683, 3)
# 'predictions' includes segmentation mask with label indice
print(result['predictions'].shape)
# (512, 683)
result = inferencer('demo/demo.png', return_datasamples=True)
print(type(result))
# <class 'mmseg.structures.seg_data_sample.SegDataSample'>
# Input a list of images
results = inferencer(images)
# The output is list
print(type(results['visualization']), results['visualization'][0].shape)
# <class 'list'> (512, 683, 3)
print(type(results['predictions']), results['predictions'][0].shape)
# <class 'list'> (512, 683)
results = inferencer(images, return_datasamples=True)
# <class 'list'>
print(type(results[0]))
# <class 'mmseg.structures.seg_data_sample.SegDataSample'>
```
### Initialization
`MMSegInferencer` must be initialized from a `model`, which can be a model name or a `Config` even a path of config file.
The model names can be found in models' metafile, like one model name of maskformer is `maskformer_r50-d32_8xb2-160k_ade20k-512x512`, and if input model name and the weights of the model will be download automatically. Below are other input parameters:
- weights (str, optional) - Path to the checkpoint. If it is not specified and model is a model name of metafile, the weights will be loaded
from metafile. Defaults to None.
- classes (list, optional) - Input classes for result rendering, as the prediction of segmentation
model is a segment map with label indices, `classes` is a list which includes
items responding to the label indices. If classes is not defined, visualizer will take `cityscapes` classes by default. Defaults to None.
- palette (list, optional) - Input palette for result rendering, which is a list of color palette
responding to the classes. If palette is not defined, visualizer will take `cityscapes` palette by default. Defaults to None.
- dataset_name (str, optional)[Dataset name or alias](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/class_names.py#L302-L317)
visulizer will use the meta information of the dataset i.e. classes and palette,
but the `classes` and `palette` have higher priority. Defaults to None.
- device (str, optional) - Device to run inference. If None, the available device will be automatically used. Defaults to None.
- scope (str, optional) - The scope of the model. Defaults to 'mmseg'.
### Visualize prediction
`MMSegInferencer` supports 4 parameters for visualize prediction, you can use them when call initialized inferencer:
- show (bool) - Whether to display the image in a popup window. Defaults to False.
- wait_time (float) - The interval of show (s). Defaults to 0.
- img_out_dir (str) - Subdirectory of `out_dir`, used to save rendering color segmentation mask, so `out_dir` must be defined
if you would like to save predicted mask. Defaults to 'vis'.
- opacity (int, float) - The transparency of segmentation mask. Defaults to 0.8.
The examples of these parameters is in [Basic Usage](#basic-usage)
### List model
There is a very easy to list all model names in MMSegmentation
```
>>> from mmseg.apis import MMSegInferencer
# models is a list of model names, and them will print automatically
>>> models = MMSegInferencer.list_models('mmseg')
```
## Inference API
### mmseg.apis.init_model ### mmseg.apis.init_model

View File

@ -1,15 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp import os.path as osp
import warnings
from typing import List, Optional, Sequence, Union from typing import List, Optional, Sequence, Union
import mmcv import mmcv
import mmengine import mmengine
import numpy as np import numpy as np
import torch
import torch.nn as nn
from mmcv.transforms import Compose from mmcv.transforms import Compose
from mmengine.infer.infer import BaseInferencer, ModelType from mmengine.infer.infer import BaseInferencer, ModelType
from mmengine.model import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.runner.checkpoint import _load_checkpoint_to_model
from PIL import Image
from mmseg.structures import SegDataSample from mmseg.structures import SegDataSample
from mmseg.utils import ConfigType, SampleList, register_all_modules from mmseg.utils import ConfigType, SampleList, get_classes, get_palette
from mmseg.visualization import SegLocalVisualizer from mmseg.visualization import SegLocalVisualizer
InputType = Union[str, np.ndarray] InputType = Union[str, np.ndarray]
@ -23,90 +30,164 @@ class MMSegInferencer(BaseInferencer):
Args: Args:
model (str, optional): Path to the config file or the model name model (str, optional): Path to the config file or the model name
defined in metafile. For example, it could be defined in metafile. Take the `mmseg metafile <https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/fcn/fcn.yml>`_
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024" or as an example the `model` could be
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py" "fcn_r50-d8_4xb2-40k_cityscapes-512x1024", and the weights of model
will be download automatically. If use config file, like
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py", the
`weights` should be defined.
weights (str, optional): Path to the checkpoint. If it is not specified weights (str, optional): Path to the checkpoint. If it is not specified
and model is a model name of metafile, the weights will be loaded and model is a model name of metafile, the weights will be loaded
from metafile. Defaults to None. from metafile. Defaults to None.
palette (List[List[int]], optional): The palette of classes (list, optional): Input classes for result rendering, as the
segmentation map. prediction of segmentation model is a segment map with label
classes (Tuple[str], optional): Category information. indices, `classes` is a list which includes items responding to the
dataset_name (str, optional): Name of the datasets supported in mmseg. label indices. If classes is not defined, visualizer will take
`cityscapes` classes by default. Defaults to None.
palette (list, optional): Input palette for result rendering, which is
a list of color palette responding to the classes. If palette is
not defined, visualizer will take `cityscapes` palette by default.
Defaults to None.
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/class_names.py#L302-L317>`_
visulizer will use the meta information of the dataset i.e. classes
and palette, but the `classes` and `palette` have higher priority.
Defaults to None.
device (str, optional): Device to run inference. If None, the available device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None. device will be automatically used. Defaults to None.
scope (str, optional): The scope of the model. Defaults to None. scope (str, optional): The scope of the model. Defaults to 'mmseg'.
""" """ # noqa
preprocess_kwargs: set = set() preprocess_kwargs: set = set()
forward_kwargs: set = {'mode', 'out_dir'} forward_kwargs: set = {'mode', 'out_dir'}
visualize_kwargs: set = { visualize_kwargs: set = {'show', 'wait_time', 'img_out_dir', 'opacity'}
'show', 'wait_time', 'draw_pred', 'img_out_dir', 'opacity' postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
}
postprocess_kwargs: set = {
'pred_out_dir', 'return_datasample', 'save_mask', 'mask_dir'
}
def __init__(self, def __init__(self,
model: Union[ModelType, str], model: Union[ModelType, str],
weights: Optional[str] = None, weights: Optional[str] = None,
palette: Optional[Union[str, List]] = None,
classes: Optional[Union[str, List]] = None, classes: Optional[Union[str, List]] = None,
palette: Optional[Union[str, List]] = None,
dataset_name: Optional[str] = None, dataset_name: Optional[str] = None,
device: Optional[str] = None, device: Optional[str] = None,
scope: Optional[str] = 'mmseg') -> None: scope: Optional[str] = 'mmseg') -> None:
# A global counter tracking the number of images processes, for # A global counter tracking the number of images processes, for
# naming of the output images # naming of the output images
self.num_visualized_imgs = 0 self.num_visualized_imgs = 0
register_all_modules() self.num_pred_imgs = 0
init_default_scope(scope if scope else 'mmseg')
super().__init__( super().__init__(
model=model, weights=weights, device=device, scope=scope) model=model, weights=weights, device=device, scope=scope)
if device == 'cpu' or not torch.cuda.is_available():
self.model = revert_sync_batchnorm(self.model)
assert isinstance(self.visualizer, SegLocalVisualizer) assert isinstance(self.visualizer, SegLocalVisualizer)
self.visualizer.set_dataset_meta(palette, classes, dataset_name) self.visualizer.set_dataset_meta(palette, classes, dataset_name)
def _load_weights_to_model(self, model: nn.Module,
checkpoint: Optional[dict],
cfg: Optional[ConfigType]) -> None:
"""Loading model weights and meta information from cfg and checkpoint.
Subclasses could override this method to load extra meta information
from ``checkpoint`` and ``cfg`` to model.
Args:
model (nn.Module): Model to load weights and meta information.
checkpoint (dict, optional): The loaded checkpoint.
cfg (Config or ConfigDict, optional): The loaded config.
"""
if checkpoint is not None:
_load_checkpoint_to_model(model, checkpoint)
checkpoint_meta = checkpoint.get('meta', {})
# save the dataset_meta in the model for convenience
if 'dataset_meta' in checkpoint_meta:
# mmsegmentation 1.x
model.dataset_meta = {
'classes': checkpoint_meta['dataset_meta'].get('classes'),
'palette': checkpoint_meta['dataset_meta'].get('palette')
}
elif 'CLASSES' in checkpoint_meta:
# mmsegmentation 0.x
classes = checkpoint_meta['CLASSES']
palette = checkpoint_meta.get('PALETTE', None)
model.dataset_meta = {'classes': classes, 'palette': palette}
else:
warnings.warn(
'dataset_meta or class names are not saved in the '
'checkpoint\'s meta data, use classes of Cityscapes by '
'default.')
model.dataset_meta = {
'classes': get_classes('cityscapes'),
'palette': get_palette('cityscapes')
}
else:
warnings.warn('Checkpoint is not loaded, and the inference '
'result is calculated by the randomly initialized '
'model!')
warnings.warn(
'weights is None, use cityscapes classes by default.')
model.dataset_meta = {
'classes': get_classes('cityscapes'),
'palette': get_palette('cityscapes')
}
def __call__(self, def __call__(self,
inputs: InputsType, inputs: InputsType,
return_datasamples: bool = False, return_datasamples: bool = False,
batch_size: int = 1, batch_size: int = 1,
show: bool = False, show: bool = False,
wait_time: int = 0, wait_time: int = 0,
draw_pred: bool = True,
out_dir: str = '', out_dir: str = '',
save_mask: bool = False, img_out_dir: str = 'vis',
mask_dir: str = 'mask', pred_out_dir: str = 'pred',
**kwargs) -> dict: **kwargs) -> dict:
"""Call the inferencer. """Call the inferencer.
Args: Args:
inputs (Union[str, np.ndarray]): Inputs for the inferencer. inputs (Union[list, str, np.ndarray]): Inputs for the inferencer.
return_datasamples (bool): Whether to return results as return_datasamples (bool): Whether to return results as
:obj:`SegDataSample`. Defaults to False. :obj:`SegDataSample`. Defaults to False.
batch_size (int): Batch size. Defaults to 1. batch_size (int): Batch size. Defaults to 1.
show (bool): Whether to display the image in a popup window. show (bool): Whether to display the rendering color segmentation
Defaults to False. mask in a popup window. Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0. wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw Prediction SegDataSample. out_dir (str): Output directory of inference results. Defaults
Defaults to True. to ''.
out_dir (str): Output directory of inference results. Defaults: ''. img_out_dir (str): Subdirectory of `out_dir`, used to save
save_mask (bool): Whether save pred mask as a file. rendering color segmentation mask, so `out_dir` must be defined
mask_dir (str): Sub directory of `pred_out_dir`, used to save pred if you would like to save predicted mask. Defaults to 'vis'.
mask file. pred_out_dir (str): Subdirectory of `out_dir`, used to save
predicted mask file, so `out_dir` must be defined if you would
like to save predicted mask. Defaults to 'pred'.
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
Each key in kwargs should be in the corresponding set of
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
and ``postprocess_kwargs``.
Returns: Returns:
dict: Inference and visualization results. dict: Inference and visualization results.
""" """
if out_dir != '':
pred_out_dir = osp.join(out_dir, pred_out_dir)
img_out_dir = osp.join(out_dir, img_out_dir)
else:
pred_out_dir = ''
img_out_dir = ''
return super().__call__( return super().__call__(
inputs=inputs, inputs=inputs,
return_datasamples=return_datasamples, return_datasamples=return_datasamples,
batch_size=batch_size, batch_size=batch_size,
show=show, show=show,
wait_time=wait_time, wait_time=wait_time,
draw_pred=draw_pred, img_out_dir=img_out_dir,
img_out_dir=out_dir, pred_out_dir=pred_out_dir,
pred_out_dir=out_dir,
save_mask=save_mask,
mask_dir=mask_dir,
**kwargs) **kwargs)
def visualize(self, def visualize(self,
@ -114,7 +195,6 @@ class MMSegInferencer(BaseInferencer):
preds: List[dict], preds: List[dict],
show: bool = False, show: bool = False,
wait_time: int = 0, wait_time: int = 0,
draw_pred: bool = True,
img_out_dir: str = '', img_out_dir: str = '',
opacity: float = 0.8) -> List[np.ndarray]: opacity: float = 0.8) -> List[np.ndarray]:
"""Visualize predictions. """Visualize predictions.
@ -125,9 +205,8 @@ class MMSegInferencer(BaseInferencer):
show (bool): Whether to display the image in a popup window. show (bool): Whether to display the image in a popup window.
Defaults to False. Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0. wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw Prediction SegDataSample. img_out_dir (str): Output directory of rendering prediction i.e.
Defaults to True. color segmentation mask. Defaults: ''
img_out_dir (str): Output directory of drawn images. Defaults: ''
opacity (int, float): The transparency of segmentation mask. opacity (int, float): The transparency of segmentation mask.
Defaults to 0.8. Defaults to 0.8.
@ -140,7 +219,7 @@ class MMSegInferencer(BaseInferencer):
if getattr(self, 'visualizer') is None: if getattr(self, 'visualizer') is None:
raise ValueError('Visualization needs the "visualizer" term' raise ValueError('Visualization needs the "visualizer" term'
'defined in the config, but got None') 'defined in the config, but got None')
self.visualizer.set_dataset_meta(**self.model.dataset_meta)
self.visualizer.alpha = opacity self.visualizer.alpha = opacity
results = [] results = []
@ -153,7 +232,7 @@ class MMSegInferencer(BaseInferencer):
img_name = osp.basename(single_input) img_name = osp.basename(single_input)
elif isinstance(single_input, np.ndarray): elif isinstance(single_input, np.ndarray):
img = single_input.copy() img = single_input.copy()
img_num = str(self.num_visualized_imgs).zfill(8) img_num = str(self.num_visualized_imgs).zfill(8) + '_vis'
img_name = f'{img_num}.jpg' img_name = f'{img_num}.jpg'
else: else:
raise ValueError('Unsupported input type:' raise ValueError('Unsupported input type:'
@ -169,7 +248,7 @@ class MMSegInferencer(BaseInferencer):
show=show, show=show,
wait_time=wait_time, wait_time=wait_time,
draw_gt=False, draw_gt=False,
draw_pred=draw_pred, draw_pred=True,
out_file=out_file) out_file=out_file)
results.append(self.visualizer.get_image()) results.append(self.visualizer.get_image())
self.num_visualized_imgs += 1 self.num_visualized_imgs += 1
@ -180,62 +259,65 @@ class MMSegInferencer(BaseInferencer):
preds: PredType, preds: PredType,
visualization: List[np.ndarray], visualization: List[np.ndarray],
return_datasample: bool = False, return_datasample: bool = False,
mask_dir: str = 'mask',
save_mask: bool = True,
pred_out_dir: str = '') -> dict: pred_out_dir: str = '') -> dict:
"""Process the predictions and visualization results from ``forward`` """Process the predictions and visualization results from ``forward``
and ``visualize``. and ``visualize``.
This method should be responsible for the following tasks: This method should be responsible for the following tasks:
1. Convert datasamples into a json-serializable dict if needed. 1. Pack the predictions and visualization results and return them.
2. Pack the predictions and visualization results and return them. 2. Save the predictions, if it needed.
3. Dump or log the predictions.
Args: Args:
preds (List[Dict]): Predictions of the model. preds (List[Dict]): Predictions of the model.
visualization (np.ndarray): Visualized predictions. visualization (List[np.ndarray]): The list of rendering color
segmentation mask.
return_datasample (bool): Whether to return results as datasamples. return_datasample (bool): Whether to return results as datasamples.
Defaults to False. Defaults to False.
pred_out_dir: File to save the inference results w/o pred_out_dir: File to save the inference results w/o
visualization. If left as empty, no file will be saved. visualization. If left as empty, no file will be saved.
Defaults to ''. Defaults to ''.
mask_dir (str): Sub directory of `pred_out_dir`, used to save pred
mask file.
save_mask (bool): Whether save pred mask as a file.
Returns: Returns:
dict: Inference and visualization results with key ``predictions`` dict: Inference and visualization results with key ``predictions``
and ``visualization`` and ``visualization``
- ``visualization (Any)``: Returned by :meth:`visualize` - ``visualization (Any)``: Returned by :meth:`visualize`
- ``predictions`` (dict or DataSample): Returned by - ``predictions`` (List[np.ndarray], np.ndarray): Returned by
:meth:`forward` and processed in :meth:`postprocess`. :meth:`forward` and processed in :meth:`postprocess`.
If ``return_datasample=False``, it usually should be a If ``return_datasample=False``, it will be the segmentation mask
json-serializable dict containing only basic data elements such with label indice.
as strings and numbers.
""" """
if return_datasample:
if len(preds) == 1:
return preds[0]
else:
return preds
results_dict = {} results_dict = {}
results_dict['predictions'] = preds results_dict['predictions'] = []
results_dict['visualization'] = visualization results_dict['visualization'] = []
if pred_out_dir != '': for i, pred in enumerate(preds):
mmengine.mkdir_or_exist(pred_out_dir) pred_data = pred.pred_sem_seg.numpy().data[0]
if save_mask: results_dict['predictions'].append(pred_data)
preds = [preds] if isinstance(preds, SegDataSample) else preds if visualization is not None:
for pred in preds: vis = visualization[i]
mmcv.imwrite( results_dict['visualization'].append(vis)
pred.pred_sem_seg.numpy().data[0], if pred_out_dir != '':
osp.join(pred_out_dir, mask_dir, mmengine.mkdir_or_exist(pred_out_dir)
osp.basename(pred.metainfo['img_path']))) img_name = str(self.num_pred_imgs).zfill(8) + '_pred.png'
else: img_path = osp.join(pred_out_dir, img_name)
mmengine.dump(results_dict, output = Image.fromarray(pred_data.astype(np.uint8))
osp.join(pred_out_dir, 'results.pkl')) output.save(img_path)
self.num_pred_imgs += 1
if return_datasample:
return preds
if len(results_dict['predictions']) == 1:
results_dict['predictions'] = results_dict['predictions'][0]
if visualization is not None:
results_dict['visualization'] = \
results_dict['visualization'][0]
return results_dict return results_dict
def _init_pipeline(self, cfg: ConfigType) -> Compose: def _init_pipeline(self, cfg: ConfigType) -> Compose:

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional
import mmcv import mmcv
import numpy as np import numpy as np
@ -24,6 +24,17 @@ class SegLocalVisualizer(Visualizer):
Defaults to None. Defaults to None.
save_dir (str, optional): Save file dir for all storage backends. save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data. If it is None, the backend storage will not save any data.
classes (list, optional): Input classes for result rendering, as the
prediction of segmentation model is a segment map with label
indices, `classes` is a list which includes items responding to the
label indices. If classes is not defined, visualizer will take
`cityscapes` classes by default. Defaults to None.
palette (list, optional): Input palette for result rendering, which is
a list of color palette responding to the classes. Defaults to None.
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/class_names.py#L302-L317>`_
visulizer will use the meta information of the dataset i.e. classes
and palette, but the `classes` and `palette` have higher priority.
Defaults to None.
alpha (int, float): The transparency of segmentation mask. alpha (int, float): The transparency of segmentation mask.
Defaults to 0.8. Defaults to 0.8.
@ -49,15 +60,15 @@ class SegLocalVisualizer(Visualizer):
>>> seg_local_visualizer.add_datasample( >>> seg_local_visualizer.add_datasample(
... 'visualizer_example', image, ... 'visualizer_example', image,
... gt_seg_data_sample, show=True) ... gt_seg_data_sample, show=True)
""" """ # noqa
def __init__(self, def __init__(self,
name: str = 'visualizer', name: str = 'visualizer',
image: Optional[np.ndarray] = None, image: Optional[np.ndarray] = None,
vis_backends: Optional[Dict] = None, vis_backends: Optional[Dict] = None,
save_dir: Optional[str] = None, save_dir: Optional[str] = None,
palette: Optional[Union[str, List]] = None, classes: Optional[List] = None,
classes: Optional[Union[str, List]] = None, palette: Optional[List] = None,
dataset_name: Optional[str] = None, dataset_name: Optional[str] = None,
alpha: float = 0.8, alpha: float = 0.8,
**kwargs): **kwargs):
@ -66,17 +77,23 @@ class SegLocalVisualizer(Visualizer):
self.set_dataset_meta(palette, classes, dataset_name) self.set_dataset_meta(palette, classes, dataset_name)
def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData, def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
classes: Optional[Tuple[str]], classes: Optional[List],
palette: Optional[List[List[int]]]) -> np.ndarray: palette: Optional[List]) -> np.ndarray:
"""Draw semantic seg of GT or prediction. """Draw semantic seg of GT or prediction.
Args: Args:
image (np.ndarray): The image to draw. image (np.ndarray): The image to draw.
sem_seg (:obj:`PixelData`): Data structure for sem_seg (:obj:`PixelData`): Data structure for pixel-level
pixel-level annotations or predictions. annotations or predictions.
classes (Tuple[str], optional): Category information. classes (list, optional): Input classes for result rendering, as
palette (List[List[int]], optional): The palette of the prediction of segmentation model is a segment map with
segmentation map. label indices, `classes` is a list which includes items
responding to the label indices. If classes is not defined,
visualizer will take `cityscapes` classes by default.
Defaults to None.
palette (list, optional): Input palette for result rendering, which
is a list of color palette responding to the classes.
Defaults to None.
Returns: Returns:
np.ndarray: the drawn image which channel is RGB. np.ndarray: the drawn image which channel is RGB.
@ -101,9 +118,26 @@ class SegLocalVisualizer(Visualizer):
return self.get_image() return self.get_image()
def set_dataset_meta(self, def set_dataset_meta(self,
palette: Optional[Union[str, List]] = None, classes: Optional[List] = None,
classes: Optional[Union[str, List]] = None, palette: Optional[List] = None,
dataset_name: Optional[str] = None) -> None: dataset_name: Optional[str] = None) -> None:
"""Set meta information to visualizer.
Args:
classes (list, optional): Input classes for result rendering, as
the prediction of segmentation model is a segment map with
label indices, `classes` is a list which includes items
responding to the label indices. If classes is not defined,
visualizer will take `cityscapes` classes by default.
Defaults to None.
palette (list, optional): Input palette for result rendering, which
is a list of color palette responding to the classes.
Defaults to None.
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/class_names.py#L302-L317>`_
visulizer will use the meta information of the dataset i.e.
classes and palette, but the `classes` and `palette` have
higher priority. Defaults to None.
""" # noqa
# Set default value. When calling # Set default value. When calling
# `SegLocalVisualizer().dataset_meta=xxx`, # `SegLocalVisualizer().dataset_meta=xxx`,
# it will override the default value. # it will override the default value.

View File

@ -104,12 +104,10 @@ def test_inferencer():
imgs = [img, img] imgs = [img, img]
infer(imgs) infer(imgs)
results = infer(imgs, out_dir=tempfile.gettempdir(), draw_pred=True) results = infer(imgs, out_dir=tempfile.gettempdir())
# test results # test results
assert 'predictions' in results assert 'predictions' in results
assert 'visualization' in results assert 'visualization' in results
assert len(results['predictions']) == 2 assert len(results['predictions']) == 2
assert results['predictions'][0].seg_logits.data.shape == torch.Size( assert results['predictions'][0].shape == (4, 4)
(19, 4, 4))
assert results['predictions'][0].pred_sem_seg.shape == torch.Size((4, 4))