EasyCV/easycv/predictors/segmentation.py

341 lines
12 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import mmcv
import numpy as np
import torch
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon
from torchvision.transforms import Compose
from easycv.core.visualization.image import imshow_bboxes
from easycv.datasets.registry import PIPELINES
from easycv.file import io
from easycv.models import build_model
from easycv.predictors.builder import PREDICTORS
from easycv.predictors.interface import PredictorInterface
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.registry import build_from_cfg
@PREDICTORS.register_module()
class Mask2formerPredictor(PredictorInterface):
def __init__(self, model_path, model_config=None):
"""init model
Args:
model_path (str): Path of model path
model_config (config, optional): config string for model to init. Defaults to None.
"""
self.model_path = model_path
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = None
with io.open(self.model_path, 'rb') as infile:
checkpoint = torch.load(infile, map_location='cpu')
assert 'meta' in checkpoint and 'config' in checkpoint[
'meta'], 'meta.config is missing from checkpoint'
self.cfg = checkpoint['meta']['config']
self.classes = len(self.cfg.PALETTE)
self.class_name = self.cfg.CLASSES
# build model
self.model = build_model(self.cfg.model)
self.ckpt = load_checkpoint(
self.model, self.model_path, map_location=self.device)
self.model.to(self.device)
self.model.eval()
# build pipeline
test_pipeline = self.cfg.test_pipeline
pipeline = [build_from_cfg(p, PIPELINES) for p in test_pipeline]
self.pipeline = Compose(pipeline)
def predict(self, input_data_list, mode='panoptic'):
"""
Args:
input_data_list: a list of numpy array(in rgb order), each array is a sample
to be predicted
"""
output_list = []
for idx, img in enumerate(input_data_list):
output = {}
if not isinstance(img, np.ndarray):
img = np.asarray(img)
data_dict = {'img': img}
ori_shape = img.shape
data_dict = self.pipeline(data_dict)
img = data_dict['img']
img[0] = torch.unsqueeze(img[0], 0).to(self.device)
img_metas = [[
img_meta._data for img_meta in data_dict['img_metas']
]]
img_metas[0][0]['ori_shape'] = ori_shape
res = self.model.forward_test(img, img_metas, encode=False)
if mode == 'panoptic':
output['pan'] = res['pan_results'][0]
elif mode == 'instance':
output['segms'] = res['detection_masks'][0]
output['bboxes'] = res['detection_boxes'][0]
output['scores'] = res['detection_scores'][0]
output['labels'] = res['detection_classes'][0]
output_list.append(output)
return output_list
def show_panoptic(self, img, pan_mask):
pan_label = np.unique(pan_mask)
pan_label = pan_label[pan_label % 1000 != self.classes]
masks = np.array([pan_mask == num for num in pan_label])
palette = np.asarray(self.cfg.PALETTE)
palette = palette[pan_label % 1000]
panoptic_result = draw_masks(img, masks, palette)
return panoptic_result
def show_instance(self, img, segms, bboxes, scores, labels, score_thr=0.5):
if score_thr > 0:
inds = scores > score_thr
bboxes = bboxes[inds, :]
segms = segms[inds, ...]
labels = labels[inds]
palette = np.asarray(self.cfg.PALETTE)
palette = palette[labels]
instance_result = draw_masks(img, segms, palette)
class_name = np.array(self.class_name)
instance_result = imshow_bboxes(
instance_result, bboxes, class_name[labels], show=False)
return instance_result
@PREDICTORS.register_module()
class SegFormerPredictor(PredictorInterface):
def __init__(self, model_path, model_config):
"""init model
Args:
model_path (str): Path of model path
model_config (config): config string for model to init. Defaults to None.
"""
self.model_path = model_path
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = None
with io.open(self.model_path, 'rb') as infile:
checkpoint = torch.load(infile, map_location='cpu')
self.cfg = mmcv_config_fromfile(model_config)
self.CLASSES = self.cfg.CLASSES
self.PALETTE = self.cfg.PALETTE
# build model
self.model = build_model(self.cfg.model)
self.ckpt = load_checkpoint(
self.model, self.model_path, map_location=self.device)
self.model.to(self.device)
self.model.eval()
# build pipeline
test_pipeline = self.cfg.test_pipeline
pipeline = [build_from_cfg(p, PIPELINES) for p in test_pipeline]
self.pipeline = Compose(pipeline)
def predict(self, input_data_list):
"""
using session run predict a number of samples using batch_size
Args:
input_data_list: a list of numpy array(in rgb order), each array is a sample
to be predicted
use a fixed number if you do not want to adjust batch_size in runtime
"""
output_list = []
for idx, img in enumerate(input_data_list):
if type(img) is not np.ndarray:
img = np.asarray(img)
ori_img_shape = img.shape[:2]
data_dict = {'img': img}
data_dict['ori_shape'] = ori_img_shape
data_dict = self.pipeline(data_dict)
img = data_dict['img']
img = torch.unsqueeze(img[0], 0).to(self.device)
data_dict.pop('img')
with torch.no_grad():
out = self.model([img],
mode='test',
img_metas=[[data_dict['img_metas'][0]._data]])
output_list.append(out)
return output_list
def show_result(self,
img,
result,
palette=None,
win_name='',
show=False,
wait_time=0,
out_file=None,
opacity=0.5):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
result (Tensor): The semantic segmentation results to draw over
`img`.
palette (list[list[int]]] | np.ndarray | None): The palette of
segmentation map. If None is given, random palette will be
generated. Default: None
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Returns:
img (Tensor): Only if not `show` or `out_file`
"""
img = mmcv.imread(img)
img = img.copy()
seg = result[0]
if palette is None:
if self.PALETTE is None:
# Get random state before set seed,
# and restore random state later.
# It will prevent loss of randomness, as the palette
# may be different in each iteration if not specified.
# See: https://github.com/open-mmlab/mmdetection/issues/5844
state = np.random.get_state()
np.random.seed(42)
# random palette
palette = np.random.randint(
0, 255, size=(len(self.CLASSES), 3))
np.random.set_state(state)
else:
palette = self.PALETTE
palette = np.array(palette)
assert palette.shape[0] == len(self.CLASSES)
assert palette.shape[1] == 3
assert len(palette.shape) == 2
assert 0 < opacity <= 1.0
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# convert to BGR
color_seg = color_seg[..., ::-1]
img = img * (1 - opacity) + color_seg * opacity
img = img.astype(np.uint8)
# if out_file specified, do not show image in window
if out_file is not None:
show = False
if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)
if not (show or out_file):
return img
def _get_bias_color(base, max_dist=30):
"""Get different colors for each masks.
Get different colors for each masks by adding a bias
color to the base category color.
Args:
base (ndarray): The base category color with the shape
of (3, ).
max_dist (int): The max distance of bias. Default: 30.
Returns:
ndarray: The new color for a mask with the shape of (3, ).
"""
new_color = base + np.random.randint(
low=-max_dist, high=max_dist + 1, size=3)
return np.clip(new_color, 0, 255, new_color)
def draw_masks(img, masks, color=None, with_edge=True, alpha=0.8):
"""Draw masks on the image and their edges on the axes.
Args:
ax (matplotlib.Axes): The input axes.
img (ndarray): The image with the shape of (3, h, w).
masks (ndarray): The masks with the shape of (n, h, w).
color (ndarray): The colors for each masks with the shape
of (n, 3).
with_edge (bool): Whether to draw edges. Default: True.
alpha (float): Transparency of bounding boxes. Default: 0.8.
Returns:
matplotlib.Axes: The result axes.
ndarray: The result image.
"""
taken_colors = set([0, 0, 0])
if color is None:
random_colors = np.random.randint(0, 255, (masks.size(0), 3))
color = [tuple(c) for c in random_colors]
color = np.array(color, dtype=np.uint8)
polygons = []
for i, mask in enumerate(masks):
if with_edge:
contours, _ = bitmap_to_polygon(mask)
polygons += [Polygon(c) for c in contours]
color_mask = color[i]
while tuple(color_mask) in taken_colors:
color_mask = _get_bias_color(color_mask)
taken_colors.add(tuple(color_mask))
mask = mask.astype(bool)
img[mask] = img[mask] * (1 - alpha) + color_mask * alpha
p = PatchCollection(
polygons, facecolor='none', edgecolors='w', linewidths=1, alpha=0.8)
return img
def bitmap_to_polygon(bitmap):
"""Convert masks from the form of bitmaps to polygons.
Args:
bitmap (ndarray): masks in bitmap representation.
Return:
list[ndarray]: the converted mask in polygon representation.
bool: whether the mask has holes.
"""
bitmap = np.ascontiguousarray(bitmap).astype(np.uint8)
# cv2.RETR_CCOMP: retrieves all of the contours and organizes them
# into a two-level hierarchy. At the top level, there are external
# boundaries of the components. At the second level, there are
# boundaries of the holes. If there is another contour inside a hole
# of a connected component, it is still put at the top level.
# cv2.CHAIN_APPROX_NONE: stores absolutely all the contour points.
outs = cv2.findContours(bitmap, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
contours = outs[-2]
hierarchy = outs[-1]
if hierarchy is None:
return [], False
# hierarchy[i]: 4 elements, for the indexes of next, previous,
# parent, or nested contours. If there is no corresponding contour,
# it will be -1.
with_hole = (hierarchy.reshape(-1, 4)[:, 3] >= 0).any()
contours = [c.reshape(-1, 2) for c in contours]
return contours, with_hole