[Enhancement]Add out_file in add_datasample to directly save image (#2090)

* [Enhancement]Add `out_file` in add_datasample to for save vis image directly

* comments

* ut
This commit is contained in:
Miao Zheng 2022-09-20 15:23:13 +08:00 committed by GitHub
parent 230246f557
commit 2a183283f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 24 deletions

View File

@ -199,9 +199,8 @@ def show_result_pyplot(model: BaseSegmentor,
draw_gt=draw_gt, draw_gt=draw_gt,
draw_pred=draw_pred, draw_pred=draw_pred,
wait_time=wait_time, wait_time=wait_time,
out_file=out_file,
show=show) show=show)
vis_img = visualizer.get_image() vis_img = visualizer.get_image()
if out_file is not None:
mmcv.imwrite(vis_img, out_file)
return vis_img return vis_img

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import mmcv
import numpy as np import numpy as np
from mmengine.dist import master_only from mmengine.dist import master_only
from mmengine.structures import PixelData from mmengine.structures import PixelData
@ -99,15 +100,18 @@ class SegLocalVisualizer(Visualizer):
return self.get_image() return self.get_image()
@master_only @master_only
def add_datasample(self, def add_datasample(
name: str, self,
image: np.ndarray, name: str,
data_sample: Optional[SegDataSample] = None, image: np.ndarray,
draw_gt: bool = True, data_sample: Optional[SegDataSample] = None,
draw_pred: bool = True, draw_gt: bool = True,
show: bool = False, draw_pred: bool = True,
wait_time: float = 0, show: bool = False,
step: int = 0) -> None: wait_time: float = 0,
# TODO: Supported in mmengine's Viusalizer.
out_file: Optional[str] = None,
step: int = 0) -> None:
"""Draw datasample and save to all backends. """Draw datasample and save to all backends.
- If GT and prediction are plotted at the same time, they are - If GT and prediction are plotted at the same time, they are
@ -115,6 +119,9 @@ class SegLocalVisualizer(Visualizer):
ground truth and the right image is the prediction. ground truth and the right image is the prediction.
- If ``show`` is True, all storage backends are ignored, and - If ``show`` is True, all storage backends are ignored, and
the images will be displayed in a local window. the images will be displayed in a local window.
- If ``out_file`` is specified, the drawn image will be
saved to ``out_file``. it is usually used when the display
is not available.
Args: Args:
name (str): The image identifier. name (str): The image identifier.
@ -128,6 +135,7 @@ class SegLocalVisualizer(Visualizer):
Defaults to True. Defaults to True.
show (bool): Whether to display the drawn image. Default to False. show (bool): Whether to display the drawn image. Default to False.
wait_time (float): The interval of show (s). Defaults to 0. wait_time (float): The interval of show (s). Defaults to 0.
out_file (str): Path to output file. Defaults to None.
step (int): Global step value to record. Defaults to 0. step (int): Global step value to record. Defaults to 0.
""" """
classes = self.dataset_meta.get('classes', None) classes = self.dataset_meta.get('classes', None)
@ -166,5 +174,8 @@ class SegLocalVisualizer(Visualizer):
if show: if show:
self.show(drawn_img, win_name=name, wait_time=wait_time) self.show(drawn_img, win_name=name, wait_time=wait_time)
if out_file is not None:
mmcv.imwrite(drawn_img, out_file)
else: else:
self.add_image(name, drawn_img, step) self.add_image(name, drawn_img, step)

View File

@ -118,19 +118,14 @@ class TestSegLocalVisualizer(TestCase):
[255, 0, 0], [0, 0, 142], [0, 0, 70], [255, 0, 0], [0, 0, 142], [0, 0, 70],
[0, 60, 100], [0, 80, 100], [0, 0, 230], [0, 60, 100], [0, 80, 100], [0, 0, 230],
[119, 11, 32]]) [119, 11, 32]])
seg_local_visualizer.add_datasample(out_file, image,
data_sample)
# test out_file # test out_file
seg_local_visualizer.add_datasample(out_file, image, seg_local_visualizer.add_datasample(
data_sample) out_file,
assert os.path.exists( image,
osp.join(tmp_dir, 'vis_data', 'vis_image', data_sample,
out_file + '_0.png')) out_file=osp.join(tmp_dir, 'test.png'))
drawn_img = cv2.imread( self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image', osp.join(tmp_dir, 'test.png'), (h, w, 3))
out_file + '_0.png'))
assert drawn_img.shape == (h, w, 3)
# test gt_instances and pred_instances # test gt_instances and pred_instances
pred_sem_seg_data = dict( pred_sem_seg_data = dict(
@ -139,12 +134,13 @@ class TestSegLocalVisualizer(TestCase):
data_sample.pred_sem_seg = pred_sem_seg data_sample.pred_sem_seg = pred_sem_seg
# test draw prediction with gt
seg_local_visualizer.add_datasample(out_file, image, seg_local_visualizer.add_datasample(out_file, image,
data_sample) data_sample)
self._assert_image_and_shape( self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image', osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w * 2, 3)) out_file + '_0.png'), (h, w * 2, 3))
# test draw prediction without gt
seg_local_visualizer.add_datasample( seg_local_visualizer.add_datasample(
out_file, image, data_sample, draw_gt=False) out_file, image, data_sample, draw_gt=False)
self._assert_image_and_shape( self._assert_image_and_shape(