mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Support inference and visualization of VPD ## Modification 1. add a new VPD model that does not generate black border in predictions 2. update `SegLocalVisualizer` to support depth visualization 3. update `MMSegInferencer` to support save predictions of depth estimation in method `postprocess` ## BC-breaking (Optional) Does the modification introduce changes that break the backward-compatibility of the downstream repos? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR. ## Use cases (Optional) Run inference with VPD using the this command ```sh python demo/image_demo_with_inferencer.py demo/classroom__rgb_00283.jpg vpd_depth --out-dir vis_results ``` The following image will be saved under `vis_results/vis`  ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 4. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 5. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 6. The documentation has been modified accordingly, like docstring or example tutorials.
380 lines
16 KiB
Python
380 lines
16 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
import warnings
|
|
from typing import List, Optional, Sequence, Union
|
|
|
|
import mmcv
|
|
import mmengine
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.transforms import Compose
|
|
from mmengine.infer.infer import BaseInferencer, ModelType
|
|
from mmengine.model import revert_sync_batchnorm
|
|
from mmengine.registry import init_default_scope
|
|
from mmengine.runner.checkpoint import _load_checkpoint_to_model
|
|
from PIL import Image
|
|
|
|
from mmseg.structures import SegDataSample
|
|
from mmseg.utils import ConfigType, SampleList, get_classes, get_palette
|
|
from mmseg.visualization import SegLocalVisualizer
|
|
|
|
InputType = Union[str, np.ndarray]
|
|
InputsType = Union[InputType, Sequence[InputType]]
|
|
PredType = Union[SegDataSample, SampleList]
|
|
|
|
|
|
class MMSegInferencer(BaseInferencer):
|
|
"""Semantic segmentation inferencer, provides inference and visualization
|
|
interfaces. Note: MMEngine >= 0.5.0 is required.
|
|
|
|
Args:
|
|
model (str, optional): Path to the config file or the model name
|
|
defined in metafile. Take the `mmseg metafile <https://github.com/open-mmlab/mmsegmentation/blob/main/configs/fcn/metafile.yaml>`_
|
|
as an example the `model` could be
|
|
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024", and the weights of model
|
|
will be download automatically. If use config file, like
|
|
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py", the
|
|
`weights` should be defined.
|
|
weights (str, optional): Path to the checkpoint. If it is not specified
|
|
and model is a model name of metafile, the weights will be loaded
|
|
from metafile. Defaults to None.
|
|
classes (list, optional): Input classes for result rendering, as the
|
|
prediction of segmentation model is a segment map with label
|
|
indices, `classes` is a list which includes items responding to the
|
|
label indices. If classes is not defined, visualizer will take
|
|
`cityscapes` classes by default. Defaults to None.
|
|
palette (list, optional): Input palette for result rendering, which is
|
|
a list of color palette responding to the classes. If palette is
|
|
not defined, visualizer will take `cityscapes` palette by default.
|
|
Defaults to None.
|
|
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/utils/class_names.py#L302-L317>`_
|
|
visulizer will use the meta information of the dataset i.e. classes
|
|
and palette, but the `classes` and `palette` have higher priority.
|
|
Defaults to None.
|
|
device (str, optional): Device to run inference. If None, the available
|
|
device will be automatically used. Defaults to None.
|
|
scope (str, optional): The scope of the model. Defaults to 'mmseg'.
|
|
""" # noqa
|
|
|
|
preprocess_kwargs: set = set()
|
|
forward_kwargs: set = {'mode', 'out_dir'}
|
|
visualize_kwargs: set = {
|
|
'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis'
|
|
}
|
|
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
|
|
|
|
def __init__(self,
|
|
model: Union[ModelType, str],
|
|
weights: Optional[str] = None,
|
|
classes: Optional[Union[str, List]] = None,
|
|
palette: Optional[Union[str, List]] = None,
|
|
dataset_name: Optional[str] = None,
|
|
device: Optional[str] = None,
|
|
scope: Optional[str] = 'mmseg') -> None:
|
|
# A global counter tracking the number of images processes, for
|
|
# naming of the output images
|
|
self.num_visualized_imgs = 0
|
|
self.num_pred_imgs = 0
|
|
init_default_scope(scope if scope else 'mmseg')
|
|
super().__init__(
|
|
model=model, weights=weights, device=device, scope=scope)
|
|
|
|
if device == 'cpu' or not torch.cuda.is_available():
|
|
self.model = revert_sync_batchnorm(self.model)
|
|
|
|
assert isinstance(self.visualizer, SegLocalVisualizer)
|
|
self.visualizer.set_dataset_meta(palette, classes, dataset_name)
|
|
|
|
def _load_weights_to_model(self, model: nn.Module,
|
|
checkpoint: Optional[dict],
|
|
cfg: Optional[ConfigType]) -> None:
|
|
"""Loading model weights and meta information from cfg and checkpoint.
|
|
|
|
Subclasses could override this method to load extra meta information
|
|
from ``checkpoint`` and ``cfg`` to model.
|
|
|
|
Args:
|
|
model (nn.Module): Model to load weights and meta information.
|
|
checkpoint (dict, optional): The loaded checkpoint.
|
|
cfg (Config or ConfigDict, optional): The loaded config.
|
|
"""
|
|
|
|
if checkpoint is not None:
|
|
_load_checkpoint_to_model(model, checkpoint)
|
|
checkpoint_meta = checkpoint.get('meta', {})
|
|
# save the dataset_meta in the model for convenience
|
|
if 'dataset_meta' in checkpoint_meta:
|
|
# mmsegmentation 1.x
|
|
model.dataset_meta = {
|
|
'classes': checkpoint_meta['dataset_meta'].get('classes'),
|
|
'palette': checkpoint_meta['dataset_meta'].get('palette')
|
|
}
|
|
elif 'CLASSES' in checkpoint_meta:
|
|
# mmsegmentation 0.x
|
|
classes = checkpoint_meta['CLASSES']
|
|
palette = checkpoint_meta.get('PALETTE', None)
|
|
model.dataset_meta = {'classes': classes, 'palette': palette}
|
|
else:
|
|
warnings.warn(
|
|
'dataset_meta or class names are not saved in the '
|
|
'checkpoint\'s meta data, use classes of Cityscapes by '
|
|
'default.')
|
|
model.dataset_meta = {
|
|
'classes': get_classes('cityscapes'),
|
|
'palette': get_palette('cityscapes')
|
|
}
|
|
else:
|
|
warnings.warn('Checkpoint is not loaded, and the inference '
|
|
'result is calculated by the randomly initialized '
|
|
'model!')
|
|
warnings.warn(
|
|
'weights is None, use cityscapes classes by default.')
|
|
model.dataset_meta = {
|
|
'classes': get_classes('cityscapes'),
|
|
'palette': get_palette('cityscapes')
|
|
}
|
|
|
|
def __call__(self,
|
|
inputs: InputsType,
|
|
return_datasamples: bool = False,
|
|
batch_size: int = 1,
|
|
return_vis: bool = False,
|
|
show: bool = False,
|
|
wait_time: int = 0,
|
|
out_dir: str = '',
|
|
img_out_dir: str = 'vis',
|
|
pred_out_dir: str = 'pred',
|
|
**kwargs) -> dict:
|
|
"""Call the inferencer.
|
|
|
|
Args:
|
|
inputs (Union[list, str, np.ndarray]): Inputs for the inferencer.
|
|
return_datasamples (bool): Whether to return results as
|
|
:obj:`SegDataSample`. Defaults to False.
|
|
batch_size (int): Batch size. Defaults to 1.
|
|
show (bool): Whether to display the rendering color segmentation
|
|
mask in a popup window. Defaults to False.
|
|
wait_time (float): The interval of show (s). Defaults to 0.
|
|
out_dir (str): Output directory of inference results. Defaults
|
|
to ''.
|
|
img_out_dir (str): Subdirectory of `out_dir`, used to save
|
|
rendering color segmentation mask, so `out_dir` must be defined
|
|
if you would like to save predicted mask. Defaults to 'vis'.
|
|
pred_out_dir (str): Subdirectory of `out_dir`, used to save
|
|
predicted mask file, so `out_dir` must be defined if you would
|
|
like to save predicted mask. Defaults to 'pred'.
|
|
|
|
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
|
|
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
|
|
Each key in kwargs should be in the corresponding set of
|
|
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
|
|
and ``postprocess_kwargs``.
|
|
|
|
|
|
Returns:
|
|
dict: Inference and visualization results.
|
|
"""
|
|
|
|
if out_dir != '':
|
|
pred_out_dir = osp.join(out_dir, pred_out_dir)
|
|
img_out_dir = osp.join(out_dir, img_out_dir)
|
|
else:
|
|
pred_out_dir = ''
|
|
img_out_dir = ''
|
|
|
|
return super().__call__(
|
|
inputs=inputs,
|
|
return_datasamples=return_datasamples,
|
|
batch_size=batch_size,
|
|
show=show,
|
|
wait_time=wait_time,
|
|
img_out_dir=img_out_dir,
|
|
pred_out_dir=pred_out_dir,
|
|
return_vis=return_vis,
|
|
**kwargs)
|
|
|
|
def visualize(self,
|
|
inputs: list,
|
|
preds: List[dict],
|
|
return_vis: bool = False,
|
|
show: bool = False,
|
|
wait_time: int = 0,
|
|
img_out_dir: str = '',
|
|
opacity: float = 0.8) -> List[np.ndarray]:
|
|
"""Visualize predictions.
|
|
|
|
Args:
|
|
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
|
|
preds (Any): Predictions of the model.
|
|
show (bool): Whether to display the image in a popup window.
|
|
Defaults to False.
|
|
wait_time (float): The interval of show (s). Defaults to 0.
|
|
img_out_dir (str): Output directory of rendering prediction i.e.
|
|
color segmentation mask. Defaults: ''
|
|
opacity (int, float): The transparency of segmentation mask.
|
|
Defaults to 0.8.
|
|
|
|
Returns:
|
|
List[np.ndarray]: Visualization results.
|
|
"""
|
|
if not show and img_out_dir == '' and not return_vis:
|
|
return None
|
|
if self.visualizer is None:
|
|
raise ValueError('Visualization needs the "visualizer" term'
|
|
'defined in the config, but got None.')
|
|
|
|
self.visualizer.set_dataset_meta(**self.model.dataset_meta)
|
|
self.visualizer.alpha = opacity
|
|
|
|
results = []
|
|
|
|
for single_input, pred in zip(inputs, preds):
|
|
if isinstance(single_input, str):
|
|
img_bytes = mmengine.fileio.get(single_input)
|
|
img = mmcv.imfrombytes(img_bytes)
|
|
img = img[:, :, ::-1]
|
|
img_name = osp.basename(single_input)
|
|
elif isinstance(single_input, np.ndarray):
|
|
img = single_input.copy()
|
|
img_num = str(self.num_visualized_imgs).zfill(8) + '_vis'
|
|
img_name = f'{img_num}.jpg'
|
|
else:
|
|
raise ValueError('Unsupported input type:'
|
|
f'{type(single_input)}')
|
|
|
|
out_file = osp.join(img_out_dir, img_name) if img_out_dir != ''\
|
|
else None
|
|
|
|
self.visualizer.add_datasample(
|
|
img_name,
|
|
img,
|
|
pred,
|
|
show=show,
|
|
wait_time=wait_time,
|
|
draw_gt=False,
|
|
draw_pred=True,
|
|
out_file=out_file)
|
|
if return_vis:
|
|
results.append(self.visualizer.get_image())
|
|
self.num_visualized_imgs += 1
|
|
|
|
return results if return_vis else None
|
|
|
|
def postprocess(self,
|
|
preds: PredType,
|
|
visualization: List[np.ndarray],
|
|
return_datasample: bool = False,
|
|
pred_out_dir: str = '') -> dict:
|
|
"""Process the predictions and visualization results from ``forward``
|
|
and ``visualize``.
|
|
|
|
This method should be responsible for the following tasks:
|
|
|
|
1. Pack the predictions and visualization results and return them.
|
|
2. Save the predictions, if it needed.
|
|
|
|
Args:
|
|
preds (List[Dict]): Predictions of the model.
|
|
visualization (List[np.ndarray]): The list of rendering color
|
|
segmentation mask.
|
|
return_datasample (bool): Whether to return results as datasamples.
|
|
Defaults to False.
|
|
pred_out_dir: File to save the inference results w/o
|
|
visualization. If left as empty, no file will be saved.
|
|
Defaults to ''.
|
|
|
|
Returns:
|
|
dict: Inference and visualization results with key ``predictions``
|
|
and ``visualization``
|
|
|
|
- ``visualization (Any)``: Returned by :meth:`visualize`
|
|
- ``predictions`` (List[np.ndarray], np.ndarray): Returned by
|
|
:meth:`forward` and processed in :meth:`postprocess`.
|
|
If ``return_datasample=False``, it will be the segmentation mask
|
|
with label indice.
|
|
"""
|
|
if return_datasample:
|
|
if len(preds) == 1:
|
|
return preds[0]
|
|
else:
|
|
return preds
|
|
|
|
results_dict = {}
|
|
|
|
results_dict['predictions'] = []
|
|
results_dict['visualization'] = []
|
|
|
|
for i, pred in enumerate(preds):
|
|
pred_data = dict()
|
|
if 'pred_sem_seg' in pred.keys():
|
|
pred_data['sem_seg'] = pred.pred_sem_seg.numpy().data[0]
|
|
elif 'pred_depth_map' in pred.keys():
|
|
pred_data['depth_map'] = pred.pred_depth_map.numpy().data[0]
|
|
|
|
if visualization is not None:
|
|
vis = visualization[i]
|
|
results_dict['visualization'].append(vis)
|
|
if pred_out_dir != '':
|
|
mmengine.mkdir_or_exist(pred_out_dir)
|
|
for key, data in pred_data.items():
|
|
post_fix = '_pred.png' if key == 'sem_seg' else '_pred.npy'
|
|
img_name = str(self.num_pred_imgs).zfill(8) + post_fix
|
|
img_path = osp.join(pred_out_dir, img_name)
|
|
if key == 'sem_seg':
|
|
output = Image.fromarray(data.astype(np.uint8))
|
|
output.save(img_path)
|
|
else:
|
|
np.save(img_path, data)
|
|
pred_data = next(iter(pred_data.values()))
|
|
results_dict['predictions'].append(pred_data)
|
|
self.num_pred_imgs += 1
|
|
|
|
if len(results_dict['predictions']) == 1:
|
|
results_dict['predictions'] = results_dict['predictions'][0]
|
|
if visualization is not None:
|
|
results_dict['visualization'] = \
|
|
results_dict['visualization'][0]
|
|
return results_dict
|
|
|
|
def _init_pipeline(self, cfg: ConfigType) -> Compose:
|
|
"""Initialize the test pipeline.
|
|
|
|
Return a pipeline to handle various input data, such as ``str``,
|
|
``np.ndarray``. It is an abstract method in BaseInferencer, and should
|
|
be implemented in subclasses.
|
|
|
|
The returned pipeline will be used to process a single data.
|
|
It will be used in :meth:`preprocess` like this:
|
|
|
|
.. code-block:: python
|
|
def preprocess(self, inputs, batch_size, **kwargs):
|
|
...
|
|
dataset = map(self.pipeline, dataset)
|
|
...
|
|
"""
|
|
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
|
# Loading annotations is also not applicable
|
|
for transform in ('LoadAnnotations', 'LoadDepthAnnotation'):
|
|
idx = self._get_transform_idx(pipeline_cfg, transform)
|
|
if idx != -1:
|
|
del pipeline_cfg[idx]
|
|
|
|
load_img_idx = self._get_transform_idx(pipeline_cfg,
|
|
'LoadImageFromFile')
|
|
if load_img_idx == -1:
|
|
raise ValueError(
|
|
'LoadImageFromFile is not found in the test pipeline')
|
|
pipeline_cfg[load_img_idx]['type'] = 'InferencerLoader'
|
|
return Compose(pipeline_cfg)
|
|
|
|
def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
|
|
"""Returns the index of the transform in a pipeline.
|
|
|
|
If the transform is not found, returns -1.
|
|
"""
|
|
for i, transform in enumerate(pipeline_cfg):
|
|
if transform['type'] == name:
|
|
return i
|
|
return -1
|