mirror of https://github.com/alibaba/EasyCV.git
198 lines
7.2 KiB
Python
198 lines
7.2 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from matplotlib.collections import PatchCollection
|
|
from matplotlib.patches import Polygon
|
|
from torchvision.transforms import Compose
|
|
|
|
from easycv.core.visualization.image import imshow_bboxes
|
|
from easycv.datasets.registry import PIPELINES
|
|
from easycv.file import io
|
|
from easycv.models import build_model
|
|
from easycv.predictors.builder import PREDICTORS
|
|
from easycv.predictors.interface import PredictorInterface
|
|
from easycv.utils.checkpoint import load_checkpoint
|
|
from easycv.utils.registry import build_from_cfg
|
|
|
|
|
|
@PREDICTORS.register_module()
|
|
class Mask2formerPredictor(PredictorInterface):
|
|
|
|
def __init__(self, model_path, model_config=None):
|
|
"""init model
|
|
|
|
Args:
|
|
model_path (str): Path of model path
|
|
model_config (config, optional): config string for model to init. Defaults to None.
|
|
"""
|
|
self.model_path = model_path
|
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
self.model = None
|
|
with io.open(self.model_path, 'rb') as infile:
|
|
checkpoint = torch.load(infile, map_location='cpu')
|
|
|
|
assert 'meta' in checkpoint and 'config' in checkpoint[
|
|
'meta'], 'meta.config is missing from checkpoint'
|
|
|
|
self.cfg = checkpoint['meta']['config']
|
|
self.classes = len(self.cfg.PALETTE)
|
|
self.class_name = self.cfg.CLASSES
|
|
# build model
|
|
self.model = build_model(self.cfg.model)
|
|
|
|
self.ckpt = load_checkpoint(
|
|
self.model, self.model_path, map_location=self.device)
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
|
|
# build pipeline
|
|
test_pipeline = self.cfg.test_pipeline
|
|
pipeline = [build_from_cfg(p, PIPELINES) for p in test_pipeline]
|
|
self.pipeline = Compose(pipeline)
|
|
|
|
def predict(self, input_data_list, mode='panoptic'):
|
|
"""
|
|
Args:
|
|
input_data_list: a list of numpy array(in rgb order), each array is a sample
|
|
to be predicted
|
|
"""
|
|
output_list = []
|
|
for idx, img in enumerate(input_data_list):
|
|
output = {}
|
|
if not isinstance(img, np.ndarray):
|
|
img = np.asarray(img)
|
|
data_dict = {'img': img}
|
|
ori_shape = img.shape
|
|
data_dict = self.pipeline(data_dict)
|
|
img = data_dict['img']
|
|
img[0] = torch.unsqueeze(img[0], 0).to(self.device)
|
|
img_metas = [[
|
|
img_meta._data for img_meta in data_dict['img_metas']
|
|
]]
|
|
img_metas[0][0]['ori_shape'] = ori_shape
|
|
res = self.model.forward_test(img, img_metas, encode=False)
|
|
if mode == 'panoptic':
|
|
output['pan'] = res['pan_results'][0]
|
|
elif mode == 'instance':
|
|
output['segms'] = res['detection_masks'][0]
|
|
output['bboxes'] = res['detection_boxes'][0]
|
|
output['scores'] = res['detection_scores'][0]
|
|
output['labels'] = res['detection_classes'][0]
|
|
output_list.append(output)
|
|
return output_list
|
|
|
|
def show_panoptic(self, img, pan_mask):
|
|
pan_label = np.unique(pan_mask)
|
|
pan_label = pan_label[pan_label % 1000 != self.classes]
|
|
masks = np.array([pan_mask == num for num in pan_label])
|
|
|
|
palette = np.asarray(self.cfg.PALETTE)
|
|
palette = palette[pan_label % 1000]
|
|
panoptic_result = draw_masks(img, masks, palette)
|
|
return panoptic_result
|
|
|
|
def show_instance(self, img, segms, bboxes, scores, labels, score_thr=0.5):
|
|
if score_thr > 0:
|
|
inds = scores > score_thr
|
|
bboxes = bboxes[inds, :]
|
|
segms = segms[inds, ...]
|
|
labels = labels[inds]
|
|
palette = np.asarray(self.cfg.PALETTE)
|
|
palette = palette[labels]
|
|
instance_result = draw_masks(img, segms, palette)
|
|
class_name = np.array(self.class_name)
|
|
instance_result = imshow_bboxes(
|
|
instance_result, bboxes, class_name[labels], show=False)
|
|
return instance_result
|
|
|
|
|
|
def _get_bias_color(base, max_dist=30):
|
|
"""Get different colors for each masks.
|
|
|
|
Get different colors for each masks by adding a bias
|
|
color to the base category color.
|
|
Args:
|
|
base (ndarray): The base category color with the shape
|
|
of (3, ).
|
|
max_dist (int): The max distance of bias. Default: 30.
|
|
|
|
Returns:
|
|
ndarray: The new color for a mask with the shape of (3, ).
|
|
"""
|
|
new_color = base + np.random.randint(
|
|
low=-max_dist, high=max_dist + 1, size=3)
|
|
return np.clip(new_color, 0, 255, new_color)
|
|
|
|
|
|
def draw_masks(img, masks, color=None, with_edge=True, alpha=0.8):
|
|
"""Draw masks on the image and their edges on the axes.
|
|
|
|
Args:
|
|
ax (matplotlib.Axes): The input axes.
|
|
img (ndarray): The image with the shape of (3, h, w).
|
|
masks (ndarray): The masks with the shape of (n, h, w).
|
|
color (ndarray): The colors for each masks with the shape
|
|
of (n, 3).
|
|
with_edge (bool): Whether to draw edges. Default: True.
|
|
alpha (float): Transparency of bounding boxes. Default: 0.8.
|
|
|
|
Returns:
|
|
matplotlib.Axes: The result axes.
|
|
ndarray: The result image.
|
|
"""
|
|
taken_colors = set([0, 0, 0])
|
|
if color is None:
|
|
random_colors = np.random.randint(0, 255, (masks.size(0), 3))
|
|
color = [tuple(c) for c in random_colors]
|
|
color = np.array(color, dtype=np.uint8)
|
|
polygons = []
|
|
for i, mask in enumerate(masks):
|
|
if with_edge:
|
|
contours, _ = bitmap_to_polygon(mask)
|
|
polygons += [Polygon(c) for c in contours]
|
|
|
|
color_mask = color[i]
|
|
while tuple(color_mask) in taken_colors:
|
|
color_mask = _get_bias_color(color_mask)
|
|
taken_colors.add(tuple(color_mask))
|
|
|
|
mask = mask.astype(bool)
|
|
img[mask] = img[mask] * (1 - alpha) + color_mask * alpha
|
|
|
|
p = PatchCollection(
|
|
polygons, facecolor='none', edgecolors='w', linewidths=1, alpha=0.8)
|
|
|
|
return img
|
|
|
|
|
|
def bitmap_to_polygon(bitmap):
|
|
"""Convert masks from the form of bitmaps to polygons.
|
|
|
|
Args:
|
|
bitmap (ndarray): masks in bitmap representation.
|
|
|
|
Return:
|
|
list[ndarray]: the converted mask in polygon representation.
|
|
bool: whether the mask has holes.
|
|
"""
|
|
bitmap = np.ascontiguousarray(bitmap).astype(np.uint8)
|
|
# cv2.RETR_CCOMP: retrieves all of the contours and organizes them
|
|
# into a two-level hierarchy. At the top level, there are external
|
|
# boundaries of the components. At the second level, there are
|
|
# boundaries of the holes. If there is another contour inside a hole
|
|
# of a connected component, it is still put at the top level.
|
|
# cv2.CHAIN_APPROX_NONE: stores absolutely all the contour points.
|
|
outs = cv2.findContours(bitmap, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
|
|
contours = outs[-2]
|
|
hierarchy = outs[-1]
|
|
if hierarchy is None:
|
|
return [], False
|
|
# hierarchy[i]: 4 elements, for the indexes of next, previous,
|
|
# parent, or nested contours. If there is no corresponding contour,
|
|
# it will be -1.
|
|
with_hole = (hierarchy.reshape(-1, 4)[:, 3] >= 0).any()
|
|
contours = [c.reshape(-1, 2) for c in contours]
|
|
return contours, with_hole
|