support coco evaluation
parent
4e6f23d35c
commit
ade33b7b47
2
LICENSE
2
LICENSE
|
@ -186,7 +186,7 @@
|
||||||
same "printed page" as the copyright notice for easier
|
same "printed page" as the copyright notice for easier
|
||||||
identification within third-party archives.
|
identification within third-party archives.
|
||||||
|
|
||||||
Copyright 2020 - present, Facebook, Inc
|
Copyright 2023 - present, IDEA Research.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|
30
README.md
30
README.md
|
@ -68,6 +68,7 @@ PyTorch implementation and pretrained models for Grounding DINO. For details, se
|
||||||
|
|
||||||
|
|
||||||
## :fire: News
|
## :fire: News
|
||||||
|
- **`2023/06/17`**: We provide an example to evaluat Grounding DINO on COCO zero-shot performance.
|
||||||
- **`2023/04/15`**: Refer to [CV in the Wild Readings](https://github.com/Computer-Vision-in-the-Wild/CVinW_Readings) for those who are interested in open-set recognition!
|
- **`2023/04/15`**: Refer to [CV in the Wild Readings](https://github.com/Computer-Vision-in-the-Wild/CVinW_Readings) for those who are interested in open-set recognition!
|
||||||
- **`2023/04/08`**: We release [demos](demo/image_editing_with_groundingdino_gligen.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [GLIGEN](https://github.com/gligen/GLIGEN) for more controllable image editings.
|
- **`2023/04/08`**: We release [demos](demo/image_editing_with_groundingdino_gligen.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [GLIGEN](https://github.com/gligen/GLIGEN) for more controllable image editings.
|
||||||
- **`2023/04/08`**: We release [demos](demo/image_editing_with_groundingdino_stablediffusion.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) for image editings.
|
- **`2023/04/08`**: We release [demos](demo/image_editing_with_groundingdino_stablediffusion.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) for image editings.
|
||||||
|
@ -129,24 +130,16 @@ cd GroundingDINO/
|
||||||
Install the required dependencies in the current directory.
|
Install the required dependencies in the current directory.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install -q -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
Create a new directory called "weights" to store the model weights.
|
|
||||||
|
Download pre-trained model weights.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
mkdir weights
|
mkdir weights
|
||||||
```
|
|
||||||
|
|
||||||
Change the current directory to the "weights" folder.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd weights
|
cd weights
|
||||||
```
|
|
||||||
|
|
||||||
Download the model weights file.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
|
wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
|
||||||
|
cd ..
|
||||||
```
|
```
|
||||||
|
|
||||||
## :arrow_forward: Demo
|
## :arrow_forward: Demo
|
||||||
|
@ -201,6 +194,19 @@ We also provide a demo code to integrate Grounding DINO with Gradio Web UI. See
|
||||||
- We release [demos](demo/image_editing_with_groundingdino_gligen.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [GLIGEN](https://github.com/gligen/GLIGEN) for more controllable image editings.
|
- We release [demos](demo/image_editing_with_groundingdino_gligen.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [GLIGEN](https://github.com/gligen/GLIGEN) for more controllable image editings.
|
||||||
- We release [demos](demo/image_editing_with_groundingdino_stablediffusion.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) for image editings.
|
- We release [demos](demo/image_editing_with_groundingdino_stablediffusion.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) for image editings.
|
||||||
|
|
||||||
|
## COCO Zero-shot Evaluations
|
||||||
|
|
||||||
|
We provide an example to evaluate Grounding DINO zero-shot performance on COCO. The results should be **48.5**.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python demo/test_ap_on_coco.py \
|
||||||
|
-c groundingdino/config/GroundingDINO_SwinT_OGC.py \
|
||||||
|
-p weights/groundingdino_swint_ogc.pth \
|
||||||
|
--anno_path /path/to/annoataions/ie/instances_val2017.json \
|
||||||
|
--image_dir /path/to/imagedir/ie/val2017
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## :luggage: Checkpoints
|
## :luggage: Checkpoints
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,233 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
|
||||||
|
from groundingdino.models import build_model
|
||||||
|
import groundingdino.datasets.transforms as T
|
||||||
|
from groundingdino.util import box_ops, get_tokenlizer
|
||||||
|
from groundingdino.util.misc import clean_state_dict, collate_fn
|
||||||
|
from groundingdino.util.slconfig import SLConfig
|
||||||
|
|
||||||
|
# from torchvision.datasets import CocoDetection
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
from groundingdino.util.vl_utils import build_captions_and_token_span, create_positive_map_from_span
|
||||||
|
from groundingdino.datasets.cocogrounding_eval import CocoGroundingEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
|
||||||
|
args = SLConfig.fromfile(model_config_path)
|
||||||
|
args.device = device
|
||||||
|
model = build_model(args)
|
||||||
|
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
||||||
|
model.load_state_dict(clean_state_dict(checkpoint["ema_model"]), strict=False)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class CocoDetection(torchvision.datasets.CocoDetection):
|
||||||
|
def __init__(self, img_folder, ann_file, transforms):
|
||||||
|
super().__init__(img_folder, ann_file)
|
||||||
|
self._transforms = transforms
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
img, target = super().__getitem__(idx) # target: list
|
||||||
|
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
w, h = img.size
|
||||||
|
boxes = [obj["bbox"] for obj in target]
|
||||||
|
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
||||||
|
boxes[:, 2:] += boxes[:, :2] # xywh -> xyxy
|
||||||
|
boxes[:, 0::2].clamp_(min=0, max=w)
|
||||||
|
boxes[:, 1::2].clamp_(min=0, max=h)
|
||||||
|
# filt invalid boxes/masks/keypoints
|
||||||
|
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
||||||
|
boxes = boxes[keep]
|
||||||
|
|
||||||
|
target_new = {}
|
||||||
|
image_id = self.ids[idx]
|
||||||
|
target_new["image_id"] = image_id
|
||||||
|
target_new["boxes"] = boxes
|
||||||
|
target_new["orig_size"] = torch.as_tensor([int(h), int(w)])
|
||||||
|
|
||||||
|
if self._transforms is not None:
|
||||||
|
img, target = self._transforms(img, target_new)
|
||||||
|
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
|
||||||
|
class PostProcessCocoGrounding(nn.Module):
|
||||||
|
""" This module converts the model's output into the format expected by the coco api"""
|
||||||
|
|
||||||
|
def __init__(self, num_select=300, coco_api=None, tokenlizer=None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.num_select = num_select
|
||||||
|
|
||||||
|
assert coco_api is not None
|
||||||
|
category_dict = coco_api.dataset['categories']
|
||||||
|
cat_list = [item['name'] for item in category_dict]
|
||||||
|
captions, cat2tokenspan = build_captions_and_token_span(cat_list, True)
|
||||||
|
tokenspanlist = [cat2tokenspan[cat] for cat in cat_list]
|
||||||
|
positive_map = create_positive_map_from_span(
|
||||||
|
tokenlizer(captions), tokenspanlist) # 80, 256. normed
|
||||||
|
|
||||||
|
id_map = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9, 9: 10, 10: 11, 11: 13, 12: 14, 13: 15, 14: 16, 15: 17, 16: 18, 17: 19, 18: 20, 19: 21, 20: 22, 21: 23, 22: 24, 23: 25, 24: 27, 25: 28, 26: 31, 27: 32, 28: 33, 29: 34, 30: 35, 31: 36, 32: 37, 33: 38, 34: 39, 35: 40, 36: 41, 37: 42, 38: 43, 39: 44, 40: 46,
|
||||||
|
41: 47, 42: 48, 43: 49, 44: 50, 45: 51, 46: 52, 47: 53, 48: 54, 49: 55, 50: 56, 51: 57, 52: 58, 53: 59, 54: 60, 55: 61, 56: 62, 57: 63, 58: 64, 59: 65, 60: 67, 61: 70, 62: 72, 63: 73, 64: 74, 65: 75, 66: 76, 67: 77, 68: 78, 69: 79, 70: 80, 71: 81, 72: 82, 73: 84, 74: 85, 75: 86, 76: 87, 77: 88, 78: 89, 79: 90}
|
||||||
|
|
||||||
|
# build a mapping from label_id to pos_map
|
||||||
|
new_pos_map = torch.zeros((91, 256))
|
||||||
|
for k, v in id_map.items():
|
||||||
|
new_pos_map[v] = positive_map[k]
|
||||||
|
self.positive_map = new_pos_map
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, outputs, target_sizes, not_to_xyxy=False):
|
||||||
|
""" Perform the computation
|
||||||
|
Parameters:
|
||||||
|
outputs: raw outputs of the model
|
||||||
|
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
|
||||||
|
For evaluation, this must be the original image size (before any data augmentation)
|
||||||
|
For visualization, this should be the image size after data augment, but before padding
|
||||||
|
"""
|
||||||
|
num_select = self.num_select
|
||||||
|
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
|
||||||
|
|
||||||
|
# pos map to logit
|
||||||
|
prob_to_token = out_logits.sigmoid() # bs, 100, 256
|
||||||
|
pos_maps = self.positive_map.to(prob_to_token.device)
|
||||||
|
# (bs, 100, 256) @ (91, 256).T -> (bs, 100, 91)
|
||||||
|
prob_to_label = prob_to_token @ pos_maps.T
|
||||||
|
|
||||||
|
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
assert len(out_logits) == len(target_sizes)
|
||||||
|
assert target_sizes.shape[1] == 2
|
||||||
|
|
||||||
|
prob = prob_to_label
|
||||||
|
topk_values, topk_indexes = torch.topk(
|
||||||
|
prob.view(out_logits.shape[0], -1), num_select, dim=1)
|
||||||
|
scores = topk_values
|
||||||
|
topk_boxes = topk_indexes // prob.shape[2]
|
||||||
|
labels = topk_indexes % prob.shape[2]
|
||||||
|
|
||||||
|
if not_to_xyxy:
|
||||||
|
boxes = out_bbox
|
||||||
|
else:
|
||||||
|
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
|
||||||
|
|
||||||
|
boxes = torch.gather(
|
||||||
|
boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
||||||
|
|
||||||
|
# and from relative [0, 1] to absolute [0, height] coordinates
|
||||||
|
img_h, img_w = target_sizes.unbind(1)
|
||||||
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
||||||
|
boxes = boxes * scale_fct[:, None, :]
|
||||||
|
|
||||||
|
results = [{'scores': s, 'labels': l, 'boxes': b}
|
||||||
|
for s, l, b in zip(scores, labels, boxes)]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# config
|
||||||
|
cfg = SLConfig.fromfile(args.config_file)
|
||||||
|
|
||||||
|
# build model
|
||||||
|
model = load_model(args.config_file, args.checkpoint_path)
|
||||||
|
model = model.to(args.device)
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
# build dataloader
|
||||||
|
transform = T.Compose(
|
||||||
|
[
|
||||||
|
T.RandomResize([800], max_size=1333),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dataset = CocoDetection(
|
||||||
|
args.image_dir, args.anno_path, transforms=transform)
|
||||||
|
data_loader = DataLoader(
|
||||||
|
dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn)
|
||||||
|
|
||||||
|
# build post processor
|
||||||
|
tokenlizer = get_tokenlizer.get_tokenlizer(cfg.text_encoder_type)
|
||||||
|
postprocessor = PostProcessCocoGrounding(
|
||||||
|
coco_api=dataset.coco, tokenlizer=tokenlizer)
|
||||||
|
|
||||||
|
# build evaluator
|
||||||
|
evaluator = CocoGroundingEvaluator(
|
||||||
|
dataset.coco, iou_types=("bbox",), useCats=True)
|
||||||
|
|
||||||
|
# build captions
|
||||||
|
category_dict = dataset.coco.dataset['categories']
|
||||||
|
cat_list = [item['name'] for item in category_dict]
|
||||||
|
caption = " . ".join(cat_list) + ' .'
|
||||||
|
print("Input text prompt:", caption)
|
||||||
|
|
||||||
|
# run inference
|
||||||
|
start = time.time()
|
||||||
|
for i, (images, targets) in enumerate(data_loader):
|
||||||
|
# get images and captions
|
||||||
|
images = images.tensors.to(args.device)
|
||||||
|
bs = images.shape[0]
|
||||||
|
input_captions = [caption] * bs
|
||||||
|
|
||||||
|
# feed to the model
|
||||||
|
outputs = model(images, captions=input_captions)
|
||||||
|
|
||||||
|
orig_target_sizes = torch.stack(
|
||||||
|
[t["orig_size"] for t in targets], dim=0).to(images.device)
|
||||||
|
results = postprocessor(outputs, orig_target_sizes)
|
||||||
|
cocogrounding_res = {
|
||||||
|
target["image_id"]: output for target, output in zip(targets, results)}
|
||||||
|
evaluator.update(cocogrounding_res)
|
||||||
|
|
||||||
|
if (i+1) % 30 == 0:
|
||||||
|
used_time = time.time() - start
|
||||||
|
eta = len(data_loader) / (i+1e-5) * used_time - used_time
|
||||||
|
print(
|
||||||
|
f"processed {i}/{len(data_loader)} images. time: {used_time:.2f}s, ETA: {eta:.2f}s")
|
||||||
|
|
||||||
|
evaluator.synchronize_between_processes()
|
||||||
|
evaluator.accumulate()
|
||||||
|
evaluator.summarize()
|
||||||
|
|
||||||
|
print("Final results:", evaluator.coco_eval["bbox"].stats.tolist())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
"Grounding DINO eval on COCO", add_help=True)
|
||||||
|
# load model
|
||||||
|
parser.add_argument("--config_file", "-c", type=str,
|
||||||
|
required=True, help="path to config file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file"
|
||||||
|
)
|
||||||
|
parser.add_argument("--device", type=str, default="cuda",
|
||||||
|
help="running device (default: cuda)")
|
||||||
|
|
||||||
|
# post processing
|
||||||
|
parser.add_argument("--num_select", type=int, default=300,
|
||||||
|
help="number of topk to select")
|
||||||
|
|
||||||
|
# coco info
|
||||||
|
parser.add_argument("--anno_path", type=str,
|
||||||
|
required=True, help="coco root")
|
||||||
|
parser.add_argument("--image_dir", type=str,
|
||||||
|
required=True, help="coco image dir")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=4,
|
||||||
|
help="number of workers for dataloader")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
|
@ -0,0 +1,269 @@
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Grounding DINO. Midified by Shilong Liu.
|
||||||
|
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||||
|
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
"""
|
||||||
|
COCO evaluator that works in distributed mode.
|
||||||
|
|
||||||
|
Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
|
||||||
|
The difference is that there is less copy-pasting from pycocotools
|
||||||
|
in the end of the file, as python3 can suppress prints with contextlib
|
||||||
|
"""
|
||||||
|
import contextlib
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pycocotools.mask as mask_util
|
||||||
|
import torch
|
||||||
|
from pycocotools.coco import COCO
|
||||||
|
from pycocotools.cocoeval import COCOeval
|
||||||
|
|
||||||
|
from groundingdino.util.misc import all_gather
|
||||||
|
|
||||||
|
|
||||||
|
class CocoGroundingEvaluator(object):
|
||||||
|
def __init__(self, coco_gt, iou_types, useCats=True):
|
||||||
|
assert isinstance(iou_types, (list, tuple))
|
||||||
|
coco_gt = copy.deepcopy(coco_gt)
|
||||||
|
self.coco_gt = coco_gt
|
||||||
|
|
||||||
|
self.iou_types = iou_types
|
||||||
|
self.coco_eval = {}
|
||||||
|
for iou_type in iou_types:
|
||||||
|
self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
|
||||||
|
self.coco_eval[iou_type].useCats = useCats
|
||||||
|
|
||||||
|
self.img_ids = []
|
||||||
|
self.eval_imgs = {k: [] for k in iou_types}
|
||||||
|
self.useCats = useCats
|
||||||
|
|
||||||
|
def update(self, predictions):
|
||||||
|
img_ids = list(np.unique(list(predictions.keys())))
|
||||||
|
self.img_ids.extend(img_ids)
|
||||||
|
|
||||||
|
for iou_type in self.iou_types:
|
||||||
|
results = self.prepare(predictions, iou_type)
|
||||||
|
|
||||||
|
# suppress pycocotools prints
|
||||||
|
with open(os.devnull, "w") as devnull:
|
||||||
|
with contextlib.redirect_stdout(devnull):
|
||||||
|
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
|
||||||
|
|
||||||
|
coco_eval = self.coco_eval[iou_type]
|
||||||
|
|
||||||
|
coco_eval.cocoDt = coco_dt
|
||||||
|
coco_eval.params.imgIds = list(img_ids)
|
||||||
|
coco_eval.params.useCats = self.useCats
|
||||||
|
img_ids, eval_imgs = evaluate(coco_eval)
|
||||||
|
|
||||||
|
self.eval_imgs[iou_type].append(eval_imgs)
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
for iou_type in self.iou_types:
|
||||||
|
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
|
||||||
|
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
|
||||||
|
|
||||||
|
def accumulate(self):
|
||||||
|
for coco_eval in self.coco_eval.values():
|
||||||
|
coco_eval.accumulate()
|
||||||
|
|
||||||
|
def summarize(self):
|
||||||
|
for iou_type, coco_eval in self.coco_eval.items():
|
||||||
|
print("IoU metric: {}".format(iou_type))
|
||||||
|
coco_eval.summarize()
|
||||||
|
|
||||||
|
def prepare(self, predictions, iou_type):
|
||||||
|
if iou_type == "bbox":
|
||||||
|
return self.prepare_for_coco_detection(predictions)
|
||||||
|
elif iou_type == "segm":
|
||||||
|
return self.prepare_for_coco_segmentation(predictions)
|
||||||
|
elif iou_type == "keypoints":
|
||||||
|
return self.prepare_for_coco_keypoint(predictions)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown iou type {}".format(iou_type))
|
||||||
|
|
||||||
|
def prepare_for_coco_detection(self, predictions):
|
||||||
|
coco_results = []
|
||||||
|
for original_id, prediction in predictions.items():
|
||||||
|
if len(prediction) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
boxes = prediction["boxes"]
|
||||||
|
boxes = convert_to_xywh(boxes).tolist()
|
||||||
|
scores = prediction["scores"].tolist()
|
||||||
|
labels = prediction["labels"].tolist()
|
||||||
|
|
||||||
|
coco_results.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"image_id": original_id,
|
||||||
|
"category_id": labels[k],
|
||||||
|
"bbox": box,
|
||||||
|
"score": scores[k],
|
||||||
|
}
|
||||||
|
for k, box in enumerate(boxes)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return coco_results
|
||||||
|
|
||||||
|
def prepare_for_coco_segmentation(self, predictions):
|
||||||
|
coco_results = []
|
||||||
|
for original_id, prediction in predictions.items():
|
||||||
|
if len(prediction) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
scores = prediction["scores"]
|
||||||
|
labels = prediction["labels"]
|
||||||
|
masks = prediction["masks"]
|
||||||
|
|
||||||
|
masks = masks > 0.5
|
||||||
|
|
||||||
|
scores = prediction["scores"].tolist()
|
||||||
|
labels = prediction["labels"].tolist()
|
||||||
|
|
||||||
|
rles = [
|
||||||
|
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
|
||||||
|
for mask in masks
|
||||||
|
]
|
||||||
|
for rle in rles:
|
||||||
|
rle["counts"] = rle["counts"].decode("utf-8")
|
||||||
|
|
||||||
|
coco_results.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"image_id": original_id,
|
||||||
|
"category_id": labels[k],
|
||||||
|
"segmentation": rle,
|
||||||
|
"score": scores[k],
|
||||||
|
}
|
||||||
|
for k, rle in enumerate(rles)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return coco_results
|
||||||
|
|
||||||
|
def prepare_for_coco_keypoint(self, predictions):
|
||||||
|
coco_results = []
|
||||||
|
for original_id, prediction in predictions.items():
|
||||||
|
if len(prediction) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
boxes = prediction["boxes"]
|
||||||
|
boxes = convert_to_xywh(boxes).tolist()
|
||||||
|
scores = prediction["scores"].tolist()
|
||||||
|
labels = prediction["labels"].tolist()
|
||||||
|
keypoints = prediction["keypoints"]
|
||||||
|
keypoints = keypoints.flatten(start_dim=1).tolist()
|
||||||
|
|
||||||
|
coco_results.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"image_id": original_id,
|
||||||
|
"category_id": labels[k],
|
||||||
|
"keypoints": keypoint,
|
||||||
|
"score": scores[k],
|
||||||
|
}
|
||||||
|
for k, keypoint in enumerate(keypoints)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return coco_results
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_xywh(boxes):
|
||||||
|
xmin, ymin, xmax, ymax = boxes.unbind(1)
|
||||||
|
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def merge(img_ids, eval_imgs):
|
||||||
|
all_img_ids = all_gather(img_ids)
|
||||||
|
all_eval_imgs = all_gather(eval_imgs)
|
||||||
|
|
||||||
|
merged_img_ids = []
|
||||||
|
for p in all_img_ids:
|
||||||
|
merged_img_ids.extend(p)
|
||||||
|
|
||||||
|
merged_eval_imgs = []
|
||||||
|
for p in all_eval_imgs:
|
||||||
|
merged_eval_imgs.append(p)
|
||||||
|
|
||||||
|
merged_img_ids = np.array(merged_img_ids)
|
||||||
|
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
|
||||||
|
|
||||||
|
# keep only unique (and in sorted order) images
|
||||||
|
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
|
||||||
|
merged_eval_imgs = merged_eval_imgs[..., idx]
|
||||||
|
|
||||||
|
return merged_img_ids, merged_eval_imgs
|
||||||
|
|
||||||
|
|
||||||
|
def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
|
||||||
|
img_ids, eval_imgs = merge(img_ids, eval_imgs)
|
||||||
|
img_ids = list(img_ids)
|
||||||
|
eval_imgs = list(eval_imgs.flatten())
|
||||||
|
|
||||||
|
coco_eval.evalImgs = eval_imgs
|
||||||
|
coco_eval.params.imgIds = img_ids
|
||||||
|
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################
|
||||||
|
# From pycocotools, just removed the prints and fixed
|
||||||
|
# a Python3 bug about unicode not defined
|
||||||
|
#################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(self):
|
||||||
|
"""
|
||||||
|
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
# tic = time.time()
|
||||||
|
# print('Running per image evaluation...')
|
||||||
|
p = self.params
|
||||||
|
# add backward compatibility if useSegm is specified in params
|
||||||
|
if p.useSegm is not None:
|
||||||
|
p.iouType = "segm" if p.useSegm == 1 else "bbox"
|
||||||
|
print("useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType))
|
||||||
|
# print('Evaluate annotation type *{}*'.format(p.iouType))
|
||||||
|
p.imgIds = list(np.unique(p.imgIds))
|
||||||
|
if p.useCats:
|
||||||
|
p.catIds = list(np.unique(p.catIds))
|
||||||
|
p.maxDets = sorted(p.maxDets)
|
||||||
|
self.params = p
|
||||||
|
|
||||||
|
self._prepare()
|
||||||
|
# loop through images, area range, max detection number
|
||||||
|
catIds = p.catIds if p.useCats else [-1]
|
||||||
|
|
||||||
|
if p.iouType == "segm" or p.iouType == "bbox":
|
||||||
|
computeIoU = self.computeIoU
|
||||||
|
elif p.iouType == "keypoints":
|
||||||
|
computeIoU = self.computeOks
|
||||||
|
self.ious = {
|
||||||
|
(imgId, catId): computeIoU(imgId, catId)
|
||||||
|
for imgId in p.imgIds
|
||||||
|
for catId in catIds}
|
||||||
|
|
||||||
|
evaluateImg = self.evaluateImg
|
||||||
|
maxDet = p.maxDets[-1]
|
||||||
|
evalImgs = [
|
||||||
|
evaluateImg(imgId, catId, areaRng, maxDet)
|
||||||
|
for catId in catIds
|
||||||
|
for areaRng in p.areaRng
|
||||||
|
for imgId in p.imgIds
|
||||||
|
]
|
||||||
|
# this is NOT in the pycocotools code, but could be done outside
|
||||||
|
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
|
||||||
|
self._paramsEval = copy.deepcopy(self.params)
|
||||||
|
# toc = time.time()
|
||||||
|
# print('DONE (t={:0.2f}s).'.format(toc-tic))
|
||||||
|
return p.imgIds, evalImgs
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################
|
||||||
|
# end of straight copy from pycocotools, just removing the prints
|
||||||
|
#################################################################
|
|
@ -228,7 +228,6 @@ class GroundingDINO(nn.Module):
|
||||||
captions = kw["captions"]
|
captions = kw["captions"]
|
||||||
else:
|
else:
|
||||||
captions = [t["caption"] for t in targets]
|
captions = [t["caption"] for t in targets]
|
||||||
len(captions)
|
|
||||||
|
|
||||||
# encoder texts
|
# encoder texts
|
||||||
tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(
|
tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
__version__ = "0.1.0"
|
|
Loading…
Reference in New Issue