mirror of
https://github.com/YifanXu74/MQ-Det.git
synced 2025-06-03 15:03:07 +08:00
723 lines
30 KiB
Python
723 lines
30 KiB
Python
# Delete some ununsed functions from modulated_coco.
|
|
# Suitable for object365 pre-training
|
|
import logging
|
|
import os
|
|
import os.path
|
|
import math
|
|
from PIL import Image, ImageDraw
|
|
|
|
import random
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torchvision
|
|
import torch.utils.data as data
|
|
from pycocotools import mask as coco_mask
|
|
|
|
from maskrcnn_benchmark.structures.bounding_box import BoxList
|
|
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
|
|
from maskrcnn_benchmark.data.datasets.coco import has_valid_annotation
|
|
from .od_to_grounding import convert_od_to_grounding_simple, check_for_positive_overflow, sanity_check_target_after_processing, convert_object_detection_to_grounding_optimized_for_od
|
|
import pdb
|
|
import json
|
|
from tqdm import tqdm
|
|
from groundingdino_new.util.inference import preprocess_caption
|
|
from groundingdino_new.util.box_ops import box_xyxy_to_cxcywh
|
|
|
|
import copy
|
|
|
|
def _has_only_crowd_bbox(anno):
|
|
return all(obj["iscrowd"] == 1 for obj in anno)
|
|
|
|
class CocoGrounding_New(torchvision.datasets.CocoDetection):
|
|
def __init__(self,
|
|
img_folder,
|
|
ann_file,
|
|
transforms,
|
|
return_masks,
|
|
return_tokens,
|
|
is_train=False,
|
|
tokenizer=None,
|
|
disable_shuffle=False,
|
|
add_detection_prompt=False,
|
|
add_detection_prompt_advanced=False,
|
|
control_probabilities={},
|
|
one_hot=False,
|
|
disable_clip_to_image=False,
|
|
no_minus_one_for_one_hot=False,
|
|
separation_tokens=" ",
|
|
few_shot=0,
|
|
no_mask_for_od=False,
|
|
override_category=None,
|
|
use_caption_prompt=False,
|
|
caption_prompt=None,
|
|
max_num_labels=-1,
|
|
max_query_len=256,
|
|
special_safeguard_for_coco_grounding=False,
|
|
random_sample_negative=-1,
|
|
cumtom_ids=None,
|
|
exclude_crowd=False,
|
|
sep_at_last = False,
|
|
add_normed_cxcy = False,
|
|
custom_category_ids=None,
|
|
**kwargs
|
|
):
|
|
super(CocoGrounding_New, self).__init__(img_folder, ann_file)
|
|
self.ids = sorted(self.ids)
|
|
|
|
self.sep_at_last = sep_at_last
|
|
self.add_normed_cxcy = add_normed_cxcy
|
|
|
|
self.iscrowd = False if exclude_crowd else None
|
|
|
|
ids = []
|
|
for img_id in self.ids:
|
|
if isinstance(img_id, str):
|
|
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=self.iscrowd)
|
|
else:
|
|
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=self.iscrowd)
|
|
anno = self.coco.loadAnns(ann_ids)
|
|
|
|
if has_valid_annotation(anno):
|
|
ids.append(img_id)
|
|
|
|
self.ids = ids
|
|
|
|
# self.ids = self.remove_invalid_images(self.ids)
|
|
|
|
if few_shot:
|
|
ids = []
|
|
# cats_freq = [few_shot]*len(self.coco.cats.keys())
|
|
cats_freq = [few_shot]*max(list(self.coco.cats.keys()))
|
|
for img_id in self.ids:
|
|
if isinstance(img_id, str):
|
|
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=self.iscrowd)
|
|
else:
|
|
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=self.iscrowd)
|
|
anno = self.coco.loadAnns(ann_ids)
|
|
cat = set([ann['category_id'] for ann in anno]) #set/tuple corresponde to instance/image level
|
|
is_needed = sum([cats_freq[c-1]>0 for c in cat])
|
|
if is_needed:
|
|
ids.append(img_id)
|
|
for c in cat:
|
|
cats_freq[c-1] -= 1
|
|
# print(cat, cats_freq)
|
|
self.ids = ids
|
|
if cumtom_ids is not None:
|
|
self.ids = cumtom_ids
|
|
|
|
if custom_category_ids is not None:
|
|
new_ids = []
|
|
for img_id in self.ids:
|
|
if isinstance(img_id, str):
|
|
ann_ids = self.coco.getAnnIds(imgIds=[img_id], catIds=custom_category_ids, iscrowd=self.iscrowd)
|
|
else:
|
|
ann_ids = self.coco.getAnnIds(imgIds=img_id, catIds=custom_category_ids, iscrowd=self.iscrowd)
|
|
if len(ann_ids) > 0:
|
|
new_ids.append(img_id)
|
|
self.ids = new_ids
|
|
|
|
|
|
self.json_category_id_to_contiguous_id = {
|
|
v: i + 1 for i, v in enumerate(self.coco.getCatIds())
|
|
}
|
|
self.contiguous_category_id_to_json_id = {
|
|
v: k for k, v in self.json_category_id_to_contiguous_id.items()
|
|
}
|
|
|
|
if override_category is not None:
|
|
self.coco.dataset["categories"] = override_category
|
|
self.max_num_labels=max_num_labels
|
|
self.control_probabilities=control_probabilities
|
|
self.use_caption_prompt = use_caption_prompt
|
|
self.caption_prompt = caption_prompt
|
|
self.special_safeguard_for_coco_grounding = special_safeguard_for_coco_grounding
|
|
self.random_sample_negative = random_sample_negative
|
|
self.ind_to_class = self.categories(no_background=False)
|
|
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
|
|
self._transforms = transforms
|
|
self.max_query_len = max_query_len
|
|
self.prepare = ConvertCocoPolysToMask(False, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len, ind_to_class=self.ind_to_class)
|
|
self.tokenizer = tokenizer
|
|
self.is_train = is_train
|
|
|
|
self.ind_to_class = self.categories(no_background=False)
|
|
|
|
self.disable_shuffle = disable_shuffle
|
|
self.add_detection_prompt = add_detection_prompt
|
|
self.add_detection_prompt_advanced=add_detection_prompt_advanced
|
|
self.one_hot = one_hot
|
|
self.no_minus_one_for_one_hot = no_minus_one_for_one_hot
|
|
|
|
self.disable_clip_to_image = disable_clip_to_image
|
|
self.separation_tokens = separation_tokens
|
|
self.no_mask_for_od = no_mask_for_od
|
|
self.return_masks = return_masks
|
|
|
|
def remove_invalid_images(self, ids):
|
|
print('removing non-exist images from dataset...')
|
|
new_ids=[]
|
|
invalid_num=0
|
|
for id in tqdm(ids):
|
|
path = self.coco.loadImgs(id)[0]["file_name"]
|
|
if os.path.exists(os.path.join(self.root, path)):
|
|
new_ids.append(id)
|
|
else:
|
|
invalid_num+=1
|
|
print('removed {} non-exist images from dataset'.format(invalid_num))
|
|
return new_ids
|
|
|
|
def categories(self, no_background=True):
|
|
categories = self.coco.dataset["categories"]
|
|
label_list = {}
|
|
for index, i in enumerate(categories):
|
|
# assert(index + 1 == i["id"])
|
|
if not no_background or (i["name"] != "__background__" and i['id'] != 0):
|
|
label_list[self.json_category_id_to_contiguous_id[i["id"]]] = i["name"]
|
|
return label_list
|
|
|
|
def get_box_mask(self, rect, img_size, mode="poly"):
|
|
assert mode=="poly", "Only support poly mask right now!"
|
|
x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3]
|
|
return [[x1, y1, x1, y2, x2, y2, x2, y1]]
|
|
|
|
def __getitem__(self, idx):
|
|
img, tgt = super(CocoGrounding_New, self).__getitem__(idx)
|
|
image_id = self.ids[idx]
|
|
tgt = [obj for obj in tgt if obj["iscrowd"] == 0]
|
|
boxes = [obj["bbox"] for obj in tgt]
|
|
boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes
|
|
target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")
|
|
classes = [obj["category_id"] for obj in tgt]
|
|
classes = [self.json_category_id_to_contiguous_id[c] for c in classes]
|
|
classes = torch.tensor(classes)
|
|
target.add_field("labels", classes)
|
|
|
|
if not self.disable_clip_to_image:
|
|
target = target.clip_to_image(remove_empty=True)
|
|
|
|
if self.special_safeguard_for_coco_grounding:
|
|
# Intended for LVIS and Object365
|
|
assert(not self.use_caption_prompt)
|
|
|
|
original_box_num = len(target)
|
|
target, positive_caption_length = check_for_positive_overflow(target, self.ind_to_class, self.tokenizer, self.max_query_len-2) # leave some space for the special tokens
|
|
if len(target) < original_box_num:
|
|
print("WARNING: removed {} boxes due to positive caption overflow".format(original_box_num - len(target)))
|
|
|
|
annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_object_detection_to_grounding_optimized_for_od(
|
|
target=target,
|
|
image_id=image_id,
|
|
ind_to_class=self.ind_to_class,
|
|
disable_shuffle=self.disable_shuffle,
|
|
add_detection_prompt=self.add_detection_prompt,
|
|
add_detection_prompt_advanced=self.add_detection_prompt_advanced,
|
|
random_sample_negative=self.random_sample_negative,
|
|
control_probabilities=self.control_probabilities, # always try to add a lot of negatives
|
|
restricted_negative_list=None,
|
|
separation_tokens=self.separation_tokens,
|
|
max_num_labels=self.max_num_labels,
|
|
positive_caption_length=positive_caption_length,
|
|
tokenizer=self.tokenizer,
|
|
max_seq_length=self.max_query_len-2,
|
|
obj356_debug=True
|
|
)
|
|
else:
|
|
# Intended for COCO / ODinW
|
|
annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_od_to_grounding_simple(
|
|
target=target,
|
|
image_id=image_id,
|
|
ind_to_class=self.ind_to_class,
|
|
disable_shuffle=self.disable_shuffle,
|
|
add_detection_prompt=self.add_detection_prompt,
|
|
separation_tokens=self.separation_tokens,
|
|
caption_prompt=self.caption_prompt if self.use_caption_prompt else None,
|
|
)
|
|
|
|
# if self.sep_at_last:
|
|
# caption = preprocess_caption(caption)
|
|
anno = {"image_id": image_id, "annotations": annotations, "caption": caption, "label_to_positions_caption": label_to_positions}
|
|
anno["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective
|
|
if self.no_mask_for_od:
|
|
anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1))
|
|
img, anno = self.prepare(img, anno, box_format="xyxy")
|
|
|
|
# for equivalence check
|
|
if self.one_hot:
|
|
logging.info("using one hot for equivalence check.")
|
|
one_hot_map = torch.zeros_like(anno["positive_map"], dtype=torch.float)
|
|
text_mask = torch.zeros(anno["positive_map"].shape[1], dtype=torch.int64)
|
|
# create one hot mapping
|
|
for ii, cls in enumerate(classes):
|
|
if self.no_minus_one_for_one_hot:
|
|
one_hot_map[ii, cls] = 1.0
|
|
else:
|
|
one_hot_map[ii, cls - 1] = 1.0
|
|
if self.no_minus_one_for_one_hot:
|
|
text_mask[:] = 1
|
|
else:
|
|
text_mask[:len(self.ind_to_class)] = 1
|
|
anno["positive_map"] = one_hot_map
|
|
anno["text_mask"] = text_mask
|
|
|
|
if self._transforms is not None:
|
|
img, target = self._transforms(img, target)
|
|
|
|
# add additional property
|
|
for ann in anno:
|
|
target.add_field(ann, anno[ann])
|
|
|
|
if self.add_normed_cxcy:
|
|
bbox = target.bbox
|
|
H, W = target.size
|
|
normed_bbox = bbox / torch.Tensor([[H,W,H,W]])
|
|
normed_cxcy = box_xyxy_to_cxcywh(normed_bbox)
|
|
target.add_field('normed_cxcy_boxes', normed_cxcy)
|
|
|
|
sanity_check_target_after_processing(target)
|
|
|
|
return img, target, idx
|
|
|
|
def get_img_info(self, index):
|
|
img_id = self.id_to_img_map[index]
|
|
img_data = self.coco.imgs[img_id]
|
|
return img_data
|
|
|
|
def get_raw_image(self, idx):
|
|
image, *_ = super(CocoGrounding_New, self).__getitem__(idx)
|
|
return image
|
|
|
|
|
|
class ModulatedDataset(torchvision.datasets.CocoDetection):
|
|
def __init__(self,
|
|
img_folder,
|
|
ann_file,
|
|
transforms,
|
|
return_masks,
|
|
return_tokens,
|
|
is_train=False,
|
|
tokenizer=None,
|
|
disable_clip_to_image=False,
|
|
no_mask_for_gold=False,
|
|
max_query_len=256,
|
|
**kwargs):
|
|
super(ModulatedDataset, self).__init__(img_folder, ann_file)
|
|
self.ids = sorted(self.ids)
|
|
|
|
ids = []
|
|
for img_id in self.ids:
|
|
if isinstance(img_id, str):
|
|
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
|
|
else:
|
|
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
|
anno = self.coco.loadAnns(ann_ids)
|
|
if has_valid_annotation(anno):
|
|
ids.append(img_id)
|
|
self.ids = ids
|
|
|
|
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
|
|
self._transforms = transforms
|
|
self.max_query_len = max_query_len
|
|
self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len)
|
|
self.is_train = is_train
|
|
self.disable_clip_to_image = disable_clip_to_image
|
|
self.no_mask_for_gold = no_mask_for_gold
|
|
|
|
def __getitem__(self, idx):
|
|
img, target = super(ModulatedDataset, self).__getitem__(idx)
|
|
image_id = self.ids[idx]
|
|
coco_img = self.coco.loadImgs(image_id)[0]
|
|
caption = coco_img["caption"]
|
|
dataset_name = coco_img["dataset_name"] if "dataset_name" in coco_img else None
|
|
anno = {"image_id": image_id, "annotations": target, "caption": caption}
|
|
|
|
# This dataset is used for Flickr & Mixed, so the sequence is maskable
|
|
anno["greenlight_span_for_masked_lm_objective"] = [(0, len(caption))]
|
|
if self.no_mask_for_gold:
|
|
anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1))
|
|
img, anno = self.prepare(img, anno)
|
|
|
|
# convert to BoxList (bboxes, labels)
|
|
boxes = torch.as_tensor(anno["boxes"]).reshape(-1, 4) # guard against no boxes
|
|
target = BoxList(boxes, img.size, mode="xyxy")
|
|
classes = anno["labels"]
|
|
target.add_field("labels", classes)
|
|
if self.prepare.return_masks:
|
|
target.add_field("masks", anno.pop("masks"))
|
|
target.add_field("is_box_mask", anno.pop("is_box_mask"))
|
|
if not self.disable_clip_to_image:
|
|
num_boxes = len(target.bbox)
|
|
target = target.clip_to_image(remove_empty=True)
|
|
assert num_boxes == len(target.bbox), "Box got removed in MixedDataset!!!"
|
|
|
|
# Check if bboxes are correct
|
|
# draw = ImageDraw.Draw(img)
|
|
# boxes = target.bbox
|
|
# for box in boxes:
|
|
# draw.rectangle([box[0], box[1], box[2], box[3]])
|
|
# img.save('OUTPUT/images/{}.jpg'.format(idx))
|
|
|
|
if self._transforms is not None:
|
|
img, target = self._transforms(img, target)
|
|
|
|
# add additional property
|
|
for ann in anno:
|
|
target.add_field(ann, anno[ann])
|
|
|
|
target.add_field("dataset_name", dataset_name)
|
|
for extra_key in ["sentence_id", "original_img_id", "original_id", "task_id"]:
|
|
if extra_key in coco_img:
|
|
target.add_field(extra_key, coco_img[extra_key])
|
|
|
|
if "tokens_positive_eval" in coco_img and not self.is_train:
|
|
tokenized = self.prepare.tokenizer(caption, return_tensors="pt")
|
|
target.add_field("positive_map_eval", create_positive_map(tokenized, coco_img["tokens_positive_eval"]))
|
|
target.add_field("nb_eval", len(target.get_field("positive_map_eval")))
|
|
|
|
sanity_check_target_after_processing(target)
|
|
return img, target, idx
|
|
|
|
def get_img_info(self, index):
|
|
img_id = self.id_to_img_map[index]
|
|
img_data = self.coco.imgs[img_id]
|
|
return img_data
|
|
|
|
|
|
class CocoDetection(data.Dataset):
|
|
"""`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
|
|
|
|
Args:
|
|
root (string): Root directory where images are downloaded to.
|
|
annFile (string): Path to json annotation file.
|
|
transform (callable, optional): A function/transform that takes in an PIL image
|
|
and returns a transformed version. E.g, ``transforms.ToTensor``
|
|
target_transform (callable, optional): A function/transform that takes in the
|
|
target and transforms it.
|
|
"""
|
|
|
|
def __init__(self, root, annFile, transform=None, target_transform=None):
|
|
from pycocotools.coco import COCO
|
|
self.root = root
|
|
self.coco = COCO(annFile)
|
|
self.ids = list(self.coco.imgs.keys())
|
|
self.transform = transform
|
|
self.target_transform = target_transform
|
|
|
|
def __getitem__(self, index, return_meta=False):
|
|
"""
|
|
Args:
|
|
index (int): Index
|
|
|
|
Returns:
|
|
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
|
|
"""
|
|
coco = self.coco
|
|
img_id = self.ids[index]
|
|
if isinstance(img_id, str):
|
|
img_id = [img_id]
|
|
ann_ids = coco.getAnnIds(imgIds=img_id)
|
|
target = coco.loadAnns(ann_ids)
|
|
|
|
meta = coco.loadImgs(img_id)[0]
|
|
path = meta['file_name']
|
|
img = pil_loader(os.path.join(self.root, path))
|
|
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
|
|
if self.target_transform is not None:
|
|
target = self.target_transform(target)
|
|
|
|
if return_meta:
|
|
return img, target, meta
|
|
else:
|
|
return img, target
|
|
|
|
def __len__(self):
|
|
return len(self.ids)
|
|
|
|
def __repr__(self):
|
|
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
|
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
|
|
fmt_str += ' Root Location: {}\n'.format(self.root)
|
|
tmp = ' Transforms (if any): '
|
|
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
|
tmp = ' Target Transforms (if any): '
|
|
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
|
return fmt_str
|
|
|
|
|
|
class ConvertCocoPolysToMask(object):
|
|
def __init__(self, return_masks=False, return_tokens=False, tokenizer=None, max_query_len=256, ind_to_class=None):
|
|
self.return_masks = return_masks
|
|
self.return_tokens = return_tokens
|
|
self.tokenizer = tokenizer
|
|
self.max_query_len = max_query_len
|
|
self.ind_to_class=ind_to_class
|
|
|
|
def get_box_mask(self, rect, img_size, mode="poly"):
|
|
assert mode=="poly", "Only support poly mask right now!"
|
|
x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3]
|
|
return [[x1, y1, x1, y2, x2, y2, x2, y1]]
|
|
|
|
def __call__(self, image, target, ignore_box_screen=False, box_format="xywh"):
|
|
w, h = image.size
|
|
|
|
image_id = target["image_id"]
|
|
image_id = torch.tensor([image_id])
|
|
|
|
anno = target["annotations"]
|
|
caption = target["caption"] if "caption" in target else None
|
|
label_to_positions = target.get("label_to_positions", {})
|
|
label_to_positions_caption = target.get("label_to_positions_caption", {})
|
|
|
|
greenlight_span_for_masked_lm_objective = target.get("greenlight_span_for_masked_lm_objective", None)
|
|
|
|
anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0]
|
|
|
|
boxes = [obj["bbox"] for obj in anno]
|
|
# guard against no boxes via resizing
|
|
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
|
if box_format == "xywh":
|
|
boxes[:, 2:] += boxes[:, :2] - 1 # TO_REMOVE = 1
|
|
boxes[:, 0::2].clamp_(min=0, max=w-1) # TO_REMOVE = 1
|
|
boxes[:, 1::2].clamp_(min=0, max=h-1) # TO_REMOVE = 1
|
|
|
|
classes = [obj["category_id"] for obj in anno]
|
|
classes = torch.tensor(classes, dtype=torch.int64)
|
|
|
|
if self.return_masks:
|
|
masks = []
|
|
is_box_mask = []
|
|
for obj, bbox in zip(anno, boxes):
|
|
if "segmentation" in obj:
|
|
masks.append(obj["segmentation"])
|
|
is_box_mask.append(0)
|
|
else:
|
|
masks.append(self.get_box_mask(bbox, image.size, mode='poly'))
|
|
is_box_mask.append(1)
|
|
masks = SegmentationMask(masks, image.size, mode='poly')
|
|
is_box_mask = torch.tensor(is_box_mask)
|
|
|
|
keypoints = None
|
|
if anno and "keypoints" in anno[0]:
|
|
keypoints = [obj["keypoints"] for obj in anno]
|
|
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
|
|
num_keypoints = keypoints.shape[0]
|
|
if num_keypoints:
|
|
keypoints = keypoints.view(num_keypoints, -1, 3)
|
|
|
|
isfinal = None
|
|
if anno and "isfinal" in anno[0]:
|
|
isfinal = torch.as_tensor([obj["isfinal"] for obj in anno], dtype=torch.float)
|
|
|
|
tokens_positive = [] if self.return_tokens else None
|
|
if self.return_tokens and anno and "tokens" in anno[0]:
|
|
tokens_positive = [obj["tokens"] for obj in anno]
|
|
elif self.return_tokens and anno and "tokens_positive" in anno[0]:
|
|
tokens_positive = [obj["tokens_positive"] for obj in anno]
|
|
|
|
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
|
boxes = boxes[keep]
|
|
classes = classes[keep]
|
|
if self.return_masks:
|
|
masks = masks[keep]
|
|
is_box_mask = is_box_mask[keep]
|
|
if keypoints is not None:
|
|
keypoints = keypoints[keep]
|
|
|
|
target = {}
|
|
target["boxes"] = boxes
|
|
target["labels"] = classes
|
|
if caption is not None:
|
|
target["caption"] = caption
|
|
if self.return_masks:
|
|
target["masks"] = masks
|
|
target["is_box_mask"] = is_box_mask
|
|
target["image_id"] = image_id
|
|
if keypoints is not None:
|
|
target["keypoints"] = keypoints
|
|
|
|
if tokens_positive is not None:
|
|
target["tokens_positive"] = []
|
|
|
|
for i, k in enumerate(keep):
|
|
if k or ignore_box_screen:
|
|
target["tokens_positive"].append(tokens_positive[i])
|
|
|
|
if isfinal is not None:
|
|
target["isfinal"] = isfinal
|
|
|
|
# for conversion to coco api
|
|
area = torch.tensor([obj["area"] for obj in anno])
|
|
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
|
|
target["area"] = area[keep]
|
|
target["iscrowd"] = iscrowd[keep]
|
|
|
|
target["orig_size"] = torch.as_tensor([int(h), int(w)])
|
|
target["size"] = torch.as_tensor([int(h), int(w)])
|
|
|
|
if self.return_tokens and self.tokenizer is not None:
|
|
if not ignore_box_screen:
|
|
assert len(target["boxes"]) == len(target["tokens_positive"])
|
|
|
|
tokenized = self.tokenizer(caption, return_tensors="pt",
|
|
max_length=self.max_query_len,
|
|
truncation=True)
|
|
|
|
target["positive_map"] = create_positive_map(tokenized, target["tokens_positive"])
|
|
# target['greenlight_map'] = create_greenlight_map(greenlight_span_for_masked_lm_objective,tokenized)
|
|
# target["positive_map_for_od_labels"] = create_positive_map_for_od_labels(tokenized, label_to_positions)
|
|
|
|
all_tokens = [[v] for k,v in label_to_positions_caption.items()]
|
|
target["all_map"] = create_positive_map(tokenized, all_tokens)
|
|
target["labels_in_caption"]=[k for k,v in label_to_positions_caption.items()]
|
|
|
|
pos_label_set=list(set(target['labels'].tolist()))
|
|
pos_category_tokens = [[v] for k,v in label_to_positions_caption.items() if k in pos_label_set]
|
|
target["positive_category_map"] = create_positive_map(tokenized, pos_category_tokens)
|
|
target["positive_category_map"][target["positive_category_map"]!=0]=1
|
|
|
|
|
|
original_od_label = []
|
|
for obj in anno:
|
|
original_od_label.append(
|
|
obj.get("original_od_label", -10)) # NOTE: The padding value has to be not the same as -1 or -100
|
|
target["original_od_label"] = torch.as_tensor(original_od_label)
|
|
|
|
return image, target
|
|
|
|
def create_greenlight_map(tok_list, tokenized):
|
|
# An example tok_list:
|
|
# [(0, 5), (10, 13), (-1, -1, -1)]
|
|
# The last one is a special indicator..
|
|
|
|
greenlight_map = torch.zeros(256, dtype=torch.float)
|
|
for item in tok_list:
|
|
if len(item) != 2:
|
|
assert(len(item) == 3)
|
|
# Make everything unmakable
|
|
greenlight_map[:] = -1
|
|
break
|
|
|
|
beg, end = item
|
|
beg_pos = tokenized.char_to_token(beg)
|
|
end_pos = tokenized.char_to_token(end - 1)
|
|
if beg_pos is None:
|
|
try:
|
|
beg_pos = tokenized.char_to_token(beg + 1)
|
|
if beg_pos is None:
|
|
beg_pos = tokenized.char_to_token(beg + 2)
|
|
except:
|
|
beg_pos = None
|
|
if end_pos is None:
|
|
try:
|
|
end_pos = tokenized.char_to_token(end - 2)
|
|
if end_pos is None:
|
|
end_pos = tokenized.char_to_token(end - 3)
|
|
except:
|
|
end_pos = None
|
|
if beg_pos is None or end_pos is None:
|
|
continue
|
|
|
|
assert beg_pos is not None and end_pos is not None
|
|
greenlight_map[beg_pos: end_pos + 1].fill_(1)
|
|
return greenlight_map
|
|
|
|
|
|
def create_positive_map_for_od_labels(tokenized, label_to_positions):
|
|
"""construct a map such that positive_map[i] = j, where j is the object detection label of the token i"""
|
|
"""
|
|
{3: [1: 5)}
|
|
256 : -1 3 3 3 3 -1 .. 8 8 ..
|
|
the woman in the garden
|
|
-1 -1 -1 -1 -1
|
|
"""
|
|
positive_map = torch.ones(256, dtype=torch.float) * -1 # -1 means no match
|
|
keys = list(label_to_positions.keys())
|
|
for j, key in enumerate(keys):
|
|
tok_list = label_to_positions[key]
|
|
# one label only mapps to one location
|
|
beg, end = tok_list
|
|
beg_pos = tokenized.char_to_token(beg)
|
|
end_pos = tokenized.char_to_token(end - 1)
|
|
if beg_pos is None:
|
|
try:
|
|
beg_pos = tokenized.char_to_token(beg + 1)
|
|
if beg_pos is None:
|
|
beg_pos = tokenized.char_to_token(beg + 2)
|
|
except:
|
|
beg_pos = None
|
|
if end_pos is None:
|
|
try:
|
|
end_pos = tokenized.char_to_token(end - 2)
|
|
if end_pos is None:
|
|
end_pos = tokenized.char_to_token(end - 3)
|
|
except:
|
|
end_pos = None
|
|
if beg_pos is None or end_pos is None:
|
|
continue
|
|
assert beg_pos is not None and end_pos is not None
|
|
positive_map[beg_pos: end_pos + 1].fill_(key)
|
|
return positive_map
|
|
|
|
|
|
def convert_coco_poly_to_mask(segmentations, height, width):
|
|
masks = []
|
|
for polygons in segmentations:
|
|
rles = coco_mask.frPyObjects(polygons, height, width)
|
|
mask = coco_mask.decode(rles)
|
|
if len(mask.shape) < 3:
|
|
mask = mask[..., None]
|
|
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
|
mask = mask.any(dim=2)
|
|
masks.append(mask)
|
|
if masks:
|
|
masks = torch.stack(masks, dim=0)
|
|
else:
|
|
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
|
return masks
|
|
|
|
|
|
def create_positive_map(tokenized, tokens_positive):
|
|
"""construct a map such that positive_map[i,j] = True iff box i is associated to token j"""
|
|
positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float)
|
|
|
|
for j, tok_list in enumerate(tokens_positive):
|
|
for (beg, end) in tok_list:
|
|
beg_pos = tokenized.char_to_token(beg)
|
|
end_pos = tokenized.char_to_token(end - 1)
|
|
if beg_pos is None:
|
|
try:
|
|
beg_pos = tokenized.char_to_token(beg + 1)
|
|
if beg_pos is None:
|
|
beg_pos = tokenized.char_to_token(beg + 2)
|
|
except:
|
|
beg_pos = None
|
|
if end_pos is None:
|
|
try:
|
|
end_pos = tokenized.char_to_token(end - 2)
|
|
if end_pos is None:
|
|
end_pos = tokenized.char_to_token(end - 3)
|
|
except:
|
|
end_pos = None
|
|
if beg_pos is None or end_pos is None:
|
|
continue
|
|
|
|
assert beg_pos is not None and end_pos is not None
|
|
positive_map[j, beg_pos: end_pos + 1].fill_(1)
|
|
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
|
|
|
|
|
|
def pil_loader(path, retry=5):
|
|
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
|
ri = 0
|
|
while ri < retry:
|
|
try:
|
|
with open(path, 'rb') as f:
|
|
img = Image.open(f)
|
|
return img.convert('RGB')
|
|
except:
|
|
ri += 1
|
|
|