[Fix] Fix demo scripts (#1815)

* [Feature] Add SegVisualizer

* change name to visualizer_example

* fix inference api

* fix video demo and refine inference api

* fix

* mmseg compose

* set default device to cuda:0

* fix import

* update dir

* rm engine/visualizer ut

* refine inference api and docs

* rename

Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>
pull/1850/head
谢昕辰 2022-07-29 18:37:20 +08:00 committed by GitHub
parent 3bbdd6dc4a
commit 5d9650838e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 138 additions and 94 deletions

View File

@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser from argparse import ArgumentParser
from mmengine.utils import revert_sync_batchnorm
from mmseg.apis import inference_model, init_model, show_result_pyplot from mmseg.apis import inference_model, init_model, show_result_pyplot
from mmseg.utils import get_palette from mmseg.utils import register_all_modules
def main(): def main():
@ -13,27 +15,35 @@ def main():
parser.add_argument( parser.add_argument(
'--device', default='cuda:0', help='Device used for inference') '--device', default='cuda:0', help='Device used for inference')
parser.add_argument( parser.add_argument(
'--palette', '--save-dir',
default='cityscapes', default=None,
help='Color palette used for segmentation map') help='Save file dir for all storage backends.')
parser.add_argument( parser.add_argument(
'--opacity', '--opacity',
type=float, type=float,
default=0.5, default=0.5,
help='Opacity of painted segmentation map. In (0, 1] range.') help='Opacity of painted segmentation map. In (0, 1] range.')
parser.add_argument(
'--title', default='result', help='The image identifier.')
args = parser.parse_args() args = parser.parse_args()
register_all_modules()
# build the model from a config file and a checkpoint file # build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device) model = init_model(args.config, args.checkpoint, device=args.device)
if args.device == 'cpu':
model = revert_sync_batchnorm(model)
# test a single image # test a single image
result = inference_model(model, args.img) result = inference_model(model, args.img)
# show the results # show the results
show_result_pyplot( show_result_pyplot(
model, model,
args.img, args.img, [result],
result, title=args.title,
get_palette(args.palette), opacity=args.opacity,
opacity=args.opacity) draw_gt=False,
show=False if args.save_dir is not None else True,
save_dir=args.save_dir)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -20,8 +20,11 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"import torch\n",
"from mmengine import revert_sync_batchnorm\n",
"from mmseg.apis import init_model, inference_model, show_result_pyplot\n", "from mmseg.apis import init_model, inference_model, show_result_pyplot\n",
"from mmseg.utils import get_palette" "from mmseg.utils import register_all_modules\n",
"register_all_modules()"
] ]
}, },
{ {
@ -56,6 +59,8 @@
"source": [ "source": [
"# test a single image\n", "# test a single image\n",
"img = 'demo.png'\n", "img = 'demo.png'\n",
"if not torch.cuda.is_available():\n",
" model = revert_sync_batchnorm(model)\n",
"result = inference_model(model, img)" "result = inference_model(model, img)"
] ]
}, },
@ -66,7 +71,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# show the results\n", "# show the results\n",
"show_result_pyplot(model, img, result, get_palette('cityscapes'))" "show_result_pyplot(model, img, result)"
] ]
}, },
{ {
@ -79,7 +84,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3.10.4 ('pt1.11-v2')",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -93,7 +98,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.0" "version": "3.10.4"
}, },
"pycharm": { "pycharm": {
"stem_cell": { "stem_cell": {
@ -103,6 +108,11 @@
}, },
"source": [] "source": []
} }
},
"vscode": {
"interpreter": {
"hash": "fdab7187f8cbd4ce42bbf864ddb4c4693e7329271a15a7fa96e4bdb82b9302c9"
}
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -2,9 +2,11 @@
from argparse import ArgumentParser from argparse import ArgumentParser
import cv2 import cv2
from mmengine.utils import revert_sync_batchnorm
from mmseg.apis import inference_model, init_model from mmseg.apis import inference_model, init_model
from mmseg.utils import get_palette from mmseg.apis.inference import show_result_pyplot
from mmseg.utils import register_all_modules
def main(): def main():
@ -51,10 +53,16 @@ def main():
assert args.show or args.output_file, \ assert args.show or args.output_file, \
'At least one output should be enabled.' 'At least one output should be enabled.'
register_all_modules()
# build the model from a config file and a checkpoint file # build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device) model = init_model(args.config, args.checkpoint, device=args.device)
if args.device == 'cpu':
model = revert_sync_batchnorm(model)
# build input video # build input video
if args.video.isdigit():
args.video = int(args.video)
cap = cv2.VideoCapture(args.video) cap = cv2.VideoCapture(args.video)
assert (cap.isOpened()) assert (cap.isOpened())
input_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) input_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
@ -86,12 +94,7 @@ def main():
result = inference_model(model, frame) result = inference_model(model, frame)
# blend raw image and prediction # blend raw image and prediction
draw_img = model.show_result( draw_img = show_result_pyplot(model, frame, [result])
frame,
result,
palette=get_palette(args.palette),
show=False,
opacity=args.opacity)
if args.show: if args.show:
cv2.imshow('video_demo', draw_img) cv2.imshow('video_demo', draw_img)

View File

@ -1,12 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import matplotlib.pyplot as plt from typing import Sequence, Union
import mmcv
import torch
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmseg.datasets.transforms import Compose import mmcv
from mmseg.models import build_segmentor import numpy as np
import torch
from mmcv.runner import load_checkpoint
from mmengine import Config
from mmengine.dataset import Compose
from mmseg.data import SegDataSample
from mmseg.models import BaseSegmentor
from mmseg.registry import MODELS
from mmseg.utils import SampleList
from mmseg.visualization import SegLocalVisualizer
def init_model(config, checkpoint=None, device='cuda:0'): def init_model(config, checkpoint=None, device='cuda:0'):
@ -23,13 +29,13 @@ def init_model(config, checkpoint=None, device='cuda:0'):
nn.Module: The constructed segmentor. nn.Module: The constructed segmentor.
""" """
if isinstance(config, str): if isinstance(config, str):
config = mmcv.Config.fromfile(config) config = Config.fromfile(config)
elif not isinstance(config, mmcv.Config): elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, ' raise TypeError('config must be a filename or Config object, '
'but got {}'.format(type(config))) 'but got {}'.format(type(config)))
config.model.pretrained = None config.model.pretrained = None
config.model.train_cfg = None config.model.train_cfg = None
model = build_segmentor(config.model, test_cfg=config.get('test_cfg')) model = MODELS.build(config.model)
if checkpoint is not None: if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
model.CLASSES = checkpoint['meta']['CLASSES'] model.CLASSES = checkpoint['meta']['CLASSES']
@ -40,34 +46,41 @@ def init_model(config, checkpoint=None, device='cuda:0'):
return model return model
class LoadImage: ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
"""A simple pipeline to load image."""
def __call__(self, results):
"""Call function to load images into results.
Args: def _preprare_data(imgs: ImageType, model: BaseSegmentor):
results (dict): A result dict contains the file name
of the image to be read.
Returns: cfg = model.cfg
dict: ``results`` will be returned containing loaded image. if dict(type='LoadAnnotations') in cfg.test_pipeline:
""" cfg.test_pipeline.remove(dict(type='LoadAnnotations'))
if isinstance(results['img'], str): is_batch = True
results['filename'] = results['img'] if not isinstance(imgs, (list, tuple)):
results['ori_filename'] = results['img'] imgs = [imgs]
is_batch = False
if isinstance(imgs[0], np.ndarray):
cfg.test_pipeline[0].type = 'LoadImageFromNDArray'
# TODO: Consider using the singleton pattern to avoid building
# a pipeline for each inference
pipeline = Compose(cfg.test_pipeline)
data = []
for img in imgs:
if isinstance(img, np.ndarray):
data_ = dict(img=img)
else: else:
results['filename'] = None data_ = dict(img_path=img)
results['ori_filename'] = None data_ = pipeline(data_)
img = mmcv.imread(results['img']) data.append(data_)
results['img'] = img
results['img_shape'] = img.shape return data, is_batch
results['ori_shape'] = img.shape
return results
def inference_model(model, img): def inference_model(model: BaseSegmentor,
img: ImageType) -> Union[SegDataSample, SampleList]:
"""Inference image(s) with the segmentor. """Inference image(s) with the segmentor.
Args: Args:
@ -76,61 +89,70 @@ def inference_model(model, img):
images. images.
Returns: Returns:
(list[Tensor]): The segmentation result. :obj:`SegDataSample` or list[:obj:`SegDataSample`]:
If imgs is a list or tuple, the same length list type results
will be returned, otherwise return the segmentation results directly.
""" """
cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data # prepare data
data = dict(img=img) data, is_batch = _preprare_data(img, model)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
data['img_metas'] = [i.data[0] for i in data['img_metas']]
# forward the model # forward the model
with torch.no_grad(): with torch.no_grad():
result = model(return_loss=False, rescale=True, **data) results = model.test_step(data)
return result
return results if is_batch else results[0]
def show_result_pyplot(model, def show_result_pyplot(model: BaseSegmentor,
img, img: Union[str, np.ndarray],
result, result: SampleList,
palette=None, opacity: float = 0.5,
fig_size=(15, 10), title: str = '',
opacity=0.5, draw_gt: bool = True,
title='', draw_pred: bool = True,
block=True): wait_time: float = 0,
show: bool = True,
save_dir=None):
"""Visualize the segmentation results on the image. """Visualize the segmentation results on the image.
Args: Args:
model (nn.Module): The loaded segmentor. model (nn.Module): The loaded segmentor.
img (str or np.ndarray): Image filename or loaded image. img (str or np.ndarray): Image filename or loaded image.
result (list): The segmentation result. result (list): The prediction SegDataSample result.
palette (list[list[int]]] | None): The palette of segmentation
map. If None is given, random palette will be generated.
Default: None
fig_size (tuple): Figure size of the pyplot figure.
opacity(float): Opacity of painted segmentation map. opacity(float): Opacity of painted segmentation map.
Default 0.5. Default 0.5. Must be in (0, 1] range.
Must be in (0, 1] range.
title (str): The title of pyplot figure. title (str): The title of pyplot figure.
Default is ''. Default is ''.
block (bool): Whether to block the pyplot figure. draw_gt (bool): Whether to draw GT SegDataSample. Default to True.
Default is True. draw_pred (bool): Whether to draw Prediction SegDataSample.
Defaults to True.
wait_time (float): The interval of show (s). Defaults to 0.
show (bool): Whether to display the drawn image.
Default to True.
save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data.
""" """
if hasattr(model, 'module'): if hasattr(model, 'module'):
model = model.module model = model.module
img = model.show_result( if isinstance(img, str):
img, result, palette=palette, show=False, opacity=opacity) image = mmcv.imread(img)
plt.figure(figsize=fig_size) else:
plt.imshow(mmcv.bgr2rgb(img)) image = img
plt.title(title) if save_dir is not None:
plt.tight_layout() mmcv.mkdir_or_exist(save_dir)
plt.show(block=block) # init visualizer
visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')],
save_dir=save_dir,
alpha=opacity)
visualizer.dataset_meta = dict(
classes=model.CLASSES, palette=model.PALETTE)
visualizer.add_datasample(
name=title,
image=image,
pred_sample=result[0],
draw_gt=draw_gt,
draw_pred=draw_pred,
wait_time=wait_time,
show=show)
return visualizer.get_image()

View File

@ -2,9 +2,8 @@
from .hooks import SegVisualizationHook from .hooks import SegVisualizationHook
from .optimizers import (LayerDecayOptimizerConstructor, from .optimizers import (LayerDecayOptimizerConstructor,
LearningRateDecayOptimizerConstructor) LearningRateDecayOptimizerConstructor)
from .visualization import SegLocalVisualizer
__all__ = [ __all__ = [
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor', 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
'SegVisualizationHook', 'SegLocalVisualizer' 'SegVisualizationHook'
] ]

View File

@ -8,8 +8,8 @@ from mmengine.hooks import Hook
from mmengine.runner import Runner from mmengine.runner import Runner
from mmseg.data import SegDataSample from mmseg.data import SegDataSample
from mmseg.engine.visualization import SegLocalVisualizer
from mmseg.registry import HOOKS from mmseg.registry import HOOKS
from mmseg.visualization import SegLocalVisualizer
@HOOKS.register_module() @HOOKS.register_module()

View File

@ -7,7 +7,7 @@ from mmengine.data import PixelData
from mmseg.data import SegDataSample from mmseg.data import SegDataSample
from mmseg.engine.hooks import SegVisualizationHook from mmseg.engine.hooks import SegVisualizationHook
from mmseg.engine.visualization import SegLocalVisualizer from mmseg.visualization import SegLocalVisualizer
class TestVisualizationHook(TestCase): class TestVisualizationHook(TestCase):

View File

@ -10,7 +10,7 @@ import torch
from mmengine.data import PixelData from mmengine.data import PixelData
from mmseg.data import SegDataSample from mmseg.data import SegDataSample
from mmseg.engine.visualization import SegLocalVisualizer from mmseg.visualization import SegLocalVisualizer
class TestSegLocalVisualizer(TestCase): class TestSegLocalVisualizer(TestCase):