support coco evaluation

pull/149/head
SlongLiu 2023-06-17 17:30:22 +08:00
parent 4e6f23d35c
commit ade33b7b47
6 changed files with 521 additions and 15 deletions

View File

@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2020 - present, Facebook, Inc
Copyright 2023 - present, IDEA Research.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -68,6 +68,7 @@ PyTorch implementation and pretrained models for Grounding DINO. For details, se
## :fire: News
- **`2023/06/17`**: We provide an example to evaluat Grounding DINO on COCO zero-shot performance.
- **`2023/04/15`**: Refer to [CV in the Wild Readings](https://github.com/Computer-Vision-in-the-Wild/CVinW_Readings) for those who are interested in open-set recognition!
- **`2023/04/08`**: We release [demos](demo/image_editing_with_groundingdino_gligen.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [GLIGEN](https://github.com/gligen/GLIGEN) for more controllable image editings.
- **`2023/04/08`**: We release [demos](demo/image_editing_with_groundingdino_stablediffusion.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) for image editings.
@ -129,24 +130,16 @@ cd GroundingDINO/
Install the required dependencies in the current directory.
```bash
pip3 install -q -e .
pip install -e .
```
Create a new directory called "weights" to store the model weights.
Download pre-trained model weights.
```bash
mkdir weights
```
Change the current directory to the "weights" folder.
```bash
cd weights
```
Download the model weights file.
```bash
wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
cd ..
```
## :arrow_forward: Demo
@ -201,6 +194,19 @@ We also provide a demo code to integrate Grounding DINO with Gradio Web UI. See
- We release [demos](demo/image_editing_with_groundingdino_gligen.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [GLIGEN](https://github.com/gligen/GLIGEN) for more controllable image editings.
- We release [demos](demo/image_editing_with_groundingdino_stablediffusion.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) for image editings.
## COCO Zero-shot Evaluations
We provide an example to evaluate Grounding DINO zero-shot performance on COCO. The results should be **48.5**.
```bash
CUDA_VISIBLE_DEVICES=0 \
python demo/test_ap_on_coco.py \
-c groundingdino/config/GroundingDINO_SwinT_OGC.py \
-p weights/groundingdino_swint_ogc.pth \
--anno_path /path/to/annoataions/ie/instances_val2017.json \
--image_dir /path/to/imagedir/ie/val2017
```
## :luggage: Checkpoints

View File

@ -0,0 +1,233 @@
import argparse
import os
import sys
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, DistributedSampler
from groundingdino.models import build_model
import groundingdino.datasets.transforms as T
from groundingdino.util import box_ops, get_tokenlizer
from groundingdino.util.misc import clean_state_dict, collate_fn
from groundingdino.util.slconfig import SLConfig
# from torchvision.datasets import CocoDetection
import torchvision
from groundingdino.util.vl_utils import build_captions_and_token_span, create_positive_map_from_span
from groundingdino.datasets.cocogrounding_eval import CocoGroundingEvaluator
def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
model.load_state_dict(clean_state_dict(checkpoint["ema_model"]), strict=False)
model.eval()
return model
class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, transforms):
super().__init__(img_folder, ann_file)
self._transforms = transforms
def __getitem__(self, idx):
img, target = super().__getitem__(idx) # target: list
# import ipdb; ipdb.set_trace()
w, h = img.size
boxes = [obj["bbox"] for obj in target]
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
boxes[:, 2:] += boxes[:, :2] # xywh -> xyxy
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)
# filt invalid boxes/masks/keypoints
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
target_new = {}
image_id = self.ids[idx]
target_new["image_id"] = image_id
target_new["boxes"] = boxes
target_new["orig_size"] = torch.as_tensor([int(h), int(w)])
if self._transforms is not None:
img, target = self._transforms(img, target_new)
return img, target
class PostProcessCocoGrounding(nn.Module):
""" This module converts the model's output into the format expected by the coco api"""
def __init__(self, num_select=300, coco_api=None, tokenlizer=None) -> None:
super().__init__()
self.num_select = num_select
assert coco_api is not None
category_dict = coco_api.dataset['categories']
cat_list = [item['name'] for item in category_dict]
captions, cat2tokenspan = build_captions_and_token_span(cat_list, True)
tokenspanlist = [cat2tokenspan[cat] for cat in cat_list]
positive_map = create_positive_map_from_span(
tokenlizer(captions), tokenspanlist) # 80, 256. normed
id_map = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9, 9: 10, 10: 11, 11: 13, 12: 14, 13: 15, 14: 16, 15: 17, 16: 18, 17: 19, 18: 20, 19: 21, 20: 22, 21: 23, 22: 24, 23: 25, 24: 27, 25: 28, 26: 31, 27: 32, 28: 33, 29: 34, 30: 35, 31: 36, 32: 37, 33: 38, 34: 39, 35: 40, 36: 41, 37: 42, 38: 43, 39: 44, 40: 46,
41: 47, 42: 48, 43: 49, 44: 50, 45: 51, 46: 52, 47: 53, 48: 54, 49: 55, 50: 56, 51: 57, 52: 58, 53: 59, 54: 60, 55: 61, 56: 62, 57: 63, 58: 64, 59: 65, 60: 67, 61: 70, 62: 72, 63: 73, 64: 74, 65: 75, 66: 76, 67: 77, 68: 78, 69: 79, 70: 80, 71: 81, 72: 82, 73: 84, 74: 85, 75: 86, 76: 87, 77: 88, 78: 89, 79: 90}
# build a mapping from label_id to pos_map
new_pos_map = torch.zeros((91, 256))
for k, v in id_map.items():
new_pos_map[v] = positive_map[k]
self.positive_map = new_pos_map
@torch.no_grad()
def forward(self, outputs, target_sizes, not_to_xyxy=False):
""" Perform the computation
Parameters:
outputs: raw outputs of the model
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
For evaluation, this must be the original image size (before any data augmentation)
For visualization, this should be the image size after data augment, but before padding
"""
num_select = self.num_select
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
# pos map to logit
prob_to_token = out_logits.sigmoid() # bs, 100, 256
pos_maps = self.positive_map.to(prob_to_token.device)
# (bs, 100, 256) @ (91, 256).T -> (bs, 100, 91)
prob_to_label = prob_to_token @ pos_maps.T
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
assert len(out_logits) == len(target_sizes)
assert target_sizes.shape[1] == 2
prob = prob_to_label
topk_values, topk_indexes = torch.topk(
prob.view(out_logits.shape[0], -1), num_select, dim=1)
scores = topk_values
topk_boxes = topk_indexes // prob.shape[2]
labels = topk_indexes % prob.shape[2]
if not_to_xyxy:
boxes = out_bbox
else:
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
boxes = torch.gather(
boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
# and from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :]
results = [{'scores': s, 'labels': l, 'boxes': b}
for s, l, b in zip(scores, labels, boxes)]
return results
def main(args):
# config
cfg = SLConfig.fromfile(args.config_file)
# build model
model = load_model(args.config_file, args.checkpoint_path)
model = model.to(args.device)
model = model.eval()
# build dataloader
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
dataset = CocoDetection(
args.image_dir, args.anno_path, transforms=transform)
data_loader = DataLoader(
dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn)
# build post processor
tokenlizer = get_tokenlizer.get_tokenlizer(cfg.text_encoder_type)
postprocessor = PostProcessCocoGrounding(
coco_api=dataset.coco, tokenlizer=tokenlizer)
# build evaluator
evaluator = CocoGroundingEvaluator(
dataset.coco, iou_types=("bbox",), useCats=True)
# build captions
category_dict = dataset.coco.dataset['categories']
cat_list = [item['name'] for item in category_dict]
caption = " . ".join(cat_list) + ' .'
print("Input text prompt:", caption)
# run inference
start = time.time()
for i, (images, targets) in enumerate(data_loader):
# get images and captions
images = images.tensors.to(args.device)
bs = images.shape[0]
input_captions = [caption] * bs
# feed to the model
outputs = model(images, captions=input_captions)
orig_target_sizes = torch.stack(
[t["orig_size"] for t in targets], dim=0).to(images.device)
results = postprocessor(outputs, orig_target_sizes)
cocogrounding_res = {
target["image_id"]: output for target, output in zip(targets, results)}
evaluator.update(cocogrounding_res)
if (i+1) % 30 == 0:
used_time = time.time() - start
eta = len(data_loader) / (i+1e-5) * used_time - used_time
print(
f"processed {i}/{len(data_loader)} images. time: {used_time:.2f}s, ETA: {eta:.2f}s")
evaluator.synchronize_between_processes()
evaluator.accumulate()
evaluator.summarize()
print("Final results:", evaluator.coco_eval["bbox"].stats.tolist())
if __name__ == "__main__":
parser = argparse.ArgumentParser(
"Grounding DINO eval on COCO", add_help=True)
# load model
parser.add_argument("--config_file", "-c", type=str,
required=True, help="path to config file")
parser.add_argument(
"--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file"
)
parser.add_argument("--device", type=str, default="cuda",
help="running device (default: cuda)")
# post processing
parser.add_argument("--num_select", type=int, default=300,
help="number of topk to select")
# coco info
parser.add_argument("--anno_path", type=str,
required=True, help="coco root")
parser.add_argument("--image_dir", type=str,
required=True, help="coco image dir")
parser.add_argument("--num_workers", type=int, default=4,
help="number of workers for dataloader")
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,269 @@
# ------------------------------------------------------------------------
# Grounding DINO. Midified by Shilong Liu.
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
COCO evaluator that works in distributed mode.
Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
The difference is that there is less copy-pasting from pycocotools
in the end of the file, as python3 can suppress prints with contextlib
"""
import contextlib
import copy
import os
import numpy as np
import pycocotools.mask as mask_util
import torch
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from groundingdino.util.misc import all_gather
class CocoGroundingEvaluator(object):
def __init__(self, coco_gt, iou_types, useCats=True):
assert isinstance(iou_types, (list, tuple))
coco_gt = copy.deepcopy(coco_gt)
self.coco_gt = coco_gt
self.iou_types = iou_types
self.coco_eval = {}
for iou_type in iou_types:
self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
self.coco_eval[iou_type].useCats = useCats
self.img_ids = []
self.eval_imgs = {k: [] for k in iou_types}
self.useCats = useCats
def update(self, predictions):
img_ids = list(np.unique(list(predictions.keys())))
self.img_ids.extend(img_ids)
for iou_type in self.iou_types:
results = self.prepare(predictions, iou_type)
# suppress pycocotools prints
with open(os.devnull, "w") as devnull:
with contextlib.redirect_stdout(devnull):
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
coco_eval = self.coco_eval[iou_type]
coco_eval.cocoDt = coco_dt
coco_eval.params.imgIds = list(img_ids)
coco_eval.params.useCats = self.useCats
img_ids, eval_imgs = evaluate(coco_eval)
self.eval_imgs[iou_type].append(eval_imgs)
def synchronize_between_processes(self):
for iou_type in self.iou_types:
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
def accumulate(self):
for coco_eval in self.coco_eval.values():
coco_eval.accumulate()
def summarize(self):
for iou_type, coco_eval in self.coco_eval.items():
print("IoU metric: {}".format(iou_type))
coco_eval.summarize()
def prepare(self, predictions, iou_type):
if iou_type == "bbox":
return self.prepare_for_coco_detection(predictions)
elif iou_type == "segm":
return self.prepare_for_coco_segmentation(predictions)
elif iou_type == "keypoints":
return self.prepare_for_coco_keypoint(predictions)
else:
raise ValueError("Unknown iou type {}".format(iou_type))
def prepare_for_coco_detection(self, predictions):
coco_results = []
for original_id, prediction in predictions.items():
if len(prediction) == 0:
continue
boxes = prediction["boxes"]
boxes = convert_to_xywh(boxes).tolist()
scores = prediction["scores"].tolist()
labels = prediction["labels"].tolist()
coco_results.extend(
[
{
"image_id": original_id,
"category_id": labels[k],
"bbox": box,
"score": scores[k],
}
for k, box in enumerate(boxes)
]
)
return coco_results
def prepare_for_coco_segmentation(self, predictions):
coco_results = []
for original_id, prediction in predictions.items():
if len(prediction) == 0:
continue
scores = prediction["scores"]
labels = prediction["labels"]
masks = prediction["masks"]
masks = masks > 0.5
scores = prediction["scores"].tolist()
labels = prediction["labels"].tolist()
rles = [
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
for mask in masks
]
for rle in rles:
rle["counts"] = rle["counts"].decode("utf-8")
coco_results.extend(
[
{
"image_id": original_id,
"category_id": labels[k],
"segmentation": rle,
"score": scores[k],
}
for k, rle in enumerate(rles)
]
)
return coco_results
def prepare_for_coco_keypoint(self, predictions):
coco_results = []
for original_id, prediction in predictions.items():
if len(prediction) == 0:
continue
boxes = prediction["boxes"]
boxes = convert_to_xywh(boxes).tolist()
scores = prediction["scores"].tolist()
labels = prediction["labels"].tolist()
keypoints = prediction["keypoints"]
keypoints = keypoints.flatten(start_dim=1).tolist()
coco_results.extend(
[
{
"image_id": original_id,
"category_id": labels[k],
"keypoints": keypoint,
"score": scores[k],
}
for k, keypoint in enumerate(keypoints)
]
)
return coco_results
def convert_to_xywh(boxes):
xmin, ymin, xmax, ymax = boxes.unbind(1)
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
def merge(img_ids, eval_imgs):
all_img_ids = all_gather(img_ids)
all_eval_imgs = all_gather(eval_imgs)
merged_img_ids = []
for p in all_img_ids:
merged_img_ids.extend(p)
merged_eval_imgs = []
for p in all_eval_imgs:
merged_eval_imgs.append(p)
merged_img_ids = np.array(merged_img_ids)
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
# keep only unique (and in sorted order) images
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
merged_eval_imgs = merged_eval_imgs[..., idx]
return merged_img_ids, merged_eval_imgs
def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
img_ids, eval_imgs = merge(img_ids, eval_imgs)
img_ids = list(img_ids)
eval_imgs = list(eval_imgs.flatten())
coco_eval.evalImgs = eval_imgs
coco_eval.params.imgIds = img_ids
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
#################################################################
# From pycocotools, just removed the prints and fixed
# a Python3 bug about unicode not defined
#################################################################
def evaluate(self):
"""
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
:return: None
"""
# tic = time.time()
# print('Running per image evaluation...')
p = self.params
# add backward compatibility if useSegm is specified in params
if p.useSegm is not None:
p.iouType = "segm" if p.useSegm == 1 else "bbox"
print("useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType))
# print('Evaluate annotation type *{}*'.format(p.iouType))
p.imgIds = list(np.unique(p.imgIds))
if p.useCats:
p.catIds = list(np.unique(p.catIds))
p.maxDets = sorted(p.maxDets)
self.params = p
self._prepare()
# loop through images, area range, max detection number
catIds = p.catIds if p.useCats else [-1]
if p.iouType == "segm" or p.iouType == "bbox":
computeIoU = self.computeIoU
elif p.iouType == "keypoints":
computeIoU = self.computeOks
self.ious = {
(imgId, catId): computeIoU(imgId, catId)
for imgId in p.imgIds
for catId in catIds}
evaluateImg = self.evaluateImg
maxDet = p.maxDets[-1]
evalImgs = [
evaluateImg(imgId, catId, areaRng, maxDet)
for catId in catIds
for areaRng in p.areaRng
for imgId in p.imgIds
]
# this is NOT in the pycocotools code, but could be done outside
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
self._paramsEval = copy.deepcopy(self.params)
# toc = time.time()
# print('DONE (t={:0.2f}s).'.format(toc-tic))
return p.imgIds, evalImgs
#################################################################
# end of straight copy from pycocotools, just removing the prints
#################################################################

View File

@ -228,7 +228,6 @@ class GroundingDINO(nn.Module):
captions = kw["captions"]
else:
captions = [t["caption"] for t in targets]
len(captions)
# encoder texts
tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(

View File

@ -1 +0,0 @@
__version__ = "0.1.0"