mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Add opacity option to show_result (#425)
This commit is contained in:
parent
1722010396
commit
9cbb4b1288
@ -15,6 +15,11 @@ def main():
|
|||||||
'--palette',
|
'--palette',
|
||||||
default='cityscapes',
|
default='cityscapes',
|
||||||
help='Color palette used for segmentation map')
|
help='Color palette used for segmentation map')
|
||||||
|
parser.add_argument(
|
||||||
|
'--opacity',
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help='Opacity of painted segmentation map. In (0, 1] range.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# build the model from a config file and a checkpoint file
|
# build the model from a config file and a checkpoint file
|
||||||
@ -22,7 +27,12 @@ def main():
|
|||||||
# test a single image
|
# test a single image
|
||||||
result = inference_segmentor(model, args.img)
|
result = inference_segmentor(model, args.img)
|
||||||
# show the results
|
# show the results
|
||||||
show_result_pyplot(model, args.img, result, get_palette(args.palette))
|
show_result_pyplot(
|
||||||
|
model,
|
||||||
|
args.img,
|
||||||
|
result,
|
||||||
|
get_palette(args.palette),
|
||||||
|
opacity=args.opacity)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -166,7 +166,8 @@ result = inference_segmentor(model, img)
|
|||||||
# visualize the results in a new window
|
# visualize the results in a new window
|
||||||
model.show_result(img, result, show=True)
|
model.show_result(img, result, show=True)
|
||||||
# or save the visualization results to image files
|
# or save the visualization results to image files
|
||||||
model.show_result(img, result, out_file='result.jpg')
|
# you can change the opacity of the painted segmentation map in (0, 1].
|
||||||
|
model.show_result(img, result, out_file='result.jpg', opacity=0.5)
|
||||||
|
|
||||||
# test a video and show the results
|
# test a video and show the results
|
||||||
video = mmcv.VideoReader('video.mp4')
|
video = mmcv.VideoReader('video.mp4')
|
||||||
|
@ -98,7 +98,12 @@ def inference_segmentor(model, img):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def show_result_pyplot(model, img, result, palette=None, fig_size=(15, 10)):
|
def show_result_pyplot(model,
|
||||||
|
img,
|
||||||
|
result,
|
||||||
|
palette=None,
|
||||||
|
fig_size=(15, 10),
|
||||||
|
opacity=0.5):
|
||||||
"""Visualize the segmentation results on the image.
|
"""Visualize the segmentation results on the image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -109,10 +114,14 @@ def show_result_pyplot(model, img, result, palette=None, fig_size=(15, 10)):
|
|||||||
map. If None is given, random palette will be generated.
|
map. If None is given, random palette will be generated.
|
||||||
Default: None
|
Default: None
|
||||||
fig_size (tuple): Figure size of the pyplot figure.
|
fig_size (tuple): Figure size of the pyplot figure.
|
||||||
|
opacity(float): Opacity of painted segmentation map.
|
||||||
|
Default 0.5.
|
||||||
|
Must be in (0, 1] range.
|
||||||
"""
|
"""
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model = model.module
|
model = model.module
|
||||||
img = model.show_result(img, result, palette=palette, show=False)
|
img = model.show_result(
|
||||||
|
img, result, palette=palette, show=False, opacity=opacity)
|
||||||
plt.figure(figsize=fig_size)
|
plt.figure(figsize=fig_size)
|
||||||
plt.imshow(mmcv.bgr2rgb(img))
|
plt.imshow(mmcv.bgr2rgb(img))
|
||||||
plt.show()
|
plt.show()
|
||||||
|
@ -35,7 +35,8 @@ def single_gpu_test(model,
|
|||||||
data_loader,
|
data_loader,
|
||||||
show=False,
|
show=False,
|
||||||
out_dir=None,
|
out_dir=None,
|
||||||
efficient_test=False):
|
efficient_test=False,
|
||||||
|
opacity=0.5):
|
||||||
"""Test with single GPU.
|
"""Test with single GPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -46,7 +47,9 @@ def single_gpu_test(model,
|
|||||||
the directory to save output results.
|
the directory to save output results.
|
||||||
efficient_test (bool): Whether save the results as local numpy files to
|
efficient_test (bool): Whether save the results as local numpy files to
|
||||||
save CPU memory during evaluation. Default: False.
|
save CPU memory during evaluation. Default: False.
|
||||||
|
opacity(float): Opacity of painted segmentation map.
|
||||||
|
Default 0.5.
|
||||||
|
Must be in (0, 1] range.
|
||||||
Returns:
|
Returns:
|
||||||
list: The prediction results.
|
list: The prediction results.
|
||||||
"""
|
"""
|
||||||
@ -82,7 +85,8 @@ def single_gpu_test(model,
|
|||||||
result,
|
result,
|
||||||
palette=dataset.PALETTE,
|
palette=dataset.PALETTE,
|
||||||
show=show,
|
show=show,
|
||||||
out_file=out_file)
|
out_file=out_file,
|
||||||
|
opacity=opacity)
|
||||||
|
|
||||||
if isinstance(result, list):
|
if isinstance(result, list):
|
||||||
if efficient_test:
|
if efficient_test:
|
||||||
|
@ -212,7 +212,8 @@ class BaseSegmentor(nn.Module):
|
|||||||
win_name='',
|
win_name='',
|
||||||
show=False,
|
show=False,
|
||||||
wait_time=0,
|
wait_time=0,
|
||||||
out_file=None):
|
out_file=None,
|
||||||
|
opacity=0.5):
|
||||||
"""Draw `result` over `img`.
|
"""Draw `result` over `img`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -229,7 +230,9 @@ class BaseSegmentor(nn.Module):
|
|||||||
Default: False.
|
Default: False.
|
||||||
out_file (str or None): The filename to write the image.
|
out_file (str or None): The filename to write the image.
|
||||||
Default: None.
|
Default: None.
|
||||||
|
opacity(float): Opacity of painted segmentation map.
|
||||||
|
Default 0.5.
|
||||||
|
Must be in (0, 1] range.
|
||||||
Returns:
|
Returns:
|
||||||
img (Tensor): Only if not `show` or `out_file`
|
img (Tensor): Only if not `show` or `out_file`
|
||||||
"""
|
"""
|
||||||
@ -246,13 +249,14 @@ class BaseSegmentor(nn.Module):
|
|||||||
assert palette.shape[0] == len(self.CLASSES)
|
assert palette.shape[0] == len(self.CLASSES)
|
||||||
assert palette.shape[1] == 3
|
assert palette.shape[1] == 3
|
||||||
assert len(palette.shape) == 2
|
assert len(palette.shape) == 2
|
||||||
|
assert 0 < opacity <= 1.0
|
||||||
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
||||||
for label, color in enumerate(palette):
|
for label, color in enumerate(palette):
|
||||||
color_seg[seg == label, :] = color
|
color_seg[seg == label, :] = color
|
||||||
# convert to BGR
|
# convert to BGR
|
||||||
color_seg = color_seg[..., ::-1]
|
color_seg = color_seg[..., ::-1]
|
||||||
|
|
||||||
img = img * 0.5 + color_seg * 0.5
|
img = img * (1 - opacity) + color_seg * opacity
|
||||||
img = img.astype(np.uint8)
|
img = img.astype(np.uint8)
|
||||||
# if out_file specified, do not show image in window
|
# if out_file specified, do not show image in window
|
||||||
if out_file is not None:
|
if out_file is not None:
|
||||||
|
@ -55,6 +55,11 @@ def parse_args():
|
|||||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||||
default='none',
|
default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
|
parser.add_argument(
|
||||||
|
'--opacity',
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help='Opacity of painted segmentation map. In (0, 1] range.')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if 'LOCAL_RANK' not in os.environ:
|
if 'LOCAL_RANK' not in os.environ:
|
||||||
@ -123,7 +128,7 @@ def main():
|
|||||||
if not distributed:
|
if not distributed:
|
||||||
model = MMDataParallel(model, device_ids=[0])
|
model = MMDataParallel(model, device_ids=[0])
|
||||||
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
|
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
|
||||||
efficient_test)
|
efficient_test, args.opacity)
|
||||||
else:
|
else:
|
||||||
model = MMDistributedDataParallel(
|
model = MMDistributedDataParallel(
|
||||||
model.cuda(),
|
model.cuda(),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user