GroundingDINO/groundingdino/util/inference.py

306 lines
10 KiB
Python
Raw Normal View History

from typing import Tuple, List
import cv2
import numpy as np
import supervision as sv
import torch
from PIL import Image
from torchvision.ops import box_convert
import bisect
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util.misc import clean_state_dict
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import get_phrases_from_posmap
# ----------------------------------------------------------------------------------------------------------------------
# OLD API
# ----------------------------------------------------------------------------------------------------------------------
def preprocess_caption(caption: str) -> str:
result = caption.lower().strip()
if result.endswith("."):
return result
return result + "."
2023-03-28 16:30:45 +08:00
def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
args = SLConfig.fromfile(model_config_path)
2023-03-28 16:30:45 +08:00
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
model.eval()
return model
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
"""
Load an image and apply transformations.
This function takes the path to an image file, loads the image, and applies a series of transformations to it.
The transformations include resizing the image, converting it to a tensor, and normalizing its pixel values.
Parameters:
image_path (str): The path to the image file.
Returns:
Tuple[np.array, torch.Tensor]: A tuple containing the original image as a NumPy array and the transformed image as a PyTorch tensor.
"""
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_source = Image.open(image_path).convert("RGB")
image = np.asarray(image_source)
image_transformed, _ = transform(image_source, None)
return image, image_transformed
def transform_image(t: PIL.Image.Image) -> Tuple[np.array, torch.Tensor]:
"""
Transform an RGB image and convert it to a tensor.
This function takes a PIL Image, applies a series of transformations to it, and returns the original and transformed images.
The transformations include resizing the image, converting it to a tensor, and normalizing its pixel values.
Parameters:
img (PIL.Image.Image): The input image.
Returns:
Tuple[np.array, torch.Tensor]: A tuple containing the original image as a NumPy array and the transformed image as a PyTorch tensor.
"""
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image = np.asarray(t)
image_transformed, _ = transform(t, None)
return image, image_transformed
def predict(
model,
image: torch.Tensor,
caption: str,
box_threshold: float,
2023-03-28 16:30:45 +08:00
text_threshold: float,
device: str = "cuda",
remove_combined: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
caption = preprocess_caption(caption=caption)
2023-03-28 16:30:45 +08:00
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
mask = prediction_logits.max(dim=1)[0] > box_threshold
logits = prediction_logits[mask] # logits.shape = (n, 256)
boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
tokenizer = model.tokenizer
tokenized = tokenizer(caption)
if remove_combined:
sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
phrases = []
for logit in logits:
max_idx = logit.argmax()
insert_idx = bisect.bisect_left(sep_idx, max_idx)
right_idx = sep_idx[insert_idx]
left_idx = sep_idx[insert_idx - 1]
phrases.append(get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer, left_idx, right_idx).replace('.', ''))
else:
phrases = [
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
for logit
in logits
]
return boxes, logits.max(dim=1)[0], phrases
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
2024-02-23 15:10:00 +08:00
"""
This function annotates an image with bounding boxes and labels.
Parameters:
image_source (np.ndarray): The source image to be annotated.
boxes (torch.Tensor): A tensor containing bounding box coordinates.
logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
phrases (List[str]): A list of labels for each bounding box.
Returns:
np.ndarray: The annotated image.
"""
h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
detections = sv.Detections(xyxy=xyxy)
labels = [
f"{phrase} {logit:.2f}"
for phrase, logit
in zip(phrases, logits)
]
box_annotator = sv.BoxAnnotator()
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
return annotated_frame
# ----------------------------------------------------------------------------------------------------------------------
# NEW API
# ----------------------------------------------------------------------------------------------------------------------
class Model:
def __init__(
self,
model_config_path: str,
model_checkpoint_path: str,
device: str = "cuda"
):
self.model = load_model(
model_config_path=model_config_path,
model_checkpoint_path=model_checkpoint_path,
device=device
).to(device)
self.device = device
def predict_with_caption(
self,
image: np.ndarray,
caption: str,
box_threshold: float = 0.35,
text_threshold: float = 0.25
) -> Tuple[sv.Detections, List[str]]:
"""
import cv2
image = cv2.imread(IMAGE_PATH)
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
detections, labels = model.predict_with_caption(
image=image,
caption=caption,
box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD
)
import supervision as sv
box_annotator = sv.BoxAnnotator()
annotated_image = box_annotator.annotate(scene=image, detections=detections, labels=labels)
"""
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
boxes, logits, phrases = predict(
model=self.model,
image=processed_image,
caption=caption,
box_threshold=box_threshold,
text_threshold=text_threshold,
2023-05-02 19:41:34 +08:00
device=self.device)
source_h, source_w, _ = image.shape
detections = Model.post_process_result(
source_h=source_h,
source_w=source_w,
boxes=boxes,
logits=logits)
return detections, phrases
def predict_with_classes(
self,
image: np.ndarray,
classes: List[str],
box_threshold: float,
text_threshold: float
) -> sv.Detections:
"""
import cv2
image = cv2.imread(IMAGE_PATH)
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
detections = model.predict_with_classes(
image=image,
classes=CLASSES,
box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD
)
import supervision as sv
box_annotator = sv.BoxAnnotator()
annotated_image = box_annotator.annotate(scene=image, detections=detections)
"""
2023-05-04 01:53:32 +08:00
caption = ". ".join(classes)
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
boxes, logits, phrases = predict(
model=self.model,
image=processed_image,
caption=caption,
box_threshold=box_threshold,
text_threshold=text_threshold,
2023-05-02 19:41:34 +08:00
device=self.device)
source_h, source_w, _ = image.shape
detections = Model.post_process_result(
source_h=source_h,
source_w=source_w,
boxes=boxes,
logits=logits)
class_id = Model.phrases2classes(phrases=phrases, classes=classes)
detections.class_id = class_id
return detections
@staticmethod
def preprocess_image(image_bgr: np.ndarray) -> torch.Tensor:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_pillow = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
image_transformed, _ = transform(image_pillow, None)
return image_transformed
@staticmethod
def post_process_result(
source_h: int,
source_w: int,
boxes: torch.Tensor,
logits: torch.Tensor
) -> sv.Detections:
boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
confidence = logits.numpy()
return sv.Detections(xyxy=xyxy, confidence=confidence)
@staticmethod
def phrases2classes(phrases: List[str], classes: List[str]) -> np.ndarray:
class_ids = []
for phrase in phrases:
for class_ in classes:
if class_ in phrase:
class_ids.append(classes.index(class_))
break
else:
class_ids.append(None)
return np.array(class_ids)