EasyCV/easycv/predictors/segmentation.py

198 lines
7.2 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import numpy as np
import torch
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon
from torchvision.transforms import Compose
from easycv.core.visualization.image import imshow_bboxes
from easycv.datasets.registry import PIPELINES
from easycv.file import io
from easycv.models import build_model
from easycv.predictors.builder import PREDICTORS
from easycv.predictors.interface import PredictorInterface
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.registry import build_from_cfg
@PREDICTORS.register_module()
class Mask2formerPredictor(PredictorInterface):
def __init__(self, model_path, model_config=None):
"""init model
Args:
model_path (str): Path of model path
model_config (config, optional): config string for model to init. Defaults to None.
"""
self.model_path = model_path
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = None
with io.open(self.model_path, 'rb') as infile:
checkpoint = torch.load(infile, map_location='cpu')
assert 'meta' in checkpoint and 'config' in checkpoint[
'meta'], 'meta.config is missing from checkpoint'
self.cfg = checkpoint['meta']['config']
self.classes = len(self.cfg.PALETTE)
self.class_name = self.cfg.CLASSES
# build model
self.model = build_model(self.cfg.model)
self.ckpt = load_checkpoint(
self.model, self.model_path, map_location=self.device)
self.model.to(self.device)
self.model.eval()
# build pipeline
test_pipeline = self.cfg.test_pipeline
pipeline = [build_from_cfg(p, PIPELINES) for p in test_pipeline]
self.pipeline = Compose(pipeline)
def predict(self, input_data_list, mode='panoptic'):
"""
Args:
input_data_list: a list of numpy array(in rgb order), each array is a sample
to be predicted
"""
output_list = []
for idx, img in enumerate(input_data_list):
output = {}
if not isinstance(img, np.ndarray):
img = np.asarray(img)
data_dict = {'img': img}
ori_shape = img.shape
data_dict = self.pipeline(data_dict)
img = data_dict['img']
img[0] = torch.unsqueeze(img[0], 0).to(self.device)
img_metas = [[
img_meta._data for img_meta in data_dict['img_metas']
]]
img_metas[0][0]['ori_shape'] = ori_shape
res = self.model.forward_test(img, img_metas, encode=False)
if mode == 'panoptic':
output['pan'] = res['pan_results'][0]
elif mode == 'instance':
output['segms'] = res['detection_masks'][0]
output['bboxes'] = res['detection_boxes'][0]
output['scores'] = res['detection_scores'][0]
output['labels'] = res['detection_classes'][0]
output_list.append(output)
return output_list
def show_panoptic(self, img, pan_mask):
pan_label = np.unique(pan_mask)
pan_label = pan_label[pan_label % 1000 != self.classes]
masks = np.array([pan_mask == num for num in pan_label])
palette = np.asarray(self.cfg.PALETTE)
palette = palette[pan_label % 1000]
panoptic_result = draw_masks(img, masks, palette)
return panoptic_result
def show_instance(self, img, segms, bboxes, scores, labels, score_thr=0.5):
if score_thr > 0:
inds = scores > score_thr
bboxes = bboxes[inds, :]
segms = segms[inds, ...]
labels = labels[inds]
palette = np.asarray(self.cfg.PALETTE)
palette = palette[labels]
instance_result = draw_masks(img, segms, palette)
class_name = np.array(self.class_name)
instance_result = imshow_bboxes(
instance_result, bboxes, class_name[labels], show=False)
return instance_result
def _get_bias_color(base, max_dist=30):
"""Get different colors for each masks.
Get different colors for each masks by adding a bias
color to the base category color.
Args:
base (ndarray): The base category color with the shape
of (3, ).
max_dist (int): The max distance of bias. Default: 30.
Returns:
ndarray: The new color for a mask with the shape of (3, ).
"""
new_color = base + np.random.randint(
low=-max_dist, high=max_dist + 1, size=3)
return np.clip(new_color, 0, 255, new_color)
def draw_masks(img, masks, color=None, with_edge=True, alpha=0.8):
"""Draw masks on the image and their edges on the axes.
Args:
ax (matplotlib.Axes): The input axes.
img (ndarray): The image with the shape of (3, h, w).
masks (ndarray): The masks with the shape of (n, h, w).
color (ndarray): The colors for each masks with the shape
of (n, 3).
with_edge (bool): Whether to draw edges. Default: True.
alpha (float): Transparency of bounding boxes. Default: 0.8.
Returns:
matplotlib.Axes: The result axes.
ndarray: The result image.
"""
taken_colors = set([0, 0, 0])
if color is None:
random_colors = np.random.randint(0, 255, (masks.size(0), 3))
color = [tuple(c) for c in random_colors]
color = np.array(color, dtype=np.uint8)
polygons = []
for i, mask in enumerate(masks):
if with_edge:
contours, _ = bitmap_to_polygon(mask)
polygons += [Polygon(c) for c in contours]
color_mask = color[i]
while tuple(color_mask) in taken_colors:
color_mask = _get_bias_color(color_mask)
taken_colors.add(tuple(color_mask))
mask = mask.astype(bool)
img[mask] = img[mask] * (1 - alpha) + color_mask * alpha
p = PatchCollection(
polygons, facecolor='none', edgecolors='w', linewidths=1, alpha=0.8)
return img
def bitmap_to_polygon(bitmap):
"""Convert masks from the form of bitmaps to polygons.
Args:
bitmap (ndarray): masks in bitmap representation.
Return:
list[ndarray]: the converted mask in polygon representation.
bool: whether the mask has holes.
"""
bitmap = np.ascontiguousarray(bitmap).astype(np.uint8)
# cv2.RETR_CCOMP: retrieves all of the contours and organizes them
# into a two-level hierarchy. At the top level, there are external
# boundaries of the components. At the second level, there are
# boundaries of the holes. If there is another contour inside a hole
# of a connected component, it is still put at the top level.
# cv2.CHAIN_APPROX_NONE: stores absolutely all the contour points.
outs = cv2.findContours(bitmap, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
contours = outs[-2]
hierarchy = outs[-1]
if hierarchy is None:
return [], False
# hierarchy[i]: 4 elements, for the indexes of next, previous,
# parent, or nested contours. If there is no corresponding contour,
# it will be -1.
with_hole = (hierarchy.reshape(-1, 4)[:, 3] >= 0).any()
contours = [c.reshape(-1, 2) for c in contours]
return contours, with_hole