2023-10-07 23:02:26 +08:00
# Delete some ununsed functions from modulated_coco.
2023-10-07 23:20:40 +08:00
# Suitable for object365 pre-training
2023-10-07 23:02:26 +08:00
import logging
import os
import os . path
import math
from PIL import Image , ImageDraw
import random
import numpy as np
import torch
import torchvision
import torch . utils . data as data
from pycocotools import mask as coco_mask
from maskrcnn_benchmark . structures . bounding_box import BoxList
from maskrcnn_benchmark . structures . segmentation_mask import SegmentationMask
from maskrcnn_benchmark . data . datasets . coco import has_valid_annotation
from . od_to_grounding import convert_od_to_grounding_simple , check_for_positive_overflow , sanity_check_target_after_processing , convert_object_detection_to_grounding_optimized_for_od
import pdb
import json
from tqdm import tqdm
from groundingdino_new . util . inference import preprocess_caption
from groundingdino_new . util . box_ops import box_xyxy_to_cxcywh
import copy
def _has_only_crowd_bbox ( anno ) :
return all ( obj [ " iscrowd " ] == 1 for obj in anno )
class CocoGrounding_New ( torchvision . datasets . CocoDetection ) :
def __init__ ( self ,
img_folder ,
ann_file ,
transforms ,
return_masks ,
return_tokens ,
is_train = False ,
tokenizer = None ,
disable_shuffle = False ,
add_detection_prompt = False ,
add_detection_prompt_advanced = False ,
control_probabilities = { } ,
one_hot = False ,
disable_clip_to_image = False ,
no_minus_one_for_one_hot = False ,
separation_tokens = " " ,
few_shot = 0 ,
no_mask_for_od = False ,
override_category = None ,
use_caption_prompt = False ,
caption_prompt = None ,
max_num_labels = - 1 ,
max_query_len = 256 ,
special_safeguard_for_coco_grounding = False ,
random_sample_negative = - 1 ,
cumtom_ids = None ,
exclude_crowd = False ,
sep_at_last = False ,
add_normed_cxcy = False ,
custom_category_ids = None ,
* * kwargs
) :
super ( CocoGrounding_New , self ) . __init__ ( img_folder , ann_file )
self . ids = sorted ( self . ids )
self . sep_at_last = sep_at_last
self . add_normed_cxcy = add_normed_cxcy
self . iscrowd = False if exclude_crowd else None
ids = [ ]
for img_id in self . ids :
if isinstance ( img_id , str ) :
ann_ids = self . coco . getAnnIds ( imgIds = [ img_id ] , iscrowd = self . iscrowd )
else :
ann_ids = self . coco . getAnnIds ( imgIds = img_id , iscrowd = self . iscrowd )
anno = self . coco . loadAnns ( ann_ids )
if has_valid_annotation ( anno ) :
ids . append ( img_id )
self . ids = ids
# self.ids = self.remove_invalid_images(self.ids)
if few_shot :
ids = [ ]
# cats_freq = [few_shot]*len(self.coco.cats.keys())
cats_freq = [ few_shot ] * max ( list ( self . coco . cats . keys ( ) ) )
for img_id in self . ids :
if isinstance ( img_id , str ) :
ann_ids = self . coco . getAnnIds ( imgIds = [ img_id ] , iscrowd = self . iscrowd )
else :
ann_ids = self . coco . getAnnIds ( imgIds = img_id , iscrowd = self . iscrowd )
anno = self . coco . loadAnns ( ann_ids )
cat = set ( [ ann [ ' category_id ' ] for ann in anno ] ) #set/tuple corresponde to instance/image level
is_needed = sum ( [ cats_freq [ c - 1 ] > 0 for c in cat ] )
if is_needed :
ids . append ( img_id )
for c in cat :
cats_freq [ c - 1 ] - = 1
# print(cat, cats_freq)
self . ids = ids
if cumtom_ids is not None :
self . ids = cumtom_ids
if custom_category_ids is not None :
new_ids = [ ]
for img_id in self . ids :
if isinstance ( img_id , str ) :
ann_ids = self . coco . getAnnIds ( imgIds = [ img_id ] , catIds = custom_category_ids , iscrowd = self . iscrowd )
else :
ann_ids = self . coco . getAnnIds ( imgIds = img_id , catIds = custom_category_ids , iscrowd = self . iscrowd )
if len ( ann_ids ) > 0 :
new_ids . append ( img_id )
self . ids = new_ids
self . json_category_id_to_contiguous_id = {
v : i + 1 for i , v in enumerate ( self . coco . getCatIds ( ) )
}
self . contiguous_category_id_to_json_id = {
v : k for k , v in self . json_category_id_to_contiguous_id . items ( )
}
if override_category is not None :
self . coco . dataset [ " categories " ] = override_category
self . max_num_labels = max_num_labels
self . control_probabilities = control_probabilities
self . use_caption_prompt = use_caption_prompt
self . caption_prompt = caption_prompt
self . special_safeguard_for_coco_grounding = special_safeguard_for_coco_grounding
self . random_sample_negative = random_sample_negative
self . ind_to_class = self . categories ( no_background = False )
self . id_to_img_map = { k : v for k , v in enumerate ( self . ids ) }
self . _transforms = transforms
self . max_query_len = max_query_len
self . prepare = ConvertCocoPolysToMask ( False , return_tokens , tokenizer = tokenizer , max_query_len = max_query_len , ind_to_class = self . ind_to_class )
self . tokenizer = tokenizer
self . is_train = is_train
self . ind_to_class = self . categories ( no_background = False )
self . disable_shuffle = disable_shuffle
self . add_detection_prompt = add_detection_prompt
self . add_detection_prompt_advanced = add_detection_prompt_advanced
self . one_hot = one_hot
self . no_minus_one_for_one_hot = no_minus_one_for_one_hot
self . disable_clip_to_image = disable_clip_to_image
self . separation_tokens = separation_tokens
self . no_mask_for_od = no_mask_for_od
self . return_masks = return_masks
def remove_invalid_images ( self , ids ) :
print ( ' removing non-exist images from dataset... ' )
new_ids = [ ]
invalid_num = 0
for id in tqdm ( ids ) :
path = self . coco . loadImgs ( id ) [ 0 ] [ " file_name " ]
if os . path . exists ( os . path . join ( self . root , path ) ) :
new_ids . append ( id )
else :
invalid_num + = 1
print ( ' removed {} non-exist images from dataset ' . format ( invalid_num ) )
return new_ids
def categories ( self , no_background = True ) :
categories = self . coco . dataset [ " categories " ]
label_list = { }
for index , i in enumerate ( categories ) :
# assert(index + 1 == i["id"])
if not no_background or ( i [ " name " ] != " __background__ " and i [ ' id ' ] != 0 ) :
label_list [ self . json_category_id_to_contiguous_id [ i [ " id " ] ] ] = i [ " name " ]
return label_list
def get_box_mask ( self , rect , img_size , mode = " poly " ) :
assert mode == " poly " , " Only support poly mask right now! "
x1 , y1 , x2 , y2 = rect [ 0 ] , rect [ 1 ] , rect [ 2 ] , rect [ 3 ]
return [ [ x1 , y1 , x1 , y2 , x2 , y2 , x2 , y1 ] ]
def __getitem__ ( self , idx ) :
img , tgt = super ( CocoGrounding_New , self ) . __getitem__ ( idx )
image_id = self . ids [ idx ]
tgt = [ obj for obj in tgt if obj [ " iscrowd " ] == 0 ]
boxes = [ obj [ " bbox " ] for obj in tgt ]
boxes = torch . as_tensor ( boxes ) . reshape ( - 1 , 4 ) # guard against no boxes
target = BoxList ( boxes , img . size , mode = " xywh " ) . convert ( " xyxy " )
classes = [ obj [ " category_id " ] for obj in tgt ]
classes = [ self . json_category_id_to_contiguous_id [ c ] for c in classes ]
classes = torch . tensor ( classes )
target . add_field ( " labels " , classes )
if not self . disable_clip_to_image :
target = target . clip_to_image ( remove_empty = True )
if self . special_safeguard_for_coco_grounding :
# Intended for LVIS and Object365
assert ( not self . use_caption_prompt )
original_box_num = len ( target )
target , positive_caption_length = check_for_positive_overflow ( target , self . ind_to_class , self . tokenizer , self . max_query_len - 2 ) # leave some space for the special tokens
if len ( target ) < original_box_num :
print ( " WARNING: removed {} boxes due to positive caption overflow " . format ( original_box_num - len ( target ) ) )
annotations , caption , greenlight_span_for_masked_lm_objective , label_to_positions = convert_object_detection_to_grounding_optimized_for_od (
target = target ,
image_id = image_id ,
ind_to_class = self . ind_to_class ,
disable_shuffle = self . disable_shuffle ,
add_detection_prompt = self . add_detection_prompt ,
add_detection_prompt_advanced = self . add_detection_prompt_advanced ,
random_sample_negative = self . random_sample_negative ,
control_probabilities = self . control_probabilities , # always try to add a lot of negatives
restricted_negative_list = None ,
separation_tokens = self . separation_tokens ,
max_num_labels = self . max_num_labels ,
positive_caption_length = positive_caption_length ,
tokenizer = self . tokenizer ,
max_seq_length = self . max_query_len - 2 ,
obj356_debug = True
)
else :
# Intended for COCO / ODinW
annotations , caption , greenlight_span_for_masked_lm_objective , label_to_positions = convert_od_to_grounding_simple (
target = target ,
image_id = image_id ,
ind_to_class = self . ind_to_class ,
disable_shuffle = self . disable_shuffle ,
add_detection_prompt = self . add_detection_prompt ,
separation_tokens = self . separation_tokens ,
caption_prompt = self . caption_prompt if self . use_caption_prompt else None ,
)
# if self.sep_at_last:
# caption = preprocess_caption(caption)
anno = { " image_id " : image_id , " annotations " : annotations , " caption " : caption , " label_to_positions_caption " : label_to_positions }
anno [ " greenlight_span_for_masked_lm_objective " ] = greenlight_span_for_masked_lm_objective
if self . no_mask_for_od :
anno [ " greenlight_span_for_masked_lm_objective " ] . append ( ( - 1 , - 1 , - 1 ) )
img , anno = self . prepare ( img , anno , box_format = " xyxy " )
# for equivalence check
if self . one_hot :
logging . info ( " using one hot for equivalence check. " )
one_hot_map = torch . zeros_like ( anno [ " positive_map " ] , dtype = torch . float )
text_mask = torch . zeros ( anno [ " positive_map " ] . shape [ 1 ] , dtype = torch . int64 )
# create one hot mapping
for ii , cls in enumerate ( classes ) :
if self . no_minus_one_for_one_hot :
one_hot_map [ ii , cls ] = 1.0
else :
one_hot_map [ ii , cls - 1 ] = 1.0
if self . no_minus_one_for_one_hot :
text_mask [ : ] = 1
else :
text_mask [ : len ( self . ind_to_class ) ] = 1
anno [ " positive_map " ] = one_hot_map
anno [ " text_mask " ] = text_mask
if self . _transforms is not None :
img , target = self . _transforms ( img , target )
# add additional property
for ann in anno :
target . add_field ( ann , anno [ ann ] )
if self . add_normed_cxcy :
bbox = target . bbox
H , W = target . size
normed_bbox = bbox / torch . Tensor ( [ [ H , W , H , W ] ] )
normed_cxcy = box_xyxy_to_cxcywh ( normed_bbox )
target . add_field ( ' normed_cxcy_boxes ' , normed_cxcy )
sanity_check_target_after_processing ( target )
return img , target , idx
def get_img_info ( self , index ) :
img_id = self . id_to_img_map [ index ]
img_data = self . coco . imgs [ img_id ]
return img_data
def get_raw_image ( self , idx ) :
image , * _ = super ( CocoGrounding_New , self ) . __getitem__ ( idx )
return image
class ModulatedDataset ( torchvision . datasets . CocoDetection ) :
def __init__ ( self ,
img_folder ,
ann_file ,
transforms ,
return_masks ,
return_tokens ,
is_train = False ,
tokenizer = None ,
disable_clip_to_image = False ,
no_mask_for_gold = False ,
max_query_len = 256 ,
* * kwargs ) :
super ( ModulatedDataset , self ) . __init__ ( img_folder , ann_file )
self . ids = sorted ( self . ids )
ids = [ ]
for img_id in self . ids :
if isinstance ( img_id , str ) :
ann_ids = self . coco . getAnnIds ( imgIds = [ img_id ] , iscrowd = None )
else :
ann_ids = self . coco . getAnnIds ( imgIds = img_id , iscrowd = None )
anno = self . coco . loadAnns ( ann_ids )
if has_valid_annotation ( anno ) :
ids . append ( img_id )
self . ids = ids
self . id_to_img_map = { k : v for k , v in enumerate ( self . ids ) }
self . _transforms = transforms
self . max_query_len = max_query_len
self . prepare = ConvertCocoPolysToMask ( return_masks , return_tokens , tokenizer = tokenizer , max_query_len = max_query_len )
self . is_train = is_train
self . disable_clip_to_image = disable_clip_to_image
self . no_mask_for_gold = no_mask_for_gold
def __getitem__ ( self , idx ) :
img , target = super ( ModulatedDataset , self ) . __getitem__ ( idx )
image_id = self . ids [ idx ]
coco_img = self . coco . loadImgs ( image_id ) [ 0 ]
caption = coco_img [ " caption " ]
dataset_name = coco_img [ " dataset_name " ] if " dataset_name " in coco_img else None
anno = { " image_id " : image_id , " annotations " : target , " caption " : caption }
# This dataset is used for Flickr & Mixed, so the sequence is maskable
anno [ " greenlight_span_for_masked_lm_objective " ] = [ ( 0 , len ( caption ) ) ]
if self . no_mask_for_gold :
anno [ " greenlight_span_for_masked_lm_objective " ] . append ( ( - 1 , - 1 , - 1 ) )
img , anno = self . prepare ( img , anno )
# convert to BoxList (bboxes, labels)
boxes = torch . as_tensor ( anno [ " boxes " ] ) . reshape ( - 1 , 4 ) # guard against no boxes
target = BoxList ( boxes , img . size , mode = " xyxy " )
classes = anno [ " labels " ]
target . add_field ( " labels " , classes )
if self . prepare . return_masks :
target . add_field ( " masks " , anno . pop ( " masks " ) )
target . add_field ( " is_box_mask " , anno . pop ( " is_box_mask " ) )
if not self . disable_clip_to_image :
num_boxes = len ( target . bbox )
target = target . clip_to_image ( remove_empty = True )
assert num_boxes == len ( target . bbox ) , " Box got removed in MixedDataset!!! "
# Check if bboxes are correct
# draw = ImageDraw.Draw(img)
# boxes = target.bbox
# for box in boxes:
# draw.rectangle([box[0], box[1], box[2], box[3]])
# img.save('OUTPUT/images/{}.jpg'.format(idx))
if self . _transforms is not None :
img , target = self . _transforms ( img , target )
# add additional property
for ann in anno :
target . add_field ( ann , anno [ ann ] )
target . add_field ( " dataset_name " , dataset_name )
for extra_key in [ " sentence_id " , " original_img_id " , " original_id " , " task_id " ] :
if extra_key in coco_img :
target . add_field ( extra_key , coco_img [ extra_key ] )
if " tokens_positive_eval " in coco_img and not self . is_train :
tokenized = self . prepare . tokenizer ( caption , return_tensors = " pt " )
target . add_field ( " positive_map_eval " , create_positive_map ( tokenized , coco_img [ " tokens_positive_eval " ] ) )
target . add_field ( " nb_eval " , len ( target . get_field ( " positive_map_eval " ) ) )
sanity_check_target_after_processing ( target )
return img , target , idx
def get_img_info ( self , index ) :
img_id = self . id_to_img_map [ index ]
img_data = self . coco . imgs [ img_id ]
return img_data
class CocoDetection ( data . Dataset ) :
""" `MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
Args :
root ( string ) : Root directory where images are downloaded to .
annFile ( string ) : Path to json annotation file .
transform ( callable , optional ) : A function / transform that takes in an PIL image
and returns a transformed version . E . g , ` ` transforms . ToTensor ` `
target_transform ( callable , optional ) : A function / transform that takes in the
target and transforms it .
"""
def __init__ ( self , root , annFile , transform = None , target_transform = None ) :
from pycocotools . coco import COCO
self . root = root
self . coco = COCO ( annFile )
self . ids = list ( self . coco . imgs . keys ( ) )
self . transform = transform
self . target_transform = target_transform
def __getitem__ ( self , index , return_meta = False ) :
"""
Args :
index ( int ) : Index
Returns :
tuple : Tuple ( image , target ) . target is the object returned by ` ` coco . loadAnns ` ` .
"""
coco = self . coco
img_id = self . ids [ index ]
if isinstance ( img_id , str ) :
img_id = [ img_id ]
ann_ids = coco . getAnnIds ( imgIds = img_id )
target = coco . loadAnns ( ann_ids )
meta = coco . loadImgs ( img_id ) [ 0 ]
path = meta [ ' file_name ' ]
img = pil_loader ( os . path . join ( self . root , path ) )
if self . transform is not None :
img = self . transform ( img )
if self . target_transform is not None :
target = self . target_transform ( target )
if return_meta :
return img , target , meta
else :
return img , target
def __len__ ( self ) :
return len ( self . ids )
def __repr__ ( self ) :
fmt_str = ' Dataset ' + self . __class__ . __name__ + ' \n '
fmt_str + = ' Number of datapoints: {} \n ' . format ( self . __len__ ( ) )
fmt_str + = ' Root Location: {} \n ' . format ( self . root )
tmp = ' Transforms (if any): '
fmt_str + = ' {0} {1} \n ' . format ( tmp , self . transform . __repr__ ( ) . replace ( ' \n ' , ' \n ' + ' ' * len ( tmp ) ) )
tmp = ' Target Transforms (if any): '
fmt_str + = ' {0} {1} ' . format ( tmp , self . target_transform . __repr__ ( ) . replace ( ' \n ' , ' \n ' + ' ' * len ( tmp ) ) )
return fmt_str
class ConvertCocoPolysToMask ( object ) :
def __init__ ( self , return_masks = False , return_tokens = False , tokenizer = None , max_query_len = 256 , ind_to_class = None ) :
self . return_masks = return_masks
self . return_tokens = return_tokens
self . tokenizer = tokenizer
self . max_query_len = max_query_len
self . ind_to_class = ind_to_class
def get_box_mask ( self , rect , img_size , mode = " poly " ) :
assert mode == " poly " , " Only support poly mask right now! "
x1 , y1 , x2 , y2 = rect [ 0 ] , rect [ 1 ] , rect [ 2 ] , rect [ 3 ]
return [ [ x1 , y1 , x1 , y2 , x2 , y2 , x2 , y1 ] ]
def __call__ ( self , image , target , ignore_box_screen = False , box_format = " xywh " ) :
w , h = image . size
image_id = target [ " image_id " ]
image_id = torch . tensor ( [ image_id ] )
anno = target [ " annotations " ]
caption = target [ " caption " ] if " caption " in target else None
label_to_positions = target . get ( " label_to_positions " , { } )
label_to_positions_caption = target . get ( " label_to_positions_caption " , { } )
greenlight_span_for_masked_lm_objective = target . get ( " greenlight_span_for_masked_lm_objective " , None )
anno = [ obj for obj in anno if " iscrowd " not in obj or obj [ " iscrowd " ] == 0 ]
boxes = [ obj [ " bbox " ] for obj in anno ]
# guard against no boxes via resizing
boxes = torch . as_tensor ( boxes , dtype = torch . float32 ) . reshape ( - 1 , 4 )
if box_format == " xywh " :
boxes [ : , 2 : ] + = boxes [ : , : 2 ] - 1 # TO_REMOVE = 1
boxes [ : , 0 : : 2 ] . clamp_ ( min = 0 , max = w - 1 ) # TO_REMOVE = 1
boxes [ : , 1 : : 2 ] . clamp_ ( min = 0 , max = h - 1 ) # TO_REMOVE = 1
classes = [ obj [ " category_id " ] for obj in anno ]
classes = torch . tensor ( classes , dtype = torch . int64 )
if self . return_masks :
masks = [ ]
is_box_mask = [ ]
for obj , bbox in zip ( anno , boxes ) :
if " segmentation " in obj :
masks . append ( obj [ " segmentation " ] )
is_box_mask . append ( 0 )
else :
masks . append ( self . get_box_mask ( bbox , image . size , mode = ' poly ' ) )
is_box_mask . append ( 1 )
masks = SegmentationMask ( masks , image . size , mode = ' poly ' )
is_box_mask = torch . tensor ( is_box_mask )
keypoints = None
if anno and " keypoints " in anno [ 0 ] :
keypoints = [ obj [ " keypoints " ] for obj in anno ]
keypoints = torch . as_tensor ( keypoints , dtype = torch . float32 )
num_keypoints = keypoints . shape [ 0 ]
if num_keypoints :
keypoints = keypoints . view ( num_keypoints , - 1 , 3 )
isfinal = None
if anno and " isfinal " in anno [ 0 ] :
isfinal = torch . as_tensor ( [ obj [ " isfinal " ] for obj in anno ] , dtype = torch . float )
tokens_positive = [ ] if self . return_tokens else None
if self . return_tokens and anno and " tokens " in anno [ 0 ] :
tokens_positive = [ obj [ " tokens " ] for obj in anno ]
elif self . return_tokens and anno and " tokens_positive " in anno [ 0 ] :
tokens_positive = [ obj [ " tokens_positive " ] for obj in anno ]
keep = ( boxes [ : , 3 ] > boxes [ : , 1 ] ) & ( boxes [ : , 2 ] > boxes [ : , 0 ] )
boxes = boxes [ keep ]
classes = classes [ keep ]
if self . return_masks :
masks = masks [ keep ]
is_box_mask = is_box_mask [ keep ]
if keypoints is not None :
keypoints = keypoints [ keep ]
target = { }
target [ " boxes " ] = boxes
target [ " labels " ] = classes
if caption is not None :
target [ " caption " ] = caption
if self . return_masks :
target [ " masks " ] = masks
target [ " is_box_mask " ] = is_box_mask
target [ " image_id " ] = image_id
if keypoints is not None :
target [ " keypoints " ] = keypoints
if tokens_positive is not None :
target [ " tokens_positive " ] = [ ]
for i , k in enumerate ( keep ) :
if k or ignore_box_screen :
target [ " tokens_positive " ] . append ( tokens_positive [ i ] )
if isfinal is not None :
target [ " isfinal " ] = isfinal
# for conversion to coco api
area = torch . tensor ( [ obj [ " area " ] for obj in anno ] )
iscrowd = torch . tensor ( [ obj [ " iscrowd " ] if " iscrowd " in obj else 0 for obj in anno ] )
target [ " area " ] = area [ keep ]
target [ " iscrowd " ] = iscrowd [ keep ]
target [ " orig_size " ] = torch . as_tensor ( [ int ( h ) , int ( w ) ] )
target [ " size " ] = torch . as_tensor ( [ int ( h ) , int ( w ) ] )
if self . return_tokens and self . tokenizer is not None :
if not ignore_box_screen :
assert len ( target [ " boxes " ] ) == len ( target [ " tokens_positive " ] )
tokenized = self . tokenizer ( caption , return_tensors = " pt " ,
max_length = self . max_query_len ,
truncation = True )
target [ " positive_map " ] = create_positive_map ( tokenized , target [ " tokens_positive " ] )
# target['greenlight_map'] = create_greenlight_map(greenlight_span_for_masked_lm_objective,tokenized)
# target["positive_map_for_od_labels"] = create_positive_map_for_od_labels(tokenized, label_to_positions)
all_tokens = [ [ v ] for k , v in label_to_positions_caption . items ( ) ]
target [ " all_map " ] = create_positive_map ( tokenized , all_tokens )
target [ " labels_in_caption " ] = [ k for k , v in label_to_positions_caption . items ( ) ]
pos_label_set = list ( set ( target [ ' labels ' ] . tolist ( ) ) )
pos_category_tokens = [ [ v ] for k , v in label_to_positions_caption . items ( ) if k in pos_label_set ]
target [ " positive_category_map " ] = create_positive_map ( tokenized , pos_category_tokens )
target [ " positive_category_map " ] [ target [ " positive_category_map " ] != 0 ] = 1
original_od_label = [ ]
for obj in anno :
original_od_label . append (
obj . get ( " original_od_label " , - 10 ) ) # NOTE: The padding value has to be not the same as -1 or -100
target [ " original_od_label " ] = torch . as_tensor ( original_od_label )
return image , target
def create_greenlight_map ( tok_list , tokenized ) :
# An example tok_list:
# [(0, 5), (10, 13), (-1, -1, -1)]
# The last one is a special indicator..
greenlight_map = torch . zeros ( 256 , dtype = torch . float )
for item in tok_list :
if len ( item ) != 2 :
assert ( len ( item ) == 3 )
# Make everything unmakable
greenlight_map [ : ] = - 1
break
beg , end = item
beg_pos = tokenized . char_to_token ( beg )
end_pos = tokenized . char_to_token ( end - 1 )
if beg_pos is None :
try :
beg_pos = tokenized . char_to_token ( beg + 1 )
if beg_pos is None :
beg_pos = tokenized . char_to_token ( beg + 2 )
except :
beg_pos = None
if end_pos is None :
try :
end_pos = tokenized . char_to_token ( end - 2 )
if end_pos is None :
end_pos = tokenized . char_to_token ( end - 3 )
except :
end_pos = None
if beg_pos is None or end_pos is None :
continue
assert beg_pos is not None and end_pos is not None
greenlight_map [ beg_pos : end_pos + 1 ] . fill_ ( 1 )
return greenlight_map
def create_positive_map_for_od_labels ( tokenized , label_to_positions ) :
""" construct a map such that positive_map[i] = j, where j is the object detection label of the token i """
"""
{ 3 : [ 1 : 5 ) }
256 : - 1 3 3 3 3 - 1 . . 8 8 . .
the woman in the garden
- 1 - 1 - 1 - 1 - 1
"""
positive_map = torch . ones ( 256 , dtype = torch . float ) * - 1 # -1 means no match
keys = list ( label_to_positions . keys ( ) )
for j , key in enumerate ( keys ) :
tok_list = label_to_positions [ key ]
# one label only mapps to one location
beg , end = tok_list
beg_pos = tokenized . char_to_token ( beg )
end_pos = tokenized . char_to_token ( end - 1 )
if beg_pos is None :
try :
beg_pos = tokenized . char_to_token ( beg + 1 )
if beg_pos is None :
beg_pos = tokenized . char_to_token ( beg + 2 )
except :
beg_pos = None
if end_pos is None :
try :
end_pos = tokenized . char_to_token ( end - 2 )
if end_pos is None :
end_pos = tokenized . char_to_token ( end - 3 )
except :
end_pos = None
if beg_pos is None or end_pos is None :
continue
assert beg_pos is not None and end_pos is not None
positive_map [ beg_pos : end_pos + 1 ] . fill_ ( key )
return positive_map
def convert_coco_poly_to_mask ( segmentations , height , width ) :
masks = [ ]
for polygons in segmentations :
rles = coco_mask . frPyObjects ( polygons , height , width )
mask = coco_mask . decode ( rles )
if len ( mask . shape ) < 3 :
mask = mask [ . . . , None ]
mask = torch . as_tensor ( mask , dtype = torch . uint8 )
mask = mask . any ( dim = 2 )
masks . append ( mask )
if masks :
masks = torch . stack ( masks , dim = 0 )
else :
masks = torch . zeros ( ( 0 , height , width ) , dtype = torch . uint8 )
return masks
def create_positive_map ( tokenized , tokens_positive ) :
""" construct a map such that positive_map[i,j] = True iff box i is associated to token j """
positive_map = torch . zeros ( ( len ( tokens_positive ) , 256 ) , dtype = torch . float )
for j , tok_list in enumerate ( tokens_positive ) :
for ( beg , end ) in tok_list :
beg_pos = tokenized . char_to_token ( beg )
end_pos = tokenized . char_to_token ( end - 1 )
if beg_pos is None :
try :
beg_pos = tokenized . char_to_token ( beg + 1 )
if beg_pos is None :
beg_pos = tokenized . char_to_token ( beg + 2 )
except :
beg_pos = None
if end_pos is None :
try :
end_pos = tokenized . char_to_token ( end - 2 )
if end_pos is None :
end_pos = tokenized . char_to_token ( end - 3 )
except :
end_pos = None
if beg_pos is None or end_pos is None :
continue
assert beg_pos is not None and end_pos is not None
positive_map [ j , beg_pos : end_pos + 1 ] . fill_ ( 1 )
return positive_map / ( positive_map . sum ( - 1 ) [ : , None ] + 1e-6 )
def pil_loader ( path , retry = 5 ) :
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
ri = 0
while ri < retry :
try :
with open ( path , ' rb ' ) as f :
img = Image . open ( f )
return img . convert ( ' RGB ' )
except :
ri + = 1