mirror of https://github.com/YifanXu74/MQ-Det.git
568 lines
20 KiB
Python
568 lines
20 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
import cv2
|
|
import torch
|
|
import numpy as np
|
|
from torchvision import transforms as T
|
|
|
|
from maskrcnn_benchmark.modeling.detector import build_detection_model
|
|
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
|
|
from maskrcnn_benchmark.structures.image_list import to_image_list
|
|
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
|
|
from maskrcnn_benchmark.structures.bounding_box import BoxList
|
|
from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker
|
|
from maskrcnn_benchmark import layers as L
|
|
from maskrcnn_benchmark.utils import cv2_util
|
|
|
|
|
|
import timeit
|
|
|
|
class COCODemo(object):
|
|
# COCO categories for pretty print
|
|
CATEGORIES = [
|
|
"__background",
|
|
"person",
|
|
"bicycle",
|
|
"car",
|
|
"motorcycle",
|
|
"airplane",
|
|
"bus",
|
|
"train",
|
|
"truck",
|
|
"boat",
|
|
"traffic light",
|
|
"fire hydrant",
|
|
"stop sign",
|
|
"parking meter",
|
|
"bench",
|
|
"bird",
|
|
"cat",
|
|
"dog",
|
|
"horse",
|
|
"sheep",
|
|
"cow",
|
|
"elephant",
|
|
"bear",
|
|
"zebra",
|
|
"giraffe",
|
|
"backpack",
|
|
"umbrella",
|
|
"handbag",
|
|
"tie",
|
|
"suitcase",
|
|
"frisbee",
|
|
"skis",
|
|
"snowboard",
|
|
"sports ball",
|
|
"kite",
|
|
"baseball bat",
|
|
"baseball glove",
|
|
"skateboard",
|
|
"surfboard",
|
|
"tennis racket",
|
|
"bottle",
|
|
"wine glass",
|
|
"cup",
|
|
"fork",
|
|
"knife",
|
|
"spoon",
|
|
"bowl",
|
|
"banana",
|
|
"apple",
|
|
"sandwich",
|
|
"orange",
|
|
"broccoli",
|
|
"carrot",
|
|
"hot dog",
|
|
"pizza",
|
|
"donut",
|
|
"cake",
|
|
"chair",
|
|
"couch",
|
|
"potted plant",
|
|
"bed",
|
|
"dining table",
|
|
"toilet",
|
|
"tv",
|
|
"laptop",
|
|
"mouse",
|
|
"remote",
|
|
"keyboard",
|
|
"cell phone",
|
|
"microwave",
|
|
"oven",
|
|
"toaster",
|
|
"sink",
|
|
"refrigerator",
|
|
"book",
|
|
"clock",
|
|
"vase",
|
|
"scissors",
|
|
"teddy bear",
|
|
"hair drier",
|
|
"toothbrush",
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
cfg,
|
|
confidence_threshold=0.7,
|
|
show_mask_heatmaps=False,
|
|
masks_per_dim=2,
|
|
min_image_size=None,
|
|
exclude_region=None,
|
|
):
|
|
self.cfg = cfg.clone()
|
|
self.model = build_detection_model(cfg)
|
|
self.model.eval()
|
|
self.device = torch.device(cfg.MODEL.DEVICE)
|
|
self.model.to(self.device)
|
|
self.min_image_size = min_image_size
|
|
|
|
save_dir = cfg.OUTPUT_DIR
|
|
checkpointer = DetectronCheckpointer(cfg, self.model, save_dir=save_dir)
|
|
_ = checkpointer.load(cfg.MODEL.WEIGHT)
|
|
|
|
self.transforms = self.build_transform()
|
|
|
|
mask_threshold = -1 if show_mask_heatmaps else 0.5
|
|
self.masker = Masker(threshold=mask_threshold, padding=1)
|
|
|
|
# used to make colors for each class
|
|
self.palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
|
|
|
|
self.cpu_device = torch.device("cpu")
|
|
self.confidence_threshold = confidence_threshold
|
|
self.show_mask_heatmaps = show_mask_heatmaps
|
|
self.masks_per_dim = masks_per_dim
|
|
self.exclude_region = exclude_region
|
|
|
|
def build_transform(self):
|
|
"""
|
|
Creates a basic transformation that was used to train the models
|
|
"""
|
|
cfg = self.cfg
|
|
|
|
# we are loading images with OpenCV, so we don't need to convert them
|
|
# to BGR, they are already! So all we need to do is to normalize
|
|
# by 255 if we want to convert to BGR255 format, or flip the channels
|
|
# if we want it to be in RGB in [0-1] range.
|
|
if cfg.INPUT.TO_BGR255:
|
|
to_bgr_transform = T.Lambda(lambda x: x * 255)
|
|
else:
|
|
to_bgr_transform = T.Lambda(lambda x: x[[2, 1, 0]])
|
|
|
|
normalize_transform = T.Normalize(
|
|
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
|
|
)
|
|
|
|
transform = T.Compose(
|
|
[
|
|
T.ToPILImage(),
|
|
T.Resize(self.min_image_size) if self.min_image_size is not None else lambda x:x,
|
|
T.ToTensor(),
|
|
to_bgr_transform,
|
|
normalize_transform,
|
|
]
|
|
)
|
|
return transform
|
|
|
|
def inference(self, image, debug=False):
|
|
"""
|
|
Arguments:
|
|
image (np.ndarray): an image as returned by OpenCV
|
|
|
|
Returns:
|
|
prediction (BoxList): the detected objects. Additional information
|
|
of the detection properties can be found in the fields of
|
|
the BoxList via `prediction.fields()`
|
|
"""
|
|
predictions, debug_info = self.compute_prediction(image)
|
|
top_predictions = self.select_top_predictions(predictions)
|
|
|
|
if debug:
|
|
return top_predictions, debug_info
|
|
else:
|
|
return top_predictions
|
|
|
|
def run_on_opencv_image(self, image):
|
|
"""
|
|
Arguments:
|
|
image (np.ndarray): an image as returned by OpenCV
|
|
|
|
Returns:
|
|
prediction (BoxList): the detected objects. Additional information
|
|
of the detection properties can be found in the fields of
|
|
the BoxList via `prediction.fields()`
|
|
"""
|
|
predictions, debug_info = self.compute_prediction(image)
|
|
top_predictions = self.select_top_predictions(predictions)
|
|
|
|
result = image.copy()
|
|
if self.show_mask_heatmaps:
|
|
return self.create_mask_montage(result, top_predictions)
|
|
result = self.overlay_boxes(result, top_predictions)
|
|
if self.cfg.MODEL.MASK_ON:
|
|
result = self.overlay_mask(result, top_predictions)
|
|
if self.cfg.MODEL.KEYPOINT_ON:
|
|
result = self.overlay_keypoints(result, top_predictions)
|
|
result = self.overlay_class_names(result, top_predictions)
|
|
|
|
return result, debug_info, top_predictions
|
|
|
|
def compute_prediction(self, original_image):
|
|
"""
|
|
Arguments:
|
|
original_image (np.ndarray): an image as returned by OpenCV
|
|
|
|
Returns:
|
|
prediction (BoxList): the detected objects. Additional information
|
|
of the detection properties can be found in the fields of
|
|
the BoxList via `prediction.fields()`
|
|
"""
|
|
# apply pre-processing to image
|
|
# if self.exclude_region:
|
|
# for region in self.exclude_region:
|
|
# original_image[region[1]:region[3], region[0]:region[2], :] = 255
|
|
image = self.transforms(original_image)
|
|
|
|
|
|
# convert to an ImageList, padded so that it is divisible by
|
|
# cfg.DATALOADER.SIZE_DIVISIBILITY
|
|
image_list = to_image_list(image, self.cfg.DATALOADER.SIZE_DIVISIBILITY)
|
|
image_list = image_list.to(self.device)
|
|
tic = timeit.time.perf_counter()
|
|
|
|
# compute predictions
|
|
with torch.no_grad():
|
|
predictions, debug_info = self.model(image_list)
|
|
predictions = [o.to(self.cpu_device) for o in predictions]
|
|
debug_info['total_time'] = timeit.time.perf_counter() - tic
|
|
|
|
# always single image is passed at a time
|
|
prediction = predictions[0]
|
|
|
|
# reshape prediction (a BoxList) into the original image size
|
|
height, width = original_image.shape[:-1]
|
|
prediction = prediction.resize((width, height))
|
|
|
|
if prediction.has_field("mask"):
|
|
# if we have masks, paste the masks in the right position
|
|
# in the image, as defined by the bounding boxes
|
|
masks = prediction.get_field("mask")
|
|
# always single image is passed at a time
|
|
masks = self.masker([masks], [prediction])[0]
|
|
prediction.add_field("mask", masks)
|
|
|
|
return prediction, debug_info
|
|
|
|
def select_top_predictions(self, predictions):
|
|
"""
|
|
Select only predictions which have a `score` > self.confidence_threshold,
|
|
and returns the predictions in descending order of score
|
|
|
|
Arguments:
|
|
predictions (BoxList): the result of the computation by the model.
|
|
It should contain the field `scores`.
|
|
|
|
Returns:
|
|
prediction (BoxList): the detected objects. Additional information
|
|
of the detection properties can be found in the fields of
|
|
the BoxList via `prediction.fields()`
|
|
"""
|
|
|
|
scores = predictions.get_field("scores")
|
|
labels = predictions.get_field("labels").tolist()
|
|
thresh = scores.clone()
|
|
for i,lb in enumerate(labels):
|
|
if isinstance(self.confidence_threshold, float):
|
|
thresh[i] = self.confidence_threshold
|
|
elif len(self.confidence_threshold)==1:
|
|
thresh[i] = self.confidence_threshold[0]
|
|
else:
|
|
thresh[i] = self.confidence_threshold[lb-1]
|
|
keep = torch.nonzero(scores > thresh).squeeze(1)
|
|
predictions = predictions[keep]
|
|
|
|
if self.exclude_region:
|
|
exlude = BoxList(self.exclude_region, predictions.size)
|
|
iou = boxlist_iou(exlude, predictions)
|
|
keep = torch.nonzero(torch.sum(iou>0.5, dim=0)==0).squeeze(1)
|
|
if len(keep)>0:
|
|
predictions = predictions[keep]
|
|
|
|
scores = predictions.get_field("scores")
|
|
_, idx = scores.sort(0, descending=True)
|
|
return predictions[idx]
|
|
|
|
def compute_colors_for_labels(self, labels):
|
|
"""
|
|
Simple function that adds fixed colors depending on the class
|
|
"""
|
|
colors = (30*(labels[:, None] -1)+1)*self.palette
|
|
colors = (colors % 255).numpy().astype("uint8")
|
|
return colors
|
|
|
|
def overlay_boxes(self, image, predictions):
|
|
"""
|
|
Adds the predicted boxes on top of the image
|
|
|
|
Arguments:
|
|
image (np.ndarray): an image as returned by OpenCV
|
|
predictions (BoxList): the result of the computation by the model.
|
|
It should contain the field `labels`.
|
|
"""
|
|
labels = predictions.get_field("labels")
|
|
boxes = predictions.bbox
|
|
|
|
colors = self.compute_colors_for_labels(labels).tolist()
|
|
|
|
for box, color in zip(boxes, colors):
|
|
box = box.to(torch.int64)
|
|
top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
|
|
image = cv2.rectangle(
|
|
image, tuple(top_left), tuple(bottom_right), tuple(color), 2)
|
|
|
|
return image
|
|
|
|
def overlay_scores(self, image, predictions):
|
|
"""
|
|
Adds the predicted boxes on top of the image
|
|
|
|
Arguments:
|
|
image (np.ndarray): an image as returned by OpenCV
|
|
predictions (BoxList): the result of the computation by the model.
|
|
It should contain the field `labels`.
|
|
"""
|
|
scores = predictions.get_field("scores")
|
|
boxes = predictions.bbox
|
|
|
|
for box, score in zip(boxes, scores):
|
|
box = box.to(torch.int64)
|
|
image = cv2.putText(image, '%.3f'%score,
|
|
(box[0], (box[1]+box[3])/2),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
|
|
(255,255,255), 1)
|
|
|
|
return image
|
|
|
|
def overlay_cboxes(self, image, predictions):
|
|
"""
|
|
Adds the predicted boxes on top of the image
|
|
|
|
Arguments:
|
|
image (np.ndarray): an image as returned by OpenCV
|
|
predictions (BoxList): the result of the computation by the model.
|
|
It should contain the field `labels`.
|
|
"""
|
|
scores = predictions.get_field("scores")
|
|
boxes = predictions.bbox
|
|
for box, score in zip(boxes, scores):
|
|
box = box.to(torch.int64)
|
|
top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
|
|
image = cv2.rectangle(
|
|
image, tuple(top_left), tuple(bottom_right), (255,0,0), 2)
|
|
image = cv2.putText(image, '%.3f'%score,
|
|
(box[0], (box[1]+box[3])/2),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
|
|
(255,0,0), 1)
|
|
return image
|
|
|
|
def overlay_centers(self, image, predictions):
|
|
"""
|
|
Adds the predicted boxes on top of the image
|
|
|
|
Arguments:
|
|
image (np.ndarray): an image as returned by OpenCV
|
|
predictions (BoxList): the result of the computation by the model.
|
|
It should contain the field `labels`.
|
|
"""
|
|
centers = predictions.get_field("centers")
|
|
|
|
for cord in centers:
|
|
cord = cord.to(torch.int64)
|
|
image = cv2.circle(image, (cord[0].item(),cord[1].item()),
|
|
2, (255,0,0), 20)
|
|
|
|
return image
|
|
|
|
def overlay_count(self, image, predictions):
|
|
"""
|
|
Adds the predicted boxes on top of the image
|
|
|
|
Arguments:
|
|
image (np.ndarray): an image as returned by OpenCV
|
|
predictions (BoxList): the result of the computation by the model.
|
|
It should contain the field `labels`.
|
|
"""
|
|
if isinstance(predictions, int):
|
|
count = predictions
|
|
else:
|
|
count = len(predictions)
|
|
image = cv2.putText(image, 'Count: %d'%count, (0,100), cv2.FONT_HERSHEY_SIMPLEX, 3, (255,0,0), 3)
|
|
|
|
return image
|
|
|
|
def overlay_mask(self, image, predictions):
|
|
"""
|
|
Adds the instances contours for each predicted object.
|
|
Each label has a different color.
|
|
|
|
Arguments:
|
|
image (np.ndarray): an image as returned by OpenCV
|
|
predictions (BoxList): the result of the computation by the model.
|
|
It should contain the field `mask` and `labels`.
|
|
"""
|
|
masks = predictions.get_field("mask").numpy()
|
|
labels = predictions.get_field("labels")
|
|
|
|
colors = self.compute_colors_for_labels(labels).tolist()
|
|
|
|
for mask, color in zip(masks, colors):
|
|
thresh = mask[0, :, :, None].astype(np.uint8)
|
|
contours, hierarchy = cv2_util.findContours(
|
|
thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
|
)
|
|
image = cv2.drawContours(image, contours, -1, color, 3)
|
|
|
|
composite = image
|
|
|
|
return composite
|
|
|
|
def overlay_keypoints(self, image, predictions):
|
|
keypoints = predictions.get_field("keypoints")
|
|
kps = keypoints.keypoints
|
|
scores = keypoints.get_field("logits")
|
|
kps = torch.cat((kps[:, :, 0:2], scores[:, :, None]), dim=2).numpy()
|
|
for region in kps:
|
|
image = vis_keypoints(image, region.transpose((1, 0)),
|
|
names=keypoints.NAMES, connections=keypoints.CONNECTIONS)
|
|
return image
|
|
|
|
def create_mask_montage(self, image, predictions):
|
|
"""
|
|
Create a montage showing the probability heatmaps for each one one of the
|
|
detected objects
|
|
|
|
Arguments:
|
|
image (np.ndarray): an image as returned by OpenCV
|
|
predictions (BoxList): the result of the computation by the model.
|
|
It should contain the field `mask`.
|
|
"""
|
|
masks = predictions.get_field("mask")
|
|
masks_per_dim = self.masks_per_dim
|
|
masks = L.interpolate(
|
|
masks.float(), scale_factor=1 / masks_per_dim
|
|
).byte()
|
|
height, width = masks.shape[-2:]
|
|
max_masks = masks_per_dim ** 2
|
|
masks = masks[:max_masks]
|
|
# handle case where we have less detections than max_masks
|
|
if len(masks) < max_masks:
|
|
masks_padded = torch.zeros(max_masks, 1, height, width, dtype=torch.uint8)
|
|
masks_padded[: len(masks)] = masks
|
|
masks = masks_padded
|
|
masks = masks.reshape(masks_per_dim, masks_per_dim, height, width)
|
|
result = torch.zeros(
|
|
(masks_per_dim * height, masks_per_dim * width), dtype=torch.uint8
|
|
)
|
|
for y in range(masks_per_dim):
|
|
start_y = y * height
|
|
end_y = (y + 1) * height
|
|
for x in range(masks_per_dim):
|
|
start_x = x * width
|
|
end_x = (x + 1) * width
|
|
result[start_y:end_y, start_x:end_x] = masks[y, x]
|
|
return cv2.applyColorMap(result.numpy(), cv2.COLORMAP_JET)
|
|
|
|
def overlay_class_names(self, image, predictions, names=None):
|
|
"""
|
|
Adds detected class names and scores in the positions defined by the
|
|
top-left corner of the predicted bounding box
|
|
|
|
Arguments:
|
|
image (np.ndarray): an image as returned by OpenCV
|
|
predictions (BoxList): the result of the computation by the model.
|
|
It should contain the field `scores` and `labels`.
|
|
"""
|
|
scores = predictions.get_field("scores").tolist()
|
|
labels = predictions.get_field("labels").tolist()
|
|
if names:
|
|
labels = [names[i-1] for i in labels]
|
|
else:
|
|
labels = [self.CATEGORIES[i] for i in labels]
|
|
boxes = predictions.bbox
|
|
|
|
template = "{}: {:.2f}"
|
|
for box, score, label in zip(boxes, scores, labels):
|
|
x, y = box[:2]
|
|
s = template.format(label, score)
|
|
cv2.putText(
|
|
image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1
|
|
)
|
|
|
|
return image
|
|
|
|
def vis_keypoints(img, kps, kp_thresh=0, alpha=0.7, names=None, connections=None):
|
|
"""Visualizes keypoints (adapted from vis_one_image).
|
|
kps has shape (4, #keypoints) where 4 rows are (x, y, logit, prob).
|
|
"""
|
|
|
|
dataset_keypoints = names
|
|
kp_lines = connections
|
|
|
|
# simple rainbow color map implementation
|
|
blue_red_ratio = 0.8
|
|
gx = lambda x: (6-2*blue_red_ratio)*x + blue_red_ratio
|
|
colors = [[256*max(0, (3-abs(gx(i)-4)-abs(gx(i)-5))/2),
|
|
256*max(0, (3-abs(gx(i)-2)-abs(gx(i)-4))/2),
|
|
256*max(0, (3-abs(gx(i)-1)-abs(gx(i)-2))/2),] for i in np.linspace(0, 1, len(kp_lines) + 2)]
|
|
|
|
# Perform the drawing on a copy of the image, to allow for blending.
|
|
kp_mask = np.copy(img)
|
|
|
|
# Draw mid shoulder / mid hip first for better visualization.
|
|
mid_shoulder = (
|
|
kps[:2, dataset_keypoints.index('right_shoulder')] +
|
|
kps[:2, dataset_keypoints.index('left_shoulder')]) / 2.0
|
|
sc_mid_shoulder = np.minimum(
|
|
kps[2, dataset_keypoints.index('right_shoulder')],
|
|
kps[2, dataset_keypoints.index('left_shoulder')])
|
|
nose_idx = dataset_keypoints.index('nose')
|
|
if sc_mid_shoulder > kp_thresh and kps[2, nose_idx] > kp_thresh:
|
|
cv2.line(
|
|
kp_mask, tuple(mid_shoulder), tuple(kps[:2, nose_idx]),
|
|
color=colors[len(kp_lines)], thickness=2, lineType=cv2.LINE_AA)
|
|
|
|
if 'right_hip' in names and 'left_hip' in names:
|
|
mid_hip = (
|
|
kps[:2, dataset_keypoints.index('right_hip')] +
|
|
kps[:2, dataset_keypoints.index('left_hip')]) / 2.0
|
|
sc_mid_hip = np.minimum(
|
|
kps[2, dataset_keypoints.index('right_hip')],
|
|
kps[2, dataset_keypoints.index('left_hip')])
|
|
if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh:
|
|
cv2.line(
|
|
kp_mask, tuple(mid_shoulder), tuple(mid_hip),
|
|
color=colors[len(kp_lines) + 1], thickness=2, lineType=cv2.LINE_AA)
|
|
|
|
# Draw the keypoints.
|
|
for l in range(len(kp_lines)):
|
|
i1 = kp_lines[l][0]
|
|
i2 = kp_lines[l][1]
|
|
p1 = kps[0, i1], kps[1, i1]
|
|
p2 = kps[0, i2], kps[1, i2]
|
|
if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh:
|
|
cv2.line(
|
|
kp_mask, p1, p2,
|
|
color=colors[l], thickness=2, lineType=cv2.LINE_AA)
|
|
if kps[2, i1] > kp_thresh:
|
|
cv2.circle(
|
|
kp_mask, p1,
|
|
radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)
|
|
if kps[2, i2] > kp_thresh:
|
|
cv2.circle(
|
|
kp_mask, p2,
|
|
radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)
|
|
|
|
# Blend the keypoints.
|
|
return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0) |