Add opacity option to show_result (#425)

This commit is contained in:
David de la Iglesia Castro 2021-03-23 04:34:38 +01:00 committed by GitHub
parent 1722010396
commit 9cbb4b1288
6 changed files with 44 additions and 11 deletions

View File

@ -15,6 +15,11 @@ def main():
'--palette', '--palette',
default='cityscapes', default='cityscapes',
help='Color palette used for segmentation map') help='Color palette used for segmentation map')
parser.add_argument(
'--opacity',
type=float,
default=0.5,
help='Opacity of painted segmentation map. In (0, 1] range.')
args = parser.parse_args() args = parser.parse_args()
# build the model from a config file and a checkpoint file # build the model from a config file and a checkpoint file
@ -22,7 +27,12 @@ def main():
# test a single image # test a single image
result = inference_segmentor(model, args.img) result = inference_segmentor(model, args.img)
# show the results # show the results
show_result_pyplot(model, args.img, result, get_palette(args.palette)) show_result_pyplot(
model,
args.img,
result,
get_palette(args.palette),
opacity=args.opacity)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -166,7 +166,8 @@ result = inference_segmentor(model, img)
# visualize the results in a new window # visualize the results in a new window
model.show_result(img, result, show=True) model.show_result(img, result, show=True)
# or save the visualization results to image files # or save the visualization results to image files
model.show_result(img, result, out_file='result.jpg') # you can change the opacity of the painted segmentation map in (0, 1].
model.show_result(img, result, out_file='result.jpg', opacity=0.5)
# test a video and show the results # test a video and show the results
video = mmcv.VideoReader('video.mp4') video = mmcv.VideoReader('video.mp4')

View File

@ -98,7 +98,12 @@ def inference_segmentor(model, img):
return result return result
def show_result_pyplot(model, img, result, palette=None, fig_size=(15, 10)): def show_result_pyplot(model,
img,
result,
palette=None,
fig_size=(15, 10),
opacity=0.5):
"""Visualize the segmentation results on the image. """Visualize the segmentation results on the image.
Args: Args:
@ -109,10 +114,14 @@ def show_result_pyplot(model, img, result, palette=None, fig_size=(15, 10)):
map. If None is given, random palette will be generated. map. If None is given, random palette will be generated.
Default: None Default: None
fig_size (tuple): Figure size of the pyplot figure. fig_size (tuple): Figure size of the pyplot figure.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
""" """
if hasattr(model, 'module'): if hasattr(model, 'module'):
model = model.module model = model.module
img = model.show_result(img, result, palette=palette, show=False) img = model.show_result(
img, result, palette=palette, show=False, opacity=opacity)
plt.figure(figsize=fig_size) plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img)) plt.imshow(mmcv.bgr2rgb(img))
plt.show() plt.show()

View File

@ -35,7 +35,8 @@ def single_gpu_test(model,
data_loader, data_loader,
show=False, show=False,
out_dir=None, out_dir=None,
efficient_test=False): efficient_test=False,
opacity=0.5):
"""Test with single GPU. """Test with single GPU.
Args: Args:
@ -46,7 +47,9 @@ def single_gpu_test(model,
the directory to save output results. the directory to save output results.
efficient_test (bool): Whether save the results as local numpy files to efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Default: False. save CPU memory during evaluation. Default: False.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Returns: Returns:
list: The prediction results. list: The prediction results.
""" """
@ -82,7 +85,8 @@ def single_gpu_test(model,
result, result,
palette=dataset.PALETTE, palette=dataset.PALETTE,
show=show, show=show,
out_file=out_file) out_file=out_file,
opacity=opacity)
if isinstance(result, list): if isinstance(result, list):
if efficient_test: if efficient_test:

View File

@ -212,7 +212,8 @@ class BaseSegmentor(nn.Module):
win_name='', win_name='',
show=False, show=False,
wait_time=0, wait_time=0,
out_file=None): out_file=None,
opacity=0.5):
"""Draw `result` over `img`. """Draw `result` over `img`.
Args: Args:
@ -229,7 +230,9 @@ class BaseSegmentor(nn.Module):
Default: False. Default: False.
out_file (str or None): The filename to write the image. out_file (str or None): The filename to write the image.
Default: None. Default: None.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Returns: Returns:
img (Tensor): Only if not `show` or `out_file` img (Tensor): Only if not `show` or `out_file`
""" """
@ -246,13 +249,14 @@ class BaseSegmentor(nn.Module):
assert palette.shape[0] == len(self.CLASSES) assert palette.shape[0] == len(self.CLASSES)
assert palette.shape[1] == 3 assert palette.shape[1] == 3
assert len(palette.shape) == 2 assert len(palette.shape) == 2
assert 0 < opacity <= 1.0
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette): for label, color in enumerate(palette):
color_seg[seg == label, :] = color color_seg[seg == label, :] = color
# convert to BGR # convert to BGR
color_seg = color_seg[..., ::-1] color_seg = color_seg[..., ::-1]
img = img * 0.5 + color_seg * 0.5 img = img * (1 - opacity) + color_seg * opacity
img = img.astype(np.uint8) img = img.astype(np.uint8)
# if out_file specified, do not show image in window # if out_file specified, do not show image in window
if out_file is not None: if out_file is not None:

View File

@ -55,6 +55,11 @@ def parse_args():
choices=['none', 'pytorch', 'slurm', 'mpi'], choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none', default='none',
help='job launcher') help='job launcher')
parser.add_argument(
'--opacity',
type=float,
default=0.5,
help='Opacity of painted segmentation map. In (0, 1] range.')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ: if 'LOCAL_RANK' not in os.environ:
@ -123,7 +128,7 @@ def main():
if not distributed: if not distributed:
model = MMDataParallel(model, device_ids=[0]) model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
efficient_test) efficient_test, args.opacity)
else: else:
model = MMDistributedDataParallel( model = MMDistributedDataParallel(
model.cuda(), model.cuda(),