mirror of https://github.com/YifanXu74/MQ-Det.git
472 lines
18 KiB
Python
472 lines
18 KiB
Python
import cv2
|
||
import torch
|
||
import re
|
||
import numpy as np
|
||
from typing import List, Union
|
||
import nltk
|
||
import inflect
|
||
from transformers import AutoTokenizer
|
||
from torchvision import transforms as T
|
||
import pdb
|
||
from maskrcnn_benchmark.modeling.detector import build_detection_model
|
||
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
|
||
from maskrcnn_benchmark.structures.image_list import to_image_list
|
||
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
|
||
from maskrcnn_benchmark.structures.bounding_box import BoxList
|
||
from maskrcnn_benchmark import layers as L
|
||
from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker
|
||
from maskrcnn_benchmark.utils import cv2_util
|
||
|
||
engine = inflect.engine()
|
||
nltk.download('punkt')
|
||
nltk.download('averaged_perceptron_tagger')
|
||
|
||
import timeit
|
||
import os
|
||
|
||
|
||
class GLIPDemo(object):
|
||
def __init__(self,
|
||
cfg,
|
||
confidence_threshold=0.7,
|
||
min_image_size=None,
|
||
show_mask_heatmaps=False,
|
||
masks_per_dim=5,
|
||
load_model=True
|
||
):
|
||
self.cfg = cfg.clone()
|
||
if load_model:
|
||
self.model = build_detection_model(cfg)
|
||
self.model.eval()
|
||
self.device = torch.device(cfg.MODEL.DEVICE)
|
||
self.model.to(self.device)
|
||
self.min_image_size = min_image_size
|
||
self.show_mask_heatmaps = show_mask_heatmaps
|
||
self.masks_per_dim = masks_per_dim
|
||
|
||
save_dir = cfg.OUTPUT_DIR
|
||
if load_model:
|
||
checkpointer = DetectronCheckpointer(cfg, self.model, save_dir=save_dir)
|
||
_ = checkpointer.load(cfg.MODEL.WEIGHT)
|
||
|
||
self.transforms = self.build_transform()
|
||
|
||
# used to make colors for each tokens
|
||
mask_threshold = -1 if show_mask_heatmaps else 0.5
|
||
self.masker = Masker(threshold=mask_threshold, padding=1)
|
||
self.palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
|
||
self.cpu_device = torch.device("cpu")
|
||
self.confidence_threshold = confidence_threshold
|
||
|
||
self.tokenizer = self.build_tokenizer()
|
||
|
||
def build_transform(self):
|
||
"""
|
||
Creates a basic transformation that was used to train the models
|
||
"""
|
||
cfg = self.cfg
|
||
|
||
# we are loading images with OpenCV, so we don't need to convert them
|
||
# to BGR, they are already! So all we need to do is to normalize
|
||
# by 255 if we want to convert to BGR255 format, or flip the channels
|
||
# if we want it to be in RGB in [0-1] range.
|
||
if cfg.INPUT.TO_BGR255:
|
||
to_bgr_transform = T.Lambda(lambda x: x * 255)
|
||
else:
|
||
to_bgr_transform = T.Lambda(lambda x: x[[2, 1, 0]])
|
||
|
||
normalize_transform = T.Normalize(
|
||
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
|
||
)
|
||
|
||
transform = T.Compose(
|
||
[
|
||
T.ToPILImage(),
|
||
T.Resize(self.min_image_size) if self.min_image_size is not None else lambda x: x,
|
||
T.ToTensor(),
|
||
to_bgr_transform,
|
||
normalize_transform,
|
||
]
|
||
)
|
||
return transform
|
||
|
||
def build_tokenizer(self):
|
||
cfg = self.cfg
|
||
tokenizer = None
|
||
if os.path.basename(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE) == "bert-base-uncased":
|
||
tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
|
||
elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip":
|
||
from transformers import CLIPTokenizerFast
|
||
if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
|
||
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
|
||
from_slow=True, mask_token='ðŁĴij</w>')
|
||
else:
|
||
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
|
||
from_slow=True)
|
||
return tokenizer
|
||
|
||
def run_ner(self, caption):
|
||
noun_phrases = find_noun_phrases(caption)
|
||
noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases]
|
||
noun_phrases = [phrase for phrase in noun_phrases if phrase != '']
|
||
relevant_phrases = noun_phrases
|
||
labels = noun_phrases
|
||
self.entities = labels
|
||
|
||
tokens_positive = []
|
||
|
||
for entity, label in zip(relevant_phrases, labels):
|
||
try:
|
||
# search all occurrences and mark them as different entities
|
||
for m in re.finditer(entity, caption.lower()):
|
||
tokens_positive.append([[m.start(), m.end()]])
|
||
except:
|
||
print("noun entities:", noun_phrases)
|
||
print("entity:", entity)
|
||
print("caption:", caption.lower())
|
||
|
||
return tokens_positive
|
||
|
||
def inference(self, original_image, original_caption):
|
||
predictions = self.compute_prediction(original_image, original_caption)
|
||
top_predictions = self._post_process_fixed_thresh(predictions)
|
||
return top_predictions
|
||
|
||
def run_on_web_image(self,
|
||
original_image,
|
||
original_caption,
|
||
thresh=0.5,
|
||
custom_entity = None,
|
||
alpha = 0.0):
|
||
predictions = self.compute_prediction(original_image, original_caption, custom_entity)
|
||
top_predictions = self._post_process(predictions, thresh)
|
||
|
||
result = original_image.copy()
|
||
if self.show_mask_heatmaps:
|
||
return self.create_mask_montage(result, top_predictions)
|
||
result = self.overlay_boxes(result, top_predictions)
|
||
result = self.overlay_entity_names(result, top_predictions)
|
||
if self.cfg.MODEL.MASK_ON:
|
||
result = self.overlay_mask(result, top_predictions)
|
||
return result, top_predictions
|
||
|
||
def visualize_with_predictions(self,
|
||
original_image,
|
||
predictions,
|
||
thresh=0.5,
|
||
alpha=0.0,
|
||
box_pixel=3,
|
||
text_size = 1,
|
||
text_pixel = 2,
|
||
text_offset = 10,
|
||
text_offset_original = 4,
|
||
color = 255):
|
||
self.color = color
|
||
height, width = original_image.shape[:-1]
|
||
predictions = predictions.resize((width, height))
|
||
top_predictions = self._post_process(predictions, thresh)
|
||
|
||
result = original_image.copy()
|
||
if self.show_mask_heatmaps:
|
||
return self.create_mask_montage(result, top_predictions)
|
||
result = self.overlay_boxes(result, top_predictions, alpha=alpha, box_pixel=box_pixel)
|
||
result = self.overlay_entity_names(result, top_predictions, text_size=text_size, text_pixel=text_pixel, text_offset = text_offset, text_offset_original = text_offset_original)
|
||
if self.cfg.MODEL.MASK_ON:
|
||
result = self.overlay_mask(result, top_predictions)
|
||
return result, top_predictions
|
||
|
||
def compute_prediction(self, original_image, original_caption, custom_entity = None):
|
||
# image
|
||
image = self.transforms(original_image)
|
||
image_list = to_image_list(image, self.cfg.DATALOADER.SIZE_DIVISIBILITY)
|
||
image_list = image_list.to(self.device)
|
||
# caption
|
||
if isinstance(original_caption, list):
|
||
# we directly provided a list of category names
|
||
caption_string = ""
|
||
tokens_positive = []
|
||
seperation_tokens = " . "
|
||
for word in original_caption:
|
||
|
||
tokens_positive.append([len(caption_string), len(caption_string) + len(word)])
|
||
caption_string += word
|
||
caption_string += seperation_tokens
|
||
|
||
tokenized = self.tokenizer([caption_string], return_tensors="pt")
|
||
tokens_positive = [tokens_positive]
|
||
|
||
original_caption = caption_string
|
||
print(tokens_positive)
|
||
else:
|
||
tokenized = self.tokenizer([original_caption], return_tensors="pt")
|
||
if custom_entity is None:
|
||
tokens_positive = self.run_ner(original_caption)
|
||
print(tokens_positive)
|
||
# process positive map
|
||
positive_map = create_positive_map(tokenized, tokens_positive)
|
||
|
||
if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD":
|
||
plus = 1
|
||
else:
|
||
plus = 0
|
||
|
||
positive_map_label_to_token = create_positive_map_label_to_token_from_positive_map(positive_map, plus=plus)
|
||
self.plus = plus
|
||
self.positive_map_label_to_token = positive_map_label_to_token
|
||
tic = timeit.time.perf_counter()
|
||
|
||
# compute predictions
|
||
with torch.no_grad():
|
||
predictions = self.model(image_list, captions=[original_caption], positive_map=positive_map_label_to_token)
|
||
predictions = [o.to(self.cpu_device) for o in predictions]
|
||
print("inference time per image: {}".format(timeit.time.perf_counter() - tic))
|
||
|
||
# always single image is passed at a time
|
||
prediction = predictions[0]
|
||
|
||
# reshape prediction (a BoxList) into the original image size
|
||
height, width = original_image.shape[:-1]
|
||
prediction = prediction.resize((width, height))
|
||
|
||
if prediction.has_field("mask"):
|
||
# if we have masks, paste the masks in the right position
|
||
# in the image, as defined by the bounding boxes
|
||
masks = prediction.get_field("mask")
|
||
# always single image is passed at a time
|
||
masks = self.masker([masks], [prediction])[0]
|
||
prediction.add_field("mask", masks)
|
||
|
||
return prediction
|
||
|
||
def _post_process_fixed_thresh(self, predictions):
|
||
scores = predictions.get_field("scores")
|
||
labels = predictions.get_field("labels").tolist()
|
||
thresh = scores.clone()
|
||
for i, lb in enumerate(labels):
|
||
if isinstance(self.confidence_threshold, float):
|
||
thresh[i] = self.confidence_threshold
|
||
elif len(self.confidence_threshold) == 1:
|
||
thresh[i] = self.confidence_threshold[0]
|
||
else:
|
||
thresh[i] = self.confidence_threshold[lb - 1]
|
||
keep = torch.nonzero(scores > thresh).squeeze(1)
|
||
predictions = predictions[keep]
|
||
|
||
scores = predictions.get_field("scores")
|
||
_, idx = scores.sort(0, descending=True)
|
||
return predictions[idx]
|
||
|
||
def _post_process(self, predictions, threshold=0.5):
|
||
scores = predictions.get_field("scores")
|
||
labels = predictions.get_field("labels").tolist()
|
||
thresh = scores.clone()
|
||
for i, lb in enumerate(labels):
|
||
if isinstance(self.confidence_threshold, float):
|
||
thresh[i] = threshold
|
||
elif len(self.confidence_threshold) == 1:
|
||
thresh[i] = threshold
|
||
else:
|
||
thresh[i] = self.confidence_threshold[lb - 1]
|
||
keep = torch.nonzero(scores > thresh).squeeze(1)
|
||
predictions = predictions[keep]
|
||
|
||
scores = predictions.get_field("scores")
|
||
_, idx = scores.sort(0, descending=True)
|
||
return predictions[idx]
|
||
|
||
def compute_colors_for_labels(self, labels):
|
||
"""
|
||
Simple function that adds fixed colors depending on the class
|
||
"""
|
||
colors = (30 * (labels[:, None] - 1) + 1) * self.palette
|
||
colors = (colors % 255).numpy().astype("uint8")
|
||
try:
|
||
colors = (colors * 0 + self.color).astype("uint8")
|
||
except:
|
||
pass
|
||
return colors
|
||
|
||
def overlay_boxes(self, image, predictions, alpha=0.5, box_pixel = 3):
|
||
labels = predictions.get_field("labels")
|
||
boxes = predictions.bbox
|
||
|
||
colors = self.compute_colors_for_labels(labels).tolist()
|
||
new_image = image.copy()
|
||
for box, color in zip(boxes, colors):
|
||
box = box.to(torch.int64)
|
||
top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
|
||
new_image = cv2.rectangle(
|
||
new_image, tuple(top_left), tuple(bottom_right), tuple(color), box_pixel)
|
||
|
||
# Following line overlays transparent rectangle over the image
|
||
image = cv2.addWeighted(new_image, alpha, image, 1 - alpha, 0)
|
||
|
||
return image
|
||
|
||
def overlay_scores(self, image, predictions):
|
||
scores = predictions.get_field("scores")
|
||
boxes = predictions.bbox
|
||
|
||
for box, score in zip(boxes, scores):
|
||
box = box.to(torch.int64)
|
||
image = cv2.putText(image, '%.3f' % score,
|
||
(int(box[0]), int((box[1] + box[3]) / 2)),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2, cv2.LINE_AA)
|
||
|
||
return image
|
||
|
||
def overlay_entity_names(self, image, predictions, names=None, text_size=1.0, text_pixel=2, text_offset = 10, text_offset_original = 4):
|
||
scores = predictions.get_field("scores").tolist()
|
||
labels = predictions.get_field("labels").tolist()
|
||
new_labels = []
|
||
if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD":
|
||
plus = 1
|
||
else:
|
||
plus = 0
|
||
self.plus = plus
|
||
if self.entities and self.plus:
|
||
for i in labels:
|
||
if i <= len(self.entities):
|
||
new_labels.append(self.entities[i - self.plus])
|
||
else:
|
||
new_labels.append('object')
|
||
# labels = [self.entities[i - self.plus] for i in labels ]
|
||
else:
|
||
new_labels = ['object' for i in labels]
|
||
boxes = predictions.bbox
|
||
|
||
template = "{}:{:.2f}"
|
||
previous_locations = []
|
||
for box, score, label in zip(boxes, scores, new_labels):
|
||
x, y = box[:2]
|
||
s = template.format(label, score).replace("_", " ").replace("(", "").replace(")", "")
|
||
for x_prev, y_prev in previous_locations:
|
||
if abs(x - x_prev) < abs(text_offset) and abs(y - y_prev) < abs(text_offset):
|
||
y -= text_offset
|
||
|
||
cv2.putText(
|
||
image, s, (int(x), int(y)-text_offset_original), cv2.FONT_HERSHEY_SIMPLEX, text_size, (self.color, self.color, self.color), text_pixel, cv2.LINE_AA
|
||
)
|
||
previous_locations.append((int(x), int(y)))
|
||
|
||
|
||
return image
|
||
|
||
def overlay_mask(self, image, predictions):
|
||
masks = predictions.get_field("mask").numpy()
|
||
labels = predictions.get_field("labels")
|
||
|
||
colors = self.compute_colors_for_labels(labels).tolist()
|
||
|
||
# import pdb
|
||
# pdb.set_trace()
|
||
# masks = masks > 0.1
|
||
|
||
for mask, color in zip(masks, colors):
|
||
thresh = mask[0, :, :, None].astype(np.uint8)
|
||
contours, hierarchy = cv2_util.findContours(
|
||
thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
||
)
|
||
image = cv2.drawContours(image, contours, -1, color, 2)
|
||
|
||
composite = image
|
||
|
||
return composite
|
||
|
||
def create_mask_montage(self, image, predictions):
|
||
masks = predictions.get_field("mask")
|
||
masks_per_dim = self.masks_per_dim
|
||
masks = L.interpolate(
|
||
masks.float(), scale_factor=1 / masks_per_dim
|
||
).byte()
|
||
height, width = masks.shape[-2:]
|
||
max_masks = masks_per_dim ** 2
|
||
masks = masks[:max_masks]
|
||
# handle case where we have less detections than max_masks
|
||
if len(masks) < max_masks:
|
||
masks_padded = torch.zeros(max_masks, 1, height, width, dtype=torch.uint8)
|
||
masks_padded[: len(masks)] = masks
|
||
masks = masks_padded
|
||
masks = masks.reshape(masks_per_dim, masks_per_dim, height, width)
|
||
result = torch.zeros(
|
||
(masks_per_dim * height, masks_per_dim * width), dtype=torch.uint8
|
||
)
|
||
for y in range(masks_per_dim):
|
||
start_y = y * height
|
||
end_y = (y + 1) * height
|
||
for x in range(masks_per_dim):
|
||
start_x = x * width
|
||
end_x = (x + 1) * width
|
||
result[start_y:end_y, start_x:end_x] = masks[y, x]
|
||
|
||
return cv2.applyColorMap(result.numpy(), cv2.COLORMAP_JET), None
|
||
|
||
|
||
def create_positive_map_label_to_token_from_positive_map(positive_map, plus=0):
|
||
positive_map_label_to_token = {}
|
||
for i in range(len(positive_map)):
|
||
positive_map_label_to_token[i + plus] = torch.nonzero(positive_map[i], as_tuple=True)[0].tolist()
|
||
return positive_map_label_to_token
|
||
|
||
|
||
def create_positive_map(tokenized, tokens_positive):
|
||
"""construct a map such that positive_map[i,j] = True iff box i is associated to token j"""
|
||
positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float)
|
||
|
||
for j, tok_list in enumerate(tokens_positive):
|
||
for (beg, end) in tok_list:
|
||
try:
|
||
beg_pos = tokenized.char_to_token(beg)
|
||
end_pos = tokenized.char_to_token(end - 1)
|
||
except Exception as e:
|
||
print("beg:", beg, "end:", end)
|
||
print("token_positive:", tokens_positive)
|
||
# print("beg_pos:", beg_pos, "end_pos:", end_pos)
|
||
raise e
|
||
if beg_pos is None:
|
||
try:
|
||
beg_pos = tokenized.char_to_token(beg + 1)
|
||
if beg_pos is None:
|
||
beg_pos = tokenized.char_to_token(beg + 2)
|
||
except:
|
||
beg_pos = None
|
||
if end_pos is None:
|
||
try:
|
||
end_pos = tokenized.char_to_token(end - 2)
|
||
if end_pos is None:
|
||
end_pos = tokenized.char_to_token(end - 3)
|
||
except:
|
||
end_pos = None
|
||
if beg_pos is None or end_pos is None:
|
||
continue
|
||
|
||
assert beg_pos is not None and end_pos is not None
|
||
positive_map[j, beg_pos: end_pos + 1].fill_(1)
|
||
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
|
||
|
||
|
||
def find_noun_phrases(caption: str) -> List[str]:
|
||
caption = caption.lower()
|
||
tokens = nltk.word_tokenize(caption)
|
||
pos_tags = nltk.pos_tag(tokens)
|
||
|
||
grammar = "NP: {<DT>?<JJ.*>*<NN.*>+}"
|
||
cp = nltk.RegexpParser(grammar)
|
||
result = cp.parse(pos_tags)
|
||
|
||
noun_phrases = list()
|
||
for subtree in result.subtrees():
|
||
if subtree.label() == 'NP':
|
||
noun_phrases.append(' '.join(t[0] for t in subtree.leaves()))
|
||
|
||
return noun_phrases
|
||
|
||
|
||
def remove_punctuation(text: str) -> str:
|
||
punct = ['|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^',
|
||
'\'', '\"', '’', '`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.'
|
||
]
|
||
for p in punct:
|
||
text = text.replace(p, '')
|
||
return text.strip()
|