[Enhancement] Modify interface of MMSeginferencer and add docs (#2658)
## Motivation Make MMSeginferencer easier to be used ## Modification 1. Add `_load_weights_to_model` to MMSeginferencer, it is for get `dataset_meta` from ckpt 2. Modify and remove some parameters of `__call__`, `visualization` and `postprocess` 3. Add function of save seg mask, remove dump pkl. 4. Refine docstring of MMSeginferencer and SegLocalVisualizer 5. Add the user documentation of MMSeginferencer ## BC-breaking (Optional) yes, remove some parameters, we need to discuss whether keep them with deprecated waring or just remove them as the MMSeginferencer just merged in mmseg a few days. Co-authored-by: xiexinch <xiexinch@outlook.com>pull/2673/head
parent
8c1d299cb6
commit
310ec4afc7
|
@ -16,11 +16,6 @@ def main():
|
|||
action='store_true',
|
||||
default=False,
|
||||
help='Whether to display the drawn image.')
|
||||
parser.add_argument(
|
||||
'--save-mask',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Enable save the mask file')
|
||||
parser.add_argument(
|
||||
'--dataset-name',
|
||||
default='cityscapes',
|
||||
|
@ -43,11 +38,7 @@ def main():
|
|||
|
||||
# test a single image
|
||||
mmseg_inferencer(
|
||||
args.img,
|
||||
show=args.show,
|
||||
out_dir=args.out_dir,
|
||||
save_mask=args.save_mask,
|
||||
opacity=args.opacity)
|
||||
args.img, show=args.show, out_dir=args.out_dir, opacity=args.opacity)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -4,13 +4,132 @@ MMSegmentation provides pre-trained models for semantic segmentation in [Model Z
|
|||
This note will show how to use existing models to inference on given images.
|
||||
As for how to test existing models on standard datasets, please see this [guide](./4_train_test.md)
|
||||
|
||||
## Inference API
|
||||
|
||||
MMSegmentation provides several interfaces for users to easily use pre-trained models for inference.
|
||||
|
||||
- [mmseg.apis.init_model](#mmsegapisinit_model)
|
||||
- [mmseg.apis.inference_model](#mmsegapisinference_model)
|
||||
- [mmseg.apis.show_result_pyplot](#mmsegapisshow_result_pyplot)
|
||||
- [Tutorial 3: Inference with existing models](#tutorial-3-inference-with-existing-models)
|
||||
- [Inferencer](#inferencer)
|
||||
- [Basic Usage](#basic-usage)
|
||||
- [Initialization](#initialization)
|
||||
- [Visualize prediction](#visualize-prediction)
|
||||
- [List model](#list-model)
|
||||
- [Inference API](#inference-api)
|
||||
- [mmseg.apis.init_model](#mmsegapisinit_model)
|
||||
- [mmseg.apis.inference_model](#mmsegapisinference_model)
|
||||
- [mmseg.apis.show_result_pyplot](#mmsegapisshow_result_pyplot)
|
||||
|
||||
## Inferencer
|
||||
|
||||
We provides the most **convenient** way to use the model in MMSegmentation `MMSegInferencer`. You can get segmentation mask for an image with only 3 lines of code.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
The following example shows how to use `MMSegInferencer` to perform inference on a single image.
|
||||
|
||||
```
|
||||
>>> from mmseg.apis import MMSegInferencer
|
||||
>>> # Load models into memory
|
||||
>>> inferencer = MMSegInferencer(model='deeplabv3plus_r18-d8_4xb2-80k_cityscapes-512x1024')
|
||||
>>> # Inference
|
||||
>>> inferencer('demo/demo.png', show=True)
|
||||
```
|
||||
|
||||
The visualization result should look like:
|
||||
|
||||
<div align="center">
|
||||
https://user-images.githubusercontent.com/76149310/221507927-ae01e3a7-016f-4425-b966-7b19cbbe494e.png
|
||||
</div>
|
||||
|
||||
Moreover, you can use `MMSegInferencer` to process a list of images:
|
||||
|
||||
```
|
||||
# Input a list of images
|
||||
>>> images = [image1, image2, ...] # image1 can be a file path or a np.ndarray
|
||||
>>> inferencer(images, show=True, wait_time=0.5) # wait_time is delay time, and 0 means forever.
|
||||
|
||||
# Or input image directory
|
||||
>>> images = $IMAGESDIR
|
||||
>>> inferencer(images, show=True, wait_time=0.5)
|
||||
|
||||
# Save visualized rendering color maps and predicted results
|
||||
# out_dir is the directory to save the output results, img_out_dir and pred_out_dir are subdirectories of out_dir
|
||||
# to save visualized rendering color maps and predicted results
|
||||
>>> inferencer(images, out_dir='outputs', img_out_dir='vis', pred_out_dir='pred')
|
||||
```
|
||||
|
||||
There is a optional parameter of inferencer, `return_datasamples`, whose default value is False, and
|
||||
return value of inferencer is a `dict` type by default, including 2 keys 'visualization' and 'predictions'.
|
||||
If `return_datasamples=True` inferencer will return [`SegDataSample`](../advanced_guides/structures.md), or list of it.
|
||||
|
||||
```
|
||||
result = inferencer('demo/demo.png')
|
||||
# result is a `dict` including 2 keys 'visualization' and 'predictions'.
|
||||
# 'visualization' includes color segmentation map
|
||||
print(result['visualization'].shape)
|
||||
# (512, 683, 3)
|
||||
|
||||
# 'predictions' includes segmentation mask with label indice
|
||||
print(result['predictions'].shape)
|
||||
# (512, 683)
|
||||
|
||||
result = inferencer('demo/demo.png', return_datasamples=True)
|
||||
print(type(result))
|
||||
# <class 'mmseg.structures.seg_data_sample.SegDataSample'>
|
||||
|
||||
# Input a list of images
|
||||
results = inferencer(images)
|
||||
# The output is list
|
||||
print(type(results['visualization']), results['visualization'][0].shape)
|
||||
# <class 'list'> (512, 683, 3)
|
||||
print(type(results['predictions']), results['predictions'][0].shape)
|
||||
# <class 'list'> (512, 683)
|
||||
|
||||
results = inferencer(images, return_datasamples=True)
|
||||
# <class 'list'>
|
||||
print(type(results[0]))
|
||||
# <class 'mmseg.structures.seg_data_sample.SegDataSample'>
|
||||
```
|
||||
|
||||
### Initialization
|
||||
|
||||
`MMSegInferencer` must be initialized from a `model`, which can be a model name or a `Config` even a path of config file.
|
||||
The model names can be found in models' metafile, like one model name of maskformer is `maskformer_r50-d32_8xb2-160k_ade20k-512x512`, and if input model name and the weights of the model will be download automatically. Below are other input parameters:
|
||||
|
||||
- weights (str, optional) - Path to the checkpoint. If it is not specified and model is a model name of metafile, the weights will be loaded
|
||||
from metafile. Defaults to None.
|
||||
- classes (list, optional) - Input classes for result rendering, as the prediction of segmentation
|
||||
model is a segment map with label indices, `classes` is a list which includes
|
||||
items responding to the label indices. If classes is not defined, visualizer will take `cityscapes` classes by default. Defaults to None.
|
||||
- palette (list, optional) - Input palette for result rendering, which is a list of color palette
|
||||
responding to the classes. If palette is not defined, visualizer will take `cityscapes` palette by default. Defaults to None.
|
||||
- dataset_name (str, optional)[Dataset name or alias](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/class_names.py#L302-L317)
|
||||
visulizer will use the meta information of the dataset i.e. classes and palette,
|
||||
but the `classes` and `palette` have higher priority. Defaults to None.
|
||||
- device (str, optional) - Device to run inference. If None, the available device will be automatically used. Defaults to None.
|
||||
- scope (str, optional) - The scope of the model. Defaults to 'mmseg'.
|
||||
|
||||
### Visualize prediction
|
||||
|
||||
`MMSegInferencer` supports 4 parameters for visualize prediction, you can use them when call initialized inferencer:
|
||||
|
||||
- show (bool) - Whether to display the image in a popup window. Defaults to False.
|
||||
- wait_time (float) - The interval of show (s). Defaults to 0.
|
||||
- img_out_dir (str) - Subdirectory of `out_dir`, used to save rendering color segmentation mask, so `out_dir` must be defined
|
||||
if you would like to save predicted mask. Defaults to 'vis'.
|
||||
- opacity (int, float) - The transparency of segmentation mask. Defaults to 0.8.
|
||||
|
||||
The examples of these parameters is in [Basic Usage](#basic-usage)
|
||||
|
||||
### List model
|
||||
|
||||
There is a very easy to list all model names in MMSegmentation
|
||||
|
||||
```
|
||||
>>> from mmseg.apis import MMSegInferencer
|
||||
# models is a list of model names, and them will print automatically
|
||||
>>> models = MMSegInferencer.list_models('mmseg')
|
||||
```
|
||||
|
||||
## Inference API
|
||||
|
||||
### mmseg.apis.init_model
|
||||
|
||||
|
|
|
@ -1,15 +1,22 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.transforms import Compose
|
||||
from mmengine.infer.infer import BaseInferencer, ModelType
|
||||
from mmengine.model import revert_sync_batchnorm
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner.checkpoint import _load_checkpoint_to_model
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import ConfigType, SampleList, register_all_modules
|
||||
from mmseg.utils import ConfigType, SampleList, get_classes, get_palette
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
|
||||
InputType = Union[str, np.ndarray]
|
||||
|
@ -23,90 +30,164 @@ class MMSegInferencer(BaseInferencer):
|
|||
|
||||
Args:
|
||||
model (str, optional): Path to the config file or the model name
|
||||
defined in metafile. For example, it could be
|
||||
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024" or
|
||||
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py"
|
||||
defined in metafile. Take the `mmseg metafile <https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/fcn/fcn.yml>`_
|
||||
as an example the `model` could be
|
||||
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024", and the weights of model
|
||||
will be download automatically. If use config file, like
|
||||
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py", the
|
||||
`weights` should be defined.
|
||||
weights (str, optional): Path to the checkpoint. If it is not specified
|
||||
and model is a model name of metafile, the weights will be loaded
|
||||
from metafile. Defaults to None.
|
||||
palette (List[List[int]], optional): The palette of
|
||||
segmentation map.
|
||||
classes (Tuple[str], optional): Category information.
|
||||
dataset_name (str, optional): Name of the datasets supported in mmseg.
|
||||
classes (list, optional): Input classes for result rendering, as the
|
||||
prediction of segmentation model is a segment map with label
|
||||
indices, `classes` is a list which includes items responding to the
|
||||
label indices. If classes is not defined, visualizer will take
|
||||
`cityscapes` classes by default. Defaults to None.
|
||||
palette (list, optional): Input palette for result rendering, which is
|
||||
a list of color palette responding to the classes. If palette is
|
||||
not defined, visualizer will take `cityscapes` palette by default.
|
||||
Defaults to None.
|
||||
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/class_names.py#L302-L317>`_
|
||||
visulizer will use the meta information of the dataset i.e. classes
|
||||
and palette, but the `classes` and `palette` have higher priority.
|
||||
Defaults to None.
|
||||
device (str, optional): Device to run inference. If None, the available
|
||||
device will be automatically used. Defaults to None.
|
||||
scope (str, optional): The scope of the model. Defaults to None.
|
||||
"""
|
||||
scope (str, optional): The scope of the model. Defaults to 'mmseg'.
|
||||
""" # noqa
|
||||
|
||||
preprocess_kwargs: set = set()
|
||||
forward_kwargs: set = {'mode', 'out_dir'}
|
||||
visualize_kwargs: set = {
|
||||
'show', 'wait_time', 'draw_pred', 'img_out_dir', 'opacity'
|
||||
}
|
||||
postprocess_kwargs: set = {
|
||||
'pred_out_dir', 'return_datasample', 'save_mask', 'mask_dir'
|
||||
}
|
||||
visualize_kwargs: set = {'show', 'wait_time', 'img_out_dir', 'opacity'}
|
||||
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
|
||||
|
||||
def __init__(self,
|
||||
model: Union[ModelType, str],
|
||||
weights: Optional[str] = None,
|
||||
palette: Optional[Union[str, List]] = None,
|
||||
classes: Optional[Union[str, List]] = None,
|
||||
palette: Optional[Union[str, List]] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
scope: Optional[str] = 'mmseg') -> None:
|
||||
# A global counter tracking the number of images processes, for
|
||||
# naming of the output images
|
||||
self.num_visualized_imgs = 0
|
||||
register_all_modules()
|
||||
self.num_pred_imgs = 0
|
||||
init_default_scope(scope if scope else 'mmseg')
|
||||
super().__init__(
|
||||
model=model, weights=weights, device=device, scope=scope)
|
||||
|
||||
if device == 'cpu' or not torch.cuda.is_available():
|
||||
self.model = revert_sync_batchnorm(self.model)
|
||||
|
||||
assert isinstance(self.visualizer, SegLocalVisualizer)
|
||||
self.visualizer.set_dataset_meta(palette, classes, dataset_name)
|
||||
|
||||
def _load_weights_to_model(self, model: nn.Module,
|
||||
checkpoint: Optional[dict],
|
||||
cfg: Optional[ConfigType]) -> None:
|
||||
"""Loading model weights and meta information from cfg and checkpoint.
|
||||
|
||||
Subclasses could override this method to load extra meta information
|
||||
from ``checkpoint`` and ``cfg`` to model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model to load weights and meta information.
|
||||
checkpoint (dict, optional): The loaded checkpoint.
|
||||
cfg (Config or ConfigDict, optional): The loaded config.
|
||||
"""
|
||||
|
||||
if checkpoint is not None:
|
||||
_load_checkpoint_to_model(model, checkpoint)
|
||||
checkpoint_meta = checkpoint.get('meta', {})
|
||||
# save the dataset_meta in the model for convenience
|
||||
if 'dataset_meta' in checkpoint_meta:
|
||||
# mmsegmentation 1.x
|
||||
model.dataset_meta = {
|
||||
'classes': checkpoint_meta['dataset_meta'].get('classes'),
|
||||
'palette': checkpoint_meta['dataset_meta'].get('palette')
|
||||
}
|
||||
elif 'CLASSES' in checkpoint_meta:
|
||||
# mmsegmentation 0.x
|
||||
classes = checkpoint_meta['CLASSES']
|
||||
palette = checkpoint_meta.get('PALETTE', None)
|
||||
model.dataset_meta = {'classes': classes, 'palette': palette}
|
||||
else:
|
||||
warnings.warn(
|
||||
'dataset_meta or class names are not saved in the '
|
||||
'checkpoint\'s meta data, use classes of Cityscapes by '
|
||||
'default.')
|
||||
model.dataset_meta = {
|
||||
'classes': get_classes('cityscapes'),
|
||||
'palette': get_palette('cityscapes')
|
||||
}
|
||||
else:
|
||||
warnings.warn('Checkpoint is not loaded, and the inference '
|
||||
'result is calculated by the randomly initialized '
|
||||
'model!')
|
||||
warnings.warn(
|
||||
'weights is None, use cityscapes classes by default.')
|
||||
model.dataset_meta = {
|
||||
'classes': get_classes('cityscapes'),
|
||||
'palette': get_palette('cityscapes')
|
||||
}
|
||||
|
||||
def __call__(self,
|
||||
inputs: InputsType,
|
||||
return_datasamples: bool = False,
|
||||
batch_size: int = 1,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
draw_pred: bool = True,
|
||||
out_dir: str = '',
|
||||
save_mask: bool = False,
|
||||
mask_dir: str = 'mask',
|
||||
img_out_dir: str = 'vis',
|
||||
pred_out_dir: str = 'pred',
|
||||
**kwargs) -> dict:
|
||||
"""Call the inferencer.
|
||||
|
||||
Args:
|
||||
inputs (Union[str, np.ndarray]): Inputs for the inferencer.
|
||||
inputs (Union[list, str, np.ndarray]): Inputs for the inferencer.
|
||||
return_datasamples (bool): Whether to return results as
|
||||
:obj:`SegDataSample`. Defaults to False.
|
||||
batch_size (int): Batch size. Defaults to 1.
|
||||
show (bool): Whether to display the image in a popup window.
|
||||
Defaults to False.
|
||||
show (bool): Whether to display the rendering color segmentation
|
||||
mask in a popup window. Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
draw_pred (bool): Whether to draw Prediction SegDataSample.
|
||||
Defaults to True.
|
||||
out_dir (str): Output directory of inference results. Defaults: ''.
|
||||
save_mask (bool): Whether save pred mask as a file.
|
||||
mask_dir (str): Sub directory of `pred_out_dir`, used to save pred
|
||||
mask file.
|
||||
out_dir (str): Output directory of inference results. Defaults
|
||||
to ''.
|
||||
img_out_dir (str): Subdirectory of `out_dir`, used to save
|
||||
rendering color segmentation mask, so `out_dir` must be defined
|
||||
if you would like to save predicted mask. Defaults to 'vis'.
|
||||
pred_out_dir (str): Subdirectory of `out_dir`, used to save
|
||||
predicted mask file, so `out_dir` must be defined if you would
|
||||
like to save predicted mask. Defaults to 'pred'.
|
||||
|
||||
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
|
||||
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
|
||||
Each key in kwargs should be in the corresponding set of
|
||||
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
|
||||
and ``postprocess_kwargs``.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results.
|
||||
"""
|
||||
|
||||
if out_dir != '':
|
||||
pred_out_dir = osp.join(out_dir, pred_out_dir)
|
||||
img_out_dir = osp.join(out_dir, img_out_dir)
|
||||
else:
|
||||
pred_out_dir = ''
|
||||
img_out_dir = ''
|
||||
|
||||
return super().__call__(
|
||||
inputs=inputs,
|
||||
return_datasamples=return_datasamples,
|
||||
batch_size=batch_size,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
draw_pred=draw_pred,
|
||||
img_out_dir=out_dir,
|
||||
pred_out_dir=out_dir,
|
||||
save_mask=save_mask,
|
||||
mask_dir=mask_dir,
|
||||
img_out_dir=img_out_dir,
|
||||
pred_out_dir=pred_out_dir,
|
||||
**kwargs)
|
||||
|
||||
def visualize(self,
|
||||
|
@ -114,7 +195,6 @@ class MMSegInferencer(BaseInferencer):
|
|||
preds: List[dict],
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
draw_pred: bool = True,
|
||||
img_out_dir: str = '',
|
||||
opacity: float = 0.8) -> List[np.ndarray]:
|
||||
"""Visualize predictions.
|
||||
|
@ -125,9 +205,8 @@ class MMSegInferencer(BaseInferencer):
|
|||
show (bool): Whether to display the image in a popup window.
|
||||
Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
draw_pred (bool): Whether to draw Prediction SegDataSample.
|
||||
Defaults to True.
|
||||
img_out_dir (str): Output directory of drawn images. Defaults: ''
|
||||
img_out_dir (str): Output directory of rendering prediction i.e.
|
||||
color segmentation mask. Defaults: ''
|
||||
opacity (int, float): The transparency of segmentation mask.
|
||||
Defaults to 0.8.
|
||||
|
||||
|
@ -140,7 +219,7 @@ class MMSegInferencer(BaseInferencer):
|
|||
if getattr(self, 'visualizer') is None:
|
||||
raise ValueError('Visualization needs the "visualizer" term'
|
||||
'defined in the config, but got None')
|
||||
|
||||
self.visualizer.set_dataset_meta(**self.model.dataset_meta)
|
||||
self.visualizer.alpha = opacity
|
||||
|
||||
results = []
|
||||
|
@ -153,7 +232,7 @@ class MMSegInferencer(BaseInferencer):
|
|||
img_name = osp.basename(single_input)
|
||||
elif isinstance(single_input, np.ndarray):
|
||||
img = single_input.copy()
|
||||
img_num = str(self.num_visualized_imgs).zfill(8)
|
||||
img_num = str(self.num_visualized_imgs).zfill(8) + '_vis'
|
||||
img_name = f'{img_num}.jpg'
|
||||
else:
|
||||
raise ValueError('Unsupported input type:'
|
||||
|
@ -169,7 +248,7 @@ class MMSegInferencer(BaseInferencer):
|
|||
show=show,
|
||||
wait_time=wait_time,
|
||||
draw_gt=False,
|
||||
draw_pred=draw_pred,
|
||||
draw_pred=True,
|
||||
out_file=out_file)
|
||||
results.append(self.visualizer.get_image())
|
||||
self.num_visualized_imgs += 1
|
||||
|
@ -180,62 +259,65 @@ class MMSegInferencer(BaseInferencer):
|
|||
preds: PredType,
|
||||
visualization: List[np.ndarray],
|
||||
return_datasample: bool = False,
|
||||
mask_dir: str = 'mask',
|
||||
save_mask: bool = True,
|
||||
pred_out_dir: str = '') -> dict:
|
||||
"""Process the predictions and visualization results from ``forward``
|
||||
and ``visualize``.
|
||||
|
||||
This method should be responsible for the following tasks:
|
||||
|
||||
1. Convert datasamples into a json-serializable dict if needed.
|
||||
2. Pack the predictions and visualization results and return them.
|
||||
3. Dump or log the predictions.
|
||||
1. Pack the predictions and visualization results and return them.
|
||||
2. Save the predictions, if it needed.
|
||||
|
||||
Args:
|
||||
preds (List[Dict]): Predictions of the model.
|
||||
visualization (np.ndarray): Visualized predictions.
|
||||
visualization (List[np.ndarray]): The list of rendering color
|
||||
segmentation mask.
|
||||
return_datasample (bool): Whether to return results as datasamples.
|
||||
Defaults to False.
|
||||
pred_out_dir: File to save the inference results w/o
|
||||
visualization. If left as empty, no file will be saved.
|
||||
Defaults to ''.
|
||||
mask_dir (str): Sub directory of `pred_out_dir`, used to save pred
|
||||
mask file.
|
||||
save_mask (bool): Whether save pred mask as a file.
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results with key ``predictions``
|
||||
and ``visualization``
|
||||
|
||||
- ``visualization (Any)``: Returned by :meth:`visualize`
|
||||
- ``predictions`` (dict or DataSample): Returned by
|
||||
- ``predictions`` (List[np.ndarray], np.ndarray): Returned by
|
||||
:meth:`forward` and processed in :meth:`postprocess`.
|
||||
If ``return_datasample=False``, it usually should be a
|
||||
json-serializable dict containing only basic data elements such
|
||||
as strings and numbers.
|
||||
If ``return_datasample=False``, it will be the segmentation mask
|
||||
with label indice.
|
||||
"""
|
||||
if return_datasample:
|
||||
if len(preds) == 1:
|
||||
return preds[0]
|
||||
else:
|
||||
return preds
|
||||
|
||||
results_dict = {}
|
||||
|
||||
results_dict['predictions'] = preds
|
||||
results_dict['visualization'] = visualization
|
||||
results_dict['predictions'] = []
|
||||
results_dict['visualization'] = []
|
||||
|
||||
if pred_out_dir != '':
|
||||
mmengine.mkdir_or_exist(pred_out_dir)
|
||||
if save_mask:
|
||||
preds = [preds] if isinstance(preds, SegDataSample) else preds
|
||||
for pred in preds:
|
||||
mmcv.imwrite(
|
||||
pred.pred_sem_seg.numpy().data[0],
|
||||
osp.join(pred_out_dir, mask_dir,
|
||||
osp.basename(pred.metainfo['img_path'])))
|
||||
else:
|
||||
mmengine.dump(results_dict,
|
||||
osp.join(pred_out_dir, 'results.pkl'))
|
||||
|
||||
if return_datasample:
|
||||
return preds
|
||||
for i, pred in enumerate(preds):
|
||||
pred_data = pred.pred_sem_seg.numpy().data[0]
|
||||
results_dict['predictions'].append(pred_data)
|
||||
if visualization is not None:
|
||||
vis = visualization[i]
|
||||
results_dict['visualization'].append(vis)
|
||||
if pred_out_dir != '':
|
||||
mmengine.mkdir_or_exist(pred_out_dir)
|
||||
img_name = str(self.num_pred_imgs).zfill(8) + '_pred.png'
|
||||
img_path = osp.join(pred_out_dir, img_name)
|
||||
output = Image.fromarray(pred_data.astype(np.uint8))
|
||||
output.save(img_path)
|
||||
self.num_pred_imgs += 1
|
||||
|
||||
if len(results_dict['predictions']) == 1:
|
||||
results_dict['predictions'] = results_dict['predictions'][0]
|
||||
if visualization is not None:
|
||||
results_dict['visualization'] = \
|
||||
results_dict['visualization'][0]
|
||||
return results_dict
|
||||
|
||||
def _init_pipeline(self, cfg: ConfigType) -> Compose:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -24,6 +24,17 @@ class SegLocalVisualizer(Visualizer):
|
|||
Defaults to None.
|
||||
save_dir (str, optional): Save file dir for all storage backends.
|
||||
If it is None, the backend storage will not save any data.
|
||||
classes (list, optional): Input classes for result rendering, as the
|
||||
prediction of segmentation model is a segment map with label
|
||||
indices, `classes` is a list which includes items responding to the
|
||||
label indices. If classes is not defined, visualizer will take
|
||||
`cityscapes` classes by default. Defaults to None.
|
||||
palette (list, optional): Input palette for result rendering, which is
|
||||
a list of color palette responding to the classes. Defaults to None.
|
||||
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/class_names.py#L302-L317>`_
|
||||
visulizer will use the meta information of the dataset i.e. classes
|
||||
and palette, but the `classes` and `palette` have higher priority.
|
||||
Defaults to None.
|
||||
alpha (int, float): The transparency of segmentation mask.
|
||||
Defaults to 0.8.
|
||||
|
||||
|
@ -49,15 +60,15 @@ class SegLocalVisualizer(Visualizer):
|
|||
>>> seg_local_visualizer.add_datasample(
|
||||
... 'visualizer_example', image,
|
||||
... gt_seg_data_sample, show=True)
|
||||
"""
|
||||
""" # noqa
|
||||
|
||||
def __init__(self,
|
||||
name: str = 'visualizer',
|
||||
image: Optional[np.ndarray] = None,
|
||||
vis_backends: Optional[Dict] = None,
|
||||
save_dir: Optional[str] = None,
|
||||
palette: Optional[Union[str, List]] = None,
|
||||
classes: Optional[Union[str, List]] = None,
|
||||
classes: Optional[List] = None,
|
||||
palette: Optional[List] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
alpha: float = 0.8,
|
||||
**kwargs):
|
||||
|
@ -66,17 +77,23 @@ class SegLocalVisualizer(Visualizer):
|
|||
self.set_dataset_meta(palette, classes, dataset_name)
|
||||
|
||||
def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
|
||||
classes: Optional[Tuple[str]],
|
||||
palette: Optional[List[List[int]]]) -> np.ndarray:
|
||||
classes: Optional[List],
|
||||
palette: Optional[List]) -> np.ndarray:
|
||||
"""Draw semantic seg of GT or prediction.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The image to draw.
|
||||
sem_seg (:obj:`PixelData`): Data structure for
|
||||
pixel-level annotations or predictions.
|
||||
classes (Tuple[str], optional): Category information.
|
||||
palette (List[List[int]], optional): The palette of
|
||||
segmentation map.
|
||||
sem_seg (:obj:`PixelData`): Data structure for pixel-level
|
||||
annotations or predictions.
|
||||
classes (list, optional): Input classes for result rendering, as
|
||||
the prediction of segmentation model is a segment map with
|
||||
label indices, `classes` is a list which includes items
|
||||
responding to the label indices. If classes is not defined,
|
||||
visualizer will take `cityscapes` classes by default.
|
||||
Defaults to None.
|
||||
palette (list, optional): Input palette for result rendering, which
|
||||
is a list of color palette responding to the classes.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
np.ndarray: the drawn image which channel is RGB.
|
||||
|
@ -101,9 +118,26 @@ class SegLocalVisualizer(Visualizer):
|
|||
return self.get_image()
|
||||
|
||||
def set_dataset_meta(self,
|
||||
palette: Optional[Union[str, List]] = None,
|
||||
classes: Optional[Union[str, List]] = None,
|
||||
classes: Optional[List] = None,
|
||||
palette: Optional[List] = None,
|
||||
dataset_name: Optional[str] = None) -> None:
|
||||
"""Set meta information to visualizer.
|
||||
|
||||
Args:
|
||||
classes (list, optional): Input classes for result rendering, as
|
||||
the prediction of segmentation model is a segment map with
|
||||
label indices, `classes` is a list which includes items
|
||||
responding to the label indices. If classes is not defined,
|
||||
visualizer will take `cityscapes` classes by default.
|
||||
Defaults to None.
|
||||
palette (list, optional): Input palette for result rendering, which
|
||||
is a list of color palette responding to the classes.
|
||||
Defaults to None.
|
||||
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/class_names.py#L302-L317>`_
|
||||
visulizer will use the meta information of the dataset i.e.
|
||||
classes and palette, but the `classes` and `palette` have
|
||||
higher priority. Defaults to None.
|
||||
""" # noqa
|
||||
# Set default value. When calling
|
||||
# `SegLocalVisualizer().dataset_meta=xxx`,
|
||||
# it will override the default value.
|
||||
|
|
|
@ -104,12 +104,10 @@ def test_inferencer():
|
|||
|
||||
imgs = [img, img]
|
||||
infer(imgs)
|
||||
results = infer(imgs, out_dir=tempfile.gettempdir(), draw_pred=True)
|
||||
results = infer(imgs, out_dir=tempfile.gettempdir())
|
||||
|
||||
# test results
|
||||
assert 'predictions' in results
|
||||
assert 'visualization' in results
|
||||
assert len(results['predictions']) == 2
|
||||
assert results['predictions'][0].seg_logits.data.shape == torch.Size(
|
||||
(19, 4, 4))
|
||||
assert results['predictions'][0].pred_sem_seg.shape == torch.Size((4, 4))
|
||||
assert results['predictions'][0].shape == (4, 4)
|
||||
|
|
Loading…
Reference in New Issue