mirror of https://github.com/YifanXu74/MQ-Det.git
fix query extraction
parent
e154671cc7
commit
e12a5f3bac
|
@ -16,10 +16,10 @@ from .caption import CaptionTSV
|
|||
from .lvis import LvisDetection
|
||||
from .pseudo_data import PseudoData
|
||||
from .phrasecut import PhrasecutDetection
|
||||
# from .modulated_coco_new import CocoGrounding_New
|
||||
from .modulated_coco_new import CocoGrounding_New
|
||||
|
||||
__all__ = ["COCODataset", "TSVDataset", "ODTSVDataset", "ConcatDataset", "PascalVOCDataset", "Background",
|
||||
"ModulatedDataset", "MixedDataset", "CocoDetection", "FlickrDataset", "RefExpDataset", "GQADataset",
|
||||
"CocoDetectionTSV", "CocoGrounding", "CaptionTSV", "LvisDetection", "PseudoData", "PhrasecutDetection",
|
||||
# "CocoGrounding_New"
|
||||
"CocoGrounding_New",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,721 @@
|
|||
# Delete some ununsed functions from modulated_coco.
|
||||
# Suit for object365 pre-training
|
||||
import logging
|
||||
import os
|
||||
import os.path
|
||||
import math
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torch.utils.data as data
|
||||
from pycocotools import mask as coco_mask
|
||||
|
||||
from maskrcnn_benchmark.structures.bounding_box import BoxList
|
||||
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
|
||||
from maskrcnn_benchmark.data.datasets.coco import has_valid_annotation
|
||||
from .od_to_grounding import convert_od_to_grounding_simple, check_for_positive_overflow, sanity_check_target_after_processing, convert_object_detection_to_grounding_optimized_for_od
|
||||
import pdb
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from groundingdino_new.util.inference import preprocess_caption
|
||||
from groundingdino_new.util.box_ops import box_xyxy_to_cxcywh
|
||||
|
||||
import copy
|
||||
|
||||
def _has_only_crowd_bbox(anno):
|
||||
return all(obj["iscrowd"] == 1 for obj in anno)
|
||||
|
||||
class CocoGrounding_New(torchvision.datasets.CocoDetection):
|
||||
def __init__(self,
|
||||
img_folder,
|
||||
ann_file,
|
||||
transforms,
|
||||
return_masks,
|
||||
return_tokens,
|
||||
is_train=False,
|
||||
tokenizer=None,
|
||||
disable_shuffle=False,
|
||||
add_detection_prompt=False,
|
||||
add_detection_prompt_advanced=False,
|
||||
control_probabilities={},
|
||||
one_hot=False,
|
||||
disable_clip_to_image=False,
|
||||
no_minus_one_for_one_hot=False,
|
||||
separation_tokens=" ",
|
||||
few_shot=0,
|
||||
no_mask_for_od=False,
|
||||
override_category=None,
|
||||
use_caption_prompt=False,
|
||||
caption_prompt=None,
|
||||
max_num_labels=-1,
|
||||
max_query_len=256,
|
||||
special_safeguard_for_coco_grounding=False,
|
||||
random_sample_negative=-1,
|
||||
cumtom_ids=None,
|
||||
exclude_crowd=False,
|
||||
sep_at_last = False,
|
||||
add_normed_cxcy = False,
|
||||
custom_category_ids=None,
|
||||
**kwargs
|
||||
):
|
||||
super(CocoGrounding_New, self).__init__(img_folder, ann_file)
|
||||
self.ids = sorted(self.ids)
|
||||
|
||||
self.sep_at_last = sep_at_last
|
||||
self.add_normed_cxcy = add_normed_cxcy
|
||||
|
||||
self.iscrowd = False if exclude_crowd else None
|
||||
|
||||
ids = []
|
||||
for img_id in self.ids:
|
||||
if isinstance(img_id, str):
|
||||
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=self.iscrowd)
|
||||
else:
|
||||
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=self.iscrowd)
|
||||
anno = self.coco.loadAnns(ann_ids)
|
||||
|
||||
if has_valid_annotation(anno):
|
||||
ids.append(img_id)
|
||||
|
||||
self.ids = ids
|
||||
|
||||
# self.ids = self.remove_invalid_images(self.ids)
|
||||
|
||||
if few_shot:
|
||||
ids = []
|
||||
# cats_freq = [few_shot]*len(self.coco.cats.keys())
|
||||
cats_freq = [few_shot]*max(list(self.coco.cats.keys()))
|
||||
for img_id in self.ids:
|
||||
if isinstance(img_id, str):
|
||||
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=self.iscrowd)
|
||||
else:
|
||||
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=self.iscrowd)
|
||||
anno = self.coco.loadAnns(ann_ids)
|
||||
cat = set([ann['category_id'] for ann in anno]) #set/tuple corresponde to instance/image level
|
||||
is_needed = sum([cats_freq[c-1]>0 for c in cat])
|
||||
if is_needed:
|
||||
ids.append(img_id)
|
||||
for c in cat:
|
||||
cats_freq[c-1] -= 1
|
||||
# print(cat, cats_freq)
|
||||
self.ids = ids
|
||||
if cumtom_ids is not None:
|
||||
self.ids = cumtom_ids
|
||||
|
||||
if custom_category_ids is not None:
|
||||
new_ids = []
|
||||
for img_id in self.ids:
|
||||
if isinstance(img_id, str):
|
||||
ann_ids = self.coco.getAnnIds(imgIds=[img_id], catIds=custom_category_ids, iscrowd=self.iscrowd)
|
||||
else:
|
||||
ann_ids = self.coco.getAnnIds(imgIds=img_id, catIds=custom_category_ids, iscrowd=self.iscrowd)
|
||||
if len(ann_ids) > 0:
|
||||
new_ids.append(img_id)
|
||||
self.ids = new_ids
|
||||
|
||||
|
||||
self.json_category_id_to_contiguous_id = {
|
||||
v: i + 1 for i, v in enumerate(self.coco.getCatIds())
|
||||
}
|
||||
self.contiguous_category_id_to_json_id = {
|
||||
v: k for k, v in self.json_category_id_to_contiguous_id.items()
|
||||
}
|
||||
|
||||
if override_category is not None:
|
||||
self.coco.dataset["categories"] = override_category
|
||||
self.max_num_labels=max_num_labels
|
||||
self.control_probabilities=control_probabilities
|
||||
self.use_caption_prompt = use_caption_prompt
|
||||
self.caption_prompt = caption_prompt
|
||||
self.special_safeguard_for_coco_grounding = special_safeguard_for_coco_grounding
|
||||
self.random_sample_negative = random_sample_negative
|
||||
self.ind_to_class = self.categories(no_background=False)
|
||||
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
|
||||
self._transforms = transforms
|
||||
self.max_query_len = max_query_len
|
||||
self.prepare = ConvertCocoPolysToMask(False, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len, ind_to_class=self.ind_to_class)
|
||||
self.tokenizer = tokenizer
|
||||
self.is_train = is_train
|
||||
|
||||
self.ind_to_class = self.categories(no_background=False)
|
||||
|
||||
self.disable_shuffle = disable_shuffle
|
||||
self.add_detection_prompt = add_detection_prompt
|
||||
self.add_detection_prompt_advanced=add_detection_prompt_advanced
|
||||
self.one_hot = one_hot
|
||||
self.no_minus_one_for_one_hot = no_minus_one_for_one_hot
|
||||
|
||||
self.disable_clip_to_image = disable_clip_to_image
|
||||
self.separation_tokens = separation_tokens
|
||||
self.no_mask_for_od = no_mask_for_od
|
||||
self.return_masks = return_masks
|
||||
|
||||
def remove_invalid_images(self, ids):
|
||||
print('removing non-exist images from dataset...')
|
||||
new_ids=[]
|
||||
invalid_num=0
|
||||
for id in tqdm(ids):
|
||||
path = self.coco.loadImgs(id)[0]["file_name"]
|
||||
if os.path.exists(os.path.join(self.root, path)):
|
||||
new_ids.append(id)
|
||||
else:
|
||||
invalid_num+=1
|
||||
print('removed {} non-exist images from dataset'.format(invalid_num))
|
||||
return new_ids
|
||||
|
||||
def categories(self, no_background=True):
|
||||
categories = self.coco.dataset["categories"]
|
||||
label_list = {}
|
||||
for index, i in enumerate(categories):
|
||||
# assert(index + 1 == i["id"])
|
||||
if not no_background or (i["name"] != "__background__" and i['id'] != 0):
|
||||
label_list[self.json_category_id_to_contiguous_id[i["id"]]] = i["name"]
|
||||
return label_list
|
||||
|
||||
def get_box_mask(self, rect, img_size, mode="poly"):
|
||||
assert mode=="poly", "Only support poly mask right now!"
|
||||
x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3]
|
||||
return [[x1, y1, x1, y2, x2, y2, x2, y1]]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img, tgt = super(CocoGrounding_New, self).__getitem__(idx)
|
||||
image_id = self.ids[idx]
|
||||
tgt = [obj for obj in tgt if obj["iscrowd"] == 0]
|
||||
boxes = [obj["bbox"] for obj in tgt]
|
||||
boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes
|
||||
target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")
|
||||
classes = [obj["category_id"] for obj in tgt]
|
||||
classes = [self.json_category_id_to_contiguous_id[c] for c in classes]
|
||||
classes = torch.tensor(classes)
|
||||
target.add_field("labels", classes)
|
||||
|
||||
if not self.disable_clip_to_image:
|
||||
target = target.clip_to_image(remove_empty=True)
|
||||
|
||||
if self.special_safeguard_for_coco_grounding:
|
||||
# Intended for LVIS and Object365
|
||||
assert(not self.use_caption_prompt)
|
||||
|
||||
original_box_num = len(target)
|
||||
target, positive_caption_length = check_for_positive_overflow(target, self.ind_to_class, self.tokenizer, self.max_query_len-2) # leave some space for the special tokens
|
||||
if len(target) < original_box_num:
|
||||
print("WARNING: removed {} boxes due to positive caption overflow".format(original_box_num - len(target)))
|
||||
|
||||
annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_object_detection_to_grounding_optimized_for_od(
|
||||
target=target,
|
||||
image_id=image_id,
|
||||
ind_to_class=self.ind_to_class,
|
||||
disable_shuffle=self.disable_shuffle,
|
||||
add_detection_prompt=self.add_detection_prompt,
|
||||
add_detection_prompt_advanced=self.add_detection_prompt_advanced,
|
||||
random_sample_negative=self.random_sample_negative,
|
||||
control_probabilities=self.control_probabilities, # always try to add a lot of negatives
|
||||
restricted_negative_list=None,
|
||||
separation_tokens=self.separation_tokens,
|
||||
max_num_labels=self.max_num_labels,
|
||||
positive_caption_length=positive_caption_length,
|
||||
tokenizer=self.tokenizer,
|
||||
max_seq_length=self.max_query_len-2,
|
||||
obj356_debug=True
|
||||
)
|
||||
else:
|
||||
# Intended for COCO / ODinW
|
||||
annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_od_to_grounding_simple(
|
||||
target=target,
|
||||
image_id=image_id,
|
||||
ind_to_class=self.ind_to_class,
|
||||
disable_shuffle=self.disable_shuffle,
|
||||
add_detection_prompt=self.add_detection_prompt,
|
||||
separation_tokens=self.separation_tokens,
|
||||
caption_prompt=self.caption_prompt if self.use_caption_prompt else None,
|
||||
)
|
||||
|
||||
# if self.sep_at_last:
|
||||
# caption = preprocess_caption(caption)
|
||||
anno = {"image_id": image_id, "annotations": annotations, "caption": caption, "label_to_positions_caption": label_to_positions}
|
||||
anno["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective
|
||||
if self.no_mask_for_od:
|
||||
anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1))
|
||||
img, anno = self.prepare(img, anno, box_format="xyxy")
|
||||
|
||||
# for equivalence check
|
||||
if self.one_hot:
|
||||
logging.info("using one hot for equivalence check.")
|
||||
one_hot_map = torch.zeros_like(anno["positive_map"], dtype=torch.float)
|
||||
text_mask = torch.zeros(anno["positive_map"].shape[1], dtype=torch.int64)
|
||||
# create one hot mapping
|
||||
for ii, cls in enumerate(classes):
|
||||
if self.no_minus_one_for_one_hot:
|
||||
one_hot_map[ii, cls] = 1.0
|
||||
else:
|
||||
one_hot_map[ii, cls - 1] = 1.0
|
||||
if self.no_minus_one_for_one_hot:
|
||||
text_mask[:] = 1
|
||||
else:
|
||||
text_mask[:len(self.ind_to_class)] = 1
|
||||
anno["positive_map"] = one_hot_map
|
||||
anno["text_mask"] = text_mask
|
||||
|
||||
if self._transforms is not None:
|
||||
img, target = self._transforms(img, target)
|
||||
|
||||
# add additional property
|
||||
for ann in anno:
|
||||
target.add_field(ann, anno[ann])
|
||||
|
||||
if self.add_normed_cxcy:
|
||||
bbox = target.bbox
|
||||
H, W = target.size
|
||||
normed_bbox = bbox / torch.Tensor([[H,W,H,W]])
|
||||
normed_cxcy = box_xyxy_to_cxcywh(normed_bbox)
|
||||
target.add_field('normed_cxcy_boxes', normed_cxcy)
|
||||
|
||||
sanity_check_target_after_processing(target)
|
||||
|
||||
return img, target, idx
|
||||
|
||||
def get_img_info(self, index):
|
||||
img_id = self.id_to_img_map[index]
|
||||
img_data = self.coco.imgs[img_id]
|
||||
return img_data
|
||||
|
||||
def get_raw_image(self, idx):
|
||||
image, *_ = super(CocoGrounding_New, self).__getitem__(idx)
|
||||
return image
|
||||
|
||||
|
||||
class ModulatedDataset(torchvision.datasets.CocoDetection):
|
||||
def __init__(self,
|
||||
img_folder,
|
||||
ann_file,
|
||||
transforms,
|
||||
return_masks,
|
||||
return_tokens,
|
||||
is_train=False,
|
||||
tokenizer=None,
|
||||
disable_clip_to_image=False,
|
||||
no_mask_for_gold=False,
|
||||
max_query_len=256,
|
||||
**kwargs):
|
||||
super(ModulatedDataset, self).__init__(img_folder, ann_file)
|
||||
self.ids = sorted(self.ids)
|
||||
|
||||
ids = []
|
||||
for img_id in self.ids:
|
||||
if isinstance(img_id, str):
|
||||
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
|
||||
else:
|
||||
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
||||
anno = self.coco.loadAnns(ann_ids)
|
||||
if has_valid_annotation(anno):
|
||||
ids.append(img_id)
|
||||
self.ids = ids
|
||||
|
||||
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
|
||||
self._transforms = transforms
|
||||
self.max_query_len = max_query_len
|
||||
self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len)
|
||||
self.is_train = is_train
|
||||
self.disable_clip_to_image = disable_clip_to_image
|
||||
self.no_mask_for_gold = no_mask_for_gold
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img, target = super(ModulatedDataset, self).__getitem__(idx)
|
||||
image_id = self.ids[idx]
|
||||
coco_img = self.coco.loadImgs(image_id)[0]
|
||||
caption = coco_img["caption"]
|
||||
dataset_name = coco_img["dataset_name"] if "dataset_name" in coco_img else None
|
||||
anno = {"image_id": image_id, "annotations": target, "caption": caption}
|
||||
|
||||
# This dataset is used for Flickr & Mixed, so the sequence is maskable
|
||||
anno["greenlight_span_for_masked_lm_objective"] = [(0, len(caption))]
|
||||
if self.no_mask_for_gold:
|
||||
anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1))
|
||||
img, anno = self.prepare(img, anno)
|
||||
|
||||
# convert to BoxList (bboxes, labels)
|
||||
boxes = torch.as_tensor(anno["boxes"]).reshape(-1, 4) # guard against no boxes
|
||||
target = BoxList(boxes, img.size, mode="xyxy")
|
||||
classes = anno["labels"]
|
||||
target.add_field("labels", classes)
|
||||
if self.prepare.return_masks:
|
||||
target.add_field("masks", anno.pop("masks"))
|
||||
target.add_field("is_box_mask", anno.pop("is_box_mask"))
|
||||
if not self.disable_clip_to_image:
|
||||
num_boxes = len(target.bbox)
|
||||
target = target.clip_to_image(remove_empty=True)
|
||||
assert num_boxes == len(target.bbox), "Box got removed in MixedDataset!!!"
|
||||
|
||||
# Check if bboxes are correct
|
||||
# draw = ImageDraw.Draw(img)
|
||||
# boxes = target.bbox
|
||||
# for box in boxes:
|
||||
# draw.rectangle([box[0], box[1], box[2], box[3]])
|
||||
# img.save('OUTPUT/images/{}.jpg'.format(idx))
|
||||
|
||||
if self._transforms is not None:
|
||||
img, target = self._transforms(img, target)
|
||||
|
||||
# add additional property
|
||||
for ann in anno:
|
||||
target.add_field(ann, anno[ann])
|
||||
|
||||
target.add_field("dataset_name", dataset_name)
|
||||
for extra_key in ["sentence_id", "original_img_id", "original_id", "task_id"]:
|
||||
if extra_key in coco_img:
|
||||
target.add_field(extra_key, coco_img[extra_key])
|
||||
|
||||
if "tokens_positive_eval" in coco_img and not self.is_train:
|
||||
tokenized = self.prepare.tokenizer(caption, return_tensors="pt")
|
||||
target.add_field("positive_map_eval", create_positive_map(tokenized, coco_img["tokens_positive_eval"]))
|
||||
target.add_field("nb_eval", len(target.get_field("positive_map_eval")))
|
||||
|
||||
sanity_check_target_after_processing(target)
|
||||
return img, target, idx
|
||||
|
||||
def get_img_info(self, index):
|
||||
img_id = self.id_to_img_map[index]
|
||||
img_data = self.coco.imgs[img_id]
|
||||
return img_data
|
||||
|
||||
|
||||
class CocoDetection(data.Dataset):
|
||||
"""`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory where images are downloaded to.
|
||||
annFile (string): Path to json annotation file.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.ToTensor``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
"""
|
||||
|
||||
def __init__(self, root, annFile, transform=None, target_transform=None):
|
||||
from pycocotools.coco import COCO
|
||||
self.root = root
|
||||
self.coco = COCO(annFile)
|
||||
self.ids = list(self.coco.imgs.keys())
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
|
||||
def __getitem__(self, index, return_meta=False):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
|
||||
"""
|
||||
coco = self.coco
|
||||
img_id = self.ids[index]
|
||||
if isinstance(img_id, str):
|
||||
img_id = [img_id]
|
||||
ann_ids = coco.getAnnIds(imgIds=img_id)
|
||||
target = coco.loadAnns(ann_ids)
|
||||
|
||||
meta = coco.loadImgs(img_id)[0]
|
||||
path = meta['file_name']
|
||||
img = pil_loader(os.path.join(self.root, path))
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
if return_meta:
|
||||
return img, target, meta
|
||||
else:
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ids)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
||||
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
|
||||
fmt_str += ' Root Location: {}\n'.format(self.root)
|
||||
tmp = ' Transforms (if any): '
|
||||
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
tmp = ' Target Transforms (if any): '
|
||||
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
return fmt_str
|
||||
|
||||
|
||||
class ConvertCocoPolysToMask(object):
|
||||
def __init__(self, return_masks=False, return_tokens=False, tokenizer=None, max_query_len=256, ind_to_class=None):
|
||||
self.return_masks = return_masks
|
||||
self.return_tokens = return_tokens
|
||||
self.tokenizer = tokenizer
|
||||
self.max_query_len = max_query_len
|
||||
self.ind_to_class=ind_to_class
|
||||
|
||||
def get_box_mask(self, rect, img_size, mode="poly"):
|
||||
assert mode=="poly", "Only support poly mask right now!"
|
||||
x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3]
|
||||
return [[x1, y1, x1, y2, x2, y2, x2, y1]]
|
||||
|
||||
def __call__(self, image, target, ignore_box_screen=False, box_format="xywh"):
|
||||
w, h = image.size
|
||||
|
||||
image_id = target["image_id"]
|
||||
image_id = torch.tensor([image_id])
|
||||
|
||||
anno = target["annotations"]
|
||||
caption = target["caption"] if "caption" in target else None
|
||||
label_to_positions = target.get("label_to_positions", {})
|
||||
label_to_positions_caption = target.get("label_to_positions_caption", {})
|
||||
|
||||
greenlight_span_for_masked_lm_objective = target.get("greenlight_span_for_masked_lm_objective", None)
|
||||
|
||||
anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0]
|
||||
|
||||
boxes = [obj["bbox"] for obj in anno]
|
||||
# guard against no boxes via resizing
|
||||
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
||||
if box_format == "xywh":
|
||||
boxes[:, 2:] += boxes[:, :2] - 1 # TO_REMOVE = 1
|
||||
boxes[:, 0::2].clamp_(min=0, max=w-1) # TO_REMOVE = 1
|
||||
boxes[:, 1::2].clamp_(min=0, max=h-1) # TO_REMOVE = 1
|
||||
|
||||
classes = [obj["category_id"] for obj in anno]
|
||||
classes = torch.tensor(classes, dtype=torch.int64)
|
||||
|
||||
if self.return_masks:
|
||||
masks = []
|
||||
is_box_mask = []
|
||||
for obj, bbox in zip(anno, boxes):
|
||||
if "segmentation" in obj:
|
||||
masks.append(obj["segmentation"])
|
||||
is_box_mask.append(0)
|
||||
else:
|
||||
masks.append(self.get_box_mask(bbox, image.size, mode='poly'))
|
||||
is_box_mask.append(1)
|
||||
masks = SegmentationMask(masks, image.size, mode='poly')
|
||||
is_box_mask = torch.tensor(is_box_mask)
|
||||
|
||||
keypoints = None
|
||||
if anno and "keypoints" in anno[0]:
|
||||
keypoints = [obj["keypoints"] for obj in anno]
|
||||
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
|
||||
num_keypoints = keypoints.shape[0]
|
||||
if num_keypoints:
|
||||
keypoints = keypoints.view(num_keypoints, -1, 3)
|
||||
|
||||
isfinal = None
|
||||
if anno and "isfinal" in anno[0]:
|
||||
isfinal = torch.as_tensor([obj["isfinal"] for obj in anno], dtype=torch.float)
|
||||
|
||||
tokens_positive = [] if self.return_tokens else None
|
||||
if self.return_tokens and anno and "tokens" in anno[0]:
|
||||
tokens_positive = [obj["tokens"] for obj in anno]
|
||||
elif self.return_tokens and anno and "tokens_positive" in anno[0]:
|
||||
tokens_positive = [obj["tokens_positive"] for obj in anno]
|
||||
|
||||
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
||||
boxes = boxes[keep]
|
||||
classes = classes[keep]
|
||||
if self.return_masks:
|
||||
masks = masks[keep]
|
||||
is_box_mask = is_box_mask[keep]
|
||||
if keypoints is not None:
|
||||
keypoints = keypoints[keep]
|
||||
|
||||
target = {}
|
||||
target["boxes"] = boxes
|
||||
target["labels"] = classes
|
||||
if caption is not None:
|
||||
target["caption"] = caption
|
||||
if self.return_masks:
|
||||
target["masks"] = masks
|
||||
target["is_box_mask"] = is_box_mask
|
||||
target["image_id"] = image_id
|
||||
if keypoints is not None:
|
||||
target["keypoints"] = keypoints
|
||||
|
||||
if tokens_positive is not None:
|
||||
target["tokens_positive"] = []
|
||||
|
||||
for i, k in enumerate(keep):
|
||||
if k or ignore_box_screen:
|
||||
target["tokens_positive"].append(tokens_positive[i])
|
||||
|
||||
if isfinal is not None:
|
||||
target["isfinal"] = isfinal
|
||||
|
||||
# for conversion to coco api
|
||||
area = torch.tensor([obj["area"] for obj in anno])
|
||||
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
|
||||
target["area"] = area[keep]
|
||||
target["iscrowd"] = iscrowd[keep]
|
||||
|
||||
target["orig_size"] = torch.as_tensor([int(h), int(w)])
|
||||
target["size"] = torch.as_tensor([int(h), int(w)])
|
||||
|
||||
if self.return_tokens and self.tokenizer is not None:
|
||||
if not ignore_box_screen:
|
||||
assert len(target["boxes"]) == len(target["tokens_positive"])
|
||||
|
||||
tokenized = self.tokenizer(caption, return_tensors="pt",
|
||||
max_length=self.max_query_len,
|
||||
truncation=True)
|
||||
|
||||
target["positive_map"] = create_positive_map(tokenized, target["tokens_positive"])
|
||||
# target['greenlight_map'] = create_greenlight_map(greenlight_span_for_masked_lm_objective,tokenized)
|
||||
# target["positive_map_for_od_labels"] = create_positive_map_for_od_labels(tokenized, label_to_positions)
|
||||
|
||||
all_tokens = [[v] for k,v in label_to_positions_caption.items()]
|
||||
target["all_map"] = create_positive_map(tokenized, all_tokens)
|
||||
target["labels_in_caption"]=[k for k,v in label_to_positions_caption.items()]
|
||||
|
||||
pos_label_set=list(set(target['labels'].tolist()))
|
||||
pos_category_tokens = [[v] for k,v in label_to_positions_caption.items() if k in pos_label_set]
|
||||
target["positive_category_map"] = create_positive_map(tokenized, pos_category_tokens)
|
||||
target["positive_category_map"][target["positive_category_map"]!=0]=1
|
||||
|
||||
|
||||
original_od_label = []
|
||||
for obj in anno:
|
||||
original_od_label.append(
|
||||
obj.get("original_od_label", -10)) # NOTE: The padding value has to be not the same as -1 or -100
|
||||
target["original_od_label"] = torch.as_tensor(original_od_label)
|
||||
|
||||
return image, target
|
||||
|
||||
def create_greenlight_map(tok_list, tokenized):
|
||||
# An example tok_list:
|
||||
# [(0, 5), (10, 13), (-1, -1, -1)]
|
||||
# The last one is a special indicator..
|
||||
|
||||
greenlight_map = torch.zeros(256, dtype=torch.float)
|
||||
for item in tok_list:
|
||||
if len(item) != 2:
|
||||
assert(len(item) == 3)
|
||||
# Make everything unmakable
|
||||
greenlight_map[:] = -1
|
||||
break
|
||||
|
||||
beg, end = item
|
||||
beg_pos = tokenized.char_to_token(beg)
|
||||
end_pos = tokenized.char_to_token(end - 1)
|
||||
if beg_pos is None:
|
||||
try:
|
||||
beg_pos = tokenized.char_to_token(beg + 1)
|
||||
if beg_pos is None:
|
||||
beg_pos = tokenized.char_to_token(beg + 2)
|
||||
except:
|
||||
beg_pos = None
|
||||
if end_pos is None:
|
||||
try:
|
||||
end_pos = tokenized.char_to_token(end - 2)
|
||||
if end_pos is None:
|
||||
end_pos = tokenized.char_to_token(end - 3)
|
||||
except:
|
||||
end_pos = None
|
||||
if beg_pos is None or end_pos is None:
|
||||
continue
|
||||
|
||||
assert beg_pos is not None and end_pos is not None
|
||||
greenlight_map[beg_pos: end_pos + 1].fill_(1)
|
||||
return greenlight_map
|
||||
|
||||
|
||||
def create_positive_map_for_od_labels(tokenized, label_to_positions):
|
||||
"""construct a map such that positive_map[i] = j, where j is the object detection label of the token i"""
|
||||
"""
|
||||
{3: [1: 5)}
|
||||
256 : -1 3 3 3 3 -1 .. 8 8 ..
|
||||
the woman in the garden
|
||||
-1 -1 -1 -1 -1
|
||||
"""
|
||||
positive_map = torch.ones(256, dtype=torch.float) * -1 # -1 means no match
|
||||
keys = list(label_to_positions.keys())
|
||||
for j, key in enumerate(keys):
|
||||
tok_list = label_to_positions[key]
|
||||
# one label only mapps to one location
|
||||
beg, end = tok_list
|
||||
beg_pos = tokenized.char_to_token(beg)
|
||||
end_pos = tokenized.char_to_token(end - 1)
|
||||
if beg_pos is None:
|
||||
try:
|
||||
beg_pos = tokenized.char_to_token(beg + 1)
|
||||
if beg_pos is None:
|
||||
beg_pos = tokenized.char_to_token(beg + 2)
|
||||
except:
|
||||
beg_pos = None
|
||||
if end_pos is None:
|
||||
try:
|
||||
end_pos = tokenized.char_to_token(end - 2)
|
||||
if end_pos is None:
|
||||
end_pos = tokenized.char_to_token(end - 3)
|
||||
except:
|
||||
end_pos = None
|
||||
if beg_pos is None or end_pos is None:
|
||||
continue
|
||||
assert beg_pos is not None and end_pos is not None
|
||||
positive_map[beg_pos: end_pos + 1].fill_(key)
|
||||
return positive_map
|
||||
|
||||
|
||||
def convert_coco_poly_to_mask(segmentations, height, width):
|
||||
masks = []
|
||||
for polygons in segmentations:
|
||||
rles = coco_mask.frPyObjects(polygons, height, width)
|
||||
mask = coco_mask.decode(rles)
|
||||
if len(mask.shape) < 3:
|
||||
mask = mask[..., None]
|
||||
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
||||
mask = mask.any(dim=2)
|
||||
masks.append(mask)
|
||||
if masks:
|
||||
masks = torch.stack(masks, dim=0)
|
||||
else:
|
||||
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
||||
return masks
|
||||
|
||||
|
||||
def create_positive_map(tokenized, tokens_positive):
|
||||
"""construct a map such that positive_map[i,j] = True iff box i is associated to token j"""
|
||||
positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float)
|
||||
|
||||
for j, tok_list in enumerate(tokens_positive):
|
||||
for (beg, end) in tok_list:
|
||||
beg_pos = tokenized.char_to_token(beg)
|
||||
end_pos = tokenized.char_to_token(end - 1)
|
||||
if beg_pos is None:
|
||||
try:
|
||||
beg_pos = tokenized.char_to_token(beg + 1)
|
||||
if beg_pos is None:
|
||||
beg_pos = tokenized.char_to_token(beg + 2)
|
||||
except:
|
||||
beg_pos = None
|
||||
if end_pos is None:
|
||||
try:
|
||||
end_pos = tokenized.char_to_token(end - 2)
|
||||
if end_pos is None:
|
||||
end_pos = tokenized.char_to_token(end - 3)
|
||||
except:
|
||||
end_pos = None
|
||||
if beg_pos is None or end_pos is None:
|
||||
continue
|
||||
|
||||
assert beg_pos is not None and end_pos is not None
|
||||
positive_map[j, beg_pos: end_pos + 1].fill_(1)
|
||||
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
|
||||
|
||||
|
||||
def pil_loader(path, retry=5):
|
||||
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
||||
ri = 0
|
||||
while ri < retry:
|
||||
try:
|
||||
with open(path, 'rb') as f:
|
||||
img = Image.open(f)
|
||||
return img.convert('RGB')
|
||||
except:
|
||||
ri += 1
|
|
@ -1,362 +0,0 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
import datetime
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from maskrcnn_benchmark.utils.comm import get_world_size, all_gather, is_main_process, broadcast_data, get_rank
|
||||
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
|
||||
from maskrcnn_benchmark.utils.ema import ModelEma
|
||||
from maskrcnn_benchmark.utils.amp import autocast, GradScaler
|
||||
from maskrcnn_benchmark.data.datasets.evaluation import evaluate
|
||||
from .inference import inference
|
||||
import pdb
|
||||
|
||||
def reduce_loss_dict(loss_dict):
|
||||
"""
|
||||
Reduce the loss dictionary from all processes so that process with rank
|
||||
0 has the averaged results. Returns a dict with the same fields as
|
||||
loss_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return loss_dict
|
||||
with torch.no_grad():
|
||||
loss_names = []
|
||||
all_losses = []
|
||||
for k in sorted(loss_dict.keys()):
|
||||
loss_names.append(k)
|
||||
all_losses.append(loss_dict[k])
|
||||
all_losses = torch.stack(all_losses, dim=0)
|
||||
dist.reduce(all_losses, dst=0)
|
||||
if dist.get_rank() == 0:
|
||||
# only main process gets accumulated, so only divide by
|
||||
# world_size in this case
|
||||
all_losses /= world_size
|
||||
reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
|
||||
return reduced_losses
|
||||
|
||||
|
||||
def do_train(
|
||||
cfg,
|
||||
model,
|
||||
data_loader,
|
||||
optimizer,
|
||||
scheduler,
|
||||
checkpointer,
|
||||
device,
|
||||
checkpoint_period,
|
||||
arguments,
|
||||
val_data_loader=None,
|
||||
meters=None,
|
||||
zero_shot=False,
|
||||
):
|
||||
logger = logging.getLogger("maskrcnn_benchmark.trainer")
|
||||
logger.info("Start training")
|
||||
# meters = MetricLogger(delimiter=" ")
|
||||
max_iter = len(data_loader)
|
||||
start_iter = arguments["iteration"]
|
||||
model.train()
|
||||
model_ema = None
|
||||
if cfg.SOLVER.MODEL_EMA > 0:
|
||||
model_ema = ModelEma(model, decay=cfg.SOLVER.MODEL_EMA)
|
||||
start_training_time = time.time()
|
||||
end = time.time()
|
||||
|
||||
if cfg.SOLVER.USE_AMP:
|
||||
scaler = GradScaler()
|
||||
|
||||
global_rank = get_rank()
|
||||
|
||||
if cfg.SOLVER.CHECKPOINT_PER_EPOCH != -1 and cfg.SOLVER.MAX_EPOCH >= 1:
|
||||
checkpoint_period = len(data_loader) // cfg.SOLVER.CHECKPOINT_PER_EPOCH // cfg.SOLVER.MAX_EPOCH
|
||||
|
||||
if global_rank <= 0 and cfg.SOLVER.MAX_EPOCH >= 1:
|
||||
print("Iter per epoch ", len(data_loader) // cfg.SOLVER.MAX_EPOCH )
|
||||
|
||||
if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1:
|
||||
patience_counter = 0
|
||||
previous_best = 0.0
|
||||
|
||||
# Adapt the weight decay
|
||||
if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, 'milestones'):
|
||||
milestone_target = 0
|
||||
for i, milstone in enumerate(list(scheduler.milestones)):
|
||||
if scheduler.last_epoch >= milstone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO:
|
||||
milestone_target = i+1
|
||||
for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter):
|
||||
nnegative = sum(len(target) < 1 for target in targets)
|
||||
nsample = len(targets)
|
||||
if nsample == nnegative or nnegative > nsample * cfg.SOLVER.MAX_NEG_PER_BATCH:
|
||||
logger.info('[WARNING] Sampled {} negative in {} in a batch, greater the allowed ratio {}, skip'.
|
||||
format(nnegative, nsample, cfg.SOLVER.MAX_NEG_PER_BATCH))
|
||||
continue
|
||||
|
||||
data_time = time.time() - end
|
||||
iteration = iteration + 1
|
||||
arguments["iteration"] = iteration
|
||||
|
||||
images = images.to(device)
|
||||
captions = None
|
||||
try:
|
||||
targets = [target.to(device) for target in targets]
|
||||
captions = [t.get_field("caption") for t in targets if "caption" in t.fields()]
|
||||
except:
|
||||
pass
|
||||
# Freeze language backbone
|
||||
if cfg.MODEL.LANGUAGE_BACKBONE.FREEZE:
|
||||
if hasattr(model, "module"):
|
||||
model.module.language_backbone.eval()
|
||||
else:
|
||||
model.language_backbone.eval()
|
||||
|
||||
if cfg.SOLVER.USE_AMP:
|
||||
with autocast():
|
||||
if len(captions) > 0:
|
||||
loss_dict = model(images, targets, captions=captions, positive_map=positive_map, greenlight_map = greenlight_map)
|
||||
else:
|
||||
loss_dict = model(images, targets)
|
||||
losses = sum(loss for loss in loss_dict.values())
|
||||
|
||||
# save checkpoints for further debug if nan happens
|
||||
# loss_value = losses.item()
|
||||
# if not math.isfinite(loss_value):
|
||||
# logging.error(f'=> loss is {loss_value}, stopping training')
|
||||
# logging.error("Losses are : {}".format(loss_dict))
|
||||
# time_str = time.strftime('%Y-%m-%d-%H-%M')
|
||||
# fname = os.path.join(checkpointer.save_dir, f'{time_str}_states.pth')
|
||||
# logging.info(f'=> save error state to {fname}')
|
||||
# dict_to_save = {
|
||||
# 'x': images,
|
||||
# 'y': targets,
|
||||
# 'loss': losses,
|
||||
# 'states': model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
|
||||
# }
|
||||
# if len(captions) > 0:
|
||||
# dict_to_save['captions'] = captions
|
||||
# dict_to_save['positive_map'] = positive_map
|
||||
# torch.save(
|
||||
# dict_to_save,
|
||||
# fname
|
||||
# )
|
||||
|
||||
|
||||
if torch.isnan(losses) or torch.isinf(losses):
|
||||
logging.error("NaN encountered, ignoring")
|
||||
losses[losses != losses] = 0
|
||||
optimizer.zero_grad()
|
||||
scaler.scale(losses).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
scheduler.step()
|
||||
else:
|
||||
if len(captions) > 0:
|
||||
loss_dict = model(images, targets, captions, positive_map)
|
||||
else:
|
||||
loss_dict = model(images, targets)
|
||||
losses = sum(loss for loss in loss_dict.values())
|
||||
|
||||
# loss_value = losses.item()
|
||||
# if not math.isfinite(loss_value):
|
||||
# logging.error(f'=> loss is {loss_value}, stopping training')
|
||||
# time_str = time.strftime('%Y-%m-%d-%H-%M')
|
||||
# fname = os.path.join(checkpointer.save_dir, f'{time_str}_states.pth')
|
||||
# logging.info(f'=> save error state to {fname}')
|
||||
# dict_to_save = {
|
||||
# 'x': images,
|
||||
# 'y': targets,
|
||||
# 'loss': losses,
|
||||
# 'states': model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
|
||||
# }
|
||||
# if len(captions) > 0:
|
||||
# dict_to_save['captions'] = captions
|
||||
# dict_to_save['positive_map'] = positive_map
|
||||
# torch.save(
|
||||
# dict_to_save,
|
||||
# fname
|
||||
# )
|
||||
|
||||
|
||||
if torch.isnan(losses) or torch.isinf(losses):
|
||||
losses[losses != losses] = 0
|
||||
optimizer.zero_grad()
|
||||
losses.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
# Adapt the weight decay: only support multiStepLR
|
||||
if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, 'milestones'):
|
||||
if milestone_target < len(scheduler.milestones):
|
||||
next_milestone = list(scheduler.milestones)[milestone_target]
|
||||
else:
|
||||
next_milestone = float('inf')
|
||||
if scheduler.last_epoch >= next_milestone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO:
|
||||
gamma = scheduler.gamma
|
||||
logger.info("Drop the weight decay by {}!".format(gamma))
|
||||
for param in optimizer.param_groups:
|
||||
if 'weight_decay' in param:
|
||||
param['weight_decay'] *= gamma
|
||||
# move the target forward
|
||||
milestone_target += 1
|
||||
|
||||
# reduce losses over all GPUs for logging purposes
|
||||
loss_dict_reduced = reduce_loss_dict(loss_dict)
|
||||
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
|
||||
meters.update(loss=losses_reduced, **loss_dict_reduced)
|
||||
if model_ema is not None:
|
||||
model_ema.update(model)
|
||||
arguments["model_ema"] = model_ema.state_dict()
|
||||
|
||||
batch_time = time.time() - end
|
||||
end = time.time()
|
||||
meters.update(time=batch_time, data=data_time)
|
||||
eta_seconds = meters.time.global_avg * (max_iter - iteration)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
|
||||
if iteration % 20 == 0 or iteration == max_iter:
|
||||
# if iteration % 1 == 0 or iteration == max_iter:
|
||||
#logger.info(
|
||||
if global_rank <= 0:
|
||||
print(
|
||||
meters.delimiter.join(
|
||||
[
|
||||
"eta: {eta}",
|
||||
"iter: {iter}",
|
||||
"{meters}",
|
||||
"lr: {lr:.6f}",
|
||||
"wd: {wd:.6f}",
|
||||
"max mem: {memory:.0f}",
|
||||
]
|
||||
).format(
|
||||
eta=eta_string,
|
||||
iter=iteration,
|
||||
meters=str(meters),
|
||||
lr=optimizer.param_groups[0]["lr"],
|
||||
wd=optimizer.param_groups[0]["weight_decay"],
|
||||
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
|
||||
)
|
||||
)
|
||||
if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter):
|
||||
if is_main_process():
|
||||
print("Evaluating")
|
||||
eval_result = 0.0
|
||||
model.eval()
|
||||
|
||||
if cfg.SOLVER.TEST_WITH_INFERENCE:
|
||||
with torch.no_grad():
|
||||
try:
|
||||
_model = model.module
|
||||
except:
|
||||
_model = model
|
||||
_result = inference(
|
||||
model = _model,
|
||||
data_loader = val_data_loader,
|
||||
dataset_name="val",
|
||||
device=device,
|
||||
expected_results=cfg.TEST.EXPECTED_RESULTS,
|
||||
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
|
||||
output_folder=None,
|
||||
cfg=cfg,
|
||||
verbose=False,
|
||||
disable_print=True
|
||||
)
|
||||
if is_main_process():
|
||||
eval_result = _result[0].results['bbox']['AP']
|
||||
else:
|
||||
results_dict = {}
|
||||
cpu_device = torch.device("cpu")
|
||||
for i, batch in enumerate(val_data_loader):
|
||||
images, targets, image_ids, positive_map, *_ = batch
|
||||
with torch.no_grad():
|
||||
images = images.to(device)
|
||||
if positive_map is None:
|
||||
output = model(images)
|
||||
else:
|
||||
captions = [t.get_field("caption") for t in targets if "caption" in t.fields()]
|
||||
output = model(images, captions, positive_map)
|
||||
output = [o.to(cpu_device) for o in output]
|
||||
results_dict.update(
|
||||
{img_id: result for img_id, result in zip(image_ids, output)}
|
||||
)
|
||||
all_predictions = all_gather(results_dict)
|
||||
if is_main_process():
|
||||
predictions = {}
|
||||
for p in all_predictions:
|
||||
predictions.update(p)
|
||||
predictions = [predictions[i] for i in list(sorted(predictions.keys()))]
|
||||
eval_result, _ = evaluate(val_data_loader.dataset, predictions, output_folder=None,
|
||||
box_only=cfg.DATASETS.CLASS_AGNOSTIC)
|
||||
if cfg.DATASETS.CLASS_AGNOSTIC:
|
||||
eval_result = eval_result.results['box_proposal']['AR@100']
|
||||
else:
|
||||
eval_result = eval_result.results['bbox']['AP']
|
||||
model.train()
|
||||
|
||||
if model_ema is not None and cfg.SOLVER.USE_EMA_FOR_MONITOR:
|
||||
model_ema.ema.eval()
|
||||
results_dict = {}
|
||||
cpu_device = torch.device("cpu")
|
||||
for i, batch in enumerate(val_data_loader):
|
||||
images, targets, image_ids, positive_map, positive_map_eval = batch
|
||||
with torch.no_grad():
|
||||
images = images.to(device)
|
||||
if positive_map is None:
|
||||
output = model_ema.ema(images)
|
||||
else:
|
||||
captions = [t.get_field("caption") for t in targets if "caption" in t.fields()]
|
||||
output = model_ema.ema(images, captions, positive_map)
|
||||
output = [o.to(cpu_device) for o in output]
|
||||
results_dict.update(
|
||||
{img_id: result for img_id, result in zip(image_ids, output)}
|
||||
)
|
||||
all_predictions = all_gather(results_dict)
|
||||
if is_main_process():
|
||||
predictions = {}
|
||||
for p in all_predictions:
|
||||
predictions.update(p)
|
||||
predictions = [predictions[i] for i in list(sorted(predictions.keys()))]
|
||||
eval_result, _ = evaluate(val_data_loader.dataset, predictions, output_folder=None,
|
||||
box_only=cfg.DATASETS.CLASS_AGNOSTIC)
|
||||
if cfg.DATASETS.CLASS_AGNOSTIC:
|
||||
eval_result = eval_result.results['box_proposal']['AR@100']
|
||||
else:
|
||||
eval_result = eval_result.results['bbox']['AP']
|
||||
|
||||
arguments.update(eval_result=eval_result)
|
||||
|
||||
if cfg.SOLVER.USE_AUTOSTEP:
|
||||
eval_result = all_gather(eval_result)[0] #broadcast_data([eval_result])[0]
|
||||
# print("Rank {} eval result gathered".format(cfg.local_rank), eval_result)
|
||||
scheduler.step(eval_result)
|
||||
|
||||
if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1:
|
||||
if eval_result < previous_best:
|
||||
patience_counter += 1
|
||||
else:
|
||||
patience_counter = 0
|
||||
previous_best = eval_result
|
||||
checkpointer.save("model_best", **arguments)
|
||||
print("Previous Best", previous_best, "Patience Counter", patience_counter, "Eval Result", eval_result)
|
||||
if patience_counter >= cfg.SOLVER.AUTO_TERMINATE_PATIENCE:
|
||||
if is_main_process():
|
||||
print("\n\n\n\nAuto Termination at {}, current best {}\n\n\n".format(iteration, previous_best))
|
||||
break
|
||||
|
||||
if iteration % checkpoint_period == 0:
|
||||
checkpointer.save("model_{:07d}".format(iteration), **arguments)
|
||||
if iteration == max_iter:
|
||||
checkpointer.save("model_final", **arguments)
|
||||
break
|
||||
|
||||
total_training_time = time.time() - start_training_time
|
||||
total_time_str = str(datetime.timedelta(seconds=total_training_time))
|
||||
logger.info(
|
||||
"Total training time: {} ({:.4f} s / it)".format(
|
||||
total_time_str, total_training_time / (max_iter)
|
||||
)
|
||||
)
|
|
@ -0,0 +1,287 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
r"""
|
||||
Basic training script for PyTorch
|
||||
"""
|
||||
|
||||
# Set up custom environment before nearly anything else is imported
|
||||
# NOTE: this should be the first import (no not reorder)
|
||||
from maskrcnn_benchmark.utils.env import setup_environment # noqa F401 isort:skip
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from maskrcnn_benchmark.config import cfg, try_to_find
|
||||
from maskrcnn_benchmark.data import make_data_loader
|
||||
from maskrcnn_benchmark.solver import make_lr_scheduler
|
||||
from maskrcnn_benchmark.solver import make_optimizer
|
||||
from maskrcnn_benchmark.engine.inference import inference
|
||||
from maskrcnn_benchmark.modeling.detector import build_detection_model
|
||||
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
|
||||
from maskrcnn_benchmark.utils.collect_env import collect_env_info
|
||||
from maskrcnn_benchmark.utils.comm import synchronize, get_rank, is_main_process, all_gather
|
||||
from maskrcnn_benchmark.utils.imports import import_file
|
||||
from maskrcnn_benchmark.utils.logger import setup_logger
|
||||
from maskrcnn_benchmark.utils.metric_logger import (MetricLogger, TensorboardLogger)
|
||||
from maskrcnn_benchmark.utils.miscellaneous import mkdir, save_config
|
||||
import random
|
||||
from maskrcnn_benchmark.utils.amp import autocast, GradScaler
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from collections import defaultdict
|
||||
|
||||
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
||||
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def tuning_highlevel_override(cfg,):
|
||||
if cfg.SOLVER.TUNING_HIGHLEVEL_OVERRIDE == "vision_query":
|
||||
cfg.MODEL.BACKBONE.FREEZE = True
|
||||
cfg.MODEL.FPN.FREEZE = True
|
||||
cfg.MODEL.RPN.FREEZE = True if not cfg.VISION_QUERY.QUERY_FUSION else False
|
||||
cfg.MODEL.LINEAR_PROB = False
|
||||
cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER = False
|
||||
cfg.MODEL.LANGUAGE_BACKBONE.FREEZE = False
|
||||
cfg.MODEL.DYHEAD.USE_CHECKPOINT = False # Disable checkpoint
|
||||
cfg.VISION_QUERY.ENABLED = True
|
||||
if cfg.SOLVER.TUNING_HIGHLEVEL_OVERRIDE == "vs_with_txt_enc":
|
||||
cfg.MODEL.BACKBONE.FREEZE = True
|
||||
cfg.MODEL.FPN.FREEZE = True
|
||||
cfg.MODEL.RPN.FREEZE = True if not cfg.VISION_QUERY.QUERY_FUSION else False
|
||||
cfg.MODEL.LINEAR_PROB = False
|
||||
cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER = False
|
||||
cfg.MODEL.LANGUAGE_BACKBONE.FREEZE = False
|
||||
cfg.MODEL.DYHEAD.USE_CHECKPOINT = False # Disable checkpoint
|
||||
cfg.VISION_QUERY.ENABLED = True
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop("force", False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
def extract_query(cfg):
|
||||
if cfg.DATASETS.FEW_SHOT:
|
||||
assert cfg.DATASETS.FEW_SHOT == cfg.VISION_QUERY.MAX_QUERY_NUMBER, 'To extract the right query instances, set VISION_QUERY.MAX_QUERY_NUMBER = DATASETS.FEW_SHOT.'
|
||||
# if cfg.num_gpus > 1:
|
||||
# max_query_number = cfg.VISION_QUERY.MAX_QUERY_NUMBER
|
||||
# cfg.defrost()
|
||||
# cfg.VISION_QUERY.MAX_QUERY_NUMBER = int(cfg.VISION_QUERY.MAX_QUERY_NUMBER/cfg.num_gpus)
|
||||
# cfg.freeze()
|
||||
|
||||
model = build_detection_model(cfg)
|
||||
device = torch.device(cfg.MODEL.DEVICE)
|
||||
model.to(device)
|
||||
|
||||
|
||||
checkpointer = DetectronCheckpointer(
|
||||
cfg, model
|
||||
)
|
||||
checkpointer.load(try_to_find(cfg.MODEL.WEIGHT))
|
||||
|
||||
data_loader = make_data_loader(
|
||||
cfg,
|
||||
is_train=False,
|
||||
is_cache=True,
|
||||
is_distributed= cfg.num_gpus > 1,
|
||||
)
|
||||
assert isinstance(data_loader, list) and len(data_loader)==1
|
||||
data_loader=data_loader[0]
|
||||
|
||||
# if cfg.VISION_QUERY.CUSTOM_DATA_IDS is not None:
|
||||
# data_loader.dataset.ids = cfg.VISION_QUERY.CUSTOM_DATA_IDS
|
||||
|
||||
if cfg.num_gpus > 1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[cfg.local_rank], output_device=cfg.local_rank,
|
||||
broadcast_buffers=cfg.MODEL.BACKBONE.USE_BN,
|
||||
find_unused_parameters=cfg.SOLVER.FIND_UNUSED_PARAMETERS
|
||||
)
|
||||
|
||||
query_images=defaultdict(list)
|
||||
_iterator = tqdm(data_loader)
|
||||
# _iterator = data_loader # for debug
|
||||
model.eval()
|
||||
for i, batch in enumerate(_iterator):
|
||||
images, targets, *_ = batch
|
||||
if cfg.num_gpus > 1:
|
||||
query_images = model.module.extract_query(images.to(device), targets, query_images)
|
||||
else:
|
||||
query_images = model.extract_query(images.to(device), targets, query_images)
|
||||
|
||||
if cfg.num_gpus > 1:
|
||||
## not stable when using all_gather, easy to OOM.
|
||||
# all_query_images = all_gather(query_images)
|
||||
# if is_main_process():
|
||||
# accumulated_query_images = defaultdict(list)
|
||||
# for r, query_images_dict in enumerate(all_query_images):
|
||||
# print('accumulating results: {}/{}'.format(r, len(all_query_images)))
|
||||
# for label, feat in query_images_dict.items():
|
||||
# num_queries=len(accumulated_query_images[label])
|
||||
# if num_queries >= cfg.VISION_QUERY.MAX_QUERY_NUMBER:
|
||||
# continue
|
||||
# if num_queries==0:
|
||||
# accumulated_query_images[label] = feat.to(device)
|
||||
# else:
|
||||
# accumulated_query_images[label] = torch.cat([accumulated_query_images[label].to(device), feat.to(device)])
|
||||
|
||||
# save_name = 'MODEL/{}_query_{}_pool{}_{}{}_multi-node.pth'.format(cfg.VISION_QUERY.DATASET_NAME if cfg.VISION_QUERY.DATASET_NAME else cfg.DATASETS.TRAIN[0].split('_')[0] , cfg.VISION_QUERY.MAX_QUERY_NUMBER, cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION ,'sel' if cfg.VISION_QUERY.SELECT_FPN_LEVEL else 'all', cfg.VISION_QUERY.QUERY_ADDITION_NAME)
|
||||
# print('saving to ', save_name)
|
||||
# torch.save(accumulated_query_images, save_name)
|
||||
if cfg.VISION_QUERY.QUERY_BANK_SAVE_PATH != '':
|
||||
raise NotImplementedError
|
||||
global_rank = get_rank()
|
||||
save_name = 'MODEL/{}_query_{}_pool{}_{}{}_rank{}.pth'.format(cfg.VISION_QUERY.DATASET_NAME if cfg.VISION_QUERY.DATASET_NAME else cfg.DATASETS.TRAIN[0].split('_')[0] , cfg.VISION_QUERY.MAX_QUERY_NUMBER, cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION ,'sel' if cfg.VISION_QUERY.SELECT_FPN_LEVEL else 'all', cfg.VISION_QUERY.QUERY_ADDITION_NAME, global_rank)
|
||||
print('saving to ', save_name)
|
||||
torch.save(query_images, save_name)
|
||||
else:
|
||||
if cfg.VISION_QUERY.QUERY_BANK_SAVE_PATH != '':
|
||||
save_name = cfg.VISION_QUERY.QUERY_BANK_SAVE_PATH
|
||||
else:
|
||||
save_name = 'MODEL/{}_query_{}_pool{}_{}{}.pth'.format(cfg.VISION_QUERY.DATASET_NAME if cfg.VISION_QUERY.DATASET_NAME else cfg.DATASETS.TRAIN[0].split('_')[0] , cfg.VISION_QUERY.MAX_QUERY_NUMBER, cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION ,'sel' if cfg.VISION_QUERY.SELECT_FPN_LEVEL else 'all', cfg.VISION_QUERY.QUERY_ADDITION_NAME)
|
||||
print('saving to ', save_name)
|
||||
torch.save(query_images, save_name)
|
||||
# if cfg.num_gpus > 1:
|
||||
# #
|
||||
# world_size = torch.distributed.dist.get_world_size()
|
||||
# if is_main_process():
|
||||
# query_images_list = []
|
||||
# for r in range(world_size):
|
||||
# saved_path = 'MODEL/{}_query_{}_pool{}_{}{}_rank{}.pth'.format(cfg.VISION_QUERY.DATASET_NAME if cfg.VISION_QUERY.DATASET_NAME else cfg.DATASETS.TRAIN[0].split('_')[0] , cfg.VISION_QUERY.MAX_QUERY_NUMBER, cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION ,'sel' if cfg.VISION_QUERY.SELECT_FPN_LEVEL else 'all', cfg.VISION_QUERY.QUERY_ADDITION_NAME, r)
|
||||
# query_images_list.append(torch.load(saved_path, map_location='cpu'))
|
||||
|
||||
# for s in query_images_list
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
|
||||
parser.add_argument(
|
||||
"--config-file",
|
||||
default="",
|
||||
metavar="FILE",
|
||||
help="path to config file",
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--skip-test",
|
||||
dest="skip_test",
|
||||
help="Do not test the final model",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
parser.add_argument("--use-tensorboard",
|
||||
dest="use_tensorboard",
|
||||
help="Use tensorboardX logger (Requires tensorboardX installed)",
|
||||
action="store_true",
|
||||
default=False
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"opts",
|
||||
help="Modify config options using the command-line",
|
||||
default=None,
|
||||
nargs=argparse.REMAINDER,
|
||||
)
|
||||
|
||||
parser.add_argument("--save_original_config", action="store_true")
|
||||
parser.add_argument("--disable_output_distributed", action="store_true")
|
||||
parser.add_argument("--override_output_dir", default=None)
|
||||
parser.add_argument("--custom_shot_and_epoch_and_general_copy", default=None, type=str)
|
||||
parser.add_argument("--resume", action="store_true", default=False)
|
||||
parser.add_argument("--extract_query", action="store_true", default=False)
|
||||
parser.add_argument(
|
||||
"--task_config",
|
||||
default="",
|
||||
metavar="FILE",
|
||||
help="path to config file",
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--additional_model_config",
|
||||
default="",
|
||||
metavar="FILE",
|
||||
help="path to config file",
|
||||
type=str,
|
||||
)
|
||||
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
|
||||
args.distributed = num_gpus > 1
|
||||
|
||||
if args.distributed:
|
||||
import datetime
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl", init_method="env://",
|
||||
timeout=datetime.timedelta(0, 7200)
|
||||
)
|
||||
|
||||
if args.disable_output_distributed:
|
||||
setup_for_distributed(args.local_rank <= 0)
|
||||
|
||||
cfg.local_rank = args.local_rank
|
||||
cfg.num_gpus = num_gpus
|
||||
|
||||
cfg.merge_from_file(args.config_file)
|
||||
if args.task_config:
|
||||
cfg.merge_from_file(args.task_config)
|
||||
if args.additional_model_config:
|
||||
cfg.merge_from_file(args.additional_model_config)
|
||||
cfg.merge_from_list(args.opts)
|
||||
# specify output dir for models
|
||||
if args.override_output_dir:
|
||||
cfg.OUTPUT_DIR = args.override_output_dir
|
||||
tuning_highlevel_override(cfg)
|
||||
cfg.freeze()
|
||||
|
||||
seed = cfg.SOLVER.SEED + args.local_rank
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
output_dir = cfg.OUTPUT_DIR
|
||||
if output_dir:
|
||||
mkdir(output_dir)
|
||||
|
||||
logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank())
|
||||
logger.info(args)
|
||||
logger.info("Using {} GPUs".format(num_gpus))
|
||||
|
||||
logger.info("Loaded configuration file {}".format(args.config_file))
|
||||
with open(args.config_file, "r") as cf:
|
||||
config_str = "\n" + cf.read()
|
||||
logger.info(config_str)
|
||||
logger.info("Running with config:\n{}".format(cfg))
|
||||
|
||||
output_config_path = os.path.join(cfg.OUTPUT_DIR, 'config.yml')
|
||||
logger.info("Saving config into: {}".format(output_config_path))
|
||||
# save overloaded model config in the output directory
|
||||
if args.save_original_config:
|
||||
import shutil
|
||||
shutil.copy(args.config_file, os.path.join(cfg.OUTPUT_DIR, 'config_original.yml'))
|
||||
|
||||
save_config(cfg, output_config_path)
|
||||
|
||||
if args.extract_query:
|
||||
extract_query(cfg)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue