2022-08-08 18:17:01 +08:00
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import cv2
|
2022-08-18 10:40:18 +08:00
|
|
|
import mmcv
|
2022-08-08 18:17:01 +08:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
from matplotlib.collections import PatchCollection
|
|
|
|
from matplotlib.patches import Polygon
|
|
|
|
|
|
|
|
from easycv.core.visualization.image import imshow_bboxes
|
|
|
|
from easycv.predictors.builder import PREDICTORS
|
2023-02-01 12:14:44 +08:00
|
|
|
from .base import OutputProcessor, PredictorV2
|
2022-08-23 19:52:52 +08:00
|
|
|
|
|
|
|
|
|
|
|
@PREDICTORS.register_module()
|
|
|
|
class SegmentationPredictor(PredictorV2):
|
2022-09-20 10:04:42 +08:00
|
|
|
"""Predictor for Segmentation.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model_path (str): Path of model path.
|
|
|
|
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
|
|
|
|
batch_size (int): batch size for forward.
|
|
|
|
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
|
|
|
|
save_results (bool): Whether to save predict results.
|
|
|
|
save_path (str): File path for saving results, only valid when `save_results` is True.
|
|
|
|
pipelines (list[dict]): Data pipeline configs.
|
2023-02-01 12:14:44 +08:00
|
|
|
input_processor_threads (int): Number of processes to process inputs.
|
|
|
|
mode (str): The image mode into the model.
|
2022-09-20 10:04:42 +08:00
|
|
|
"""
|
2022-08-23 19:52:52 +08:00
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
model_path,
|
2023-02-16 14:00:59 +08:00
|
|
|
config_file=None,
|
2022-08-23 19:52:52 +08:00
|
|
|
batch_size=1,
|
|
|
|
device=None,
|
|
|
|
save_results=False,
|
2022-09-20 10:04:42 +08:00
|
|
|
save_path=None,
|
|
|
|
pipelines=None,
|
2023-02-01 12:14:44 +08:00
|
|
|
input_processor_threads=8,
|
|
|
|
mode='BGR',
|
2022-09-20 10:04:42 +08:00
|
|
|
*args,
|
|
|
|
**kwargs):
|
2022-08-23 19:52:52 +08:00
|
|
|
|
|
|
|
super(SegmentationPredictor, self).__init__(
|
|
|
|
model_path,
|
|
|
|
config_file,
|
|
|
|
batch_size=batch_size,
|
|
|
|
device=device,
|
|
|
|
save_results=save_results,
|
2022-09-20 10:04:42 +08:00
|
|
|
save_path=save_path,
|
|
|
|
pipelines=pipelines,
|
2023-02-01 12:14:44 +08:00
|
|
|
input_processor_threads=input_processor_threads,
|
|
|
|
mode=mode,
|
2022-09-20 10:04:42 +08:00
|
|
|
*args,
|
|
|
|
**kwargs)
|
2022-08-23 19:52:52 +08:00
|
|
|
|
|
|
|
self.CLASSES = self.cfg.CLASSES
|
2023-02-16 14:00:59 +08:00
|
|
|
self.PALETTE = self.cfg.get('PALETTE', None)
|
2022-08-23 19:52:52 +08:00
|
|
|
|
|
|
|
def show_result(self,
|
|
|
|
img,
|
|
|
|
result,
|
|
|
|
palette=None,
|
|
|
|
win_name='',
|
|
|
|
show=False,
|
|
|
|
wait_time=0,
|
|
|
|
out_file=None,
|
|
|
|
opacity=0.5):
|
|
|
|
"""Draw `result` over `img`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
img (str or Tensor): The image to be displayed.
|
|
|
|
result (Tensor): The semantic segmentation results to draw over
|
|
|
|
`img`.
|
|
|
|
palette (list[list[int]]] | np.ndarray | None): The palette of
|
|
|
|
segmentation map. If None is given, random palette will be
|
|
|
|
generated. Default: None
|
|
|
|
win_name (str): The window name.
|
|
|
|
wait_time (int): Value of waitKey param.
|
|
|
|
Default: 0.
|
|
|
|
show (bool): Whether to show the image.
|
|
|
|
Default: False.
|
|
|
|
out_file (str or None): The filename to write the image.
|
|
|
|
Default: None.
|
|
|
|
opacity(float): Opacity of painted segmentation map.
|
|
|
|
Default 0.5.
|
|
|
|
Must be in (0, 1] range.
|
|
|
|
Returns:
|
|
|
|
img (Tensor): Only if not `show` or `out_file`
|
|
|
|
"""
|
|
|
|
|
|
|
|
img = mmcv.imread(img)
|
|
|
|
img = img.copy()
|
2023-02-16 14:00:59 +08:00
|
|
|
# seg = result[0]
|
|
|
|
seg = result
|
2022-08-23 19:52:52 +08:00
|
|
|
if palette is None:
|
|
|
|
if self.PALETTE is None:
|
|
|
|
# Get random state before set seed,
|
|
|
|
# and restore random state later.
|
|
|
|
# It will prevent loss of randomness, as the palette
|
|
|
|
# may be different in each iteration if not specified.
|
|
|
|
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
|
|
|
state = np.random.get_state()
|
|
|
|
np.random.seed(42)
|
|
|
|
# random palette
|
|
|
|
palette = np.random.randint(
|
|
|
|
0, 255, size=(len(self.CLASSES), 3))
|
|
|
|
np.random.set_state(state)
|
|
|
|
else:
|
|
|
|
palette = self.PALETTE
|
|
|
|
palette = np.array(palette)
|
|
|
|
assert palette.shape[0] == len(self.CLASSES)
|
|
|
|
assert palette.shape[1] == 3
|
|
|
|
assert len(palette.shape) == 2
|
|
|
|
assert 0 < opacity <= 1.0
|
|
|
|
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
|
|
|
for label, color in enumerate(palette):
|
|
|
|
color_seg[seg == label, :] = color
|
|
|
|
# convert to BGR
|
|
|
|
color_seg = color_seg[..., ::-1]
|
|
|
|
|
|
|
|
img = img * (1 - opacity) + color_seg * opacity
|
|
|
|
img = img.astype(np.uint8)
|
|
|
|
# if out_file specified, do not show image in window
|
|
|
|
if out_file is not None:
|
|
|
|
show = False
|
|
|
|
|
|
|
|
if show:
|
|
|
|
mmcv.imshow(img, win_name, wait_time)
|
|
|
|
if out_file is not None:
|
|
|
|
mmcv.imwrite(img, out_file)
|
|
|
|
|
|
|
|
if not (show or out_file):
|
|
|
|
return img
|
2022-08-08 18:17:01 +08:00
|
|
|
|
|
|
|
|
2023-02-01 12:14:44 +08:00
|
|
|
class Mask2formerOutputProcessor(OutputProcessor):
|
|
|
|
"""Process the output of Mask2former.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
task_mode (str): Support task in ['panoptic', 'instance', 'semantic'].
|
|
|
|
classes (list): Classes name list.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, task_mode, classes):
|
|
|
|
super(Mask2formerOutputProcessor, self).__init__()
|
|
|
|
self.task_mode = task_mode
|
|
|
|
self.classes = classes
|
|
|
|
|
|
|
|
def process_single(self, inputs):
|
|
|
|
output = {}
|
|
|
|
if self.task_mode == 'panoptic':
|
|
|
|
pan_results = inputs['pan_results']
|
|
|
|
# keep objects ahead
|
|
|
|
ids = np.unique(pan_results)[::-1]
|
|
|
|
legal_indices = ids != len(self.classes) # for VOID label
|
|
|
|
ids = ids[legal_indices]
|
|
|
|
labels = np.array([id % 1000 for id in ids], dtype=np.int64)
|
|
|
|
segms = (pan_results[None] == ids[:, None, None])
|
|
|
|
masks = [it.astype(np.int) for it in segms]
|
|
|
|
labels_txt = np.array(self.classes)[labels].tolist()
|
|
|
|
|
|
|
|
output['masks'] = masks
|
|
|
|
output['labels'] = labels_txt
|
|
|
|
output['labels_ids'] = labels
|
|
|
|
elif self.task_mode == 'instance':
|
|
|
|
output['segms'] = inputs['detection_masks']
|
|
|
|
output['bboxes'] = inputs['detection_boxes']
|
|
|
|
output['scores'] = inputs['detection_scores']
|
|
|
|
output['labels'] = inputs['detection_classes']
|
|
|
|
elif self.task_mode == 'semantic':
|
|
|
|
output['seg_pred'] = inputs['seg_pred']
|
|
|
|
else:
|
|
|
|
raise ValueError(f'Not support model {self.task_mode}')
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
2022-08-08 18:17:01 +08:00
|
|
|
@PREDICTORS.register_module()
|
2022-09-20 10:04:42 +08:00
|
|
|
class Mask2formerPredictor(SegmentationPredictor):
|
|
|
|
"""Predictor for Mask2former.
|
2022-08-08 18:17:01 +08:00
|
|
|
|
2022-09-20 10:04:42 +08:00
|
|
|
Args:
|
|
|
|
model_path (str): Path of model path.
|
|
|
|
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
|
|
|
|
batch_size (int): batch size for forward.
|
|
|
|
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
|
|
|
|
save_results (bool): Whether to save predict results.
|
|
|
|
save_path (str): File path for saving results, only valid when `save_results` is True.
|
|
|
|
pipelines (list[dict]): Data pipeline configs.
|
2023-02-01 12:14:44 +08:00
|
|
|
input_processor_threads (int): Number of processes to process inputs.
|
|
|
|
mode (str): The image mode into the model.
|
2022-09-20 10:04:42 +08:00
|
|
|
"""
|
2022-08-08 18:17:01 +08:00
|
|
|
|
2022-09-20 10:04:42 +08:00
|
|
|
def __init__(self,
|
|
|
|
model_path,
|
|
|
|
config_file=None,
|
|
|
|
batch_size=1,
|
|
|
|
device=None,
|
|
|
|
save_results=False,
|
|
|
|
save_path=None,
|
|
|
|
pipelines=None,
|
|
|
|
task_mode='panoptic',
|
2023-02-01 12:14:44 +08:00
|
|
|
input_processor_threads=8,
|
|
|
|
mode='BGR',
|
2022-09-20 10:04:42 +08:00
|
|
|
*args,
|
|
|
|
**kwargs):
|
|
|
|
super(Mask2formerPredictor, self).__init__(
|
|
|
|
model_path,
|
|
|
|
config_file,
|
|
|
|
batch_size=batch_size,
|
|
|
|
device=device,
|
|
|
|
save_results=save_results,
|
|
|
|
save_path=save_path,
|
|
|
|
pipelines=pipelines,
|
2023-02-01 12:14:44 +08:00
|
|
|
input_processor_threads=input_processor_threads,
|
|
|
|
mode=mode,
|
2022-09-20 10:04:42 +08:00
|
|
|
*args,
|
|
|
|
**kwargs)
|
|
|
|
self.task_mode = task_mode
|
2022-10-26 17:23:21 +08:00
|
|
|
self.class_name = self.cfg.CLASSES
|
|
|
|
self.PALETTE = self.cfg.PALETTE
|
2022-09-20 10:04:42 +08:00
|
|
|
|
2023-02-01 12:14:44 +08:00
|
|
|
def get_output_processor(self):
|
|
|
|
return Mask2formerOutputProcessor(self.task_mode, self.CLASSES)
|
|
|
|
|
|
|
|
def model_forward(self, inputs):
|
2022-09-20 10:04:42 +08:00
|
|
|
"""Model forward.
|
2022-08-08 18:17:01 +08:00
|
|
|
"""
|
2022-09-20 10:04:42 +08:00
|
|
|
with torch.no_grad():
|
2022-09-30 15:31:58 +08:00
|
|
|
outputs = self.model.forward(**inputs, mode='test', encode=False)
|
2022-09-20 10:04:42 +08:00
|
|
|
return outputs
|
|
|
|
|
2022-09-30 15:31:58 +08:00
|
|
|
def show_panoptic(self, img, masks, labels):
|
2022-08-08 18:17:01 +08:00
|
|
|
palette = np.asarray(self.cfg.PALETTE)
|
2022-09-30 15:31:58 +08:00
|
|
|
palette = palette[labels % 1000]
|
2022-08-08 18:17:01 +08:00
|
|
|
panoptic_result = draw_masks(img, masks, palette)
|
|
|
|
return panoptic_result
|
|
|
|
|
|
|
|
def show_instance(self, img, segms, bboxes, scores, labels, score_thr=0.5):
|
|
|
|
if score_thr > 0:
|
|
|
|
inds = scores > score_thr
|
|
|
|
bboxes = bboxes[inds, :]
|
|
|
|
segms = segms[inds, ...]
|
|
|
|
labels = labels[inds]
|
2022-09-30 15:31:58 +08:00
|
|
|
palette = np.asarray(self.PALETTE)
|
2022-08-08 18:17:01 +08:00
|
|
|
palette = palette[labels]
|
2022-09-30 15:31:58 +08:00
|
|
|
|
2022-08-08 18:17:01 +08:00
|
|
|
instance_result = draw_masks(img, segms, palette)
|
2022-09-30 15:31:58 +08:00
|
|
|
class_name = np.array(self.CLASSES)
|
2022-08-08 18:17:01 +08:00
|
|
|
instance_result = imshow_bboxes(
|
|
|
|
instance_result, bboxes, class_name[labels], show=False)
|
|
|
|
return instance_result
|
|
|
|
|
2022-10-26 17:23:21 +08:00
|
|
|
def show_semantic(self, img, seg_pred, alpha=0.5, palette=None):
|
|
|
|
|
|
|
|
if palette is None:
|
|
|
|
if self.PALETTE is None:
|
|
|
|
# Get random state before set seed,
|
|
|
|
# and restore random state later.
|
|
|
|
# It will prevent loss of randomness, as the palette
|
|
|
|
# may be different in each iteration if not specified.
|
|
|
|
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
|
|
|
state = np.random.get_state()
|
|
|
|
np.random.seed(42)
|
|
|
|
# random palette
|
|
|
|
palette = np.random.randint(
|
|
|
|
0, 255, size=(len(self.CLASSES), 3))
|
|
|
|
np.random.set_state(state)
|
|
|
|
else:
|
|
|
|
palette = self.PALETTE
|
|
|
|
palette = np.array(palette)
|
|
|
|
assert palette.shape[0] == len(self.CLASSES)
|
|
|
|
assert palette.shape[1] == 3
|
|
|
|
assert len(palette.shape) == 2
|
|
|
|
assert 0 < alpha <= 1.0
|
|
|
|
color_seg = np.zeros((seg_pred.shape[0], seg_pred.shape[1], 3),
|
|
|
|
dtype=np.uint8)
|
|
|
|
for label, color in enumerate(palette):
|
|
|
|
color_seg[seg_pred == label, :] = color
|
|
|
|
# convert to BGR
|
|
|
|
color_seg = color_seg[..., ::-1]
|
|
|
|
|
|
|
|
img = img * (1 - alpha) + color_seg * alpha
|
|
|
|
img = img.astype(np.uint8)
|
|
|
|
|
|
|
|
return img
|
|
|
|
|
2022-08-08 18:17:01 +08:00
|
|
|
|
|
|
|
def _get_bias_color(base, max_dist=30):
|
|
|
|
"""Get different colors for each masks.
|
|
|
|
|
|
|
|
Get different colors for each masks by adding a bias
|
|
|
|
color to the base category color.
|
|
|
|
Args:
|
|
|
|
base (ndarray): The base category color with the shape
|
|
|
|
of (3, ).
|
|
|
|
max_dist (int): The max distance of bias. Default: 30.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
ndarray: The new color for a mask with the shape of (3, ).
|
|
|
|
"""
|
|
|
|
new_color = base + np.random.randint(
|
|
|
|
low=-max_dist, high=max_dist + 1, size=3)
|
|
|
|
return np.clip(new_color, 0, 255, new_color)
|
|
|
|
|
|
|
|
|
|
|
|
def draw_masks(img, masks, color=None, with_edge=True, alpha=0.8):
|
|
|
|
"""Draw masks on the image and their edges on the axes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
ax (matplotlib.Axes): The input axes.
|
|
|
|
img (ndarray): The image with the shape of (3, h, w).
|
|
|
|
masks (ndarray): The masks with the shape of (n, h, w).
|
|
|
|
color (ndarray): The colors for each masks with the shape
|
|
|
|
of (n, 3).
|
|
|
|
with_edge (bool): Whether to draw edges. Default: True.
|
|
|
|
alpha (float): Transparency of bounding boxes. Default: 0.8.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
matplotlib.Axes: The result axes.
|
|
|
|
ndarray: The result image.
|
|
|
|
"""
|
|
|
|
taken_colors = set([0, 0, 0])
|
|
|
|
if color is None:
|
|
|
|
random_colors = np.random.randint(0, 255, (masks.size(0), 3))
|
|
|
|
color = [tuple(c) for c in random_colors]
|
|
|
|
color = np.array(color, dtype=np.uint8)
|
|
|
|
polygons = []
|
|
|
|
for i, mask in enumerate(masks):
|
|
|
|
if with_edge:
|
|
|
|
contours, _ = bitmap_to_polygon(mask)
|
|
|
|
polygons += [Polygon(c) for c in contours]
|
|
|
|
|
|
|
|
color_mask = color[i]
|
|
|
|
while tuple(color_mask) in taken_colors:
|
|
|
|
color_mask = _get_bias_color(color_mask)
|
|
|
|
taken_colors.add(tuple(color_mask))
|
|
|
|
|
|
|
|
mask = mask.astype(bool)
|
|
|
|
img[mask] = img[mask] * (1 - alpha) + color_mask * alpha
|
|
|
|
|
|
|
|
p = PatchCollection(
|
|
|
|
polygons, facecolor='none', edgecolors='w', linewidths=1, alpha=0.8)
|
|
|
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
|
|
def bitmap_to_polygon(bitmap):
|
|
|
|
"""Convert masks from the form of bitmaps to polygons.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
bitmap (ndarray): masks in bitmap representation.
|
|
|
|
|
|
|
|
Return:
|
|
|
|
list[ndarray]: the converted mask in polygon representation.
|
|
|
|
bool: whether the mask has holes.
|
|
|
|
"""
|
|
|
|
bitmap = np.ascontiguousarray(bitmap).astype(np.uint8)
|
|
|
|
# cv2.RETR_CCOMP: retrieves all of the contours and organizes them
|
|
|
|
# into a two-level hierarchy. At the top level, there are external
|
|
|
|
# boundaries of the components. At the second level, there are
|
|
|
|
# boundaries of the holes. If there is another contour inside a hole
|
|
|
|
# of a connected component, it is still put at the top level.
|
|
|
|
# cv2.CHAIN_APPROX_NONE: stores absolutely all the contour points.
|
|
|
|
outs = cv2.findContours(bitmap, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
|
|
|
|
contours = outs[-2]
|
|
|
|
hierarchy = outs[-1]
|
|
|
|
if hierarchy is None:
|
|
|
|
return [], False
|
|
|
|
# hierarchy[i]: 4 elements, for the indexes of next, previous,
|
|
|
|
# parent, or nested contours. If there is no corresponding contour,
|
|
|
|
# it will be -1.
|
|
|
|
with_hole = (hierarchy.reshape(-1, 4)[:, 3] >= 0).any()
|
|
|
|
contours = [c.reshape(-1, 2) for c in contours]
|
|
|
|
return contours, with_hole
|