mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Enhancement] Modify interface of MMSeginferencer and add docs (#2658)
## Motivation Make MMSeginferencer easier to be used ## Modification 1. Add `_load_weights_to_model` to MMSeginferencer, it is for get `dataset_meta` from ckpt 2. Modify and remove some parameters of `__call__`, `visualization` and `postprocess` 3. Add function of save seg mask, remove dump pkl. 4. Refine docstring of MMSeginferencer and SegLocalVisualizer 5. Add the user documentation of MMSeginferencer ## BC-breaking (Optional) yes, remove some parameters, we need to discuss whether keep them with deprecated waring or just remove them as the MMSeginferencer just merged in mmseg a few days. Co-authored-by: xiexinch <xiexinch@outlook.com>
This commit is contained in:
parent
8c1d299cb6
commit
310ec4afc7
@ -16,11 +16,6 @@ def main():
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
default=False,
|
default=False,
|
||||||
help='Whether to display the drawn image.')
|
help='Whether to display the drawn image.')
|
||||||
parser.add_argument(
|
|
||||||
'--save-mask',
|
|
||||||
action='store_true',
|
|
||||||
default=False,
|
|
||||||
help='Enable save the mask file')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dataset-name',
|
'--dataset-name',
|
||||||
default='cityscapes',
|
default='cityscapes',
|
||||||
@ -43,11 +38,7 @@ def main():
|
|||||||
|
|
||||||
# test a single image
|
# test a single image
|
||||||
mmseg_inferencer(
|
mmseg_inferencer(
|
||||||
args.img,
|
args.img, show=args.show, out_dir=args.out_dir, opacity=args.opacity)
|
||||||
show=args.show,
|
|
||||||
out_dir=args.out_dir,
|
|
||||||
save_mask=args.save_mask,
|
|
||||||
opacity=args.opacity)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -4,13 +4,132 @@ MMSegmentation provides pre-trained models for semantic segmentation in [Model Z
|
|||||||
This note will show how to use existing models to inference on given images.
|
This note will show how to use existing models to inference on given images.
|
||||||
As for how to test existing models on standard datasets, please see this [guide](./4_train_test.md)
|
As for how to test existing models on standard datasets, please see this [guide](./4_train_test.md)
|
||||||
|
|
||||||
## Inference API
|
|
||||||
|
|
||||||
MMSegmentation provides several interfaces for users to easily use pre-trained models for inference.
|
MMSegmentation provides several interfaces for users to easily use pre-trained models for inference.
|
||||||
|
|
||||||
- [mmseg.apis.init_model](#mmsegapisinit_model)
|
- [Tutorial 3: Inference with existing models](#tutorial-3-inference-with-existing-models)
|
||||||
- [mmseg.apis.inference_model](#mmsegapisinference_model)
|
- [Inferencer](#inferencer)
|
||||||
- [mmseg.apis.show_result_pyplot](#mmsegapisshow_result_pyplot)
|
- [Basic Usage](#basic-usage)
|
||||||
|
- [Initialization](#initialization)
|
||||||
|
- [Visualize prediction](#visualize-prediction)
|
||||||
|
- [List model](#list-model)
|
||||||
|
- [Inference API](#inference-api)
|
||||||
|
- [mmseg.apis.init_model](#mmsegapisinit_model)
|
||||||
|
- [mmseg.apis.inference_model](#mmsegapisinference_model)
|
||||||
|
- [mmseg.apis.show_result_pyplot](#mmsegapisshow_result_pyplot)
|
||||||
|
|
||||||
|
## Inferencer
|
||||||
|
|
||||||
|
We provides the most **convenient** way to use the model in MMSegmentation `MMSegInferencer`. You can get segmentation mask for an image with only 3 lines of code.
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
The following example shows how to use `MMSegInferencer` to perform inference on a single image.
|
||||||
|
|
||||||
|
```
|
||||||
|
>>> from mmseg.apis import MMSegInferencer
|
||||||
|
>>> # Load models into memory
|
||||||
|
>>> inferencer = MMSegInferencer(model='deeplabv3plus_r18-d8_4xb2-80k_cityscapes-512x1024')
|
||||||
|
>>> # Inference
|
||||||
|
>>> inferencer('demo/demo.png', show=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
The visualization result should look like:
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
https://user-images.githubusercontent.com/76149310/221507927-ae01e3a7-016f-4425-b966-7b19cbbe494e.png
|
||||||
|
</div>
|
||||||
|
|
||||||
|
Moreover, you can use `MMSegInferencer` to process a list of images:
|
||||||
|
|
||||||
|
```
|
||||||
|
# Input a list of images
|
||||||
|
>>> images = [image1, image2, ...] # image1 can be a file path or a np.ndarray
|
||||||
|
>>> inferencer(images, show=True, wait_time=0.5) # wait_time is delay time, and 0 means forever.
|
||||||
|
|
||||||
|
# Or input image directory
|
||||||
|
>>> images = $IMAGESDIR
|
||||||
|
>>> inferencer(images, show=True, wait_time=0.5)
|
||||||
|
|
||||||
|
# Save visualized rendering color maps and predicted results
|
||||||
|
# out_dir is the directory to save the output results, img_out_dir and pred_out_dir are subdirectories of out_dir
|
||||||
|
# to save visualized rendering color maps and predicted results
|
||||||
|
>>> inferencer(images, out_dir='outputs', img_out_dir='vis', pred_out_dir='pred')
|
||||||
|
```
|
||||||
|
|
||||||
|
There is a optional parameter of inferencer, `return_datasamples`, whose default value is False, and
|
||||||
|
return value of inferencer is a `dict` type by default, including 2 keys 'visualization' and 'predictions'.
|
||||||
|
If `return_datasamples=True` inferencer will return [`SegDataSample`](../advanced_guides/structures.md), or list of it.
|
||||||
|
|
||||||
|
```
|
||||||
|
result = inferencer('demo/demo.png')
|
||||||
|
# result is a `dict` including 2 keys 'visualization' and 'predictions'.
|
||||||
|
# 'visualization' includes color segmentation map
|
||||||
|
print(result['visualization'].shape)
|
||||||
|
# (512, 683, 3)
|
||||||
|
|
||||||
|
# 'predictions' includes segmentation mask with label indice
|
||||||
|
print(result['predictions'].shape)
|
||||||
|
# (512, 683)
|
||||||
|
|
||||||
|
result = inferencer('demo/demo.png', return_datasamples=True)
|
||||||
|
print(type(result))
|
||||||
|
# <class 'mmseg.structures.seg_data_sample.SegDataSample'>
|
||||||
|
|
||||||
|
# Input a list of images
|
||||||
|
results = inferencer(images)
|
||||||
|
# The output is list
|
||||||
|
print(type(results['visualization']), results['visualization'][0].shape)
|
||||||
|
# <class 'list'> (512, 683, 3)
|
||||||
|
print(type(results['predictions']), results['predictions'][0].shape)
|
||||||
|
# <class 'list'> (512, 683)
|
||||||
|
|
||||||
|
results = inferencer(images, return_datasamples=True)
|
||||||
|
# <class 'list'>
|
||||||
|
print(type(results[0]))
|
||||||
|
# <class 'mmseg.structures.seg_data_sample.SegDataSample'>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Initialization
|
||||||
|
|
||||||
|
`MMSegInferencer` must be initialized from a `model`, which can be a model name or a `Config` even a path of config file.
|
||||||
|
The model names can be found in models' metafile, like one model name of maskformer is `maskformer_r50-d32_8xb2-160k_ade20k-512x512`, and if input model name and the weights of the model will be download automatically. Below are other input parameters:
|
||||||
|
|
||||||
|
- weights (str, optional) - Path to the checkpoint. If it is not specified and model is a model name of metafile, the weights will be loaded
|
||||||
|
from metafile. Defaults to None.
|
||||||
|
- classes (list, optional) - Input classes for result rendering, as the prediction of segmentation
|
||||||
|
model is a segment map with label indices, `classes` is a list which includes
|
||||||
|
items responding to the label indices. If classes is not defined, visualizer will take `cityscapes` classes by default. Defaults to None.
|
||||||
|
- palette (list, optional) - Input palette for result rendering, which is a list of color palette
|
||||||
|
responding to the classes. If palette is not defined, visualizer will take `cityscapes` palette by default. Defaults to None.
|
||||||
|
- dataset_name (str, optional)[Dataset name or alias](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/class_names.py#L302-L317)
|
||||||
|
visulizer will use the meta information of the dataset i.e. classes and palette,
|
||||||
|
but the `classes` and `palette` have higher priority. Defaults to None.
|
||||||
|
- device (str, optional) - Device to run inference. If None, the available device will be automatically used. Defaults to None.
|
||||||
|
- scope (str, optional) - The scope of the model. Defaults to 'mmseg'.
|
||||||
|
|
||||||
|
### Visualize prediction
|
||||||
|
|
||||||
|
`MMSegInferencer` supports 4 parameters for visualize prediction, you can use them when call initialized inferencer:
|
||||||
|
|
||||||
|
- show (bool) - Whether to display the image in a popup window. Defaults to False.
|
||||||
|
- wait_time (float) - The interval of show (s). Defaults to 0.
|
||||||
|
- img_out_dir (str) - Subdirectory of `out_dir`, used to save rendering color segmentation mask, so `out_dir` must be defined
|
||||||
|
if you would like to save predicted mask. Defaults to 'vis'.
|
||||||
|
- opacity (int, float) - The transparency of segmentation mask. Defaults to 0.8.
|
||||||
|
|
||||||
|
The examples of these parameters is in [Basic Usage](#basic-usage)
|
||||||
|
|
||||||
|
### List model
|
||||||
|
|
||||||
|
There is a very easy to list all model names in MMSegmentation
|
||||||
|
|
||||||
|
```
|
||||||
|
>>> from mmseg.apis import MMSegInferencer
|
||||||
|
# models is a list of model names, and them will print automatically
|
||||||
|
>>> models = MMSegInferencer.list_models('mmseg')
|
||||||
|
```
|
||||||
|
|
||||||
|
## Inference API
|
||||||
|
|
||||||
### mmseg.apis.init_model
|
### mmseg.apis.init_model
|
||||||
|
|
||||||
|
@ -1,15 +1,22 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
import warnings
|
||||||
from typing import List, Optional, Sequence, Union
|
from typing import List, Optional, Sequence, Union
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import mmengine
|
import mmengine
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from mmcv.transforms import Compose
|
from mmcv.transforms import Compose
|
||||||
from mmengine.infer.infer import BaseInferencer, ModelType
|
from mmengine.infer.infer import BaseInferencer, ModelType
|
||||||
|
from mmengine.model import revert_sync_batchnorm
|
||||||
|
from mmengine.registry import init_default_scope
|
||||||
|
from mmengine.runner.checkpoint import _load_checkpoint_to_model
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from mmseg.structures import SegDataSample
|
from mmseg.structures import SegDataSample
|
||||||
from mmseg.utils import ConfigType, SampleList, register_all_modules
|
from mmseg.utils import ConfigType, SampleList, get_classes, get_palette
|
||||||
from mmseg.visualization import SegLocalVisualizer
|
from mmseg.visualization import SegLocalVisualizer
|
||||||
|
|
||||||
InputType = Union[str, np.ndarray]
|
InputType = Union[str, np.ndarray]
|
||||||
@ -23,90 +30,164 @@ class MMSegInferencer(BaseInferencer):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (str, optional): Path to the config file or the model name
|
model (str, optional): Path to the config file or the model name
|
||||||
defined in metafile. For example, it could be
|
defined in metafile. Take the `mmseg metafile <https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/fcn/fcn.yml>`_
|
||||||
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024" or
|
as an example the `model` could be
|
||||||
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py"
|
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024", and the weights of model
|
||||||
|
will be download automatically. If use config file, like
|
||||||
|
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py", the
|
||||||
|
`weights` should be defined.
|
||||||
weights (str, optional): Path to the checkpoint. If it is not specified
|
weights (str, optional): Path to the checkpoint. If it is not specified
|
||||||
and model is a model name of metafile, the weights will be loaded
|
and model is a model name of metafile, the weights will be loaded
|
||||||
from metafile. Defaults to None.
|
from metafile. Defaults to None.
|
||||||
palette (List[List[int]], optional): The palette of
|
classes (list, optional): Input classes for result rendering, as the
|
||||||
segmentation map.
|
prediction of segmentation model is a segment map with label
|
||||||
classes (Tuple[str], optional): Category information.
|
indices, `classes` is a list which includes items responding to the
|
||||||
dataset_name (str, optional): Name of the datasets supported in mmseg.
|
label indices. If classes is not defined, visualizer will take
|
||||||
|
`cityscapes` classes by default. Defaults to None.
|
||||||
|
palette (list, optional): Input palette for result rendering, which is
|
||||||
|
a list of color palette responding to the classes. If palette is
|
||||||
|
not defined, visualizer will take `cityscapes` palette by default.
|
||||||
|
Defaults to None.
|
||||||
|
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/class_names.py#L302-L317>`_
|
||||||
|
visulizer will use the meta information of the dataset i.e. classes
|
||||||
|
and palette, but the `classes` and `palette` have higher priority.
|
||||||
|
Defaults to None.
|
||||||
device (str, optional): Device to run inference. If None, the available
|
device (str, optional): Device to run inference. If None, the available
|
||||||
device will be automatically used. Defaults to None.
|
device will be automatically used. Defaults to None.
|
||||||
scope (str, optional): The scope of the model. Defaults to None.
|
scope (str, optional): The scope of the model. Defaults to 'mmseg'.
|
||||||
"""
|
""" # noqa
|
||||||
|
|
||||||
preprocess_kwargs: set = set()
|
preprocess_kwargs: set = set()
|
||||||
forward_kwargs: set = {'mode', 'out_dir'}
|
forward_kwargs: set = {'mode', 'out_dir'}
|
||||||
visualize_kwargs: set = {
|
visualize_kwargs: set = {'show', 'wait_time', 'img_out_dir', 'opacity'}
|
||||||
'show', 'wait_time', 'draw_pred', 'img_out_dir', 'opacity'
|
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
|
||||||
}
|
|
||||||
postprocess_kwargs: set = {
|
|
||||||
'pred_out_dir', 'return_datasample', 'save_mask', 'mask_dir'
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model: Union[ModelType, str],
|
model: Union[ModelType, str],
|
||||||
weights: Optional[str] = None,
|
weights: Optional[str] = None,
|
||||||
palette: Optional[Union[str, List]] = None,
|
|
||||||
classes: Optional[Union[str, List]] = None,
|
classes: Optional[Union[str, List]] = None,
|
||||||
|
palette: Optional[Union[str, List]] = None,
|
||||||
dataset_name: Optional[str] = None,
|
dataset_name: Optional[str] = None,
|
||||||
device: Optional[str] = None,
|
device: Optional[str] = None,
|
||||||
scope: Optional[str] = 'mmseg') -> None:
|
scope: Optional[str] = 'mmseg') -> None:
|
||||||
# A global counter tracking the number of images processes, for
|
# A global counter tracking the number of images processes, for
|
||||||
# naming of the output images
|
# naming of the output images
|
||||||
self.num_visualized_imgs = 0
|
self.num_visualized_imgs = 0
|
||||||
register_all_modules()
|
self.num_pred_imgs = 0
|
||||||
|
init_default_scope(scope if scope else 'mmseg')
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model=model, weights=weights, device=device, scope=scope)
|
model=model, weights=weights, device=device, scope=scope)
|
||||||
|
|
||||||
|
if device == 'cpu' or not torch.cuda.is_available():
|
||||||
|
self.model = revert_sync_batchnorm(self.model)
|
||||||
|
|
||||||
assert isinstance(self.visualizer, SegLocalVisualizer)
|
assert isinstance(self.visualizer, SegLocalVisualizer)
|
||||||
self.visualizer.set_dataset_meta(palette, classes, dataset_name)
|
self.visualizer.set_dataset_meta(palette, classes, dataset_name)
|
||||||
|
|
||||||
|
def _load_weights_to_model(self, model: nn.Module,
|
||||||
|
checkpoint: Optional[dict],
|
||||||
|
cfg: Optional[ConfigType]) -> None:
|
||||||
|
"""Loading model weights and meta information from cfg and checkpoint.
|
||||||
|
|
||||||
|
Subclasses could override this method to load extra meta information
|
||||||
|
from ``checkpoint`` and ``cfg`` to model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): Model to load weights and meta information.
|
||||||
|
checkpoint (dict, optional): The loaded checkpoint.
|
||||||
|
cfg (Config or ConfigDict, optional): The loaded config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if checkpoint is not None:
|
||||||
|
_load_checkpoint_to_model(model, checkpoint)
|
||||||
|
checkpoint_meta = checkpoint.get('meta', {})
|
||||||
|
# save the dataset_meta in the model for convenience
|
||||||
|
if 'dataset_meta' in checkpoint_meta:
|
||||||
|
# mmsegmentation 1.x
|
||||||
|
model.dataset_meta = {
|
||||||
|
'classes': checkpoint_meta['dataset_meta'].get('classes'),
|
||||||
|
'palette': checkpoint_meta['dataset_meta'].get('palette')
|
||||||
|
}
|
||||||
|
elif 'CLASSES' in checkpoint_meta:
|
||||||
|
# mmsegmentation 0.x
|
||||||
|
classes = checkpoint_meta['CLASSES']
|
||||||
|
palette = checkpoint_meta.get('PALETTE', None)
|
||||||
|
model.dataset_meta = {'classes': classes, 'palette': palette}
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
'dataset_meta or class names are not saved in the '
|
||||||
|
'checkpoint\'s meta data, use classes of Cityscapes by '
|
||||||
|
'default.')
|
||||||
|
model.dataset_meta = {
|
||||||
|
'classes': get_classes('cityscapes'),
|
||||||
|
'palette': get_palette('cityscapes')
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
warnings.warn('Checkpoint is not loaded, and the inference '
|
||||||
|
'result is calculated by the randomly initialized '
|
||||||
|
'model!')
|
||||||
|
warnings.warn(
|
||||||
|
'weights is None, use cityscapes classes by default.')
|
||||||
|
model.dataset_meta = {
|
||||||
|
'classes': get_classes('cityscapes'),
|
||||||
|
'palette': get_palette('cityscapes')
|
||||||
|
}
|
||||||
|
|
||||||
def __call__(self,
|
def __call__(self,
|
||||||
inputs: InputsType,
|
inputs: InputsType,
|
||||||
return_datasamples: bool = False,
|
return_datasamples: bool = False,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
show: bool = False,
|
show: bool = False,
|
||||||
wait_time: int = 0,
|
wait_time: int = 0,
|
||||||
draw_pred: bool = True,
|
|
||||||
out_dir: str = '',
|
out_dir: str = '',
|
||||||
save_mask: bool = False,
|
img_out_dir: str = 'vis',
|
||||||
mask_dir: str = 'mask',
|
pred_out_dir: str = 'pred',
|
||||||
**kwargs) -> dict:
|
**kwargs) -> dict:
|
||||||
"""Call the inferencer.
|
"""Call the inferencer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (Union[str, np.ndarray]): Inputs for the inferencer.
|
inputs (Union[list, str, np.ndarray]): Inputs for the inferencer.
|
||||||
return_datasamples (bool): Whether to return results as
|
return_datasamples (bool): Whether to return results as
|
||||||
:obj:`SegDataSample`. Defaults to False.
|
:obj:`SegDataSample`. Defaults to False.
|
||||||
batch_size (int): Batch size. Defaults to 1.
|
batch_size (int): Batch size. Defaults to 1.
|
||||||
show (bool): Whether to display the image in a popup window.
|
show (bool): Whether to display the rendering color segmentation
|
||||||
Defaults to False.
|
mask in a popup window. Defaults to False.
|
||||||
wait_time (float): The interval of show (s). Defaults to 0.
|
wait_time (float): The interval of show (s). Defaults to 0.
|
||||||
draw_pred (bool): Whether to draw Prediction SegDataSample.
|
out_dir (str): Output directory of inference results. Defaults
|
||||||
Defaults to True.
|
to ''.
|
||||||
out_dir (str): Output directory of inference results. Defaults: ''.
|
img_out_dir (str): Subdirectory of `out_dir`, used to save
|
||||||
save_mask (bool): Whether save pred mask as a file.
|
rendering color segmentation mask, so `out_dir` must be defined
|
||||||
mask_dir (str): Sub directory of `pred_out_dir`, used to save pred
|
if you would like to save predicted mask. Defaults to 'vis'.
|
||||||
mask file.
|
pred_out_dir (str): Subdirectory of `out_dir`, used to save
|
||||||
|
predicted mask file, so `out_dir` must be defined if you would
|
||||||
|
like to save predicted mask. Defaults to 'pred'.
|
||||||
|
|
||||||
|
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
|
||||||
|
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
|
||||||
|
Each key in kwargs should be in the corresponding set of
|
||||||
|
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
|
||||||
|
and ``postprocess_kwargs``.
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Inference and visualization results.
|
dict: Inference and visualization results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if out_dir != '':
|
||||||
|
pred_out_dir = osp.join(out_dir, pred_out_dir)
|
||||||
|
img_out_dir = osp.join(out_dir, img_out_dir)
|
||||||
|
else:
|
||||||
|
pred_out_dir = ''
|
||||||
|
img_out_dir = ''
|
||||||
|
|
||||||
return super().__call__(
|
return super().__call__(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
return_datasamples=return_datasamples,
|
return_datasamples=return_datasamples,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
show=show,
|
show=show,
|
||||||
wait_time=wait_time,
|
wait_time=wait_time,
|
||||||
draw_pred=draw_pred,
|
img_out_dir=img_out_dir,
|
||||||
img_out_dir=out_dir,
|
pred_out_dir=pred_out_dir,
|
||||||
pred_out_dir=out_dir,
|
|
||||||
save_mask=save_mask,
|
|
||||||
mask_dir=mask_dir,
|
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def visualize(self,
|
def visualize(self,
|
||||||
@ -114,7 +195,6 @@ class MMSegInferencer(BaseInferencer):
|
|||||||
preds: List[dict],
|
preds: List[dict],
|
||||||
show: bool = False,
|
show: bool = False,
|
||||||
wait_time: int = 0,
|
wait_time: int = 0,
|
||||||
draw_pred: bool = True,
|
|
||||||
img_out_dir: str = '',
|
img_out_dir: str = '',
|
||||||
opacity: float = 0.8) -> List[np.ndarray]:
|
opacity: float = 0.8) -> List[np.ndarray]:
|
||||||
"""Visualize predictions.
|
"""Visualize predictions.
|
||||||
@ -125,9 +205,8 @@ class MMSegInferencer(BaseInferencer):
|
|||||||
show (bool): Whether to display the image in a popup window.
|
show (bool): Whether to display the image in a popup window.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
wait_time (float): The interval of show (s). Defaults to 0.
|
wait_time (float): The interval of show (s). Defaults to 0.
|
||||||
draw_pred (bool): Whether to draw Prediction SegDataSample.
|
img_out_dir (str): Output directory of rendering prediction i.e.
|
||||||
Defaults to True.
|
color segmentation mask. Defaults: ''
|
||||||
img_out_dir (str): Output directory of drawn images. Defaults: ''
|
|
||||||
opacity (int, float): The transparency of segmentation mask.
|
opacity (int, float): The transparency of segmentation mask.
|
||||||
Defaults to 0.8.
|
Defaults to 0.8.
|
||||||
|
|
||||||
@ -140,7 +219,7 @@ class MMSegInferencer(BaseInferencer):
|
|||||||
if getattr(self, 'visualizer') is None:
|
if getattr(self, 'visualizer') is None:
|
||||||
raise ValueError('Visualization needs the "visualizer" term'
|
raise ValueError('Visualization needs the "visualizer" term'
|
||||||
'defined in the config, but got None')
|
'defined in the config, but got None')
|
||||||
|
self.visualizer.set_dataset_meta(**self.model.dataset_meta)
|
||||||
self.visualizer.alpha = opacity
|
self.visualizer.alpha = opacity
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
@ -153,7 +232,7 @@ class MMSegInferencer(BaseInferencer):
|
|||||||
img_name = osp.basename(single_input)
|
img_name = osp.basename(single_input)
|
||||||
elif isinstance(single_input, np.ndarray):
|
elif isinstance(single_input, np.ndarray):
|
||||||
img = single_input.copy()
|
img = single_input.copy()
|
||||||
img_num = str(self.num_visualized_imgs).zfill(8)
|
img_num = str(self.num_visualized_imgs).zfill(8) + '_vis'
|
||||||
img_name = f'{img_num}.jpg'
|
img_name = f'{img_num}.jpg'
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unsupported input type:'
|
raise ValueError('Unsupported input type:'
|
||||||
@ -169,7 +248,7 @@ class MMSegInferencer(BaseInferencer):
|
|||||||
show=show,
|
show=show,
|
||||||
wait_time=wait_time,
|
wait_time=wait_time,
|
||||||
draw_gt=False,
|
draw_gt=False,
|
||||||
draw_pred=draw_pred,
|
draw_pred=True,
|
||||||
out_file=out_file)
|
out_file=out_file)
|
||||||
results.append(self.visualizer.get_image())
|
results.append(self.visualizer.get_image())
|
||||||
self.num_visualized_imgs += 1
|
self.num_visualized_imgs += 1
|
||||||
@ -180,62 +259,65 @@ class MMSegInferencer(BaseInferencer):
|
|||||||
preds: PredType,
|
preds: PredType,
|
||||||
visualization: List[np.ndarray],
|
visualization: List[np.ndarray],
|
||||||
return_datasample: bool = False,
|
return_datasample: bool = False,
|
||||||
mask_dir: str = 'mask',
|
|
||||||
save_mask: bool = True,
|
|
||||||
pred_out_dir: str = '') -> dict:
|
pred_out_dir: str = '') -> dict:
|
||||||
"""Process the predictions and visualization results from ``forward``
|
"""Process the predictions and visualization results from ``forward``
|
||||||
and ``visualize``.
|
and ``visualize``.
|
||||||
|
|
||||||
This method should be responsible for the following tasks:
|
This method should be responsible for the following tasks:
|
||||||
|
|
||||||
1. Convert datasamples into a json-serializable dict if needed.
|
1. Pack the predictions and visualization results and return them.
|
||||||
2. Pack the predictions and visualization results and return them.
|
2. Save the predictions, if it needed.
|
||||||
3. Dump or log the predictions.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
preds (List[Dict]): Predictions of the model.
|
preds (List[Dict]): Predictions of the model.
|
||||||
visualization (np.ndarray): Visualized predictions.
|
visualization (List[np.ndarray]): The list of rendering color
|
||||||
|
segmentation mask.
|
||||||
return_datasample (bool): Whether to return results as datasamples.
|
return_datasample (bool): Whether to return results as datasamples.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
pred_out_dir: File to save the inference results w/o
|
pred_out_dir: File to save the inference results w/o
|
||||||
visualization. If left as empty, no file will be saved.
|
visualization. If left as empty, no file will be saved.
|
||||||
Defaults to ''.
|
Defaults to ''.
|
||||||
mask_dir (str): Sub directory of `pred_out_dir`, used to save pred
|
|
||||||
mask file.
|
|
||||||
save_mask (bool): Whether save pred mask as a file.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Inference and visualization results with key ``predictions``
|
dict: Inference and visualization results with key ``predictions``
|
||||||
and ``visualization``
|
and ``visualization``
|
||||||
|
|
||||||
- ``visualization (Any)``: Returned by :meth:`visualize`
|
- ``visualization (Any)``: Returned by :meth:`visualize`
|
||||||
- ``predictions`` (dict or DataSample): Returned by
|
- ``predictions`` (List[np.ndarray], np.ndarray): Returned by
|
||||||
:meth:`forward` and processed in :meth:`postprocess`.
|
:meth:`forward` and processed in :meth:`postprocess`.
|
||||||
If ``return_datasample=False``, it usually should be a
|
If ``return_datasample=False``, it will be the segmentation mask
|
||||||
json-serializable dict containing only basic data elements such
|
with label indice.
|
||||||
as strings and numbers.
|
|
||||||
"""
|
"""
|
||||||
|
if return_datasample:
|
||||||
|
if len(preds) == 1:
|
||||||
|
return preds[0]
|
||||||
|
else:
|
||||||
|
return preds
|
||||||
|
|
||||||
results_dict = {}
|
results_dict = {}
|
||||||
|
|
||||||
results_dict['predictions'] = preds
|
results_dict['predictions'] = []
|
||||||
results_dict['visualization'] = visualization
|
results_dict['visualization'] = []
|
||||||
|
|
||||||
if pred_out_dir != '':
|
for i, pred in enumerate(preds):
|
||||||
mmengine.mkdir_or_exist(pred_out_dir)
|
pred_data = pred.pred_sem_seg.numpy().data[0]
|
||||||
if save_mask:
|
results_dict['predictions'].append(pred_data)
|
||||||
preds = [preds] if isinstance(preds, SegDataSample) else preds
|
if visualization is not None:
|
||||||
for pred in preds:
|
vis = visualization[i]
|
||||||
mmcv.imwrite(
|
results_dict['visualization'].append(vis)
|
||||||
pred.pred_sem_seg.numpy().data[0],
|
if pred_out_dir != '':
|
||||||
osp.join(pred_out_dir, mask_dir,
|
mmengine.mkdir_or_exist(pred_out_dir)
|
||||||
osp.basename(pred.metainfo['img_path'])))
|
img_name = str(self.num_pred_imgs).zfill(8) + '_pred.png'
|
||||||
else:
|
img_path = osp.join(pred_out_dir, img_name)
|
||||||
mmengine.dump(results_dict,
|
output = Image.fromarray(pred_data.astype(np.uint8))
|
||||||
osp.join(pred_out_dir, 'results.pkl'))
|
output.save(img_path)
|
||||||
|
self.num_pred_imgs += 1
|
||||||
if return_datasample:
|
|
||||||
return preds
|
|
||||||
|
|
||||||
|
if len(results_dict['predictions']) == 1:
|
||||||
|
results_dict['predictions'] = results_dict['predictions'][0]
|
||||||
|
if visualization is not None:
|
||||||
|
results_dict['visualization'] = \
|
||||||
|
results_dict['visualization'][0]
|
||||||
return results_dict
|
return results_dict
|
||||||
|
|
||||||
def _init_pipeline(self, cfg: ConfigType) -> Compose:
|
def _init_pipeline(self, cfg: ConfigType) -> Compose:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -24,6 +24,17 @@ class SegLocalVisualizer(Visualizer):
|
|||||||
Defaults to None.
|
Defaults to None.
|
||||||
save_dir (str, optional): Save file dir for all storage backends.
|
save_dir (str, optional): Save file dir for all storage backends.
|
||||||
If it is None, the backend storage will not save any data.
|
If it is None, the backend storage will not save any data.
|
||||||
|
classes (list, optional): Input classes for result rendering, as the
|
||||||
|
prediction of segmentation model is a segment map with label
|
||||||
|
indices, `classes` is a list which includes items responding to the
|
||||||
|
label indices. If classes is not defined, visualizer will take
|
||||||
|
`cityscapes` classes by default. Defaults to None.
|
||||||
|
palette (list, optional): Input palette for result rendering, which is
|
||||||
|
a list of color palette responding to the classes. Defaults to None.
|
||||||
|
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/class_names.py#L302-L317>`_
|
||||||
|
visulizer will use the meta information of the dataset i.e. classes
|
||||||
|
and palette, but the `classes` and `palette` have higher priority.
|
||||||
|
Defaults to None.
|
||||||
alpha (int, float): The transparency of segmentation mask.
|
alpha (int, float): The transparency of segmentation mask.
|
||||||
Defaults to 0.8.
|
Defaults to 0.8.
|
||||||
|
|
||||||
@ -49,15 +60,15 @@ class SegLocalVisualizer(Visualizer):
|
|||||||
>>> seg_local_visualizer.add_datasample(
|
>>> seg_local_visualizer.add_datasample(
|
||||||
... 'visualizer_example', image,
|
... 'visualizer_example', image,
|
||||||
... gt_seg_data_sample, show=True)
|
... gt_seg_data_sample, show=True)
|
||||||
"""
|
""" # noqa
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
name: str = 'visualizer',
|
name: str = 'visualizer',
|
||||||
image: Optional[np.ndarray] = None,
|
image: Optional[np.ndarray] = None,
|
||||||
vis_backends: Optional[Dict] = None,
|
vis_backends: Optional[Dict] = None,
|
||||||
save_dir: Optional[str] = None,
|
save_dir: Optional[str] = None,
|
||||||
palette: Optional[Union[str, List]] = None,
|
classes: Optional[List] = None,
|
||||||
classes: Optional[Union[str, List]] = None,
|
palette: Optional[List] = None,
|
||||||
dataset_name: Optional[str] = None,
|
dataset_name: Optional[str] = None,
|
||||||
alpha: float = 0.8,
|
alpha: float = 0.8,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -66,17 +77,23 @@ class SegLocalVisualizer(Visualizer):
|
|||||||
self.set_dataset_meta(palette, classes, dataset_name)
|
self.set_dataset_meta(palette, classes, dataset_name)
|
||||||
|
|
||||||
def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
|
def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
|
||||||
classes: Optional[Tuple[str]],
|
classes: Optional[List],
|
||||||
palette: Optional[List[List[int]]]) -> np.ndarray:
|
palette: Optional[List]) -> np.ndarray:
|
||||||
"""Draw semantic seg of GT or prediction.
|
"""Draw semantic seg of GT or prediction.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image (np.ndarray): The image to draw.
|
image (np.ndarray): The image to draw.
|
||||||
sem_seg (:obj:`PixelData`): Data structure for
|
sem_seg (:obj:`PixelData`): Data structure for pixel-level
|
||||||
pixel-level annotations or predictions.
|
annotations or predictions.
|
||||||
classes (Tuple[str], optional): Category information.
|
classes (list, optional): Input classes for result rendering, as
|
||||||
palette (List[List[int]], optional): The palette of
|
the prediction of segmentation model is a segment map with
|
||||||
segmentation map.
|
label indices, `classes` is a list which includes items
|
||||||
|
responding to the label indices. If classes is not defined,
|
||||||
|
visualizer will take `cityscapes` classes by default.
|
||||||
|
Defaults to None.
|
||||||
|
palette (list, optional): Input palette for result rendering, which
|
||||||
|
is a list of color palette responding to the classes.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: the drawn image which channel is RGB.
|
np.ndarray: the drawn image which channel is RGB.
|
||||||
@ -101,9 +118,26 @@ class SegLocalVisualizer(Visualizer):
|
|||||||
return self.get_image()
|
return self.get_image()
|
||||||
|
|
||||||
def set_dataset_meta(self,
|
def set_dataset_meta(self,
|
||||||
palette: Optional[Union[str, List]] = None,
|
classes: Optional[List] = None,
|
||||||
classes: Optional[Union[str, List]] = None,
|
palette: Optional[List] = None,
|
||||||
dataset_name: Optional[str] = None) -> None:
|
dataset_name: Optional[str] = None) -> None:
|
||||||
|
"""Set meta information to visualizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
classes (list, optional): Input classes for result rendering, as
|
||||||
|
the prediction of segmentation model is a segment map with
|
||||||
|
label indices, `classes` is a list which includes items
|
||||||
|
responding to the label indices. If classes is not defined,
|
||||||
|
visualizer will take `cityscapes` classes by default.
|
||||||
|
Defaults to None.
|
||||||
|
palette (list, optional): Input palette for result rendering, which
|
||||||
|
is a list of color palette responding to the classes.
|
||||||
|
Defaults to None.
|
||||||
|
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/class_names.py#L302-L317>`_
|
||||||
|
visulizer will use the meta information of the dataset i.e.
|
||||||
|
classes and palette, but the `classes` and `palette` have
|
||||||
|
higher priority. Defaults to None.
|
||||||
|
""" # noqa
|
||||||
# Set default value. When calling
|
# Set default value. When calling
|
||||||
# `SegLocalVisualizer().dataset_meta=xxx`,
|
# `SegLocalVisualizer().dataset_meta=xxx`,
|
||||||
# it will override the default value.
|
# it will override the default value.
|
||||||
|
@ -104,12 +104,10 @@ def test_inferencer():
|
|||||||
|
|
||||||
imgs = [img, img]
|
imgs = [img, img]
|
||||||
infer(imgs)
|
infer(imgs)
|
||||||
results = infer(imgs, out_dir=tempfile.gettempdir(), draw_pred=True)
|
results = infer(imgs, out_dir=tempfile.gettempdir())
|
||||||
|
|
||||||
# test results
|
# test results
|
||||||
assert 'predictions' in results
|
assert 'predictions' in results
|
||||||
assert 'visualization' in results
|
assert 'visualization' in results
|
||||||
assert len(results['predictions']) == 2
|
assert len(results['predictions']) == 2
|
||||||
assert results['predictions'][0].seg_logits.data.shape == torch.Size(
|
assert results['predictions'][0].shape == (4, 4)
|
||||||
(19, 4, 4))
|
|
||||||
assert results['predictions'][0].pred_sem_seg.shape == torch.Size((4, 4))
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user