feature/first_batch_of_model_usability_upgrades (#9)
* initial commit * test updated requirements.txt * move more code to inference utils * PIL import fix * add annotations utilities * README.md updatespull/12/head
parent
12ef464f9e
commit
2309f9f468
|
@ -1,3 +1,7 @@
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
|
|
34
README.md
34
README.md
|
@ -1,14 +1,21 @@
|
||||||
|
# Grounding DINO
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb)
|
||||||
|
|
||||||
[](https://paperswithcode.com/sota/zero-shot-object-detection-on-mscoco?p=grounding-dino-marrying-dino-with-grounded) \
|
[](https://paperswithcode.com/sota/zero-shot-object-detection-on-mscoco?p=grounding-dino-marrying-dino-with-grounded) \
|
||||||
[](https://paperswithcode.com/sota/zero-shot-object-detection-on-odinw?p=grounding-dino-marrying-dino-with-grounded) \
|
[](https://paperswithcode.com/sota/zero-shot-object-detection-on-odinw?p=grounding-dino-marrying-dino-with-grounded) \
|
||||||
[](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=grounding-dino-marrying-dino-with-grounded) \
|
[](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=grounding-dino-marrying-dino-with-grounded) \
|
||||||
[](https://paperswithcode.com/sota/object-detection-on-coco?p=grounding-dino-marrying-dino-with-grounded)
|
[](https://paperswithcode.com/sota/object-detection-on-coco?p=grounding-dino-marrying-dino-with-grounded)
|
||||||
|
|
||||||
# Grounding DINO
|
|
||||||
|
|
||||||
Official pytorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.05499), a stronger open-set object detector. Code is available now!
|
Official pytorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.05499), a stronger open-set object detector. Code is available now!
|
||||||
|
|
||||||
|
|
||||||
## Highlight
|
## Highlight
|
||||||
|
|
||||||
- **Open-Set Detection.** Detect **everything** with language!
|
- **Open-Set Detection.** Detect **everything** with language!
|
||||||
- **High Performancce.** COCO zero-shot **52.5 AP** (training without COCO data!). COCO fine-tune **63.0 AP**.
|
- **High Performancce.** COCO zero-shot **52.5 AP** (training without COCO data!). COCO fine-tune **63.0 AP**.
|
||||||
- **Flexible.** Collaboration with Stable Diffusion for Image Editting.
|
- **Flexible.** Collaboration with Stable Diffusion for Image Editting.
|
||||||
|
@ -23,21 +30,22 @@ Description
|
||||||
<img src=".asset/hero_figure.png" alt="ODinW" width="100%">
|
<img src=".asset/hero_figure.png" alt="ODinW" width="100%">
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## TODO List
|
## TODO
|
||||||
|
|
||||||
- [x] Release inference code and demo.
|
- [x] Release inference code and demo.
|
||||||
- [x] Release checkpoints.
|
- [x] Release checkpoints.
|
||||||
- [ ] Grounding DINO with Stable Diffusion and GLIGEN demos.
|
- [ ] Grounding DINO with Stable Diffusion and GLIGEN demos.
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
## Usage
|
|
||||||
### 1. Install
|
|
||||||
If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set.
|
If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Run an inference demo
|
## Demo
|
||||||
|
|
||||||
See the `demo/inference_on_a_image.py` for more details.
|
See the `demo/inference_on_a_image.py` for more details.
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
|
CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
|
||||||
|
@ -48,7 +56,8 @@ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
|
||||||
-t "cat ear."
|
-t "cat ear."
|
||||||
```
|
```
|
||||||
|
|
||||||
### Checkpoints
|
## Checkpoints
|
||||||
|
|
||||||
<!-- insert a table -->
|
<!-- insert a table -->
|
||||||
<table>
|
<table>
|
||||||
<thead>
|
<thead>
|
||||||
|
@ -74,6 +83,7 @@ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
## Results
|
## Results
|
||||||
|
|
||||||
<details open>
|
<details open>
|
||||||
<summary><font size="4">
|
<summary><font size="4">
|
||||||
COCO Object Detection Results
|
COCO Object Detection Results
|
||||||
|
@ -102,11 +112,6 @@ Marrying Grounding DINO with <a href="https://github.com/gligen/GLIGEN">GLIGEN</
|
||||||
<img src=".asset/GD_GLIGEN.png" alt="GD_GLIGEN" width="100%">
|
<img src=".asset/GD_GLIGEN.png" alt="GD_GLIGEN" width="100%">
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Model
|
## Model
|
||||||
|
|
||||||
Includes: a text backbone, an image backbone, a feature enhancer, a language-guided query selection, and a cross-modality decoder.
|
Includes: a text backbone, an image backbone, a feature enhancer, a language-guided query selection, and a cross-modality decoder.
|
||||||
|
@ -114,7 +119,8 @@ Includes: a text backbone, an image backbone, a feature enhancer, a language-gui
|
||||||

|

|
||||||
|
|
||||||
|
|
||||||
# Links
|
## Acknowledgement
|
||||||
|
|
||||||
Our model is related to [DINO](https://github.com/IDEA-Research/DINO) and [GLIP](https://github.com/microsoft/GLIP). Thanks for their great work!
|
Our model is related to [DINO](https://github.com/IDEA-Research/DINO) and [GLIP](https://github.com/microsoft/GLIP). Thanks for their great work!
|
||||||
|
|
||||||
We also thank great previous work including DETR, Deformable DETR, SMCA, Conditional DETR, Anchor DETR, Dynamic DETR, DAB-DETR, DN-DETR, etc. More related work are available at [Awesome Detection Transformer](https://github.com/IDEACVR/awesome-detection-transformer). A new toolbox [detrex](https://github.com/IDEA-Research/detrex) is available as well.
|
We also thank great previous work including DETR, Deformable DETR, SMCA, Conditional DETR, Anchor DETR, Dynamic DETR, DAB-DETR, DN-DETR, etc. More related work are available at [Awesome Detection Transformer](https://github.com/IDEACVR/awesome-detection-transformer). A new toolbox [detrex](https://github.com/IDEA-Research/detrex) is available as well.
|
||||||
|
@ -122,8 +128,10 @@ We also thank great previous work including DETR, Deformable DETR, SMCA, Conditi
|
||||||
Thanks [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) and [GLIGEN](https://github.com/gligen/GLIGEN) for their awesome models.
|
Thanks [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) and [GLIGEN](https://github.com/gligen/GLIGEN) for their awesome models.
|
||||||
|
|
||||||
|
|
||||||
# Bibtex
|
## Citation
|
||||||
|
|
||||||
If you find our work helpful for your research, please consider citing the following BibTeX entry.
|
If you find our work helpful for your research, please consider citing the following BibTeX entry.
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@inproceedings{ShilongLiu2023GroundingDM,
|
@inproceedings{ShilongLiu2023GroundingDM,
|
||||||
title={Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection},
|
title={Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection},
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
from typing import Tuple, List
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import supervision as sv
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.ops import box_convert
|
||||||
|
|
||||||
|
import groundingdino.datasets.transforms as T
|
||||||
|
from groundingdino.models import build_model
|
||||||
|
from groundingdino.util.misc import clean_state_dict
|
||||||
|
from groundingdino.util.slconfig import SLConfig
|
||||||
|
from groundingdino.util.utils import get_phrases_from_posmap
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_caption(caption: str) -> str:
|
||||||
|
result = caption.lower().strip()
|
||||||
|
if result.endswith("."):
|
||||||
|
return result
|
||||||
|
return result + "."
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_config_path: str, model_checkpoint_path: str):
|
||||||
|
args = SLConfig.fromfile(model_config_path)
|
||||||
|
args.device = "cuda"
|
||||||
|
model = build_model(args)
|
||||||
|
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
||||||
|
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
|
||||||
|
transform = T.Compose(
|
||||||
|
[
|
||||||
|
T.RandomResize([800], max_size=1333),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
image_source = Image.open(image_path).convert("RGB")
|
||||||
|
image = np.asarray(image_source)
|
||||||
|
image_transformed, _ = transform(image_source, None)
|
||||||
|
return image, image_transformed
|
||||||
|
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
model,
|
||||||
|
image: torch.Tensor,
|
||||||
|
caption: str,
|
||||||
|
box_threshold: float,
|
||||||
|
text_threshold: float
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
|
||||||
|
caption = preprocess_caption(caption=caption)
|
||||||
|
|
||||||
|
model = model.cuda()
|
||||||
|
image = image.cuda()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(image[None], captions=[caption])
|
||||||
|
|
||||||
|
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
|
||||||
|
prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
|
||||||
|
|
||||||
|
mask = prediction_logits.max(dim=1)[0] > box_threshold
|
||||||
|
logits = prediction_logits[mask] # logits.shape = (n, 256)
|
||||||
|
boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
|
||||||
|
|
||||||
|
tokenizer = model.tokenizer
|
||||||
|
tokenized = tokenizer(caption)
|
||||||
|
|
||||||
|
phrases = [
|
||||||
|
get_phrases_from_posmap(logit > text_threshold, tokenized, caption).replace('.', '')
|
||||||
|
for logit
|
||||||
|
in logits
|
||||||
|
]
|
||||||
|
|
||||||
|
return boxes, logits.max(dim=1)[0], phrases
|
||||||
|
|
||||||
|
|
||||||
|
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
|
||||||
|
h, w, _ = image_source.shape
|
||||||
|
boxes = boxes * torch.Tensor([w, h, w, h])
|
||||||
|
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
||||||
|
detections = sv.Detections(xyxy=xyxy)
|
||||||
|
|
||||||
|
labels = [
|
||||||
|
f"{phrase} {logit:.2f}"
|
||||||
|
for phrase, logit
|
||||||
|
in zip(phrases, logits)
|
||||||
|
]
|
||||||
|
|
||||||
|
box_annotator = sv.BoxAnnotator()
|
||||||
|
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
|
||||||
|
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||||||
|
return annotated_frame
|
|
@ -1 +1,9 @@
|
||||||
transformers==4.5.1
|
torch
|
||||||
|
torchvision
|
||||||
|
transformers
|
||||||
|
addict
|
||||||
|
yapf
|
||||||
|
timm
|
||||||
|
numpy
|
||||||
|
opencv-python
|
||||||
|
supervision==0.3.2
|
Loading…
Reference in New Issue