[Fix] Fix demo scripts (#1815)
* [Feature] Add SegVisualizer * change name to visualizer_example * fix inference api * fix video demo and refine inference api * fix * mmseg compose * set default device to cuda:0 * fix import * update dir * rm engine/visualizer ut * refine inference api and docs * rename Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>pull/1850/head
parent
3bbdd6dc4a
commit
5d9650838e
|
@ -1,8 +1,10 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
from mmengine.utils import revert_sync_batchnorm
|
||||||
|
|
||||||
from mmseg.apis import inference_model, init_model, show_result_pyplot
|
from mmseg.apis import inference_model, init_model, show_result_pyplot
|
||||||
from mmseg.utils import get_palette
|
from mmseg.utils import register_all_modules
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -13,27 +15,35 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--device', default='cuda:0', help='Device used for inference')
|
'--device', default='cuda:0', help='Device used for inference')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--palette',
|
'--save-dir',
|
||||||
default='cityscapes',
|
default=None,
|
||||||
help='Color palette used for segmentation map')
|
help='Save file dir for all storage backends.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--opacity',
|
'--opacity',
|
||||||
type=float,
|
type=float,
|
||||||
default=0.5,
|
default=0.5,
|
||||||
help='Opacity of painted segmentation map. In (0, 1] range.')
|
help='Opacity of painted segmentation map. In (0, 1] range.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--title', default='result', help='The image identifier.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
register_all_modules()
|
||||||
|
|
||||||
# build the model from a config file and a checkpoint file
|
# build the model from a config file and a checkpoint file
|
||||||
model = init_model(args.config, args.checkpoint, device=args.device)
|
model = init_model(args.config, args.checkpoint, device=args.device)
|
||||||
|
if args.device == 'cpu':
|
||||||
|
model = revert_sync_batchnorm(model)
|
||||||
# test a single image
|
# test a single image
|
||||||
result = inference_model(model, args.img)
|
result = inference_model(model, args.img)
|
||||||
# show the results
|
# show the results
|
||||||
show_result_pyplot(
|
show_result_pyplot(
|
||||||
model,
|
model,
|
||||||
args.img,
|
args.img, [result],
|
||||||
result,
|
title=args.title,
|
||||||
get_palette(args.palette),
|
opacity=args.opacity,
|
||||||
opacity=args.opacity)
|
draw_gt=False,
|
||||||
|
show=False if args.save_dir is not None else True,
|
||||||
|
save_dir=args.save_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -20,8 +20,11 @@
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"from mmengine import revert_sync_batchnorm\n",
|
||||||
"from mmseg.apis import init_model, inference_model, show_result_pyplot\n",
|
"from mmseg.apis import init_model, inference_model, show_result_pyplot\n",
|
||||||
"from mmseg.utils import get_palette"
|
"from mmseg.utils import register_all_modules\n",
|
||||||
|
"register_all_modules()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -56,6 +59,8 @@
|
||||||
"source": [
|
"source": [
|
||||||
"# test a single image\n",
|
"# test a single image\n",
|
||||||
"img = 'demo.png'\n",
|
"img = 'demo.png'\n",
|
||||||
|
"if not torch.cuda.is_available():\n",
|
||||||
|
" model = revert_sync_batchnorm(model)\n",
|
||||||
"result = inference_model(model, img)"
|
"result = inference_model(model, img)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -66,7 +71,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# show the results\n",
|
"# show the results\n",
|
||||||
"show_result_pyplot(model, img, result, get_palette('cityscapes'))"
|
"show_result_pyplot(model, img, result)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -79,7 +84,7 @@
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3",
|
"display_name": "Python 3.10.4 ('pt1.11-v2')",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
|
@ -93,7 +98,7 @@
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.7.0"
|
"version": "3.10.4"
|
||||||
},
|
},
|
||||||
"pycharm": {
|
"pycharm": {
|
||||||
"stem_cell": {
|
"stem_cell": {
|
||||||
|
@ -103,6 +108,11 @@
|
||||||
},
|
},
|
||||||
"source": []
|
"source": []
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "fdab7187f8cbd4ce42bbf864ddb4c4693e7329271a15a7fa96e4bdb82b9302c9"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|
|
@ -2,9 +2,11 @@
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
from mmengine.utils import revert_sync_batchnorm
|
||||||
|
|
||||||
from mmseg.apis import inference_model, init_model
|
from mmseg.apis import inference_model, init_model
|
||||||
from mmseg.utils import get_palette
|
from mmseg.apis.inference import show_result_pyplot
|
||||||
|
from mmseg.utils import register_all_modules
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -51,10 +53,16 @@ def main():
|
||||||
assert args.show or args.output_file, \
|
assert args.show or args.output_file, \
|
||||||
'At least one output should be enabled.'
|
'At least one output should be enabled.'
|
||||||
|
|
||||||
|
register_all_modules()
|
||||||
|
|
||||||
# build the model from a config file and a checkpoint file
|
# build the model from a config file and a checkpoint file
|
||||||
model = init_model(args.config, args.checkpoint, device=args.device)
|
model = init_model(args.config, args.checkpoint, device=args.device)
|
||||||
|
if args.device == 'cpu':
|
||||||
|
model = revert_sync_batchnorm(model)
|
||||||
|
|
||||||
# build input video
|
# build input video
|
||||||
|
if args.video.isdigit():
|
||||||
|
args.video = int(args.video)
|
||||||
cap = cv2.VideoCapture(args.video)
|
cap = cv2.VideoCapture(args.video)
|
||||||
assert (cap.isOpened())
|
assert (cap.isOpened())
|
||||||
input_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
input_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
||||||
|
@ -86,12 +94,7 @@ def main():
|
||||||
result = inference_model(model, frame)
|
result = inference_model(model, frame)
|
||||||
|
|
||||||
# blend raw image and prediction
|
# blend raw image and prediction
|
||||||
draw_img = model.show_result(
|
draw_img = show_result_pyplot(model, frame, [result])
|
||||||
frame,
|
|
||||||
result,
|
|
||||||
palette=get_palette(args.palette),
|
|
||||||
show=False,
|
|
||||||
opacity=args.opacity)
|
|
||||||
|
|
||||||
if args.show:
|
if args.show:
|
||||||
cv2.imshow('video_demo', draw_img)
|
cv2.imshow('video_demo', draw_img)
|
||||||
|
|
|
@ -1,12 +1,18 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import matplotlib.pyplot as plt
|
from typing import Sequence, Union
|
||||||
import mmcv
|
|
||||||
import torch
|
|
||||||
from mmcv.parallel import collate, scatter
|
|
||||||
from mmcv.runner import load_checkpoint
|
|
||||||
|
|
||||||
from mmseg.datasets.transforms import Compose
|
import mmcv
|
||||||
from mmseg.models import build_segmentor
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mmcv.runner import load_checkpoint
|
||||||
|
from mmengine import Config
|
||||||
|
from mmengine.dataset import Compose
|
||||||
|
|
||||||
|
from mmseg.data import SegDataSample
|
||||||
|
from mmseg.models import BaseSegmentor
|
||||||
|
from mmseg.registry import MODELS
|
||||||
|
from mmseg.utils import SampleList
|
||||||
|
from mmseg.visualization import SegLocalVisualizer
|
||||||
|
|
||||||
|
|
||||||
def init_model(config, checkpoint=None, device='cuda:0'):
|
def init_model(config, checkpoint=None, device='cuda:0'):
|
||||||
|
@ -23,13 +29,13 @@ def init_model(config, checkpoint=None, device='cuda:0'):
|
||||||
nn.Module: The constructed segmentor.
|
nn.Module: The constructed segmentor.
|
||||||
"""
|
"""
|
||||||
if isinstance(config, str):
|
if isinstance(config, str):
|
||||||
config = mmcv.Config.fromfile(config)
|
config = Config.fromfile(config)
|
||||||
elif not isinstance(config, mmcv.Config):
|
elif not isinstance(config, mmcv.Config):
|
||||||
raise TypeError('config must be a filename or Config object, '
|
raise TypeError('config must be a filename or Config object, '
|
||||||
'but got {}'.format(type(config)))
|
'but got {}'.format(type(config)))
|
||||||
config.model.pretrained = None
|
config.model.pretrained = None
|
||||||
config.model.train_cfg = None
|
config.model.train_cfg = None
|
||||||
model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))
|
model = MODELS.build(config.model)
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||||
model.CLASSES = checkpoint['meta']['CLASSES']
|
model.CLASSES = checkpoint['meta']['CLASSES']
|
||||||
|
@ -40,34 +46,41 @@ def init_model(config, checkpoint=None, device='cuda:0'):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class LoadImage:
|
ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
|
||||||
"""A simple pipeline to load image."""
|
|
||||||
|
|
||||||
def __call__(self, results):
|
|
||||||
"""Call function to load images into results.
|
|
||||||
|
|
||||||
Args:
|
def _preprare_data(imgs: ImageType, model: BaseSegmentor):
|
||||||
results (dict): A result dict contains the file name
|
|
||||||
of the image to be read.
|
|
||||||
|
|
||||||
Returns:
|
cfg = model.cfg
|
||||||
dict: ``results`` will be returned containing loaded image.
|
if dict(type='LoadAnnotations') in cfg.test_pipeline:
|
||||||
"""
|
cfg.test_pipeline.remove(dict(type='LoadAnnotations'))
|
||||||
|
|
||||||
if isinstance(results['img'], str):
|
is_batch = True
|
||||||
results['filename'] = results['img']
|
if not isinstance(imgs, (list, tuple)):
|
||||||
results['ori_filename'] = results['img']
|
imgs = [imgs]
|
||||||
|
is_batch = False
|
||||||
|
|
||||||
|
if isinstance(imgs[0], np.ndarray):
|
||||||
|
cfg.test_pipeline[0].type = 'LoadImageFromNDArray'
|
||||||
|
|
||||||
|
# TODO: Consider using the singleton pattern to avoid building
|
||||||
|
# a pipeline for each inference
|
||||||
|
pipeline = Compose(cfg.test_pipeline)
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for img in imgs:
|
||||||
|
if isinstance(img, np.ndarray):
|
||||||
|
data_ = dict(img=img)
|
||||||
else:
|
else:
|
||||||
results['filename'] = None
|
data_ = dict(img_path=img)
|
||||||
results['ori_filename'] = None
|
data_ = pipeline(data_)
|
||||||
img = mmcv.imread(results['img'])
|
data.append(data_)
|
||||||
results['img'] = img
|
|
||||||
results['img_shape'] = img.shape
|
return data, is_batch
|
||||||
results['ori_shape'] = img.shape
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def inference_model(model, img):
|
def inference_model(model: BaseSegmentor,
|
||||||
|
img: ImageType) -> Union[SegDataSample, SampleList]:
|
||||||
"""Inference image(s) with the segmentor.
|
"""Inference image(s) with the segmentor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -76,61 +89,70 @@ def inference_model(model, img):
|
||||||
images.
|
images.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(list[Tensor]): The segmentation result.
|
:obj:`SegDataSample` or list[:obj:`SegDataSample`]:
|
||||||
|
If imgs is a list or tuple, the same length list type results
|
||||||
|
will be returned, otherwise return the segmentation results directly.
|
||||||
"""
|
"""
|
||||||
cfg = model.cfg
|
|
||||||
device = next(model.parameters()).device # model device
|
|
||||||
# build the data pipeline
|
|
||||||
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
|
|
||||||
test_pipeline = Compose(test_pipeline)
|
|
||||||
# prepare data
|
# prepare data
|
||||||
data = dict(img=img)
|
data, is_batch = _preprare_data(img, model)
|
||||||
data = test_pipeline(data)
|
|
||||||
data = collate([data], samples_per_gpu=1)
|
|
||||||
if next(model.parameters()).is_cuda:
|
|
||||||
# scatter to specified GPU
|
|
||||||
data = scatter(data, [device])[0]
|
|
||||||
else:
|
|
||||||
data['img_metas'] = [i.data[0] for i in data['img_metas']]
|
|
||||||
|
|
||||||
# forward the model
|
# forward the model
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
result = model(return_loss=False, rescale=True, **data)
|
results = model.test_step(data)
|
||||||
return result
|
|
||||||
|
return results if is_batch else results[0]
|
||||||
|
|
||||||
|
|
||||||
def show_result_pyplot(model,
|
def show_result_pyplot(model: BaseSegmentor,
|
||||||
img,
|
img: Union[str, np.ndarray],
|
||||||
result,
|
result: SampleList,
|
||||||
palette=None,
|
opacity: float = 0.5,
|
||||||
fig_size=(15, 10),
|
title: str = '',
|
||||||
opacity=0.5,
|
draw_gt: bool = True,
|
||||||
title='',
|
draw_pred: bool = True,
|
||||||
block=True):
|
wait_time: float = 0,
|
||||||
|
show: bool = True,
|
||||||
|
save_dir=None):
|
||||||
"""Visualize the segmentation results on the image.
|
"""Visualize the segmentation results on the image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): The loaded segmentor.
|
model (nn.Module): The loaded segmentor.
|
||||||
img (str or np.ndarray): Image filename or loaded image.
|
img (str or np.ndarray): Image filename or loaded image.
|
||||||
result (list): The segmentation result.
|
result (list): The prediction SegDataSample result.
|
||||||
palette (list[list[int]]] | None): The palette of segmentation
|
|
||||||
map. If None is given, random palette will be generated.
|
|
||||||
Default: None
|
|
||||||
fig_size (tuple): Figure size of the pyplot figure.
|
|
||||||
opacity(float): Opacity of painted segmentation map.
|
opacity(float): Opacity of painted segmentation map.
|
||||||
Default 0.5.
|
Default 0.5. Must be in (0, 1] range.
|
||||||
Must be in (0, 1] range.
|
|
||||||
title (str): The title of pyplot figure.
|
title (str): The title of pyplot figure.
|
||||||
Default is ''.
|
Default is ''.
|
||||||
block (bool): Whether to block the pyplot figure.
|
draw_gt (bool): Whether to draw GT SegDataSample. Default to True.
|
||||||
Default is True.
|
draw_pred (bool): Whether to draw Prediction SegDataSample.
|
||||||
|
Defaults to True.
|
||||||
|
wait_time (float): The interval of show (s). Defaults to 0.
|
||||||
|
show (bool): Whether to display the drawn image.
|
||||||
|
Default to True.
|
||||||
|
save_dir (str, optional): Save file dir for all storage backends.
|
||||||
|
If it is None, the backend storage will not save any data.
|
||||||
"""
|
"""
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model = model.module
|
model = model.module
|
||||||
img = model.show_result(
|
if isinstance(img, str):
|
||||||
img, result, palette=palette, show=False, opacity=opacity)
|
image = mmcv.imread(img)
|
||||||
plt.figure(figsize=fig_size)
|
else:
|
||||||
plt.imshow(mmcv.bgr2rgb(img))
|
image = img
|
||||||
plt.title(title)
|
if save_dir is not None:
|
||||||
plt.tight_layout()
|
mmcv.mkdir_or_exist(save_dir)
|
||||||
plt.show(block=block)
|
# init visualizer
|
||||||
|
visualizer = SegLocalVisualizer(
|
||||||
|
vis_backends=[dict(type='LocalVisBackend')],
|
||||||
|
save_dir=save_dir,
|
||||||
|
alpha=opacity)
|
||||||
|
visualizer.dataset_meta = dict(
|
||||||
|
classes=model.CLASSES, palette=model.PALETTE)
|
||||||
|
visualizer.add_datasample(
|
||||||
|
name=title,
|
||||||
|
image=image,
|
||||||
|
pred_sample=result[0],
|
||||||
|
draw_gt=draw_gt,
|
||||||
|
draw_pred=draw_pred,
|
||||||
|
wait_time=wait_time,
|
||||||
|
show=show)
|
||||||
|
return visualizer.get_image()
|
||||||
|
|
|
@ -2,9 +2,8 @@
|
||||||
from .hooks import SegVisualizationHook
|
from .hooks import SegVisualizationHook
|
||||||
from .optimizers import (LayerDecayOptimizerConstructor,
|
from .optimizers import (LayerDecayOptimizerConstructor,
|
||||||
LearningRateDecayOptimizerConstructor)
|
LearningRateDecayOptimizerConstructor)
|
||||||
from .visualization import SegLocalVisualizer
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
|
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
|
||||||
'SegVisualizationHook', 'SegLocalVisualizer'
|
'SegVisualizationHook'
|
||||||
]
|
]
|
||||||
|
|
|
@ -8,8 +8,8 @@ from mmengine.hooks import Hook
|
||||||
from mmengine.runner import Runner
|
from mmengine.runner import Runner
|
||||||
|
|
||||||
from mmseg.data import SegDataSample
|
from mmseg.data import SegDataSample
|
||||||
from mmseg.engine.visualization import SegLocalVisualizer
|
|
||||||
from mmseg.registry import HOOKS
|
from mmseg.registry import HOOKS
|
||||||
|
from mmseg.visualization import SegLocalVisualizer
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
|
|
|
@ -7,7 +7,7 @@ from mmengine.data import PixelData
|
||||||
|
|
||||||
from mmseg.data import SegDataSample
|
from mmseg.data import SegDataSample
|
||||||
from mmseg.engine.hooks import SegVisualizationHook
|
from mmseg.engine.hooks import SegVisualizationHook
|
||||||
from mmseg.engine.visualization import SegLocalVisualizer
|
from mmseg.visualization import SegLocalVisualizer
|
||||||
|
|
||||||
|
|
||||||
class TestVisualizationHook(TestCase):
|
class TestVisualizationHook(TestCase):
|
||||||
|
|
|
@ -10,7 +10,7 @@ import torch
|
||||||
from mmengine.data import PixelData
|
from mmengine.data import PixelData
|
||||||
|
|
||||||
from mmseg.data import SegDataSample
|
from mmseg.data import SegDataSample
|
||||||
from mmseg.engine.visualization import SegLocalVisualizer
|
from mmseg.visualization import SegLocalVisualizer
|
||||||
|
|
||||||
|
|
||||||
class TestSegLocalVisualizer(TestCase):
|
class TestSegLocalVisualizer(TestCase):
|
Loading…
Reference in New Issue