[Fix] Fix demo scripts (#1815)

* [Feature] Add SegVisualizer

* change name to visualizer_example

* fix inference api

* fix video demo and refine inference api

* fix

* mmseg compose

* set default device to cuda:0

* fix import

* update dir

* rm engine/visualizer ut

* refine inference api and docs

* rename

Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>
pull/1850/head
谢昕辰 2022-07-29 18:37:20 +08:00 committed by GitHub
parent 3bbdd6dc4a
commit 5d9650838e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 138 additions and 94 deletions

View File

@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser
from mmengine.utils import revert_sync_batchnorm
from mmseg.apis import inference_model, init_model, show_result_pyplot
from mmseg.utils import get_palette
from mmseg.utils import register_all_modules
def main():
@ -13,27 +15,35 @@ def main():
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--palette',
default='cityscapes',
help='Color palette used for segmentation map')
'--save-dir',
default=None,
help='Save file dir for all storage backends.')
parser.add_argument(
'--opacity',
type=float,
default=0.5,
help='Opacity of painted segmentation map. In (0, 1] range.')
parser.add_argument(
'--title', default='result', help='The image identifier.')
args = parser.parse_args()
register_all_modules()
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
if args.device == 'cpu':
model = revert_sync_batchnorm(model)
# test a single image
result = inference_model(model, args.img)
# show the results
show_result_pyplot(
model,
args.img,
result,
get_palette(args.palette),
opacity=args.opacity)
args.img, [result],
title=args.title,
opacity=args.opacity,
draw_gt=False,
show=False if args.save_dir is not None else True,
save_dir=args.save_dir)
if __name__ == '__main__':

View File

@ -20,8 +20,11 @@
},
"outputs": [],
"source": [
"import torch\n",
"from mmengine import revert_sync_batchnorm\n",
"from mmseg.apis import init_model, inference_model, show_result_pyplot\n",
"from mmseg.utils import get_palette"
"from mmseg.utils import register_all_modules\n",
"register_all_modules()"
]
},
{
@ -56,6 +59,8 @@
"source": [
"# test a single image\n",
"img = 'demo.png'\n",
"if not torch.cuda.is_available():\n",
" model = revert_sync_batchnorm(model)\n",
"result = inference_model(model, img)"
]
},
@ -66,7 +71,7 @@
"outputs": [],
"source": [
"# show the results\n",
"show_result_pyplot(model, img, result, get_palette('cityscapes'))"
"show_result_pyplot(model, img, result)"
]
},
{
@ -79,7 +84,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3.10.4 ('pt1.11-v2')",
"language": "python",
"name": "python3"
},
@ -93,7 +98,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
"version": "3.10.4"
},
"pycharm": {
"stem_cell": {
@ -103,6 +108,11 @@
},
"source": []
}
},
"vscode": {
"interpreter": {
"hash": "fdab7187f8cbd4ce42bbf864ddb4c4693e7329271a15a7fa96e4bdb82b9302c9"
}
}
},
"nbformat": 4,

View File

@ -2,9 +2,11 @@
from argparse import ArgumentParser
import cv2
from mmengine.utils import revert_sync_batchnorm
from mmseg.apis import inference_model, init_model
from mmseg.utils import get_palette
from mmseg.apis.inference import show_result_pyplot
from mmseg.utils import register_all_modules
def main():
@ -51,10 +53,16 @@ def main():
assert args.show or args.output_file, \
'At least one output should be enabled.'
register_all_modules()
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
if args.device == 'cpu':
model = revert_sync_batchnorm(model)
# build input video
if args.video.isdigit():
args.video = int(args.video)
cap = cv2.VideoCapture(args.video)
assert (cap.isOpened())
input_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
@ -86,12 +94,7 @@ def main():
result = inference_model(model, frame)
# blend raw image and prediction
draw_img = model.show_result(
frame,
result,
palette=get_palette(args.palette),
show=False,
opacity=args.opacity)
draw_img = show_result_pyplot(model, frame, [result])
if args.show:
cv2.imshow('video_demo', draw_img)

View File

@ -1,12 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import matplotlib.pyplot as plt
import mmcv
import torch
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from typing import Sequence, Union
from mmseg.datasets.transforms import Compose
from mmseg.models import build_segmentor
import mmcv
import numpy as np
import torch
from mmcv.runner import load_checkpoint
from mmengine import Config
from mmengine.dataset import Compose
from mmseg.data import SegDataSample
from mmseg.models import BaseSegmentor
from mmseg.registry import MODELS
from mmseg.utils import SampleList
from mmseg.visualization import SegLocalVisualizer
def init_model(config, checkpoint=None, device='cuda:0'):
@ -23,13 +29,13 @@ def init_model(config, checkpoint=None, device='cuda:0'):
nn.Module: The constructed segmentor.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
config = Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
'but got {}'.format(type(config)))
config.model.pretrained = None
config.model.train_cfg = None
model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))
model = MODELS.build(config.model)
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
model.CLASSES = checkpoint['meta']['CLASSES']
@ -40,34 +46,41 @@ def init_model(config, checkpoint=None, device='cuda:0'):
return model
class LoadImage:
"""A simple pipeline to load image."""
ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
def __call__(self, results):
"""Call function to load images into results.
Args:
results (dict): A result dict contains the file name
of the image to be read.
def _preprare_data(imgs: ImageType, model: BaseSegmentor):
Returns:
dict: ``results`` will be returned containing loaded image.
"""
cfg = model.cfg
if dict(type='LoadAnnotations') in cfg.test_pipeline:
cfg.test_pipeline.remove(dict(type='LoadAnnotations'))
if isinstance(results['img'], str):
results['filename'] = results['img']
results['ori_filename'] = results['img']
is_batch = True
if not isinstance(imgs, (list, tuple)):
imgs = [imgs]
is_batch = False
if isinstance(imgs[0], np.ndarray):
cfg.test_pipeline[0].type = 'LoadImageFromNDArray'
# TODO: Consider using the singleton pattern to avoid building
# a pipeline for each inference
pipeline = Compose(cfg.test_pipeline)
data = []
for img in imgs:
if isinstance(img, np.ndarray):
data_ = dict(img=img)
else:
results['filename'] = None
results['ori_filename'] = None
img = mmcv.imread(results['img'])
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
return results
data_ = dict(img_path=img)
data_ = pipeline(data_)
data.append(data_)
return data, is_batch
def inference_model(model, img):
def inference_model(model: BaseSegmentor,
img: ImageType) -> Union[SegDataSample, SampleList]:
"""Inference image(s) with the segmentor.
Args:
@ -76,61 +89,70 @@ def inference_model(model, img):
images.
Returns:
(list[Tensor]): The segmentation result.
:obj:`SegDataSample` or list[:obj:`SegDataSample`]:
If imgs is a list or tuple, the same length list type results
will be returned, otherwise return the segmentation results directly.
"""
cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
data['img_metas'] = [i.data[0] for i in data['img_metas']]
data, is_batch = _preprare_data(img, model)
# forward the model
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
return result
results = model.test_step(data)
return results if is_batch else results[0]
def show_result_pyplot(model,
img,
result,
palette=None,
fig_size=(15, 10),
opacity=0.5,
title='',
block=True):
def show_result_pyplot(model: BaseSegmentor,
img: Union[str, np.ndarray],
result: SampleList,
opacity: float = 0.5,
title: str = '',
draw_gt: bool = True,
draw_pred: bool = True,
wait_time: float = 0,
show: bool = True,
save_dir=None):
"""Visualize the segmentation results on the image.
Args:
model (nn.Module): The loaded segmentor.
img (str or np.ndarray): Image filename or loaded image.
result (list): The segmentation result.
palette (list[list[int]]] | None): The palette of segmentation
map. If None is given, random palette will be generated.
Default: None
fig_size (tuple): Figure size of the pyplot figure.
result (list): The prediction SegDataSample result.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Default 0.5. Must be in (0, 1] range.
title (str): The title of pyplot figure.
Default is ''.
block (bool): Whether to block the pyplot figure.
Default is True.
draw_gt (bool): Whether to draw GT SegDataSample. Default to True.
draw_pred (bool): Whether to draw Prediction SegDataSample.
Defaults to True.
wait_time (float): The interval of show (s). Defaults to 0.
show (bool): Whether to display the drawn image.
Default to True.
save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data.
"""
if hasattr(model, 'module'):
model = model.module
img = model.show_result(
img, result, palette=palette, show=False, opacity=opacity)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
plt.title(title)
plt.tight_layout()
plt.show(block=block)
if isinstance(img, str):
image = mmcv.imread(img)
else:
image = img
if save_dir is not None:
mmcv.mkdir_or_exist(save_dir)
# init visualizer
visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')],
save_dir=save_dir,
alpha=opacity)
visualizer.dataset_meta = dict(
classes=model.CLASSES, palette=model.PALETTE)
visualizer.add_datasample(
name=title,
image=image,
pred_sample=result[0],
draw_gt=draw_gt,
draw_pred=draw_pred,
wait_time=wait_time,
show=show)
return visualizer.get_image()

View File

@ -2,9 +2,8 @@
from .hooks import SegVisualizationHook
from .optimizers import (LayerDecayOptimizerConstructor,
LearningRateDecayOptimizerConstructor)
from .visualization import SegLocalVisualizer
__all__ = [
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
'SegVisualizationHook', 'SegLocalVisualizer'
'SegVisualizationHook'
]

View File

@ -8,8 +8,8 @@ from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmseg.data import SegDataSample
from mmseg.engine.visualization import SegLocalVisualizer
from mmseg.registry import HOOKS
from mmseg.visualization import SegLocalVisualizer
@HOOKS.register_module()

View File

@ -7,7 +7,7 @@ from mmengine.data import PixelData
from mmseg.data import SegDataSample
from mmseg.engine.hooks import SegVisualizationHook
from mmseg.engine.visualization import SegLocalVisualizer
from mmseg.visualization import SegLocalVisualizer
class TestVisualizationHook(TestCase):

View File

@ -10,7 +10,7 @@ import torch
from mmengine.data import PixelData
from mmseg.data import SegDataSample
from mmseg.engine.visualization import SegLocalVisualizer
from mmseg.visualization import SegLocalVisualizer
class TestSegLocalVisualizer(TestCase):