[Fix] Fix demo scripts (#1815)
* [Feature] Add SegVisualizer * change name to visualizer_example * fix inference api * fix video demo and refine inference api * fix * mmseg compose * set default device to cuda:0 * fix import * update dir * rm engine/visualizer ut * refine inference api and docs * rename Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>pull/1850/head
parent
3bbdd6dc4a
commit
5d9650838e
|
@ -1,8 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from mmengine.utils import revert_sync_batchnorm
|
||||
|
||||
from mmseg.apis import inference_model, init_model, show_result_pyplot
|
||||
from mmseg.utils import get_palette
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -13,27 +15,35 @@ def main():
|
|||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
parser.add_argument(
|
||||
'--palette',
|
||||
default='cityscapes',
|
||||
help='Color palette used for segmentation map')
|
||||
'--save-dir',
|
||||
default=None,
|
||||
help='Save file dir for all storage backends.')
|
||||
parser.add_argument(
|
||||
'--opacity',
|
||||
type=float,
|
||||
default=0.5,
|
||||
help='Opacity of painted segmentation map. In (0, 1] range.')
|
||||
parser.add_argument(
|
||||
'--title', default='result', help='The image identifier.')
|
||||
args = parser.parse_args()
|
||||
|
||||
register_all_modules()
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_model(args.config, args.checkpoint, device=args.device)
|
||||
if args.device == 'cpu':
|
||||
model = revert_sync_batchnorm(model)
|
||||
# test a single image
|
||||
result = inference_model(model, args.img)
|
||||
# show the results
|
||||
show_result_pyplot(
|
||||
model,
|
||||
args.img,
|
||||
result,
|
||||
get_palette(args.palette),
|
||||
opacity=args.opacity)
|
||||
args.img, [result],
|
||||
title=args.title,
|
||||
opacity=args.opacity,
|
||||
draw_gt=False,
|
||||
show=False if args.save_dir is not None else True,
|
||||
save_dir=args.save_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -20,8 +20,11 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from mmengine import revert_sync_batchnorm\n",
|
||||
"from mmseg.apis import init_model, inference_model, show_result_pyplot\n",
|
||||
"from mmseg.utils import get_palette"
|
||||
"from mmseg.utils import register_all_modules\n",
|
||||
"register_all_modules()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -56,6 +59,8 @@
|
|||
"source": [
|
||||
"# test a single image\n",
|
||||
"img = 'demo.png'\n",
|
||||
"if not torch.cuda.is_available():\n",
|
||||
" model = revert_sync_batchnorm(model)\n",
|
||||
"result = inference_model(model, img)"
|
||||
]
|
||||
},
|
||||
|
@ -66,7 +71,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"# show the results\n",
|
||||
"show_result_pyplot(model, img, result, get_palette('cityscapes'))"
|
||||
"show_result_pyplot(model, img, result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -79,7 +84,7 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3.10.4 ('pt1.11-v2')",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
@ -93,7 +98,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.0"
|
||||
"version": "3.10.4"
|
||||
},
|
||||
"pycharm": {
|
||||
"stem_cell": {
|
||||
|
@ -103,6 +108,11 @@
|
|||
},
|
||||
"source": []
|
||||
}
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "fdab7187f8cbd4ce42bbf864ddb4c4693e7329271a15a7fa96e4bdb82b9302c9"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -2,9 +2,11 @@
|
|||
from argparse import ArgumentParser
|
||||
|
||||
import cv2
|
||||
from mmengine.utils import revert_sync_batchnorm
|
||||
|
||||
from mmseg.apis import inference_model, init_model
|
||||
from mmseg.utils import get_palette
|
||||
from mmseg.apis.inference import show_result_pyplot
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -51,10 +53,16 @@ def main():
|
|||
assert args.show or args.output_file, \
|
||||
'At least one output should be enabled.'
|
||||
|
||||
register_all_modules()
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_model(args.config, args.checkpoint, device=args.device)
|
||||
if args.device == 'cpu':
|
||||
model = revert_sync_batchnorm(model)
|
||||
|
||||
# build input video
|
||||
if args.video.isdigit():
|
||||
args.video = int(args.video)
|
||||
cap = cv2.VideoCapture(args.video)
|
||||
assert (cap.isOpened())
|
||||
input_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
|
@ -86,12 +94,7 @@ def main():
|
|||
result = inference_model(model, frame)
|
||||
|
||||
# blend raw image and prediction
|
||||
draw_img = model.show_result(
|
||||
frame,
|
||||
result,
|
||||
palette=get_palette(args.palette),
|
||||
show=False,
|
||||
opacity=args.opacity)
|
||||
draw_img = show_result_pyplot(model, frame, [result])
|
||||
|
||||
if args.show:
|
||||
cv2.imshow('video_demo', draw_img)
|
||||
|
|
|
@ -1,12 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import matplotlib.pyplot as plt
|
||||
import mmcv
|
||||
import torch
|
||||
from mmcv.parallel import collate, scatter
|
||||
from mmcv.runner import load_checkpoint
|
||||
from typing import Sequence, Union
|
||||
|
||||
from mmseg.datasets.transforms import Compose
|
||||
from mmseg.models import build_segmentor
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmengine import Config
|
||||
from mmengine.dataset import Compose
|
||||
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.models import BaseSegmentor
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
|
||||
|
||||
def init_model(config, checkpoint=None, device='cuda:0'):
|
||||
|
@ -23,13 +29,13 @@ def init_model(config, checkpoint=None, device='cuda:0'):
|
|||
nn.Module: The constructed segmentor.
|
||||
"""
|
||||
if isinstance(config, str):
|
||||
config = mmcv.Config.fromfile(config)
|
||||
config = Config.fromfile(config)
|
||||
elif not isinstance(config, mmcv.Config):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
'but got {}'.format(type(config)))
|
||||
config.model.pretrained = None
|
||||
config.model.train_cfg = None
|
||||
model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))
|
||||
model = MODELS.build(config.model)
|
||||
if checkpoint is not None:
|
||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
model.CLASSES = checkpoint['meta']['CLASSES']
|
||||
|
@ -40,34 +46,41 @@ def init_model(config, checkpoint=None, device='cuda:0'):
|
|||
return model
|
||||
|
||||
|
||||
class LoadImage:
|
||||
"""A simple pipeline to load image."""
|
||||
ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to load images into results.
|
||||
|
||||
Args:
|
||||
results (dict): A result dict contains the file name
|
||||
of the image to be read.
|
||||
def _preprare_data(imgs: ImageType, model: BaseSegmentor):
|
||||
|
||||
Returns:
|
||||
dict: ``results`` will be returned containing loaded image.
|
||||
"""
|
||||
cfg = model.cfg
|
||||
if dict(type='LoadAnnotations') in cfg.test_pipeline:
|
||||
cfg.test_pipeline.remove(dict(type='LoadAnnotations'))
|
||||
|
||||
if isinstance(results['img'], str):
|
||||
results['filename'] = results['img']
|
||||
results['ori_filename'] = results['img']
|
||||
is_batch = True
|
||||
if not isinstance(imgs, (list, tuple)):
|
||||
imgs = [imgs]
|
||||
is_batch = False
|
||||
|
||||
if isinstance(imgs[0], np.ndarray):
|
||||
cfg.test_pipeline[0].type = 'LoadImageFromNDArray'
|
||||
|
||||
# TODO: Consider using the singleton pattern to avoid building
|
||||
# a pipeline for each inference
|
||||
pipeline = Compose(cfg.test_pipeline)
|
||||
|
||||
data = []
|
||||
for img in imgs:
|
||||
if isinstance(img, np.ndarray):
|
||||
data_ = dict(img=img)
|
||||
else:
|
||||
results['filename'] = None
|
||||
results['ori_filename'] = None
|
||||
img = mmcv.imread(results['img'])
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
return results
|
||||
data_ = dict(img_path=img)
|
||||
data_ = pipeline(data_)
|
||||
data.append(data_)
|
||||
|
||||
return data, is_batch
|
||||
|
||||
|
||||
def inference_model(model, img):
|
||||
def inference_model(model: BaseSegmentor,
|
||||
img: ImageType) -> Union[SegDataSample, SampleList]:
|
||||
"""Inference image(s) with the segmentor.
|
||||
|
||||
Args:
|
||||
|
@ -76,61 +89,70 @@ def inference_model(model, img):
|
|||
images.
|
||||
|
||||
Returns:
|
||||
(list[Tensor]): The segmentation result.
|
||||
:obj:`SegDataSample` or list[:obj:`SegDataSample`]:
|
||||
If imgs is a list or tuple, the same length list type results
|
||||
will be returned, otherwise return the segmentation results directly.
|
||||
"""
|
||||
cfg = model.cfg
|
||||
device = next(model.parameters()).device # model device
|
||||
# build the data pipeline
|
||||
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
|
||||
test_pipeline = Compose(test_pipeline)
|
||||
# prepare data
|
||||
data = dict(img=img)
|
||||
data = test_pipeline(data)
|
||||
data = collate([data], samples_per_gpu=1)
|
||||
if next(model.parameters()).is_cuda:
|
||||
# scatter to specified GPU
|
||||
data = scatter(data, [device])[0]
|
||||
else:
|
||||
data['img_metas'] = [i.data[0] for i in data['img_metas']]
|
||||
data, is_batch = _preprare_data(img, model)
|
||||
|
||||
# forward the model
|
||||
with torch.no_grad():
|
||||
result = model(return_loss=False, rescale=True, **data)
|
||||
return result
|
||||
results = model.test_step(data)
|
||||
|
||||
return results if is_batch else results[0]
|
||||
|
||||
|
||||
def show_result_pyplot(model,
|
||||
img,
|
||||
result,
|
||||
palette=None,
|
||||
fig_size=(15, 10),
|
||||
opacity=0.5,
|
||||
title='',
|
||||
block=True):
|
||||
def show_result_pyplot(model: BaseSegmentor,
|
||||
img: Union[str, np.ndarray],
|
||||
result: SampleList,
|
||||
opacity: float = 0.5,
|
||||
title: str = '',
|
||||
draw_gt: bool = True,
|
||||
draw_pred: bool = True,
|
||||
wait_time: float = 0,
|
||||
show: bool = True,
|
||||
save_dir=None):
|
||||
"""Visualize the segmentation results on the image.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The loaded segmentor.
|
||||
img (str or np.ndarray): Image filename or loaded image.
|
||||
result (list): The segmentation result.
|
||||
palette (list[list[int]]] | None): The palette of segmentation
|
||||
map. If None is given, random palette will be generated.
|
||||
Default: None
|
||||
fig_size (tuple): Figure size of the pyplot figure.
|
||||
result (list): The prediction SegDataSample result.
|
||||
opacity(float): Opacity of painted segmentation map.
|
||||
Default 0.5.
|
||||
Must be in (0, 1] range.
|
||||
Default 0.5. Must be in (0, 1] range.
|
||||
title (str): The title of pyplot figure.
|
||||
Default is ''.
|
||||
block (bool): Whether to block the pyplot figure.
|
||||
Default is True.
|
||||
draw_gt (bool): Whether to draw GT SegDataSample. Default to True.
|
||||
draw_pred (bool): Whether to draw Prediction SegDataSample.
|
||||
Defaults to True.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
show (bool): Whether to display the drawn image.
|
||||
Default to True.
|
||||
save_dir (str, optional): Save file dir for all storage backends.
|
||||
If it is None, the backend storage will not save any data.
|
||||
"""
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
img = model.show_result(
|
||||
img, result, palette=palette, show=False, opacity=opacity)
|
||||
plt.figure(figsize=fig_size)
|
||||
plt.imshow(mmcv.bgr2rgb(img))
|
||||
plt.title(title)
|
||||
plt.tight_layout()
|
||||
plt.show(block=block)
|
||||
if isinstance(img, str):
|
||||
image = mmcv.imread(img)
|
||||
else:
|
||||
image = img
|
||||
if save_dir is not None:
|
||||
mmcv.mkdir_or_exist(save_dir)
|
||||
# init visualizer
|
||||
visualizer = SegLocalVisualizer(
|
||||
vis_backends=[dict(type='LocalVisBackend')],
|
||||
save_dir=save_dir,
|
||||
alpha=opacity)
|
||||
visualizer.dataset_meta = dict(
|
||||
classes=model.CLASSES, palette=model.PALETTE)
|
||||
visualizer.add_datasample(
|
||||
name=title,
|
||||
image=image,
|
||||
pred_sample=result[0],
|
||||
draw_gt=draw_gt,
|
||||
draw_pred=draw_pred,
|
||||
wait_time=wait_time,
|
||||
show=show)
|
||||
return visualizer.get_image()
|
||||
|
|
|
@ -2,9 +2,8 @@
|
|||
from .hooks import SegVisualizationHook
|
||||
from .optimizers import (LayerDecayOptimizerConstructor,
|
||||
LearningRateDecayOptimizerConstructor)
|
||||
from .visualization import SegLocalVisualizer
|
||||
|
||||
__all__ = [
|
||||
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
|
||||
'SegVisualizationHook', 'SegLocalVisualizer'
|
||||
'SegVisualizationHook'
|
||||
]
|
||||
|
|
|
@ -8,8 +8,8 @@ from mmengine.hooks import Hook
|
|||
from mmengine.runner import Runner
|
||||
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.engine.visualization import SegLocalVisualizer
|
||||
from mmseg.registry import HOOKS
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmengine.data import PixelData
|
|||
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.engine.hooks import SegVisualizationHook
|
||||
from mmseg.engine.visualization import SegLocalVisualizer
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
|
||||
|
||||
class TestVisualizationHook(TestCase):
|
||||
|
|
|
@ -10,7 +10,7 @@ import torch
|
|||
from mmengine.data import PixelData
|
||||
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.engine.visualization import SegLocalVisualizer
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
|
||||
|
||||
class TestSegLocalVisualizer(TestCase):
|
Loading…
Reference in New Issue