diff --git a/maskrcnn_benchmark/data/datasets/__init__.py b/maskrcnn_benchmark/data/datasets/__init__.py index f6727e9..95c0438 100644 --- a/maskrcnn_benchmark/data/datasets/__init__.py +++ b/maskrcnn_benchmark/data/datasets/__init__.py @@ -16,10 +16,10 @@ from .caption import CaptionTSV from .lvis import LvisDetection from .pseudo_data import PseudoData from .phrasecut import PhrasecutDetection -# from .modulated_coco_new import CocoGrounding_New +from .modulated_coco_new import CocoGrounding_New __all__ = ["COCODataset", "TSVDataset", "ODTSVDataset", "ConcatDataset", "PascalVOCDataset", "Background", "ModulatedDataset", "MixedDataset", "CocoDetection", "FlickrDataset", "RefExpDataset", "GQADataset", "CocoDetectionTSV", "CocoGrounding", "CaptionTSV", "LvisDetection", "PseudoData", "PhrasecutDetection", - # "CocoGrounding_New" + "CocoGrounding_New", ] diff --git a/maskrcnn_benchmark/data/datasets/modulated_coco_new.py b/maskrcnn_benchmark/data/datasets/modulated_coco_new.py new file mode 100644 index 0000000..1af48c9 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/modulated_coco_new.py @@ -0,0 +1,721 @@ +# Delete some ununsed functions from modulated_coco. +# Suit for object365 pre-training +import logging +import os +import os.path +import math +from PIL import Image, ImageDraw + +import random +import numpy as np + +import torch +import torchvision +import torch.utils.data as data +from pycocotools import mask as coco_mask + +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask +from maskrcnn_benchmark.data.datasets.coco import has_valid_annotation +from .od_to_grounding import convert_od_to_grounding_simple, check_for_positive_overflow, sanity_check_target_after_processing, convert_object_detection_to_grounding_optimized_for_od +import pdb +import json +from tqdm import tqdm +from groundingdino_new.util.inference import preprocess_caption +from groundingdino_new.util.box_ops import box_xyxy_to_cxcywh + +import copy + +def _has_only_crowd_bbox(anno): + return all(obj["iscrowd"] == 1 for obj in anno) + +class CocoGrounding_New(torchvision.datasets.CocoDetection): + def __init__(self, + img_folder, + ann_file, + transforms, + return_masks, + return_tokens, + is_train=False, + tokenizer=None, + disable_shuffle=False, + add_detection_prompt=False, + add_detection_prompt_advanced=False, + control_probabilities={}, + one_hot=False, + disable_clip_to_image=False, + no_minus_one_for_one_hot=False, + separation_tokens=" ", + few_shot=0, + no_mask_for_od=False, + override_category=None, + use_caption_prompt=False, + caption_prompt=None, + max_num_labels=-1, + max_query_len=256, + special_safeguard_for_coco_grounding=False, + random_sample_negative=-1, + cumtom_ids=None, + exclude_crowd=False, + sep_at_last = False, + add_normed_cxcy = False, + custom_category_ids=None, + **kwargs + ): + super(CocoGrounding_New, self).__init__(img_folder, ann_file) + self.ids = sorted(self.ids) + + self.sep_at_last = sep_at_last + self.add_normed_cxcy = add_normed_cxcy + + self.iscrowd = False if exclude_crowd else None + + ids = [] + for img_id in self.ids: + if isinstance(img_id, str): + ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=self.iscrowd) + else: + ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=self.iscrowd) + anno = self.coco.loadAnns(ann_ids) + + if has_valid_annotation(anno): + ids.append(img_id) + + self.ids = ids + + # self.ids = self.remove_invalid_images(self.ids) + + if few_shot: + ids = [] + # cats_freq = [few_shot]*len(self.coco.cats.keys()) + cats_freq = [few_shot]*max(list(self.coco.cats.keys())) + for img_id in self.ids: + if isinstance(img_id, str): + ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=self.iscrowd) + else: + ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=self.iscrowd) + anno = self.coco.loadAnns(ann_ids) + cat = set([ann['category_id'] for ann in anno]) #set/tuple corresponde to instance/image level + is_needed = sum([cats_freq[c-1]>0 for c in cat]) + if is_needed: + ids.append(img_id) + for c in cat: + cats_freq[c-1] -= 1 + # print(cat, cats_freq) + self.ids = ids + if cumtom_ids is not None: + self.ids = cumtom_ids + + if custom_category_ids is not None: + new_ids = [] + for img_id in self.ids: + if isinstance(img_id, str): + ann_ids = self.coco.getAnnIds(imgIds=[img_id], catIds=custom_category_ids, iscrowd=self.iscrowd) + else: + ann_ids = self.coco.getAnnIds(imgIds=img_id, catIds=custom_category_ids, iscrowd=self.iscrowd) + if len(ann_ids) > 0: + new_ids.append(img_id) + self.ids = new_ids + + + self.json_category_id_to_contiguous_id = { + v: i + 1 for i, v in enumerate(self.coco.getCatIds()) + } + self.contiguous_category_id_to_json_id = { + v: k for k, v in self.json_category_id_to_contiguous_id.items() + } + + if override_category is not None: + self.coco.dataset["categories"] = override_category + self.max_num_labels=max_num_labels + self.control_probabilities=control_probabilities + self.use_caption_prompt = use_caption_prompt + self.caption_prompt = caption_prompt + self.special_safeguard_for_coco_grounding = special_safeguard_for_coco_grounding + self.random_sample_negative = random_sample_negative + self.ind_to_class = self.categories(no_background=False) + self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} + self._transforms = transforms + self.max_query_len = max_query_len + self.prepare = ConvertCocoPolysToMask(False, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len, ind_to_class=self.ind_to_class) + self.tokenizer = tokenizer + self.is_train = is_train + + self.ind_to_class = self.categories(no_background=False) + + self.disable_shuffle = disable_shuffle + self.add_detection_prompt = add_detection_prompt + self.add_detection_prompt_advanced=add_detection_prompt_advanced + self.one_hot = one_hot + self.no_minus_one_for_one_hot = no_minus_one_for_one_hot + + self.disable_clip_to_image = disable_clip_to_image + self.separation_tokens = separation_tokens + self.no_mask_for_od = no_mask_for_od + self.return_masks = return_masks + + def remove_invalid_images(self, ids): + print('removing non-exist images from dataset...') + new_ids=[] + invalid_num=0 + for id in tqdm(ids): + path = self.coco.loadImgs(id)[0]["file_name"] + if os.path.exists(os.path.join(self.root, path)): + new_ids.append(id) + else: + invalid_num+=1 + print('removed {} non-exist images from dataset'.format(invalid_num)) + return new_ids + + def categories(self, no_background=True): + categories = self.coco.dataset["categories"] + label_list = {} + for index, i in enumerate(categories): + # assert(index + 1 == i["id"]) + if not no_background or (i["name"] != "__background__" and i['id'] != 0): + label_list[self.json_category_id_to_contiguous_id[i["id"]]] = i["name"] + return label_list + + def get_box_mask(self, rect, img_size, mode="poly"): + assert mode=="poly", "Only support poly mask right now!" + x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3] + return [[x1, y1, x1, y2, x2, y2, x2, y1]] + + def __getitem__(self, idx): + img, tgt = super(CocoGrounding_New, self).__getitem__(idx) + image_id = self.ids[idx] + tgt = [obj for obj in tgt if obj["iscrowd"] == 0] + boxes = [obj["bbox"] for obj in tgt] + boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes + target = BoxList(boxes, img.size, mode="xywh").convert("xyxy") + classes = [obj["category_id"] for obj in tgt] + classes = [self.json_category_id_to_contiguous_id[c] for c in classes] + classes = torch.tensor(classes) + target.add_field("labels", classes) + + if not self.disable_clip_to_image: + target = target.clip_to_image(remove_empty=True) + + if self.special_safeguard_for_coco_grounding: + # Intended for LVIS and Object365 + assert(not self.use_caption_prompt) + + original_box_num = len(target) + target, positive_caption_length = check_for_positive_overflow(target, self.ind_to_class, self.tokenizer, self.max_query_len-2) # leave some space for the special tokens + if len(target) < original_box_num: + print("WARNING: removed {} boxes due to positive caption overflow".format(original_box_num - len(target))) + + annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_object_detection_to_grounding_optimized_for_od( + target=target, + image_id=image_id, + ind_to_class=self.ind_to_class, + disable_shuffle=self.disable_shuffle, + add_detection_prompt=self.add_detection_prompt, + add_detection_prompt_advanced=self.add_detection_prompt_advanced, + random_sample_negative=self.random_sample_negative, + control_probabilities=self.control_probabilities, # always try to add a lot of negatives + restricted_negative_list=None, + separation_tokens=self.separation_tokens, + max_num_labels=self.max_num_labels, + positive_caption_length=positive_caption_length, + tokenizer=self.tokenizer, + max_seq_length=self.max_query_len-2, + obj356_debug=True + ) + else: + # Intended for COCO / ODinW + annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_od_to_grounding_simple( + target=target, + image_id=image_id, + ind_to_class=self.ind_to_class, + disable_shuffle=self.disable_shuffle, + add_detection_prompt=self.add_detection_prompt, + separation_tokens=self.separation_tokens, + caption_prompt=self.caption_prompt if self.use_caption_prompt else None, + ) + + # if self.sep_at_last: + # caption = preprocess_caption(caption) + anno = {"image_id": image_id, "annotations": annotations, "caption": caption, "label_to_positions_caption": label_to_positions} + anno["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective + if self.no_mask_for_od: + anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1)) + img, anno = self.prepare(img, anno, box_format="xyxy") + + # for equivalence check + if self.one_hot: + logging.info("using one hot for equivalence check.") + one_hot_map = torch.zeros_like(anno["positive_map"], dtype=torch.float) + text_mask = torch.zeros(anno["positive_map"].shape[1], dtype=torch.int64) + # create one hot mapping + for ii, cls in enumerate(classes): + if self.no_minus_one_for_one_hot: + one_hot_map[ii, cls] = 1.0 + else: + one_hot_map[ii, cls - 1] = 1.0 + if self.no_minus_one_for_one_hot: + text_mask[:] = 1 + else: + text_mask[:len(self.ind_to_class)] = 1 + anno["positive_map"] = one_hot_map + anno["text_mask"] = text_mask + + if self._transforms is not None: + img, target = self._transforms(img, target) + + # add additional property + for ann in anno: + target.add_field(ann, anno[ann]) + + if self.add_normed_cxcy: + bbox = target.bbox + H, W = target.size + normed_bbox = bbox / torch.Tensor([[H,W,H,W]]) + normed_cxcy = box_xyxy_to_cxcywh(normed_bbox) + target.add_field('normed_cxcy_boxes', normed_cxcy) + + sanity_check_target_after_processing(target) + + return img, target, idx + + def get_img_info(self, index): + img_id = self.id_to_img_map[index] + img_data = self.coco.imgs[img_id] + return img_data + + def get_raw_image(self, idx): + image, *_ = super(CocoGrounding_New, self).__getitem__(idx) + return image + + +class ModulatedDataset(torchvision.datasets.CocoDetection): + def __init__(self, + img_folder, + ann_file, + transforms, + return_masks, + return_tokens, + is_train=False, + tokenizer=None, + disable_clip_to_image=False, + no_mask_for_gold=False, + max_query_len=256, + **kwargs): + super(ModulatedDataset, self).__init__(img_folder, ann_file) + self.ids = sorted(self.ids) + + ids = [] + for img_id in self.ids: + if isinstance(img_id, str): + ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) + else: + ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = self.coco.loadAnns(ann_ids) + if has_valid_annotation(anno): + ids.append(img_id) + self.ids = ids + + self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} + self._transforms = transforms + self.max_query_len = max_query_len + self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len) + self.is_train = is_train + self.disable_clip_to_image = disable_clip_to_image + self.no_mask_for_gold = no_mask_for_gold + + def __getitem__(self, idx): + img, target = super(ModulatedDataset, self).__getitem__(idx) + image_id = self.ids[idx] + coco_img = self.coco.loadImgs(image_id)[0] + caption = coco_img["caption"] + dataset_name = coco_img["dataset_name"] if "dataset_name" in coco_img else None + anno = {"image_id": image_id, "annotations": target, "caption": caption} + + # This dataset is used for Flickr & Mixed, so the sequence is maskable + anno["greenlight_span_for_masked_lm_objective"] = [(0, len(caption))] + if self.no_mask_for_gold: + anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1)) + img, anno = self.prepare(img, anno) + + # convert to BoxList (bboxes, labels) + boxes = torch.as_tensor(anno["boxes"]).reshape(-1, 4) # guard against no boxes + target = BoxList(boxes, img.size, mode="xyxy") + classes = anno["labels"] + target.add_field("labels", classes) + if self.prepare.return_masks: + target.add_field("masks", anno.pop("masks")) + target.add_field("is_box_mask", anno.pop("is_box_mask")) + if not self.disable_clip_to_image: + num_boxes = len(target.bbox) + target = target.clip_to_image(remove_empty=True) + assert num_boxes == len(target.bbox), "Box got removed in MixedDataset!!!" + + # Check if bboxes are correct + # draw = ImageDraw.Draw(img) + # boxes = target.bbox + # for box in boxes: + # draw.rectangle([box[0], box[1], box[2], box[3]]) + # img.save('OUTPUT/images/{}.jpg'.format(idx)) + + if self._transforms is not None: + img, target = self._transforms(img, target) + + # add additional property + for ann in anno: + target.add_field(ann, anno[ann]) + + target.add_field("dataset_name", dataset_name) + for extra_key in ["sentence_id", "original_img_id", "original_id", "task_id"]: + if extra_key in coco_img: + target.add_field(extra_key, coco_img[extra_key]) + + if "tokens_positive_eval" in coco_img and not self.is_train: + tokenized = self.prepare.tokenizer(caption, return_tensors="pt") + target.add_field("positive_map_eval", create_positive_map(tokenized, coco_img["tokens_positive_eval"])) + target.add_field("nb_eval", len(target.get_field("positive_map_eval"))) + + sanity_check_target_after_processing(target) + return img, target, idx + + def get_img_info(self, index): + img_id = self.id_to_img_map[index] + img_data = self.coco.imgs[img_id] + return img_data + + +class CocoDetection(data.Dataset): + """`MS Coco Detection `_ Dataset. + + Args: + root (string): Root directory where images are downloaded to. + annFile (string): Path to json annotation file. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.ToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + def __init__(self, root, annFile, transform=None, target_transform=None): + from pycocotools.coco import COCO + self.root = root + self.coco = COCO(annFile) + self.ids = list(self.coco.imgs.keys()) + self.transform = transform + self.target_transform = target_transform + + def __getitem__(self, index, return_meta=False): + """ + Args: + index (int): Index + + Returns: + tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. + """ + coco = self.coco + img_id = self.ids[index] + if isinstance(img_id, str): + img_id = [img_id] + ann_ids = coco.getAnnIds(imgIds=img_id) + target = coco.loadAnns(ann_ids) + + meta = coco.loadImgs(img_id)[0] + path = meta['file_name'] + img = pil_loader(os.path.join(self.root, path)) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + if return_meta: + return img, target, meta + else: + return img, target + + def __len__(self): + return len(self.ids) + + def __repr__(self): + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) + fmt_str += ' Root Location: {}\n'.format(self.root) + tmp = ' Transforms (if any): ' + fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + tmp = ' Target Transforms (if any): ' + fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + return fmt_str + + +class ConvertCocoPolysToMask(object): + def __init__(self, return_masks=False, return_tokens=False, tokenizer=None, max_query_len=256, ind_to_class=None): + self.return_masks = return_masks + self.return_tokens = return_tokens + self.tokenizer = tokenizer + self.max_query_len = max_query_len + self.ind_to_class=ind_to_class + + def get_box_mask(self, rect, img_size, mode="poly"): + assert mode=="poly", "Only support poly mask right now!" + x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3] + return [[x1, y1, x1, y2, x2, y2, x2, y1]] + + def __call__(self, image, target, ignore_box_screen=False, box_format="xywh"): + w, h = image.size + + image_id = target["image_id"] + image_id = torch.tensor([image_id]) + + anno = target["annotations"] + caption = target["caption"] if "caption" in target else None + label_to_positions = target.get("label_to_positions", {}) + label_to_positions_caption = target.get("label_to_positions_caption", {}) + + greenlight_span_for_masked_lm_objective = target.get("greenlight_span_for_masked_lm_objective", None) + + anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] + + boxes = [obj["bbox"] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + if box_format == "xywh": + boxes[:, 2:] += boxes[:, :2] - 1 # TO_REMOVE = 1 + boxes[:, 0::2].clamp_(min=0, max=w-1) # TO_REMOVE = 1 + boxes[:, 1::2].clamp_(min=0, max=h-1) # TO_REMOVE = 1 + + classes = [obj["category_id"] for obj in anno] + classes = torch.tensor(classes, dtype=torch.int64) + + if self.return_masks: + masks = [] + is_box_mask = [] + for obj, bbox in zip(anno, boxes): + if "segmentation" in obj: + masks.append(obj["segmentation"]) + is_box_mask.append(0) + else: + masks.append(self.get_box_mask(bbox, image.size, mode='poly')) + is_box_mask.append(1) + masks = SegmentationMask(masks, image.size, mode='poly') + is_box_mask = torch.tensor(is_box_mask) + + keypoints = None + if anno and "keypoints" in anno[0]: + keypoints = [obj["keypoints"] for obj in anno] + keypoints = torch.as_tensor(keypoints, dtype=torch.float32) + num_keypoints = keypoints.shape[0] + if num_keypoints: + keypoints = keypoints.view(num_keypoints, -1, 3) + + isfinal = None + if anno and "isfinal" in anno[0]: + isfinal = torch.as_tensor([obj["isfinal"] for obj in anno], dtype=torch.float) + + tokens_positive = [] if self.return_tokens else None + if self.return_tokens and anno and "tokens" in anno[0]: + tokens_positive = [obj["tokens"] for obj in anno] + elif self.return_tokens and anno and "tokens_positive" in anno[0]: + tokens_positive = [obj["tokens_positive"] for obj in anno] + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + classes = classes[keep] + if self.return_masks: + masks = masks[keep] + is_box_mask = is_box_mask[keep] + if keypoints is not None: + keypoints = keypoints[keep] + + target = {} + target["boxes"] = boxes + target["labels"] = classes + if caption is not None: + target["caption"] = caption + if self.return_masks: + target["masks"] = masks + target["is_box_mask"] = is_box_mask + target["image_id"] = image_id + if keypoints is not None: + target["keypoints"] = keypoints + + if tokens_positive is not None: + target["tokens_positive"] = [] + + for i, k in enumerate(keep): + if k or ignore_box_screen: + target["tokens_positive"].append(tokens_positive[i]) + + if isfinal is not None: + target["isfinal"] = isfinal + + # for conversion to coco api + area = torch.tensor([obj["area"] for obj in anno]) + iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) + target["area"] = area[keep] + target["iscrowd"] = iscrowd[keep] + + target["orig_size"] = torch.as_tensor([int(h), int(w)]) + target["size"] = torch.as_tensor([int(h), int(w)]) + + if self.return_tokens and self.tokenizer is not None: + if not ignore_box_screen: + assert len(target["boxes"]) == len(target["tokens_positive"]) + + tokenized = self.tokenizer(caption, return_tensors="pt", + max_length=self.max_query_len, + truncation=True) + + target["positive_map"] = create_positive_map(tokenized, target["tokens_positive"]) + # target['greenlight_map'] = create_greenlight_map(greenlight_span_for_masked_lm_objective,tokenized) + # target["positive_map_for_od_labels"] = create_positive_map_for_od_labels(tokenized, label_to_positions) + + all_tokens = [[v] for k,v in label_to_positions_caption.items()] + target["all_map"] = create_positive_map(tokenized, all_tokens) + target["labels_in_caption"]=[k for k,v in label_to_positions_caption.items()] + + pos_label_set=list(set(target['labels'].tolist())) + pos_category_tokens = [[v] for k,v in label_to_positions_caption.items() if k in pos_label_set] + target["positive_category_map"] = create_positive_map(tokenized, pos_category_tokens) + target["positive_category_map"][target["positive_category_map"]!=0]=1 + + + original_od_label = [] + for obj in anno: + original_od_label.append( + obj.get("original_od_label", -10)) # NOTE: The padding value has to be not the same as -1 or -100 + target["original_od_label"] = torch.as_tensor(original_od_label) + + return image, target + +def create_greenlight_map(tok_list, tokenized): + # An example tok_list: + # [(0, 5), (10, 13), (-1, -1, -1)] + # The last one is a special indicator.. + + greenlight_map = torch.zeros(256, dtype=torch.float) + for item in tok_list: + if len(item) != 2: + assert(len(item) == 3) + # Make everything unmakable + greenlight_map[:] = -1 + break + + beg, end = item + beg_pos = tokenized.char_to_token(beg) + end_pos = tokenized.char_to_token(end - 1) + if beg_pos is None: + try: + beg_pos = tokenized.char_to_token(beg + 1) + if beg_pos is None: + beg_pos = tokenized.char_to_token(beg + 2) + except: + beg_pos = None + if end_pos is None: + try: + end_pos = tokenized.char_to_token(end - 2) + if end_pos is None: + end_pos = tokenized.char_to_token(end - 3) + except: + end_pos = None + if beg_pos is None or end_pos is None: + continue + + assert beg_pos is not None and end_pos is not None + greenlight_map[beg_pos: end_pos + 1].fill_(1) + return greenlight_map + + +def create_positive_map_for_od_labels(tokenized, label_to_positions): + """construct a map such that positive_map[i] = j, where j is the object detection label of the token i""" + """ + {3: [1: 5)} + 256 : -1 3 3 3 3 -1 .. 8 8 .. + the woman in the garden + -1 -1 -1 -1 -1 + """ + positive_map = torch.ones(256, dtype=torch.float) * -1 # -1 means no match + keys = list(label_to_positions.keys()) + for j, key in enumerate(keys): + tok_list = label_to_positions[key] + # one label only mapps to one location + beg, end = tok_list + beg_pos = tokenized.char_to_token(beg) + end_pos = tokenized.char_to_token(end - 1) + if beg_pos is None: + try: + beg_pos = tokenized.char_to_token(beg + 1) + if beg_pos is None: + beg_pos = tokenized.char_to_token(beg + 2) + except: + beg_pos = None + if end_pos is None: + try: + end_pos = tokenized.char_to_token(end - 2) + if end_pos is None: + end_pos = tokenized.char_to_token(end - 3) + except: + end_pos = None + if beg_pos is None or end_pos is None: + continue + assert beg_pos is not None and end_pos is not None + positive_map[beg_pos: end_pos + 1].fill_(key) + return positive_map + + +def convert_coco_poly_to_mask(segmentations, height, width): + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, dim=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8) + return masks + + +def create_positive_map(tokenized, tokens_positive): + """construct a map such that positive_map[i,j] = True iff box i is associated to token j""" + positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float) + + for j, tok_list in enumerate(tokens_positive): + for (beg, end) in tok_list: + beg_pos = tokenized.char_to_token(beg) + end_pos = tokenized.char_to_token(end - 1) + if beg_pos is None: + try: + beg_pos = tokenized.char_to_token(beg + 1) + if beg_pos is None: + beg_pos = tokenized.char_to_token(beg + 2) + except: + beg_pos = None + if end_pos is None: + try: + end_pos = tokenized.char_to_token(end - 2) + if end_pos is None: + end_pos = tokenized.char_to_token(end - 3) + except: + end_pos = None + if beg_pos is None or end_pos is None: + continue + + assert beg_pos is not None and end_pos is not None + positive_map[j, beg_pos: end_pos + 1].fill_(1) + return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) + + +def pil_loader(path, retry=5): + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + ri = 0 + while ri < retry: + try: + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + except: + ri += 1 diff --git a/maskrcnn_benchmark/engine/trainer.py b/maskrcnn_benchmark/engine/trainer.py deleted file mode 100644 index b725ded..0000000 --- a/maskrcnn_benchmark/engine/trainer.py +++ /dev/null @@ -1,362 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. -import datetime -import logging -import sys -import os -import math -import time - -import torch -import torch.distributed as dist - -from maskrcnn_benchmark.utils.comm import get_world_size, all_gather, is_main_process, broadcast_data, get_rank -from maskrcnn_benchmark.utils.metric_logger import MetricLogger -from maskrcnn_benchmark.utils.ema import ModelEma -from maskrcnn_benchmark.utils.amp import autocast, GradScaler -from maskrcnn_benchmark.data.datasets.evaluation import evaluate -from .inference import inference -import pdb - -def reduce_loss_dict(loss_dict): - """ - Reduce the loss dictionary from all processes so that process with rank - 0 has the averaged results. Returns a dict with the same fields as - loss_dict, after reduction. - """ - world_size = get_world_size() - if world_size < 2: - return loss_dict - with torch.no_grad(): - loss_names = [] - all_losses = [] - for k in sorted(loss_dict.keys()): - loss_names.append(k) - all_losses.append(loss_dict[k]) - all_losses = torch.stack(all_losses, dim=0) - dist.reduce(all_losses, dst=0) - if dist.get_rank() == 0: - # only main process gets accumulated, so only divide by - # world_size in this case - all_losses /= world_size - reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} - return reduced_losses - - -def do_train( - cfg, - model, - data_loader, - optimizer, - scheduler, - checkpointer, - device, - checkpoint_period, - arguments, - val_data_loader=None, - meters=None, - zero_shot=False, -): - logger = logging.getLogger("maskrcnn_benchmark.trainer") - logger.info("Start training") - # meters = MetricLogger(delimiter=" ") - max_iter = len(data_loader) - start_iter = arguments["iteration"] - model.train() - model_ema = None - if cfg.SOLVER.MODEL_EMA > 0: - model_ema = ModelEma(model, decay=cfg.SOLVER.MODEL_EMA) - start_training_time = time.time() - end = time.time() - - if cfg.SOLVER.USE_AMP: - scaler = GradScaler() - - global_rank = get_rank() - - if cfg.SOLVER.CHECKPOINT_PER_EPOCH != -1 and cfg.SOLVER.MAX_EPOCH >= 1: - checkpoint_period = len(data_loader) // cfg.SOLVER.CHECKPOINT_PER_EPOCH // cfg.SOLVER.MAX_EPOCH - - if global_rank <= 0 and cfg.SOLVER.MAX_EPOCH >= 1: - print("Iter per epoch ", len(data_loader) // cfg.SOLVER.MAX_EPOCH ) - - if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1: - patience_counter = 0 - previous_best = 0.0 - - # Adapt the weight decay - if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, 'milestones'): - milestone_target = 0 - for i, milstone in enumerate(list(scheduler.milestones)): - if scheduler.last_epoch >= milstone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO: - milestone_target = i+1 - for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter): - nnegative = sum(len(target) < 1 for target in targets) - nsample = len(targets) - if nsample == nnegative or nnegative > nsample * cfg.SOLVER.MAX_NEG_PER_BATCH: - logger.info('[WARNING] Sampled {} negative in {} in a batch, greater the allowed ratio {}, skip'. - format(nnegative, nsample, cfg.SOLVER.MAX_NEG_PER_BATCH)) - continue - - data_time = time.time() - end - iteration = iteration + 1 - arguments["iteration"] = iteration - - images = images.to(device) - captions = None - try: - targets = [target.to(device) for target in targets] - captions = [t.get_field("caption") for t in targets if "caption" in t.fields()] - except: - pass - # Freeze language backbone - if cfg.MODEL.LANGUAGE_BACKBONE.FREEZE: - if hasattr(model, "module"): - model.module.language_backbone.eval() - else: - model.language_backbone.eval() - - if cfg.SOLVER.USE_AMP: - with autocast(): - if len(captions) > 0: - loss_dict = model(images, targets, captions=captions, positive_map=positive_map, greenlight_map = greenlight_map) - else: - loss_dict = model(images, targets) - losses = sum(loss for loss in loss_dict.values()) - - # save checkpoints for further debug if nan happens - # loss_value = losses.item() - # if not math.isfinite(loss_value): - # logging.error(f'=> loss is {loss_value}, stopping training') - # logging.error("Losses are : {}".format(loss_dict)) - # time_str = time.strftime('%Y-%m-%d-%H-%M') - # fname = os.path.join(checkpointer.save_dir, f'{time_str}_states.pth') - # logging.info(f'=> save error state to {fname}') - # dict_to_save = { - # 'x': images, - # 'y': targets, - # 'loss': losses, - # 'states': model.module.state_dict() if hasattr(model, 'module') else model.state_dict() - # } - # if len(captions) > 0: - # dict_to_save['captions'] = captions - # dict_to_save['positive_map'] = positive_map - # torch.save( - # dict_to_save, - # fname - # ) - - - if torch.isnan(losses) or torch.isinf(losses): - logging.error("NaN encountered, ignoring") - losses[losses != losses] = 0 - optimizer.zero_grad() - scaler.scale(losses).backward() - scaler.step(optimizer) - scaler.update() - scheduler.step() - else: - if len(captions) > 0: - loss_dict = model(images, targets, captions, positive_map) - else: - loss_dict = model(images, targets) - losses = sum(loss for loss in loss_dict.values()) - - # loss_value = losses.item() - # if not math.isfinite(loss_value): - # logging.error(f'=> loss is {loss_value}, stopping training') - # time_str = time.strftime('%Y-%m-%d-%H-%M') - # fname = os.path.join(checkpointer.save_dir, f'{time_str}_states.pth') - # logging.info(f'=> save error state to {fname}') - # dict_to_save = { - # 'x': images, - # 'y': targets, - # 'loss': losses, - # 'states': model.module.state_dict() if hasattr(model, 'module') else model.state_dict() - # } - # if len(captions) > 0: - # dict_to_save['captions'] = captions - # dict_to_save['positive_map'] = positive_map - # torch.save( - # dict_to_save, - # fname - # ) - - - if torch.isnan(losses) or torch.isinf(losses): - losses[losses != losses] = 0 - optimizer.zero_grad() - losses.backward() - optimizer.step() - scheduler.step() - - # Adapt the weight decay: only support multiStepLR - if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, 'milestones'): - if milestone_target < len(scheduler.milestones): - next_milestone = list(scheduler.milestones)[milestone_target] - else: - next_milestone = float('inf') - if scheduler.last_epoch >= next_milestone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO: - gamma = scheduler.gamma - logger.info("Drop the weight decay by {}!".format(gamma)) - for param in optimizer.param_groups: - if 'weight_decay' in param: - param['weight_decay'] *= gamma - # move the target forward - milestone_target += 1 - - # reduce losses over all GPUs for logging purposes - loss_dict_reduced = reduce_loss_dict(loss_dict) - losses_reduced = sum(loss for loss in loss_dict_reduced.values()) - meters.update(loss=losses_reduced, **loss_dict_reduced) - if model_ema is not None: - model_ema.update(model) - arguments["model_ema"] = model_ema.state_dict() - - batch_time = time.time() - end - end = time.time() - meters.update(time=batch_time, data=data_time) - eta_seconds = meters.time.global_avg * (max_iter - iteration) - eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) - - if iteration % 20 == 0 or iteration == max_iter: - # if iteration % 1 == 0 or iteration == max_iter: - #logger.info( - if global_rank <= 0: - print( - meters.delimiter.join( - [ - "eta: {eta}", - "iter: {iter}", - "{meters}", - "lr: {lr:.6f}", - "wd: {wd:.6f}", - "max mem: {memory:.0f}", - ] - ).format( - eta=eta_string, - iter=iteration, - meters=str(meters), - lr=optimizer.param_groups[0]["lr"], - wd=optimizer.param_groups[0]["weight_decay"], - memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, - ) - ) - if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter): - if is_main_process(): - print("Evaluating") - eval_result = 0.0 - model.eval() - - if cfg.SOLVER.TEST_WITH_INFERENCE: - with torch.no_grad(): - try: - _model = model.module - except: - _model = model - _result = inference( - model = _model, - data_loader = val_data_loader, - dataset_name="val", - device=device, - expected_results=cfg.TEST.EXPECTED_RESULTS, - expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, - output_folder=None, - cfg=cfg, - verbose=False, - disable_print=True - ) - if is_main_process(): - eval_result = _result[0].results['bbox']['AP'] - else: - results_dict = {} - cpu_device = torch.device("cpu") - for i, batch in enumerate(val_data_loader): - images, targets, image_ids, positive_map, *_ = batch - with torch.no_grad(): - images = images.to(device) - if positive_map is None: - output = model(images) - else: - captions = [t.get_field("caption") for t in targets if "caption" in t.fields()] - output = model(images, captions, positive_map) - output = [o.to(cpu_device) for o in output] - results_dict.update( - {img_id: result for img_id, result in zip(image_ids, output)} - ) - all_predictions = all_gather(results_dict) - if is_main_process(): - predictions = {} - for p in all_predictions: - predictions.update(p) - predictions = [predictions[i] for i in list(sorted(predictions.keys()))] - eval_result, _ = evaluate(val_data_loader.dataset, predictions, output_folder=None, - box_only=cfg.DATASETS.CLASS_AGNOSTIC) - if cfg.DATASETS.CLASS_AGNOSTIC: - eval_result = eval_result.results['box_proposal']['AR@100'] - else: - eval_result = eval_result.results['bbox']['AP'] - model.train() - - if model_ema is not None and cfg.SOLVER.USE_EMA_FOR_MONITOR: - model_ema.ema.eval() - results_dict = {} - cpu_device = torch.device("cpu") - for i, batch in enumerate(val_data_loader): - images, targets, image_ids, positive_map, positive_map_eval = batch - with torch.no_grad(): - images = images.to(device) - if positive_map is None: - output = model_ema.ema(images) - else: - captions = [t.get_field("caption") for t in targets if "caption" in t.fields()] - output = model_ema.ema(images, captions, positive_map) - output = [o.to(cpu_device) for o in output] - results_dict.update( - {img_id: result for img_id, result in zip(image_ids, output)} - ) - all_predictions = all_gather(results_dict) - if is_main_process(): - predictions = {} - for p in all_predictions: - predictions.update(p) - predictions = [predictions[i] for i in list(sorted(predictions.keys()))] - eval_result, _ = evaluate(val_data_loader.dataset, predictions, output_folder=None, - box_only=cfg.DATASETS.CLASS_AGNOSTIC) - if cfg.DATASETS.CLASS_AGNOSTIC: - eval_result = eval_result.results['box_proposal']['AR@100'] - else: - eval_result = eval_result.results['bbox']['AP'] - - arguments.update(eval_result=eval_result) - - if cfg.SOLVER.USE_AUTOSTEP: - eval_result = all_gather(eval_result)[0] #broadcast_data([eval_result])[0] - # print("Rank {} eval result gathered".format(cfg.local_rank), eval_result) - scheduler.step(eval_result) - - if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1: - if eval_result < previous_best: - patience_counter += 1 - else: - patience_counter = 0 - previous_best = eval_result - checkpointer.save("model_best", **arguments) - print("Previous Best", previous_best, "Patience Counter", patience_counter, "Eval Result", eval_result) - if patience_counter >= cfg.SOLVER.AUTO_TERMINATE_PATIENCE: - if is_main_process(): - print("\n\n\n\nAuto Termination at {}, current best {}\n\n\n".format(iteration, previous_best)) - break - - if iteration % checkpoint_period == 0: - checkpointer.save("model_{:07d}".format(iteration), **arguments) - if iteration == max_iter: - checkpointer.save("model_final", **arguments) - break - - total_training_time = time.time() - start_training_time - total_time_str = str(datetime.timedelta(seconds=total_training_time)) - logger.info( - "Total training time: {} ({:.4f} s / it)".format( - total_time_str, total_training_time / (max_iter) - ) - ) diff --git a/tools/train_net.py b/tools/train_net.py new file mode 100644 index 0000000..65b1d0a --- /dev/null +++ b/tools/train_net.py @@ -0,0 +1,287 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +r""" +Basic training script for PyTorch +""" + +# Set up custom environment before nearly anything else is imported +# NOTE: this should be the first import (no not reorder) +from maskrcnn_benchmark.utils.env import setup_environment # noqa F401 isort:skip + +import argparse +import os + +import numpy as np +import torch +from maskrcnn_benchmark.config import cfg, try_to_find +from maskrcnn_benchmark.data import make_data_loader +from maskrcnn_benchmark.solver import make_lr_scheduler +from maskrcnn_benchmark.solver import make_optimizer +from maskrcnn_benchmark.engine.inference import inference +from maskrcnn_benchmark.modeling.detector import build_detection_model +from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer +from maskrcnn_benchmark.utils.collect_env import collect_env_info +from maskrcnn_benchmark.utils.comm import synchronize, get_rank, is_main_process, all_gather +from maskrcnn_benchmark.utils.imports import import_file +from maskrcnn_benchmark.utils.logger import setup_logger +from maskrcnn_benchmark.utils.metric_logger import (MetricLogger, TensorboardLogger) +from maskrcnn_benchmark.utils.miscellaneous import mkdir, save_config +import random +from maskrcnn_benchmark.utils.amp import autocast, GradScaler + +from pathlib import Path +from tqdm import tqdm +from collections import defaultdict + +os.environ['CUDA_LAUNCH_BLOCKING'] = '1' +# os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def tuning_highlevel_override(cfg,): + if cfg.SOLVER.TUNING_HIGHLEVEL_OVERRIDE == "vision_query": + cfg.MODEL.BACKBONE.FREEZE = True + cfg.MODEL.FPN.FREEZE = True + cfg.MODEL.RPN.FREEZE = True if not cfg.VISION_QUERY.QUERY_FUSION else False + cfg.MODEL.LINEAR_PROB = False + cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER = False + cfg.MODEL.LANGUAGE_BACKBONE.FREEZE = False + cfg.MODEL.DYHEAD.USE_CHECKPOINT = False # Disable checkpoint + cfg.VISION_QUERY.ENABLED = True + if cfg.SOLVER.TUNING_HIGHLEVEL_OVERRIDE == "vs_with_txt_enc": + cfg.MODEL.BACKBONE.FREEZE = True + cfg.MODEL.FPN.FREEZE = True + cfg.MODEL.RPN.FREEZE = True if not cfg.VISION_QUERY.QUERY_FUSION else False + cfg.MODEL.LINEAR_PROB = False + cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER = False + cfg.MODEL.LANGUAGE_BACKBONE.FREEZE = False + cfg.MODEL.DYHEAD.USE_CHECKPOINT = False # Disable checkpoint + cfg.VISION_QUERY.ENABLED = True + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + +def extract_query(cfg): + if cfg.DATASETS.FEW_SHOT: + assert cfg.DATASETS.FEW_SHOT == cfg.VISION_QUERY.MAX_QUERY_NUMBER, 'To extract the right query instances, set VISION_QUERY.MAX_QUERY_NUMBER = DATASETS.FEW_SHOT.' + # if cfg.num_gpus > 1: + # max_query_number = cfg.VISION_QUERY.MAX_QUERY_NUMBER + # cfg.defrost() + # cfg.VISION_QUERY.MAX_QUERY_NUMBER = int(cfg.VISION_QUERY.MAX_QUERY_NUMBER/cfg.num_gpus) + # cfg.freeze() + + model = build_detection_model(cfg) + device = torch.device(cfg.MODEL.DEVICE) + model.to(device) + + + checkpointer = DetectronCheckpointer( + cfg, model + ) + checkpointer.load(try_to_find(cfg.MODEL.WEIGHT)) + + data_loader = make_data_loader( + cfg, + is_train=False, + is_cache=True, + is_distributed= cfg.num_gpus > 1, + ) + assert isinstance(data_loader, list) and len(data_loader)==1 + data_loader=data_loader[0] + + # if cfg.VISION_QUERY.CUSTOM_DATA_IDS is not None: + # data_loader.dataset.ids = cfg.VISION_QUERY.CUSTOM_DATA_IDS + + if cfg.num_gpus > 1: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[cfg.local_rank], output_device=cfg.local_rank, + broadcast_buffers=cfg.MODEL.BACKBONE.USE_BN, + find_unused_parameters=cfg.SOLVER.FIND_UNUSED_PARAMETERS + ) + + query_images=defaultdict(list) + _iterator = tqdm(data_loader) + # _iterator = data_loader # for debug + model.eval() + for i, batch in enumerate(_iterator): + images, targets, *_ = batch + if cfg.num_gpus > 1: + query_images = model.module.extract_query(images.to(device), targets, query_images) + else: + query_images = model.extract_query(images.to(device), targets, query_images) + + if cfg.num_gpus > 1: + ## not stable when using all_gather, easy to OOM. + # all_query_images = all_gather(query_images) + # if is_main_process(): + # accumulated_query_images = defaultdict(list) + # for r, query_images_dict in enumerate(all_query_images): + # print('accumulating results: {}/{}'.format(r, len(all_query_images))) + # for label, feat in query_images_dict.items(): + # num_queries=len(accumulated_query_images[label]) + # if num_queries >= cfg.VISION_QUERY.MAX_QUERY_NUMBER: + # continue + # if num_queries==0: + # accumulated_query_images[label] = feat.to(device) + # else: + # accumulated_query_images[label] = torch.cat([accumulated_query_images[label].to(device), feat.to(device)]) + + # save_name = 'MODEL/{}_query_{}_pool{}_{}{}_multi-node.pth'.format(cfg.VISION_QUERY.DATASET_NAME if cfg.VISION_QUERY.DATASET_NAME else cfg.DATASETS.TRAIN[0].split('_')[0] , cfg.VISION_QUERY.MAX_QUERY_NUMBER, cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION ,'sel' if cfg.VISION_QUERY.SELECT_FPN_LEVEL else 'all', cfg.VISION_QUERY.QUERY_ADDITION_NAME) + # print('saving to ', save_name) + # torch.save(accumulated_query_images, save_name) + if cfg.VISION_QUERY.QUERY_BANK_SAVE_PATH != '': + raise NotImplementedError + global_rank = get_rank() + save_name = 'MODEL/{}_query_{}_pool{}_{}{}_rank{}.pth'.format(cfg.VISION_QUERY.DATASET_NAME if cfg.VISION_QUERY.DATASET_NAME else cfg.DATASETS.TRAIN[0].split('_')[0] , cfg.VISION_QUERY.MAX_QUERY_NUMBER, cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION ,'sel' if cfg.VISION_QUERY.SELECT_FPN_LEVEL else 'all', cfg.VISION_QUERY.QUERY_ADDITION_NAME, global_rank) + print('saving to ', save_name) + torch.save(query_images, save_name) + else: + if cfg.VISION_QUERY.QUERY_BANK_SAVE_PATH != '': + save_name = cfg.VISION_QUERY.QUERY_BANK_SAVE_PATH + else: + save_name = 'MODEL/{}_query_{}_pool{}_{}{}.pth'.format(cfg.VISION_QUERY.DATASET_NAME if cfg.VISION_QUERY.DATASET_NAME else cfg.DATASETS.TRAIN[0].split('_')[0] , cfg.VISION_QUERY.MAX_QUERY_NUMBER, cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION ,'sel' if cfg.VISION_QUERY.SELECT_FPN_LEVEL else 'all', cfg.VISION_QUERY.QUERY_ADDITION_NAME) + print('saving to ', save_name) + torch.save(query_images, save_name) + # if cfg.num_gpus > 1: + # # + # world_size = torch.distributed.dist.get_world_size() + # if is_main_process(): + # query_images_list = [] + # for r in range(world_size): + # saved_path = 'MODEL/{}_query_{}_pool{}_{}{}_rank{}.pth'.format(cfg.VISION_QUERY.DATASET_NAME if cfg.VISION_QUERY.DATASET_NAME else cfg.DATASETS.TRAIN[0].split('_')[0] , cfg.VISION_QUERY.MAX_QUERY_NUMBER, cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION ,'sel' if cfg.VISION_QUERY.SELECT_FPN_LEVEL else 'all', cfg.VISION_QUERY.QUERY_ADDITION_NAME, r) + # query_images_list.append(torch.load(saved_path, map_location='cpu')) + + # for s in query_images_list + + +def main(): + parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") + parser.add_argument( + "--config-file", + default="", + metavar="FILE", + help="path to config file", + type=str, + ) + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument( + "--skip-test", + dest="skip_test", + help="Do not test the final model", + action="store_true", + ) + + parser.add_argument("--use-tensorboard", + dest="use_tensorboard", + help="Use tensorboardX logger (Requires tensorboardX installed)", + action="store_true", + default=False + ) + + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + + parser.add_argument("--save_original_config", action="store_true") + parser.add_argument("--disable_output_distributed", action="store_true") + parser.add_argument("--override_output_dir", default=None) + parser.add_argument("--custom_shot_and_epoch_and_general_copy", default=None, type=str) + parser.add_argument("--resume", action="store_true", default=False) + parser.add_argument("--extract_query", action="store_true", default=False) + parser.add_argument( + "--task_config", + default="", + metavar="FILE", + help="path to config file", + type=str, + ) + parser.add_argument( + "--additional_model_config", + default="", + metavar="FILE", + help="path to config file", + type=str, + ) + + + + args = parser.parse_args() + + num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 + args.distributed = num_gpus > 1 + + if args.distributed: + import datetime + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group( + backend="nccl", init_method="env://", + timeout=datetime.timedelta(0, 7200) + ) + + if args.disable_output_distributed: + setup_for_distributed(args.local_rank <= 0) + + cfg.local_rank = args.local_rank + cfg.num_gpus = num_gpus + + cfg.merge_from_file(args.config_file) + if args.task_config: + cfg.merge_from_file(args.task_config) + if args.additional_model_config: + cfg.merge_from_file(args.additional_model_config) + cfg.merge_from_list(args.opts) + # specify output dir for models + if args.override_output_dir: + cfg.OUTPUT_DIR = args.override_output_dir + tuning_highlevel_override(cfg) + cfg.freeze() + + seed = cfg.SOLVER.SEED + args.local_rank + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + output_dir = cfg.OUTPUT_DIR + if output_dir: + mkdir(output_dir) + + logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank()) + logger.info(args) + logger.info("Using {} GPUs".format(num_gpus)) + + logger.info("Loaded configuration file {}".format(args.config_file)) + with open(args.config_file, "r") as cf: + config_str = "\n" + cf.read() + logger.info(config_str) + logger.info("Running with config:\n{}".format(cfg)) + + output_config_path = os.path.join(cfg.OUTPUT_DIR, 'config.yml') + logger.info("Saving config into: {}".format(output_config_path)) + # save overloaded model config in the output directory + if args.save_original_config: + import shutil + shutil.copy(args.config_file, os.path.join(cfg.OUTPUT_DIR, 'config_original.yml')) + + save_config(cfg, output_config_path) + + if args.extract_query: + extract_query(cfg) + else: + raise NotImplementedError + + +if __name__ == "__main__": + main() \ No newline at end of file