[Enhancement] Modify interface of MMSeginferencer and add docs (#2658)

## Motivation

Make MMSeginferencer easier to be used

## Modification

1. Add `_load_weights_to_model` to MMSeginferencer, it is for get
`dataset_meta` from ckpt
2. Modify and remove some parameters of `__call__`, `visualization` and
`postprocess`
3. Add function of save seg mask, remove dump pkl.
4. Refine docstring of MMSeginferencer and SegLocalVisualizer
5. Add the user documentation of MMSeginferencer

## BC-breaking (Optional)

yes, remove some parameters, we need to discuss whether keep them with
deprecated waring or just remove them as the MMSeginferencer just merged
in mmseg a few days.

Co-authored-by: xiexinch <xiexinch@outlook.com>
pull/2673/head
Miao Zheng 2023-03-03 14:37:54 +08:00 committed by GitHub
parent 8c1d299cb6
commit 310ec4afc7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 328 additions and 104 deletions

View File

@ -16,11 +16,6 @@ def main():
action='store_true',
default=False,
help='Whether to display the drawn image.')
parser.add_argument(
'--save-mask',
action='store_true',
default=False,
help='Enable save the mask file')
parser.add_argument(
'--dataset-name',
default='cityscapes',
@ -43,11 +38,7 @@ def main():
# test a single image
mmseg_inferencer(
args.img,
show=args.show,
out_dir=args.out_dir,
save_mask=args.save_mask,
opacity=args.opacity)
args.img, show=args.show, out_dir=args.out_dir, opacity=args.opacity)
if __name__ == '__main__':

View File

@ -4,13 +4,132 @@ MMSegmentation provides pre-trained models for semantic segmentation in [Model Z
This note will show how to use existing models to inference on given images.
As for how to test existing models on standard datasets, please see this [guide](./4_train_test.md)
## Inference API
MMSegmentation provides several interfaces for users to easily use pre-trained models for inference.
- [mmseg.apis.init_model](#mmsegapisinit_model)
- [mmseg.apis.inference_model](#mmsegapisinference_model)
- [mmseg.apis.show_result_pyplot](#mmsegapisshow_result_pyplot)
- [Tutorial 3: Inference with existing models](#tutorial-3-inference-with-existing-models)
- [Inferencer](#inferencer)
- [Basic Usage](#basic-usage)
- [Initialization](#initialization)
- [Visualize prediction](#visualize-prediction)
- [List model](#list-model)
- [Inference API](#inference-api)
- [mmseg.apis.init_model](#mmsegapisinit_model)
- [mmseg.apis.inference_model](#mmsegapisinference_model)
- [mmseg.apis.show_result_pyplot](#mmsegapisshow_result_pyplot)
## Inferencer
We provides the most **convenient** way to use the model in MMSegmentation `MMSegInferencer`. You can get segmentation mask for an image with only 3 lines of code.
### Basic Usage
The following example shows how to use `MMSegInferencer` to perform inference on a single image.
```
>>> from mmseg.apis import MMSegInferencer
>>> # Load models into memory
>>> inferencer = MMSegInferencer(model='deeplabv3plus_r18-d8_4xb2-80k_cityscapes-512x1024')
>>> # Inference
>>> inferencer('demo/demo.png', show=True)
```
The visualization result should look like:
<div align="center">
https://user-images.githubusercontent.com/76149310/221507927-ae01e3a7-016f-4425-b966-7b19cbbe494e.png
</div>
Moreover, you can use `MMSegInferencer` to process a list of images:
```
# Input a list of images
>>> images = [image1, image2, ...] # image1 can be a file path or a np.ndarray
>>> inferencer(images, show=True, wait_time=0.5) # wait_time is delay time, and 0 means forever.
# Or input image directory
>>> images = $IMAGESDIR
>>> inferencer(images, show=True, wait_time=0.5)
# Save visualized rendering color maps and predicted results
# out_dir is the directory to save the output results, img_out_dir and pred_out_dir are subdirectories of out_dir
# to save visualized rendering color maps and predicted results
>>> inferencer(images, out_dir='outputs', img_out_dir='vis', pred_out_dir='pred')
```
There is a optional parameter of inferencer, `return_datasamples`, whose default value is False, and
return value of inferencer is a `dict` type by default, including 2 keys 'visualization' and 'predictions'.
If `return_datasamples=True` inferencer will return [`SegDataSample`](../advanced_guides/structures.md), or list of it.
```
result = inferencer('demo/demo.png')
# result is a `dict` including 2 keys 'visualization' and 'predictions'.
# 'visualization' includes color segmentation map
print(result['visualization'].shape)
# (512, 683, 3)
# 'predictions' includes segmentation mask with label indice
print(result['predictions'].shape)
# (512, 683)
result = inferencer('demo/demo.png', return_datasamples=True)
print(type(result))
# <class 'mmseg.structures.seg_data_sample.SegDataSample'>
# Input a list of images
results = inferencer(images)
# The output is list
print(type(results['visualization']), results['visualization'][0].shape)
# <class 'list'> (512, 683, 3)
print(type(results['predictions']), results['predictions'][0].shape)
# <class 'list'> (512, 683)
results = inferencer(images, return_datasamples=True)
# <class 'list'>
print(type(results[0]))
# <class 'mmseg.structures.seg_data_sample.SegDataSample'>
```
### Initialization
`MMSegInferencer` must be initialized from a `model`, which can be a model name or a `Config` even a path of config file.
The model names can be found in models' metafile, like one model name of maskformer is `maskformer_r50-d32_8xb2-160k_ade20k-512x512`, and if input model name and the weights of the model will be download automatically. Below are other input parameters:
- weights (str, optional) - Path to the checkpoint. If it is not specified and model is a model name of metafile, the weights will be loaded
from metafile. Defaults to None.
- classes (list, optional) - Input classes for result rendering, as the prediction of segmentation
model is a segment map with label indices, `classes` is a list which includes
items responding to the label indices. If classes is not defined, visualizer will take `cityscapes` classes by default. Defaults to None.
- palette (list, optional) - Input palette for result rendering, which is a list of color palette
responding to the classes. If palette is not defined, visualizer will take `cityscapes` palette by default. Defaults to None.
- dataset_name (str, optional)[Dataset name or alias](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/class_names.py#L302-L317)
visulizer will use the meta information of the dataset i.e. classes and palette,
but the `classes` and `palette` have higher priority. Defaults to None.
- device (str, optional) - Device to run inference. If None, the available device will be automatically used. Defaults to None.
- scope (str, optional) - The scope of the model. Defaults to 'mmseg'.
### Visualize prediction
`MMSegInferencer` supports 4 parameters for visualize prediction, you can use them when call initialized inferencer:
- show (bool) - Whether to display the image in a popup window. Defaults to False.
- wait_time (float) - The interval of show (s). Defaults to 0.
- img_out_dir (str) - Subdirectory of `out_dir`, used to save rendering color segmentation mask, so `out_dir` must be defined
if you would like to save predicted mask. Defaults to 'vis'.
- opacity (int, float) - The transparency of segmentation mask. Defaults to 0.8.
The examples of these parameters is in [Basic Usage](#basic-usage)
### List model
There is a very easy to list all model names in MMSegmentation
```
>>> from mmseg.apis import MMSegInferencer
# models is a list of model names, and them will print automatically
>>> models = MMSegInferencer.list_models('mmseg')
```
## Inference API
### mmseg.apis.init_model

View File

@ -1,15 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import List, Optional, Sequence, Union
import mmcv
import mmengine
import numpy as np
import torch
import torch.nn as nn
from mmcv.transforms import Compose
from mmengine.infer.infer import BaseInferencer, ModelType
from mmengine.model import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.runner.checkpoint import _load_checkpoint_to_model
from PIL import Image
from mmseg.structures import SegDataSample
from mmseg.utils import ConfigType, SampleList, register_all_modules
from mmseg.utils import ConfigType, SampleList, get_classes, get_palette
from mmseg.visualization import SegLocalVisualizer
InputType = Union[str, np.ndarray]
@ -23,90 +30,164 @@ class MMSegInferencer(BaseInferencer):
Args:
model (str, optional): Path to the config file or the model name
defined in metafile. For example, it could be
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024" or
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py"
defined in metafile. Take the `mmseg metafile <https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/fcn/fcn.yml>`_
as an example the `model` could be
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024", and the weights of model
will be download automatically. If use config file, like
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py", the
`weights` should be defined.
weights (str, optional): Path to the checkpoint. If it is not specified
and model is a model name of metafile, the weights will be loaded
from metafile. Defaults to None.
palette (List[List[int]], optional): The palette of
segmentation map.
classes (Tuple[str], optional): Category information.
dataset_name (str, optional): Name of the datasets supported in mmseg.
classes (list, optional): Input classes for result rendering, as the
prediction of segmentation model is a segment map with label
indices, `classes` is a list which includes items responding to the
label indices. If classes is not defined, visualizer will take
`cityscapes` classes by default. Defaults to None.
palette (list, optional): Input palette for result rendering, which is
a list of color palette responding to the classes. If palette is
not defined, visualizer will take `cityscapes` palette by default.
Defaults to None.
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/class_names.py#L302-L317>`_
visulizer will use the meta information of the dataset i.e. classes
and palette, but the `classes` and `palette` have higher priority.
Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
scope (str, optional): The scope of the model. Defaults to None.
"""
scope (str, optional): The scope of the model. Defaults to 'mmseg'.
""" # noqa
preprocess_kwargs: set = set()
forward_kwargs: set = {'mode', 'out_dir'}
visualize_kwargs: set = {
'show', 'wait_time', 'draw_pred', 'img_out_dir', 'opacity'
}
postprocess_kwargs: set = {
'pred_out_dir', 'return_datasample', 'save_mask', 'mask_dir'
}
visualize_kwargs: set = {'show', 'wait_time', 'img_out_dir', 'opacity'}
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
def __init__(self,
model: Union[ModelType, str],
weights: Optional[str] = None,
palette: Optional[Union[str, List]] = None,
classes: Optional[Union[str, List]] = None,
palette: Optional[Union[str, List]] = None,
dataset_name: Optional[str] = None,
device: Optional[str] = None,
scope: Optional[str] = 'mmseg') -> None:
# A global counter tracking the number of images processes, for
# naming of the output images
self.num_visualized_imgs = 0
register_all_modules()
self.num_pred_imgs = 0
init_default_scope(scope if scope else 'mmseg')
super().__init__(
model=model, weights=weights, device=device, scope=scope)
if device == 'cpu' or not torch.cuda.is_available():
self.model = revert_sync_batchnorm(self.model)
assert isinstance(self.visualizer, SegLocalVisualizer)
self.visualizer.set_dataset_meta(palette, classes, dataset_name)
def _load_weights_to_model(self, model: nn.Module,
checkpoint: Optional[dict],
cfg: Optional[ConfigType]) -> None:
"""Loading model weights and meta information from cfg and checkpoint.
Subclasses could override this method to load extra meta information
from ``checkpoint`` and ``cfg`` to model.
Args:
model (nn.Module): Model to load weights and meta information.
checkpoint (dict, optional): The loaded checkpoint.
cfg (Config or ConfigDict, optional): The loaded config.
"""
if checkpoint is not None:
_load_checkpoint_to_model(model, checkpoint)
checkpoint_meta = checkpoint.get('meta', {})
# save the dataset_meta in the model for convenience
if 'dataset_meta' in checkpoint_meta:
# mmsegmentation 1.x
model.dataset_meta = {
'classes': checkpoint_meta['dataset_meta'].get('classes'),
'palette': checkpoint_meta['dataset_meta'].get('palette')
}
elif 'CLASSES' in checkpoint_meta:
# mmsegmentation 0.x
classes = checkpoint_meta['CLASSES']
palette = checkpoint_meta.get('PALETTE', None)
model.dataset_meta = {'classes': classes, 'palette': palette}
else:
warnings.warn(
'dataset_meta or class names are not saved in the '
'checkpoint\'s meta data, use classes of Cityscapes by '
'default.')
model.dataset_meta = {
'classes': get_classes('cityscapes'),
'palette': get_palette('cityscapes')
}
else:
warnings.warn('Checkpoint is not loaded, and the inference '
'result is calculated by the randomly initialized '
'model!')
warnings.warn(
'weights is None, use cityscapes classes by default.')
model.dataset_meta = {
'classes': get_classes('cityscapes'),
'palette': get_palette('cityscapes')
}
def __call__(self,
inputs: InputsType,
return_datasamples: bool = False,
batch_size: int = 1,
show: bool = False,
wait_time: int = 0,
draw_pred: bool = True,
out_dir: str = '',
save_mask: bool = False,
mask_dir: str = 'mask',
img_out_dir: str = 'vis',
pred_out_dir: str = 'pred',
**kwargs) -> dict:
"""Call the inferencer.
Args:
inputs (Union[str, np.ndarray]): Inputs for the inferencer.
inputs (Union[list, str, np.ndarray]): Inputs for the inferencer.
return_datasamples (bool): Whether to return results as
:obj:`SegDataSample`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
show (bool): Whether to display the image in a popup window.
Defaults to False.
show (bool): Whether to display the rendering color segmentation
mask in a popup window. Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw Prediction SegDataSample.
Defaults to True.
out_dir (str): Output directory of inference results. Defaults: ''.
save_mask (bool): Whether save pred mask as a file.
mask_dir (str): Sub directory of `pred_out_dir`, used to save pred
mask file.
out_dir (str): Output directory of inference results. Defaults
to ''.
img_out_dir (str): Subdirectory of `out_dir`, used to save
rendering color segmentation mask, so `out_dir` must be defined
if you would like to save predicted mask. Defaults to 'vis'.
pred_out_dir (str): Subdirectory of `out_dir`, used to save
predicted mask file, so `out_dir` must be defined if you would
like to save predicted mask. Defaults to 'pred'.
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
Each key in kwargs should be in the corresponding set of
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
and ``postprocess_kwargs``.
Returns:
dict: Inference and visualization results.
"""
if out_dir != '':
pred_out_dir = osp.join(out_dir, pred_out_dir)
img_out_dir = osp.join(out_dir, img_out_dir)
else:
pred_out_dir = ''
img_out_dir = ''
return super().__call__(
inputs=inputs,
return_datasamples=return_datasamples,
batch_size=batch_size,
show=show,
wait_time=wait_time,
draw_pred=draw_pred,
img_out_dir=out_dir,
pred_out_dir=out_dir,
save_mask=save_mask,
mask_dir=mask_dir,
img_out_dir=img_out_dir,
pred_out_dir=pred_out_dir,
**kwargs)
def visualize(self,
@ -114,7 +195,6 @@ class MMSegInferencer(BaseInferencer):
preds: List[dict],
show: bool = False,
wait_time: int = 0,
draw_pred: bool = True,
img_out_dir: str = '',
opacity: float = 0.8) -> List[np.ndarray]:
"""Visualize predictions.
@ -125,9 +205,8 @@ class MMSegInferencer(BaseInferencer):
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw Prediction SegDataSample.
Defaults to True.
img_out_dir (str): Output directory of drawn images. Defaults: ''
img_out_dir (str): Output directory of rendering prediction i.e.
color segmentation mask. Defaults: ''
opacity (int, float): The transparency of segmentation mask.
Defaults to 0.8.
@ -140,7 +219,7 @@ class MMSegInferencer(BaseInferencer):
if getattr(self, 'visualizer') is None:
raise ValueError('Visualization needs the "visualizer" term'
'defined in the config, but got None')
self.visualizer.set_dataset_meta(**self.model.dataset_meta)
self.visualizer.alpha = opacity
results = []
@ -153,7 +232,7 @@ class MMSegInferencer(BaseInferencer):
img_name = osp.basename(single_input)
elif isinstance(single_input, np.ndarray):
img = single_input.copy()
img_num = str(self.num_visualized_imgs).zfill(8)
img_num = str(self.num_visualized_imgs).zfill(8) + '_vis'
img_name = f'{img_num}.jpg'
else:
raise ValueError('Unsupported input type:'
@ -169,7 +248,7 @@ class MMSegInferencer(BaseInferencer):
show=show,
wait_time=wait_time,
draw_gt=False,
draw_pred=draw_pred,
draw_pred=True,
out_file=out_file)
results.append(self.visualizer.get_image())
self.num_visualized_imgs += 1
@ -180,62 +259,65 @@ class MMSegInferencer(BaseInferencer):
preds: PredType,
visualization: List[np.ndarray],
return_datasample: bool = False,
mask_dir: str = 'mask',
save_mask: bool = True,
pred_out_dir: str = '') -> dict:
"""Process the predictions and visualization results from ``forward``
and ``visualize``.
This method should be responsible for the following tasks:
1. Convert datasamples into a json-serializable dict if needed.
2. Pack the predictions and visualization results and return them.
3. Dump or log the predictions.
1. Pack the predictions and visualization results and return them.
2. Save the predictions, if it needed.
Args:
preds (List[Dict]): Predictions of the model.
visualization (np.ndarray): Visualized predictions.
visualization (List[np.ndarray]): The list of rendering color
segmentation mask.
return_datasample (bool): Whether to return results as datasamples.
Defaults to False.
pred_out_dir: File to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.
mask_dir (str): Sub directory of `pred_out_dir`, used to save pred
mask file.
save_mask (bool): Whether save pred mask as a file.
Returns:
dict: Inference and visualization results with key ``predictions``
and ``visualization``
- ``visualization (Any)``: Returned by :meth:`visualize`
- ``predictions`` (dict or DataSample): Returned by
- ``predictions`` (List[np.ndarray], np.ndarray): Returned by
:meth:`forward` and processed in :meth:`postprocess`.
If ``return_datasample=False``, it usually should be a
json-serializable dict containing only basic data elements such
as strings and numbers.
If ``return_datasample=False``, it will be the segmentation mask
with label indice.
"""
if return_datasample:
if len(preds) == 1:
return preds[0]
else:
return preds
results_dict = {}
results_dict['predictions'] = preds
results_dict['visualization'] = visualization
results_dict['predictions'] = []
results_dict['visualization'] = []
if pred_out_dir != '':
mmengine.mkdir_or_exist(pred_out_dir)
if save_mask:
preds = [preds] if isinstance(preds, SegDataSample) else preds
for pred in preds:
mmcv.imwrite(
pred.pred_sem_seg.numpy().data[0],
osp.join(pred_out_dir, mask_dir,
osp.basename(pred.metainfo['img_path'])))
else:
mmengine.dump(results_dict,
osp.join(pred_out_dir, 'results.pkl'))
if return_datasample:
return preds
for i, pred in enumerate(preds):
pred_data = pred.pred_sem_seg.numpy().data[0]
results_dict['predictions'].append(pred_data)
if visualization is not None:
vis = visualization[i]
results_dict['visualization'].append(vis)
if pred_out_dir != '':
mmengine.mkdir_or_exist(pred_out_dir)
img_name = str(self.num_pred_imgs).zfill(8) + '_pred.png'
img_path = osp.join(pred_out_dir, img_name)
output = Image.fromarray(pred_data.astype(np.uint8))
output.save(img_path)
self.num_pred_imgs += 1
if len(results_dict['predictions']) == 1:
results_dict['predictions'] = results_dict['predictions'][0]
if visualization is not None:
results_dict['visualization'] = \
results_dict['visualization'][0]
return results_dict
def _init_pipeline(self, cfg: ConfigType) -> Compose:

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional
import mmcv
import numpy as np
@ -24,6 +24,17 @@ class SegLocalVisualizer(Visualizer):
Defaults to None.
save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data.
classes (list, optional): Input classes for result rendering, as the
prediction of segmentation model is a segment map with label
indices, `classes` is a list which includes items responding to the
label indices. If classes is not defined, visualizer will take
`cityscapes` classes by default. Defaults to None.
palette (list, optional): Input palette for result rendering, which is
a list of color palette responding to the classes. Defaults to None.
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/class_names.py#L302-L317>`_
visulizer will use the meta information of the dataset i.e. classes
and palette, but the `classes` and `palette` have higher priority.
Defaults to None.
alpha (int, float): The transparency of segmentation mask.
Defaults to 0.8.
@ -49,15 +60,15 @@ class SegLocalVisualizer(Visualizer):
>>> seg_local_visualizer.add_datasample(
... 'visualizer_example', image,
... gt_seg_data_sample, show=True)
"""
""" # noqa
def __init__(self,
name: str = 'visualizer',
image: Optional[np.ndarray] = None,
vis_backends: Optional[Dict] = None,
save_dir: Optional[str] = None,
palette: Optional[Union[str, List]] = None,
classes: Optional[Union[str, List]] = None,
classes: Optional[List] = None,
palette: Optional[List] = None,
dataset_name: Optional[str] = None,
alpha: float = 0.8,
**kwargs):
@ -66,17 +77,23 @@ class SegLocalVisualizer(Visualizer):
self.set_dataset_meta(palette, classes, dataset_name)
def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
classes: Optional[Tuple[str]],
palette: Optional[List[List[int]]]) -> np.ndarray:
classes: Optional[List],
palette: Optional[List]) -> np.ndarray:
"""Draw semantic seg of GT or prediction.
Args:
image (np.ndarray): The image to draw.
sem_seg (:obj:`PixelData`): Data structure for
pixel-level annotations or predictions.
classes (Tuple[str], optional): Category information.
palette (List[List[int]], optional): The palette of
segmentation map.
sem_seg (:obj:`PixelData`): Data structure for pixel-level
annotations or predictions.
classes (list, optional): Input classes for result rendering, as
the prediction of segmentation model is a segment map with
label indices, `classes` is a list which includes items
responding to the label indices. If classes is not defined,
visualizer will take `cityscapes` classes by default.
Defaults to None.
palette (list, optional): Input palette for result rendering, which
is a list of color palette responding to the classes.
Defaults to None.
Returns:
np.ndarray: the drawn image which channel is RGB.
@ -101,9 +118,26 @@ class SegLocalVisualizer(Visualizer):
return self.get_image()
def set_dataset_meta(self,
palette: Optional[Union[str, List]] = None,
classes: Optional[Union[str, List]] = None,
classes: Optional[List] = None,
palette: Optional[List] = None,
dataset_name: Optional[str] = None) -> None:
"""Set meta information to visualizer.
Args:
classes (list, optional): Input classes for result rendering, as
the prediction of segmentation model is a segment map with
label indices, `classes` is a list which includes items
responding to the label indices. If classes is not defined,
visualizer will take `cityscapes` classes by default.
Defaults to None.
palette (list, optional): Input palette for result rendering, which
is a list of color palette responding to the classes.
Defaults to None.
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/class_names.py#L302-L317>`_
visulizer will use the meta information of the dataset i.e.
classes and palette, but the `classes` and `palette` have
higher priority. Defaults to None.
""" # noqa
# Set default value. When calling
# `SegLocalVisualizer().dataset_meta=xxx`,
# it will override the default value.

View File

@ -104,12 +104,10 @@ def test_inferencer():
imgs = [img, img]
infer(imgs)
results = infer(imgs, out_dir=tempfile.gettempdir(), draw_pred=True)
results = infer(imgs, out_dir=tempfile.gettempdir())
# test results
assert 'predictions' in results
assert 'visualization' in results
assert len(results['predictions']) == 2
assert results['predictions'][0].seg_logits.data.shape == torch.Size(
(19, 4, 4))
assert results['predictions'][0].pred_sem_seg.shape == torch.Size((4, 4))
assert results['predictions'][0].shape == (4, 4)