mirror of https://github.com/YifanXu74/MQ-Det.git
797 lines
33 KiB
Python
797 lines
33 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
import datetime
|
|
import logging
|
|
import time
|
|
import os
|
|
import re
|
|
|
|
import torch
|
|
from tqdm import tqdm
|
|
from collections import defaultdict
|
|
|
|
from maskrcnn_benchmark.data.datasets.evaluation import evaluate, im_detect_bbox_aug
|
|
from ..utils.comm import is_main_process
|
|
from ..utils.comm import all_gather
|
|
from ..utils.comm import synchronize
|
|
import pdb
|
|
from maskrcnn_benchmark.data.datasets.evaluation.flickr.flickr_eval import FlickrEvaluator
|
|
from maskrcnn_benchmark.structures.bounding_box import BoxList
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.pylab as pylab
|
|
from maskrcnn_benchmark.data.datasets.tsv import load_from_yaml_file
|
|
|
|
|
|
from PIL import Image
|
|
|
|
def imshow(img, file_name = "tmp.jpg"):
|
|
plt.imshow(img[:, :, [2, 1, 0]])
|
|
plt.axis("off")
|
|
#plt.figtext(0.5, 0.09, "test", wrap=True, horizontalalignment='center', fontsize=20)
|
|
plt.savefig(file_name)
|
|
def load(url_or_file_name):
|
|
try:
|
|
response = requests.get(url_or_file_name)
|
|
except:
|
|
response = None
|
|
if response is None:
|
|
pil_image = Image.open(url_or_file_name).convert("RGB")
|
|
else:
|
|
pil_image = Image.open(BytesIO(response.content)).convert("RGB")
|
|
# convert to BGR format
|
|
image = np.array(pil_image)[:, :, [2, 1, 0]]
|
|
return image
|
|
def inference_default(
|
|
model,
|
|
data_loader,
|
|
dataset_name,
|
|
iou_types=("bbox",),
|
|
box_only=False,
|
|
device="cuda",
|
|
expected_results=(),
|
|
expected_results_sigma_tol=4,
|
|
output_folder=None,
|
|
cfg=None
|
|
):
|
|
# convert to a torch.device for efficiency
|
|
device = torch.device(device)
|
|
num_devices = (
|
|
torch.distributed.get_world_size()
|
|
if torch.distributed.is_initialized()
|
|
else 1
|
|
)
|
|
logger = logging.getLogger("maskrcnn_benchmark.inference")
|
|
dataset = data_loader.dataset
|
|
logger.info("Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset)))
|
|
start_time = time.time()
|
|
|
|
model.eval()
|
|
results_dict = {}
|
|
cpu_device = torch.device("cpu")
|
|
for i, batch in enumerate(tqdm(data_loader)):
|
|
images, targets, image_ids, *_ = batch
|
|
with torch.no_grad():
|
|
if cfg.TEST.USE_MULTISCALE:
|
|
output = im_detect_bbox_aug(model, images, device)
|
|
else:
|
|
output = model(images.to(device))
|
|
output = [o.to(cpu_device) for o in output]
|
|
results_dict.update(
|
|
{img_id: result for img_id, result in zip(image_ids, output)}
|
|
)
|
|
predictions = results_dict
|
|
# wait for all processes to complete before measuring the time
|
|
synchronize()
|
|
total_time = time.time() - start_time
|
|
total_time_str = str(datetime.timedelta(seconds=total_time))
|
|
logger.info(
|
|
"Total inference time: {} ({} s / img per device, on {} devices)".format(
|
|
total_time_str, total_time * num_devices / len(dataset), num_devices
|
|
)
|
|
)
|
|
|
|
predictions = _accumulate_predictions_from_multiple_gpus(predictions)
|
|
if not is_main_process():
|
|
return None
|
|
|
|
if output_folder:
|
|
torch.save(predictions, os.path.join(output_folder, "predictions.pth"))
|
|
|
|
extra_args = dict(
|
|
box_only=box_only,
|
|
iou_types=iou_types,
|
|
expected_results=expected_results,
|
|
expected_results_sigma_tol=expected_results_sigma_tol,
|
|
)
|
|
return evaluate(dataset=dataset, predictions=predictions, output_folder=output_folder, **extra_args)
|
|
|
|
|
|
def clean_name(name):
|
|
name = re.sub(r"\(.*\)", "", name)
|
|
name = re.sub(r"_", " ", name)
|
|
name = re.sub(r" ", " ", name)
|
|
return name
|
|
|
|
|
|
def create_one_hot_dict(labels, no_minus_one_for_one_hot = False):
|
|
positive_map_token_to_label = defaultdict(int)
|
|
positive_map_label_to_token = defaultdict(int)
|
|
|
|
for i in range(len(labels)):
|
|
positive_map_token_to_label[i] = labels[i]
|
|
positive_map_label_to_token[labels[i]] = i
|
|
|
|
if no_minus_one_for_one_hot:
|
|
positive_map_token_to_label = defaultdict(int)
|
|
positive_map_label_to_token = defaultdict(int)
|
|
|
|
for i in range(len(labels)):
|
|
positive_map_token_to_label[i+1] = labels[i]
|
|
positive_map_label_to_token[labels[i]] = i + 1
|
|
|
|
return positive_map_token_to_label, positive_map_label_to_token
|
|
|
|
|
|
def create_positive_dict(tokenized, tokens_positive, labels):
|
|
"""construct a dictionary such that positive_map[i] = j, iff token i is mapped to j label"""
|
|
positive_map = defaultdict(int)
|
|
|
|
# Additionally, have positive_map_label_to_tokens
|
|
positive_map_label_to_token = defaultdict(list)
|
|
|
|
for j, tok_list in enumerate(tokens_positive):
|
|
for (beg, end) in tok_list:
|
|
beg_pos = tokenized.char_to_token(beg)
|
|
end_pos = tokenized.char_to_token(end - 1)
|
|
if beg_pos is None:
|
|
try:
|
|
beg_pos = tokenized.char_to_token(beg + 1)
|
|
if beg_pos is None:
|
|
beg_pos = tokenized.char_to_token(beg + 2)
|
|
except:
|
|
beg_pos = None
|
|
if end_pos is None:
|
|
try:
|
|
end_pos = tokenized.char_to_token(end - 2)
|
|
if end_pos is None:
|
|
end_pos = tokenized.char_to_token(end - 3)
|
|
except:
|
|
end_pos = None
|
|
if beg_pos is None or end_pos is None:
|
|
continue
|
|
|
|
assert beg_pos is not None and end_pos is not None
|
|
for i in range(beg_pos, end_pos + 1):
|
|
positive_map[i] = labels[j] # because the labels starts from 1
|
|
positive_map_label_to_token[labels[j]].append(i)
|
|
# positive_map[j, beg_pos : end_pos + 1].fill_(1)
|
|
return positive_map, positive_map_label_to_token # / (positive_map.sum(-1)[:, None] + 1e-6)
|
|
|
|
def chunks(lst, n):
|
|
"""Yield successive n-sized chunks from lst."""
|
|
all_ = []
|
|
for i in range(0, len(lst), n):
|
|
data_index = lst[i:i + n]
|
|
all_.append(data_index)
|
|
counter = 0
|
|
for i in all_:
|
|
counter += len(i)
|
|
assert(counter == len(lst))
|
|
|
|
return all_
|
|
|
|
def create_queries_and_maps_from_dataset(dataset, cfg, disable_print=False):
|
|
categories = dataset.categories()
|
|
#one_hot = dataset.one_hot
|
|
|
|
labels = []
|
|
label_list = []
|
|
keys = list(categories.keys())
|
|
keys.sort()
|
|
for i in keys:
|
|
labels.append(i)
|
|
label_list.append(categories[i])
|
|
|
|
if cfg.TEST.CHUNKED_EVALUATION != -1:
|
|
labels = chunks(labels, cfg.TEST.CHUNKED_EVALUATION)
|
|
label_list = chunks(label_list, cfg.TEST.CHUNKED_EVALUATION)
|
|
else:
|
|
labels = [labels]
|
|
label_list = [label_list]
|
|
|
|
all_queries = []
|
|
all_positive_map_label_to_token = []
|
|
|
|
for i in range(len(labels)):
|
|
labels_i = labels[i]
|
|
label_list_i = label_list[i]
|
|
query_i, positive_map_label_to_token_i = create_queries_and_maps(
|
|
labels_i, label_list_i, additional_labels = cfg.DATASETS.SUPRESS_QUERY if cfg.DATASETS.USE_SUPRESS_QUERY else None, cfg = cfg, disable_print=disable_print)
|
|
|
|
all_queries.append(query_i)
|
|
all_positive_map_label_to_token.append(positive_map_label_to_token_i)
|
|
if not disable_print:
|
|
print("All queries", all_queries)
|
|
return all_queries, all_positive_map_label_to_token
|
|
|
|
def create_queries_and_maps(labels, label_list, additional_labels = None, cfg = None, disable_print = False):
|
|
|
|
# Clean label list
|
|
original_label_list = label_list.copy()
|
|
label_list = [clean_name(i) for i in label_list]
|
|
# Form the query and get the mapping
|
|
tokens_positive = []
|
|
start_i = 0
|
|
end_i = 0
|
|
objects_query = ""
|
|
|
|
# sep between tokens, follow training
|
|
separation_tokens = cfg.DATASETS.SEPARATION_TOKENS
|
|
|
|
caption_prompt = cfg.DATASETS.CAPTION_PROMPT
|
|
if caption_prompt is not None and isinstance(caption_prompt, str):
|
|
caption_prompt = load_from_yaml_file(caption_prompt)
|
|
use_caption_prompt = cfg.DATASETS.USE_CAPTION_PROMPT and caption_prompt is not None
|
|
for _index, label in enumerate(label_list):
|
|
if use_caption_prompt:
|
|
objects_query += caption_prompt[_index]["prefix"]
|
|
|
|
start_i = len(objects_query)
|
|
|
|
if use_caption_prompt:
|
|
objects_query += caption_prompt[_index]["name"]
|
|
else:
|
|
objects_query += label
|
|
|
|
end_i = len(objects_query)
|
|
tokens_positive.append([(start_i, end_i)]) # Every label has a [(start, end)]
|
|
|
|
if use_caption_prompt:
|
|
objects_query += caption_prompt[_index]["suffix"]
|
|
|
|
if _index != len(label_list) - 1:
|
|
objects_query += separation_tokens
|
|
|
|
if additional_labels is not None:
|
|
objects_query += separation_tokens
|
|
for _index, label in enumerate(additional_labels):
|
|
objects_query += label
|
|
if _index != len(additional_labels) - 1:
|
|
objects_query += separation_tokens
|
|
if not disable_print:
|
|
print(objects_query)
|
|
|
|
from transformers import AutoTokenizer
|
|
# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
|
if os.path.basename(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE) == "bert-base-uncased":
|
|
tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
|
|
tokenized = tokenizer(objects_query, return_tensors="pt")
|
|
elif os.path.basename(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE) == "clip":
|
|
from transformers import CLIPTokenizerFast
|
|
if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
|
|
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
|
|
from_slow=True, mask_token='ðŁĴij</w>')
|
|
else:
|
|
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
|
|
from_slow=True)
|
|
tokenized = tokenizer(objects_query,
|
|
max_length=cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN,
|
|
truncation=True,
|
|
return_tensors="pt")
|
|
else:
|
|
tokenizer = None
|
|
raise NotImplementedError
|
|
|
|
# Create the mapping between tokenized sentence and the original label
|
|
positive_map_token_to_label, positive_map_label_to_token = create_positive_dict(tokenized, tokens_positive,
|
|
labels=labels) # from token position to original label
|
|
return objects_query, positive_map_label_to_token
|
|
|
|
def create_positive_map_label_to_token_from_positive_map(positive_map, plus = 0):
|
|
positive_map_label_to_token = {}
|
|
for i in range(len(positive_map)):
|
|
positive_map_label_to_token[i + plus] = torch.nonzero(positive_map[i], as_tuple=True)[0].tolist()
|
|
return positive_map_label_to_token
|
|
|
|
|
|
|
|
def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu):
|
|
all_predictions = all_gather(predictions_per_gpu)
|
|
if not is_main_process():
|
|
return
|
|
# merge the list of dicts
|
|
predictions = {}
|
|
for p in all_predictions:
|
|
predictions.update(p)
|
|
# convert a dict where the key is the index in a list
|
|
image_ids = list(sorted(predictions.keys()))
|
|
if len(image_ids) != image_ids[-1] + 1:
|
|
logger = logging.getLogger("maskrcnn_benchmark.inference")
|
|
logger.warning(
|
|
"Number of images that were gathered from multiple processes is not "
|
|
"a contiguous set. Some images might be missing from the evaluation"
|
|
)
|
|
|
|
# convert to a list
|
|
predictions = [predictions[i] for i in image_ids]
|
|
return predictions
|
|
|
|
def resize_box(output, targets):
|
|
if isinstance(targets[0], dict):
|
|
orig_target_sizes = targets[0]["orig_size"].unsqueeze(0)
|
|
else:
|
|
orig_target_sizes = torch.stack([targets[0].extra_fields["orig_size"] for _ in range(1)], dim=0)
|
|
img_h, img_w = orig_target_sizes.unbind(1)
|
|
return output.resize((img_w, img_h))
|
|
|
|
def flickr_post_process(output, targets, positive_map_label_to_token, plus):
|
|
output = resize_box(output, targets)
|
|
scores, indices = torch.topk(output.extra_fields["scores"], k = len(output.extra_fields["scores"]), sorted=True)
|
|
boxes = output.bbox.tolist()
|
|
boxes = [boxes[i] for i in indices]
|
|
labels = [output.extra_fields["labels"][i] for i in indices]
|
|
output_boxes = [[] for i in range(len(positive_map_label_to_token))]
|
|
output_scores = [[] for i in range(len(positive_map_label_to_token))]
|
|
for i in range(len(boxes)):
|
|
output_boxes[labels[i] - plus].append(boxes[i])
|
|
output_scores[labels[i] - plus].append(scores[i])
|
|
for i in output_boxes:
|
|
i.append([0.0, 0.0, 0.0, 0.0])
|
|
image_ids = [t.extra_fields["original_img_id"] for t in targets]
|
|
sentence_ids = [t.extra_fields["sentence_id"] for t in targets]
|
|
|
|
return {"image_id": image_ids[0], "sentence_id": sentence_ids[0], "boxes": output_boxes, "scores": output_scores}
|
|
|
|
def build_flickr_evaluator(cfg):
|
|
evaluator = FlickrEvaluator(
|
|
"DATASET/flickr30k/flickr30k/", # Hard written!!
|
|
subset="test" if "test" in cfg.DATASETS.TEST[0] else "val",
|
|
merge_boxes=cfg.DATASETS.FLICKR_GT_TYPE == "merged")
|
|
return evaluator
|
|
|
|
def build_lvis_evaluator(ann_file, fixed_ap=True):
|
|
from maskrcnn_benchmark.data.datasets.evaluation.lvis.lvis import LVIS
|
|
from maskrcnn_benchmark.data.datasets.evaluation.lvis.lvis_eval import LvisEvaluatorFixedAP, LvisEvaluator
|
|
evaluator = LvisEvaluatorFixedAP(LVIS(ann_file), fixed_ap=fixed_ap)
|
|
#evaluator = LvisEvaluator(LVIS(ann_file), iou_types=['segm', 'bbox'])
|
|
return evaluator
|
|
|
|
def write_lvis_results(results, output_file_name):
|
|
lines = []
|
|
lines.append("metric, avg ")
|
|
for each_result in results:
|
|
metric_string = " ".join(each_result.split(" ")[:-2])
|
|
number = each_result.split(" ")[-1]
|
|
each_result = metric_string + ", " + number + " "
|
|
lines.append(each_result)
|
|
|
|
string_to_write = "\n".join(lines) + "\n"
|
|
with open(output_file_name, "w") as f:
|
|
f.write(string_to_write)
|
|
return
|
|
|
|
def write_flickr_results(results, output_file_name):
|
|
'''
|
|
{'Recall@1_all': 0.8394651146677753, 'Recall@1_animals': 0.9177820267686424, 'Recall@1_bodyparts': 0.7097966728280961, ...}
|
|
'''
|
|
lines = []
|
|
lines.append("metric, avg ")
|
|
for each_metric, number in results.items():
|
|
each_result = each_metric + ", " + str(number) + " "
|
|
lines.append(each_result)
|
|
|
|
string_to_write = "\n".join(lines) + "\n"
|
|
with open(output_file_name, "w") as f:
|
|
f.write(string_to_write)
|
|
return
|
|
|
|
def online_update(
|
|
model,
|
|
data_loader,
|
|
device="cuda",
|
|
cfg=None,
|
|
num_turns = 1,
|
|
save_name = None,
|
|
):
|
|
try:
|
|
device = torch.device(device)
|
|
except:
|
|
device = device
|
|
|
|
dataset = data_loader.dataset
|
|
task = cfg.TEST.EVAL_TASK
|
|
if os.path.exists(cfg.VISION_QUERY.QUERY_BANK_PATH):
|
|
query_images = torch.load(cfg.VISION_QUERY.QUERY_BANK_PATH, map_location=device)
|
|
else:
|
|
query_images=defaultdict(list)
|
|
if save_name is None:
|
|
save_name = 'MODEL/{}_val_query_{}_pool{}_loop{}_{}.pth'.format(cfg.VISION_QUERY.DATASET_NAME if cfg.VISION_QUERY.DATASET_NAME else cfg.DATASETS.TEST[0].split('_')[0] , cfg.VISION_QUERY.MAX_TEST_QUERY_NUMBER, cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION , num_turns,'sel' if cfg.VISION_QUERY.SELECT_FPN_LEVEL else 'all')
|
|
|
|
|
|
for turns in range(num_turns):
|
|
print('online update round {} / {}'.format(turns+1, num_turns))
|
|
|
|
if turns > 0:
|
|
print('load new query banks...')
|
|
model.load_query_bank(save_name)
|
|
|
|
if task == "detection":
|
|
all_queries, all_positive_map_label_to_token = create_queries_and_maps_from_dataset(dataset, cfg, disable_print=True)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
'''
|
|
Build Dataset Sepecific Evaluator
|
|
'''
|
|
if "flickr" in cfg.DATASETS.TEST[0]:
|
|
evaluator = None
|
|
elif "lvis" in cfg.DATASETS.TEST[0]:
|
|
evaluator = None
|
|
else:
|
|
evaluator = None
|
|
|
|
model.eval()
|
|
results_dict = {}
|
|
cpu_device = torch.device("cpu")
|
|
|
|
_iterator = tqdm(data_loader)
|
|
# _iterator = data_loader
|
|
for i, batch in enumerate(_iterator):
|
|
if i == cfg.TEST.SUBSET:
|
|
break
|
|
images, targets, image_ids, *_ = batch
|
|
|
|
all_output = []
|
|
mdetr_style_output = []
|
|
with torch.no_grad():
|
|
if cfg.TEST.USE_MULTISCALE:
|
|
raise NotImplementedError
|
|
else:
|
|
images = images.to(device)
|
|
query_time = len(all_queries)
|
|
|
|
for query_i in range(query_time):
|
|
if not isinstance(targets[0], dict): # For LVIS dataset and datasets directly copied from MDETR
|
|
targets = [target.to(device) for target in targets]
|
|
'''
|
|
different datasets seem to have different data format... For LVIS dataset, the target is a dictionary, while for modulatedDataset such as COCO/Flickr, the target is a BoxList
|
|
'''
|
|
|
|
if task == "detection":
|
|
captions = [all_queries[query_i] for ii in range(len(targets))]
|
|
positive_map_label_to_token = all_positive_map_label_to_token[query_i]
|
|
elif task == "grounding":
|
|
raise NotImplementedError
|
|
output, backbone_features = model(images, captions=captions, positive_map=positive_map_label_to_token, return_backbone_features=True)
|
|
output = [o.to(cpu_device) for o in output]
|
|
|
|
if "flickr" in cfg.DATASETS.TEST[0]:
|
|
raise NotImplementedError
|
|
elif "lvis" in cfg.DATASETS.TEST[0]:
|
|
output = [[row[_i] for row in all_output] for _i in range(len([output][0]))]
|
|
for index, i in enumerate(output):
|
|
output[index] = i[0].concate_box_list(i)
|
|
assert len(output) == 1 # TODO: support batched outputs
|
|
kept_instances = []
|
|
for o in output:
|
|
keep = o.get_field('scores') > cfg.VISION_QUERY.SCORE_THRESHOLD
|
|
kept_instances.append(o[keep])
|
|
query_images = model.extract_query(targets=kept_instances, query_images=query_images, visual_features=backbone_features, exclude_similar=True, device = device, max_query_number=cfg.VISION_QUERY.MAX_TEST_QUERY_NUMBER)
|
|
|
|
else:
|
|
all_output.append(output)
|
|
if "lvis" in cfg.DATASETS.TEST[0]:
|
|
continue
|
|
|
|
if evaluator is not None:
|
|
evaluator.update(mdetr_style_output)
|
|
else:
|
|
output = [[row[_i] for row in all_output] for _i in range(len(all_output[0]))]
|
|
for index, i in enumerate(output):
|
|
output[index] = i[0].concate_box_list(i)
|
|
assert len(output) == 1 # TODO: support batched outputs
|
|
kept_instances = []
|
|
for o in output:
|
|
keep = o.get_field('scores') > cfg.VISION_QUERY.SCORE_THRESHOLD
|
|
kept_instances.append(o[keep])
|
|
query_images = model.extract_query(targets=kept_instances, query_images=query_images, visual_features=backbone_features, exclude_similar=True, device = device, max_query_number=cfg.VISION_QUERY.MAX_TEST_QUERY_NUMBER)
|
|
# query_images = model.extract_query(targets=targets, query_images=query_images, visual_features=backbone_features, exclude_similar=True, device = device, max_query_number=cfg.VISION_QUERY.MAX_TEST_QUERY_NUMBER) # only for debug
|
|
|
|
|
|
print('saving to ', save_name)
|
|
torch.save(query_images, save_name)
|
|
print('succucessed')
|
|
return model
|
|
|
|
|
|
def inference(
|
|
model,
|
|
data_loader,
|
|
dataset_name,
|
|
iou_types=("bbox",),
|
|
box_only=False,
|
|
device="cuda",
|
|
expected_results=(),
|
|
expected_results_sigma_tol=4,
|
|
output_folder=None,
|
|
cfg=None,
|
|
verbose=True,
|
|
visualizer = None,
|
|
disable_print = False,
|
|
):
|
|
processor, model = model
|
|
# convert to a torch.device for efficiency
|
|
try:
|
|
device = torch.device(device)
|
|
except:
|
|
device = device
|
|
num_devices = (
|
|
torch.distributed.get_world_size()
|
|
if torch.distributed.is_initialized()
|
|
else 1
|
|
)
|
|
logger = logging.getLogger("maskrcnn_benchmark.inference")
|
|
dataset = data_loader.dataset
|
|
if verbose:
|
|
logger.info("Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset)))
|
|
start_time = time.time()
|
|
|
|
task = cfg.TEST.EVAL_TASK
|
|
|
|
if not task:
|
|
return inference_default(model, data_loader, dataset_name, iou_types, box_only, device, expected_results, expected_results_sigma_tol, output_folder, cfg)
|
|
|
|
if cfg.GLIPKNOW.PARALLEL_LANGUAGE_INPUT:
|
|
assert task == 'detection'
|
|
categories = dataset.categories()
|
|
|
|
keys = list(categories.keys())
|
|
keys.sort()
|
|
all_queries = [[categories[k] for k in keys]]
|
|
all_positive_map_label_to_token = [{k: [i] for i, k in enumerate(keys)}]
|
|
elif task == "detection":
|
|
all_queries, all_positive_map_label_to_token = create_queries_and_maps_from_dataset(dataset, cfg, disable_print=disable_print)
|
|
elif task == "grounding":
|
|
all_queries = [None]
|
|
all_positive_map_label_to_token = [None]
|
|
else:
|
|
assert(0)
|
|
|
|
'''
|
|
Build Dataset Sepecific Evaluator
|
|
'''
|
|
if "flickr" in cfg.DATASETS.TEST[0]:
|
|
evaluator = build_flickr_evaluator(cfg)
|
|
elif "lvis" in cfg.DATASETS.TEST[0]:
|
|
evaluator = build_lvis_evaluator(dataset.ann_file, fixed_ap=not cfg.DATASETS.LVIS_USE_NORMAL_AP)
|
|
else:
|
|
evaluator = None
|
|
|
|
model.eval()
|
|
results_dict = {}
|
|
cpu_device = torch.device("cpu")
|
|
if verbose:
|
|
_iterator = tqdm(data_loader)
|
|
else:
|
|
_iterator = data_loader
|
|
if cfg.VISION_QUERY.RETURN_ATTN_GATE_VALUE:
|
|
attn_values = 0.
|
|
# if cfg.VISION_QUERY.DEBUG:
|
|
# for i, batch in enumerate(_iterator):
|
|
# images, targets, image_ids, *_ = batch
|
|
# pass
|
|
|
|
query_image = Image.open('DATASET/odinw/CottontailRabbits/train/06-07-rabbits_jpg.rf.103f96525bc1f708c04ba4aea42eaf45.jpg')
|
|
|
|
for i, batch in enumerate(_iterator):
|
|
if cfg.VISION_QUERY.DEBUG:
|
|
if i > 0:
|
|
break
|
|
if i == cfg.TEST.SUBSET:
|
|
break
|
|
images, targets, image_ids, *_ = batch
|
|
path = batch[-1]
|
|
path = path[0]
|
|
|
|
image = Image.open(path)
|
|
|
|
inputs = processor(images=image, query_images=query_image, return_tensors="pt")
|
|
inputs = {k:v.to(device) for k,v in inputs.items()}
|
|
|
|
all_output = []
|
|
mdetr_style_output = []
|
|
with torch.no_grad():
|
|
if cfg.TEST.USE_MULTISCALE:
|
|
raise NotImplementedError
|
|
query_time = len(all_queries)
|
|
for query_i in range(query_time):
|
|
if task == "detection":
|
|
captions = [all_queries[query_i] for ii in range(len(targets))]
|
|
positive_map_label_to_token = all_positive_map_label_to_token[query_i]
|
|
else:
|
|
captions = None
|
|
positive_map_label_to_token = None
|
|
|
|
output = im_detect_bbox_aug(model, images, device, captions, positive_map_label_to_token)
|
|
output = [o.to(cpu_device) for o in output]
|
|
all_output.append(output)
|
|
else:
|
|
images = images.to(device)
|
|
query_time = len(all_queries)
|
|
|
|
for query_i in range(query_time):
|
|
if not isinstance(targets[0], dict): # For LVIS dataset and datasets directly copied from MDETR
|
|
targets = [target.to(device) for target in targets]
|
|
'''
|
|
different datasets seem to have different data format... For LVIS dataset, the target is a dictionary, while for modulatedDataset such as COCO/Flickr, the target is a BoxList
|
|
'''
|
|
|
|
if task == "detection":
|
|
captions = [all_queries[query_i] for ii in range(len(targets))]
|
|
positive_map_label_to_token = all_positive_map_label_to_token[query_i]
|
|
elif task == "grounding":
|
|
captions = [t.get_field("caption") for t in targets]
|
|
positive_map_eval = [t.get_field("positive_map_eval") for t in targets]
|
|
if cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD":
|
|
plus = 1
|
|
else:
|
|
plus = 0
|
|
assert(len(positive_map_eval) == 1) # Let's just use one image per batch
|
|
positive_map_eval = positive_map_eval[0]
|
|
positive_map_label_to_token = create_positive_map_label_to_token_from_positive_map(positive_map_eval, plus=plus)
|
|
|
|
|
|
# output = model(images, captions=captions, positive_map=positive_map_label_to_token)
|
|
outputs = model.image_guided_detection(**inputs)
|
|
target_sizes = torch.Tensor([image.size[::-1]])
|
|
output_ = processor.post_process_image_guided_detection(
|
|
outputs=outputs, threshold=0.6, nms_threshold=0.3, target_sizes=target_sizes
|
|
)
|
|
# output_ = processor.post_process_image_guided_detection(
|
|
# outputs=outputs, threshold=0.3, nms_threshold=0.3, target_sizes=target_sizes
|
|
# )
|
|
output_ = output_[0]
|
|
output = BoxList(bbox=output_['boxes'], image_size=(image.size[::-1][1],image.size[::-1][0]))
|
|
output.add_field('scores', output_['scores'])
|
|
lab = torch.full([len(output_['scores'])], 1, device=output_['scores'].device)
|
|
output.add_field('labels', lab)
|
|
output = [output]
|
|
|
|
if cfg.VISION_QUERY.RETURN_ATTN_GATE_VALUE:
|
|
output, attn_values_list = output
|
|
attn_values = attn_values + torch.tensor(attn_values_list)
|
|
|
|
|
|
output = [o.to(cpu_device) for o in output]
|
|
|
|
if "flickr" in cfg.DATASETS.TEST[0]:
|
|
output = output[0]
|
|
new_output = flickr_post_process(
|
|
output,
|
|
targets,
|
|
positive_map_label_to_token,
|
|
plus # This is only used in Flickr
|
|
)
|
|
mdetr_style_output.append(new_output)
|
|
elif "lvis" in cfg.DATASETS.TEST[0]:
|
|
output = output[0]
|
|
output = resize_box(output, targets)
|
|
scores = output.extra_fields["scores"]
|
|
labels = output.extra_fields["labels"]
|
|
boxes = output.bbox
|
|
mdetr_style_output.append((targets[0]["image_id"].item(), {"scores": scores, "labels": labels, "boxes": boxes}))
|
|
else:
|
|
all_output.append(output)
|
|
if visualizer is not None:
|
|
assert(len(all_output) == 1)
|
|
if "lvis" in cfg.DATASETS.TEST[0]:
|
|
scores = [o[1]["scores"] for o in mdetr_style_output]
|
|
labels = [o[1]["labels"] for o in mdetr_style_output]
|
|
boxes = [o[1]["boxes"] for o in mdetr_style_output]
|
|
scores = torch.cat(scores, dim=0)
|
|
labels = torch.cat(labels, dim=0)
|
|
boxes = torch.cat(boxes, dim=0)
|
|
visualizer_input = BoxList(boxes, output.size)
|
|
visualizer_input.add_field("scores", scores)
|
|
visualizer_input.add_field("labels", labels)
|
|
else:
|
|
visualizer_input = all_output[0][0] # single image_visualize
|
|
|
|
image_id = dataset.ids[i]
|
|
try:
|
|
image_path = os.path.join(dataset.root, dataset.coco.loadImgs(image_id)[0]["file_name"])
|
|
categories = dataset.coco.dataset["categories"]
|
|
except:
|
|
lvis = dataset.lvis
|
|
img_id = dataset.ids[i]
|
|
ann_ids = lvis.get_ann_ids(img_ids=img_id)
|
|
target = lvis.load_anns(ann_ids)
|
|
|
|
image_path = "DATASET/coco/" + "/".join(dataset.lvis.load_imgs(img_id)[0]["coco_url"].split("/")[-2:])
|
|
categories = dataset.lvis.dataset["categories"]
|
|
|
|
image = load(image_path)
|
|
no_background = True
|
|
label_list = []
|
|
for index, i in enumerate(categories):
|
|
if not no_background or (i["name"] != "__background__" and i['id'] != 0):
|
|
label_list.append(i["name"])
|
|
visualizer.entities = label_list
|
|
|
|
result, _ = visualizer.visualize_with_predictions(
|
|
image,
|
|
visualizer_input,
|
|
threshold,
|
|
alpha=alpha,
|
|
box_pixel=box_pixel,
|
|
text_size=text_size,
|
|
text_pixel=text_pixel,
|
|
text_offset=text_offset,
|
|
text_offset_original=text_offset_original,
|
|
color=color,
|
|
)
|
|
imshow(result, "./visualize/img_{}.jpg".format(i))
|
|
|
|
if evaluator is not None:
|
|
evaluator.update(mdetr_style_output)
|
|
else:
|
|
output = [[row[_i] for row in all_output] for _i in range(len(all_output[0]))]
|
|
for index, i in enumerate(output):
|
|
output[index] = i[0].concate_box_list(i)
|
|
|
|
results_dict.update({img_id: result for img_id, result in zip(image_ids, output)})
|
|
|
|
if cfg.VISION_QUERY.RETURN_ATTN_GATE_VALUE:
|
|
attn_values = attn_values / (len(_iterator)*len(all_queries))
|
|
print('attn_values: ', attn_values)
|
|
if evaluator is not None:
|
|
evaluator.synchronize_between_processes()
|
|
try:
|
|
evaluator.accumulate()
|
|
except:
|
|
print("Evaluator has no accumulation, skipped...")
|
|
score = evaluator.summarize()
|
|
print(score)
|
|
import maskrcnn_benchmark.utils.mdetr_dist as dist
|
|
if is_main_process():
|
|
if "flickr" in cfg.DATASETS.TEST[0]:
|
|
write_flickr_results(score, output_file_name=os.path.join(output_folder, "bbox.csv"))
|
|
elif "lvis" in cfg.DATASETS.TEST[0]:
|
|
write_lvis_results(score, output_file_name=os.path.join(output_folder, "bbox.csv"))
|
|
if cfg.TEST.LVIS_RESULTS_SAVE_PATH != '':
|
|
evaluator.save_instances_results(cfg.TEST.LVIS_RESULTS_SAVE_PATH)
|
|
try:
|
|
torch.distributed.barrier()
|
|
except:
|
|
print("Default process group is not initialized")
|
|
return
|
|
|
|
if evaluator is not None:
|
|
predictions = mdetr_style_output
|
|
else:
|
|
predictions = results_dict
|
|
# wait for all processes to complete before measuring the time
|
|
synchronize()
|
|
total_time = time.time() - start_time
|
|
total_time_str = str(datetime.timedelta(seconds=total_time))
|
|
logger.info(
|
|
"Total inference time: {} ({} s / img per device, on {} devices)".format(
|
|
total_time_str, total_time * num_devices / len(dataset), num_devices
|
|
)
|
|
)
|
|
|
|
predictions = _accumulate_predictions_from_multiple_gpus(predictions)
|
|
print("Accumulated results")
|
|
if not is_main_process():
|
|
return None
|
|
|
|
if output_folder:
|
|
torch.save(predictions, os.path.join(output_folder, "predictions.pth"))
|
|
|
|
extra_args = dict(
|
|
box_only=box_only,
|
|
iou_types=iou_types,
|
|
expected_results=expected_results,
|
|
expected_results_sigma_tol=expected_results_sigma_tol,
|
|
)
|
|
return evaluate(dataset=dataset, predictions=predictions, output_folder=output_folder, **extra_args)
|