Merge bffa3757b5
into 856dde20ae
After Width: | Height: | Size: 368 KiB |
After Width: | Height: | Size: 217 KiB |
After Width: | Height: | Size: 489 KiB |
After Width: | Height: | Size: 438 KiB |
After Width: | Height: | Size: 379 KiB |
After Width: | Height: | Size: 155 KiB |
After Width: | Height: | Size: 327 KiB |
After Width: | Height: | Size: 271 KiB |
After Width: | Height: | Size: 5.4 MiB |
|
@ -0,0 +1,292 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Grounding DINO - Batched Half Precision Inference"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Prepare Environments"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"import io\n",
|
||||||
|
"import os\n",
|
||||||
|
"import supervision as sv\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import requests\n",
|
||||||
|
"import cv2\n",
|
||||||
|
"\n",
|
||||||
|
"# Grounding DINO\n",
|
||||||
|
"from groundingdino.util.inference import BatchedModel\n",
|
||||||
|
"import torchvision.transforms.functional as F\n",
|
||||||
|
"from huggingface_hub import hf_hub_download\n",
|
||||||
|
"\n",
|
||||||
|
"# If you have multiple GPUs, you can set the GPU to use here.\n",
|
||||||
|
"# The default is to use the first GPU, which is usually GPU 0.\n",
|
||||||
|
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Load Grounding DINO model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Load demo image"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def download_image(url, image_file_path):\n",
|
||||||
|
" r = requests.get(url, timeout=4.0)\n",
|
||||||
|
" if r.status_code != requests.codes.ok:\n",
|
||||||
|
" assert False, 'Status code error: {}.'.format(r.status_code)\n",
|
||||||
|
"\n",
|
||||||
|
" with Image.open(io.BytesIO(r.content)) as im:\n",
|
||||||
|
" im.save(image_file_path)\n",
|
||||||
|
"\n",
|
||||||
|
" print('Image downloaded from url: {} and saved to: {}.'.format(url, image_file_path))\n",
|
||||||
|
"\n",
|
||||||
|
"def load_image(image_path):\n",
|
||||||
|
" image_source = Image.open(image_path).convert(\"RGB\")\n",
|
||||||
|
" image = np.asarray(image_source)\n",
|
||||||
|
" image_tensor = F.to_tensor(image)\n",
|
||||||
|
" return image, image_tensor\n",
|
||||||
|
"\n",
|
||||||
|
"local_image_path = \"assets/demo4.jpg\"\n",
|
||||||
|
"#download_image(image_url, local_image_path)\n",
|
||||||
|
"image_source, image_tensor = load_image(local_image_path)\n",
|
||||||
|
"Image.fromarray(image_source)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Run Grounding DINO for detection"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Use this command for evaluate the Grounding DINO model\n",
|
||||||
|
"# Or you can download the model by yourself\n",
|
||||||
|
"ckpt_repo_id = \"ShilongLiu/GroundingDINO\"\n",
|
||||||
|
"ckpt_filename = \"groundingdino_swint_ogc.pth\"\n",
|
||||||
|
"ckpt_config_filename = \"GroundingDINO_SwinT_OGC.cfg.py\"\n",
|
||||||
|
"device = \"cuda\"\n",
|
||||||
|
"\n",
|
||||||
|
"cache_config_file = hf_hub_download(repo_id=ckpt_repo_id, filename=ckpt_config_filename)\n",
|
||||||
|
"cache_file = hf_hub_download(repo_id=ckpt_repo_id, filename=ckpt_filename)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Single Precision"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"batch = 2\n",
|
||||||
|
"box_threshold = 0.3\n",
|
||||||
|
"text_threshold = 0.25\n",
|
||||||
|
"iou_threshold = 0.5\n",
|
||||||
|
"\n",
|
||||||
|
"# Batch of prompts\n",
|
||||||
|
"text_prompt = [\n",
|
||||||
|
" [\"Black dog\", \"Beige dog\"],\n",
|
||||||
|
" [\"Dog\", \"Stick\"]\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"dtype = \"float32\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Repeat image BATCH number of times\n",
|
||||||
|
"image_tensor = image_tensor.to(device=device).to(dtype=getattr(torch, dtype))\n",
|
||||||
|
"image_tensor = image_tensor[None, ...].expand(batch, -1, -1, -1)\n",
|
||||||
|
"\n",
|
||||||
|
"# Building GroundingDINO inference model\n",
|
||||||
|
"grounding_dino_model = BatchedModel(\n",
|
||||||
|
" model_config_path=cache_config_file, \n",
|
||||||
|
" model_checkpoint_path=cache_file,\n",
|
||||||
|
" device=device,\n",
|
||||||
|
" dtype=dtype,\n",
|
||||||
|
")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%%timeit -n 10\n",
|
||||||
|
"with torch.no_grad():\n",
|
||||||
|
" bbox_batch, conf_batch, class_id_batch = grounding_dino_model(\n",
|
||||||
|
" image_batch=image_tensor,\n",
|
||||||
|
" text_prompts=text_prompt,\n",
|
||||||
|
" box_threshold=box_threshold,\n",
|
||||||
|
" text_threshold=text_threshold,\n",
|
||||||
|
" nms_threshold=iou_threshold\n",
|
||||||
|
" )\n",
|
||||||
|
" bbox_batch = [bbox.cpu().numpy() for bbox in bbox_batch]\n",
|
||||||
|
" conf_batch = [conf.cpu().numpy() for conf in conf_batch]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Half Precision"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dtype = \"float16\"\n",
|
||||||
|
"\n",
|
||||||
|
"image_tensor = image_tensor.to(device=device).to(dtype=getattr(torch, dtype))\n",
|
||||||
|
"\n",
|
||||||
|
"# Building GroundingDINO inference model\n",
|
||||||
|
"grounding_dino_model = BatchedModel(\n",
|
||||||
|
" model_config_path=cache_config_file, \n",
|
||||||
|
" model_checkpoint_path=cache_file,\n",
|
||||||
|
" device=device,\n",
|
||||||
|
" dtype=dtype\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%%timeit -n 10\n",
|
||||||
|
"with torch.no_grad():\n",
|
||||||
|
" bbox_batch, conf_batch, class_id_batch = grounding_dino_model(\n",
|
||||||
|
" image_batch=image_tensor,\n",
|
||||||
|
" text_prompts=text_prompt,\n",
|
||||||
|
" box_threshold=box_threshold,\n",
|
||||||
|
" text_threshold=text_threshold,\n",
|
||||||
|
" nms_threshold=iou_threshold\n",
|
||||||
|
" )\n",
|
||||||
|
" bbox_batch = [bbox.cpu().numpy() for bbox in bbox_batch]\n",
|
||||||
|
" conf_batch = [conf.cpu().numpy() for conf in conf_batch]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Display result"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"with torch.no_grad():\n",
|
||||||
|
" bbox_batch, conf_batch, class_id_batch = grounding_dino_model(\n",
|
||||||
|
" image_batch=image_tensor,\n",
|
||||||
|
" text_prompts=text_prompt,\n",
|
||||||
|
" box_threshold=box_threshold,\n",
|
||||||
|
" text_threshold=text_threshold,\n",
|
||||||
|
" nms_threshold=iou_threshold\n",
|
||||||
|
" )\n",
|
||||||
|
" bbox_batch = [bbox.cpu().numpy() for bbox in bbox_batch]\n",
|
||||||
|
" conf_batch = [conf.cpu().numpy() for conf in conf_batch]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from IPython.display import display\n",
|
||||||
|
"def annotate(image_source, boxes, logits, phrases) -> np.ndarray:\n",
|
||||||
|
" detections = sv.Detections(xyxy=boxes)\n",
|
||||||
|
" labels = [\n",
|
||||||
|
" f\"{phrase} {logit:.2f}\"\n",
|
||||||
|
" for phrase, logit\n",
|
||||||
|
" in zip(phrases, logits)\n",
|
||||||
|
" ]\n",
|
||||||
|
" box_annotator = sv.BoxAnnotator()\n",
|
||||||
|
" annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)\n",
|
||||||
|
" annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)\n",
|
||||||
|
" return annotated_frame[...,::-1]\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"for i, (bbox, conf, class_id, class_label) in enumerate(zip(bbox_batch, conf_batch, class_id_batch, text_prompt)):\n",
|
||||||
|
" annotated_frame = annotate(\n",
|
||||||
|
" image_source=image_source, \n",
|
||||||
|
" boxes=bbox,\n",
|
||||||
|
" logits=conf,\n",
|
||||||
|
" phrases=np.array(class_label)[class_id]\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" display(Image.fromarray(annotated_frame))"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.4"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -159,6 +159,7 @@ class WindowAttention(nn.Module):
|
||||||
attn = attn + relative_position_bias.unsqueeze(0)
|
attn = attn + relative_position_bias.unsqueeze(0)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
|
mask = mask.to(dtype=x.dtype)
|
||||||
nW = mask.shape[0]
|
nW = mask.shape[0]
|
||||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||||
attn = attn.view(-1, self.num_heads, N, N)
|
attn = attn.view(-1, self.num_heads, N, N)
|
||||||
|
|
|
@ -100,7 +100,7 @@ def multi_scale_deformable_attn_pytorch(
|
||||||
bs, _, num_heads, embed_dims = value.shape
|
bs, _, num_heads, embed_dims = value.shape
|
||||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||||
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
||||||
sampling_grids = 2 * sampling_locations - 1
|
sampling_grids = 2 * sampling_locations.to(dtype=value.dtype) - 1
|
||||||
sampling_value_list = []
|
sampling_value_list = []
|
||||||
for level, (H_, W_) in enumerate(value_spatial_shapes):
|
for level, (H_, W_) in enumerate(value_spatial_shapes):
|
||||||
# bs, H_*W_, num_heads, embed_dims ->
|
# bs, H_*W_, num_heads, embed_dims ->
|
||||||
|
|
|
@ -659,6 +659,7 @@ class TransformerDecoder(nn.Module):
|
||||||
output = tgt
|
output = tgt
|
||||||
|
|
||||||
intermediate = []
|
intermediate = []
|
||||||
|
refpoints_unsigmoid = refpoints_unsigmoid.to(dtype=tgt.dtype)
|
||||||
reference_points = refpoints_unsigmoid.sigmoid()
|
reference_points = refpoints_unsigmoid.sigmoid()
|
||||||
ref_points = [reference_points]
|
ref_points = [reference_points]
|
||||||
|
|
||||||
|
@ -667,14 +668,14 @@ class TransformerDecoder(nn.Module):
|
||||||
if reference_points.shape[-1] == 4:
|
if reference_points.shape[-1] == 4:
|
||||||
reference_points_input = (
|
reference_points_input = (
|
||||||
reference_points[:, :, None]
|
reference_points[:, :, None]
|
||||||
* torch.cat([valid_ratios, valid_ratios], -1)[None, :]
|
* torch.cat([valid_ratios, valid_ratios], -1)[None, :].to(dtype=tgt.dtype)
|
||||||
) # nq, bs, nlevel, 4
|
) # nq, bs, nlevel, 4
|
||||||
else:
|
else:
|
||||||
assert reference_points.shape[-1] == 2
|
assert reference_points.shape[-1] == 2
|
||||||
reference_points_input = reference_points[:, :, None] * valid_ratios[None, :]
|
reference_points_input = reference_points[:, :, None] * valid_ratios[None, :].to(dtype=tgt.dtype)
|
||||||
query_sine_embed = gen_sineembed_for_position(
|
query_sine_embed = gen_sineembed_for_position(
|
||||||
reference_points_input[:, :, 0, :]
|
reference_points_input[:, :, 0, :]
|
||||||
) # nq, bs, 256*2
|
).to(dtype=tgt.dtype) # nq, bs, 256*2
|
||||||
|
|
||||||
# conditional query
|
# conditional query
|
||||||
raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
|
raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
|
||||||
|
|
|
@ -96,7 +96,7 @@ class TransformerEncoderLayer(nn.Module):
|
||||||
self.nhead = nhead
|
self.nhead = nhead
|
||||||
|
|
||||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||||
return tensor if pos is None else tensor + pos
|
return tensor if pos is None else tensor + pos.to(dtype=tensor.dtype)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
from typing import Tuple, List
|
from typing import Tuple, List, Any
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import supervision as sv
|
import supervision as sv
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import torchvision
|
||||||
from torchvision.ops import box_convert
|
from torchvision.ops import box_convert
|
||||||
|
import torchvision.transforms.functional as F
|
||||||
import bisect
|
import bisect
|
||||||
|
|
||||||
import groundingdino.datasets.transforms as T
|
import groundingdino.datasets.transforms as T
|
||||||
|
@ -271,3 +273,176 @@ class Model:
|
||||||
else:
|
else:
|
||||||
class_ids.append(None)
|
class_ids.append(None)
|
||||||
return np.array(class_ids)
|
return np.array(class_ids)
|
||||||
|
|
||||||
|
|
||||||
|
#==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedModel(object):
|
||||||
|
|
||||||
|
#=====================================================
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config_path: str,
|
||||||
|
model_checkpoint_path: str,
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: str = "float32",
|
||||||
|
compile: bool = False
|
||||||
|
) -> NotImplementedError:
|
||||||
|
|
||||||
|
self._device = device
|
||||||
|
self._dtype = getattr(torch, dtype)
|
||||||
|
self._model = load_model(
|
||||||
|
model_config_path=model_config_path,
|
||||||
|
model_checkpoint_path=model_checkpoint_path
|
||||||
|
).to(device=self._device).to(dtype=self._dtype)
|
||||||
|
|
||||||
|
# Compile model if necessary
|
||||||
|
if compile:
|
||||||
|
self._model = torch.compile(self._model)
|
||||||
|
|
||||||
|
#=====================================================
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def preprocess_image(
|
||||||
|
image_batch: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# Preprocessing friendly with batches
|
||||||
|
|
||||||
|
image_batch = F.resize(image_batch, [800], antialias=True)
|
||||||
|
image_batch = F.normalize(image_batch, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
|
||||||
|
return image_batch
|
||||||
|
|
||||||
|
#=====================================================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def post_process_result(
|
||||||
|
cls,
|
||||||
|
boxes_cxcywh: torch.Tensor,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
nms_threshold: float,
|
||||||
|
source_size: Tuple[int, int],
|
||||||
|
phrases: List[str],
|
||||||
|
text_prompts: List[str]
|
||||||
|
):
|
||||||
|
|
||||||
|
bbox_batch, conf_batch, class_id_batch = [], [], []
|
||||||
|
source_h, source_w = source_size
|
||||||
|
for bbox_cxcywh, conf, phrase, text_prompt in zip(boxes_cxcywh, logits, phrases, text_prompts):
|
||||||
|
bbox_cxcywh *= torch.Tensor([source_w, source_h, source_w, source_h])
|
||||||
|
bbox_xyxy = box_convert(boxes=bbox_cxcywh, in_fmt="cxcywh", out_fmt="xyxy")
|
||||||
|
|
||||||
|
# Perform NMS
|
||||||
|
nms_idx = torchvision.ops.nms(bbox_xyxy.float(), conf.float(), nms_threshold).numpy().tolist()
|
||||||
|
class_id = cls.phrases2classes(phrases=phrase, classes=text_prompt)
|
||||||
|
|
||||||
|
bbox_batch.append(bbox_xyxy[nms_idx])
|
||||||
|
conf_batch.append(conf[nms_idx])
|
||||||
|
class_id_batch.append(class_id[nms_idx])
|
||||||
|
|
||||||
|
return bbox_batch, conf_batch, class_id_batch
|
||||||
|
|
||||||
|
#=====================================================
|
||||||
|
|
||||||
|
def _batched_predict(
|
||||||
|
self,
|
||||||
|
image_batch,
|
||||||
|
text_prompts,
|
||||||
|
box_threshold,
|
||||||
|
text_threshold
|
||||||
|
):
|
||||||
|
# Predict refactored to work with batches
|
||||||
|
captions = [preprocess_caption(caption) for caption in text_prompts]
|
||||||
|
|
||||||
|
outputs = self._model(image_batch, captions=captions)
|
||||||
|
|
||||||
|
prediction_logits = outputs["pred_logits"].cpu().sigmoid() # prediction_logits.shape = (bsz,nq, 256)
|
||||||
|
prediction_boxes = outputs["pred_boxes"].cpu() # prediction_boxes.shape = (bsz, nq, 4)
|
||||||
|
|
||||||
|
logits_res = []
|
||||||
|
boxs_res = []
|
||||||
|
phrases_list = []
|
||||||
|
tokenizer = self._model.tokenizer
|
||||||
|
for ub_logits, ub_boxes, ub_captions in zip(prediction_logits, prediction_boxes, captions):
|
||||||
|
mask = ub_logits.max(dim=1)[0] > box_threshold
|
||||||
|
logits = ub_logits[mask] # logits.shape = (n, 256)
|
||||||
|
boxes = ub_boxes[mask] # boxes.shape = (n, 4)
|
||||||
|
logits_res.append(logits.max(dim=1)[0])
|
||||||
|
boxs_res.append(boxes)
|
||||||
|
|
||||||
|
tokenized = tokenizer(ub_captions)
|
||||||
|
phrases = [
|
||||||
|
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
|
||||||
|
for logit
|
||||||
|
in logits
|
||||||
|
]
|
||||||
|
phrases_list.append(phrases)
|
||||||
|
|
||||||
|
return boxs_res, logits_res, phrases_list
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
image_batch: torch.Tensor,
|
||||||
|
text_prompts: List[str],
|
||||||
|
box_threshold: float = 0.3,
|
||||||
|
text_threshold: float = 0.3,
|
||||||
|
nms_threshold: float = 0.5
|
||||||
|
):
|
||||||
|
|
||||||
|
# Move to device and type just in case
|
||||||
|
image_batch = image_batch.to(device=self._device).to(dtype=self._dtype)
|
||||||
|
source_h, source_w = image_batch.shape[-2:]
|
||||||
|
|
||||||
|
if any(isinstance(i, list) for i in text_prompts):
|
||||||
|
captions = [". ".join(text_prompt) for text_prompt in text_prompts]
|
||||||
|
else:
|
||||||
|
captions = [". ".join(text_prompts)]
|
||||||
|
text_prompts = [text_prompts]
|
||||||
|
|
||||||
|
# Extend caption to batch
|
||||||
|
if len(captions) == 1:
|
||||||
|
captions *= image_batch.shape[0]
|
||||||
|
if len(text_prompts) == 1:
|
||||||
|
text_prompts *= image_batch.shape[0]
|
||||||
|
|
||||||
|
# Preprocess, inference and postprocess
|
||||||
|
processed_image = self.preprocess_image(image_batch)
|
||||||
|
bboxes, logits, phrases = self._batched_predict(
|
||||||
|
processed_image,
|
||||||
|
captions,
|
||||||
|
box_threshold,
|
||||||
|
text_threshold
|
||||||
|
)
|
||||||
|
bbox_batch, conf_batch, class_id_batch = self.post_process_result(
|
||||||
|
bboxes,
|
||||||
|
logits,
|
||||||
|
nms_threshold,
|
||||||
|
(source_h, source_w),
|
||||||
|
phrases,
|
||||||
|
text_prompts
|
||||||
|
)
|
||||||
|
|
||||||
|
return bbox_batch, conf_batch, class_id_batch
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def phrases2classes(phrases: List[str], classes: List[str]) -> np.ndarray:
|
||||||
|
class_ids = []
|
||||||
|
for phrase in phrases:
|
||||||
|
for class_ in classes:
|
||||||
|
if class_.lower() in phrase.lower():
|
||||||
|
class_ids.append(classes.index(class_))
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
class_ids.append(None)
|
||||||
|
return np.array(class_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> Any:
|
||||||
|
return self.predict(*args, **kwargs)
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
[build-system]
|
||||||
|
requires = [
|
||||||
|
"setuptools",
|
||||||
|
"torch",
|
||||||
|
"wheel"
|
||||||
|
]
|
||||||
|
build-backend = "setuptools.build_meta"
|
|
@ -1,3 +1,4 @@
|
||||||
|
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||||
torch
|
torch
|
||||||
torchvision
|
torchvision
|
||||||
transformers
|
transformers
|
||||||
|
|
24
setup.py
|
@ -189,8 +189,25 @@ def parse_requirements(fname="requirements.txt", with_version=True):
|
||||||
item = "".join(parts)
|
item = "".join(parts)
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
def filter_index(packages):
|
||||||
|
|
||||||
|
new_packages = []
|
||||||
|
dependency_links = []
|
||||||
|
for i, requirement in enumerate(packages):
|
||||||
|
if requirement.startswith("--extra-index-url"):
|
||||||
|
dependency_links.append(requirement.split()[-1])
|
||||||
|
elif requirement.startswith("./dependencies") or requirement.startswith(
|
||||||
|
"dependencies"
|
||||||
|
):
|
||||||
|
dependency_links.append(requirement)
|
||||||
|
else:
|
||||||
|
new_packages.append(requirement)
|
||||||
|
|
||||||
|
return new_packages, dependency_links
|
||||||
|
|
||||||
packages = list(gen_packages_items())
|
packages = list(gen_packages_items())
|
||||||
return packages
|
packages, dependency_links = filter_index(packages)
|
||||||
|
return packages, dependency_links
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -201,6 +218,8 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
write_version_file()
|
write_version_file()
|
||||||
|
|
||||||
|
install_requires, dependency_links = parse_requirements("requirements.txt")
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="groundingdino",
|
name="groundingdino",
|
||||||
version="0.1.0",
|
version="0.1.0",
|
||||||
|
@ -208,7 +227,8 @@ if __name__ == "__main__":
|
||||||
url="https://github.com/IDEA-Research/GroundingDINO",
|
url="https://github.com/IDEA-Research/GroundingDINO",
|
||||||
description="open-set object detector",
|
description="open-set object detector",
|
||||||
license=license,
|
license=license,
|
||||||
install_requires=parse_requirements("requirements.txt"),
|
install_requires=install_requires,
|
||||||
|
dependency_links=dependency_links,
|
||||||
packages=find_packages(
|
packages=find_packages(
|
||||||
exclude=(
|
exclude=(
|
||||||
"configs",
|
"configs",
|
||||||
|
|