diff --git a/groundingdino/util/inference.py b/groundingdino/util/inference.py index 8168b96..66461bc 100644 --- a/groundingdino/util/inference.py +++ b/groundingdino/util/inference.py @@ -13,6 +13,10 @@ from groundingdino.util.misc import clean_state_dict from groundingdino.util.slconfig import SLConfig from groundingdino.util.utils import get_phrases_from_posmap +# ---------------------------------------------------------------------------------------------------------------------- +# OLD API +# ---------------------------------------------------------------------------------------------------------------------- + def preprocess_caption(caption: str) -> str: result = caption.lower().strip() @@ -96,3 +100,143 @@ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR) annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) return annotated_frame + + +# ---------------------------------------------------------------------------------------------------------------------- +# NEW API +# ---------------------------------------------------------------------------------------------------------------------- + + +class Model: + + def __init__( + self, + model_config_path: str, + model_checkpoint_path: str, + device: str = "cuda" + ): + self.model = load_model( + model_config_path=model_config_path, + model_checkpoint_path=model_checkpoint_path, + device=device + ).to(device) + self.device = device + + def predict_with_caption( + self, + image: np.ndarray, + caption: str, + box_threshold: float = 0.35, + text_threshold: float = 0.25 + ) -> Tuple[sv.Detections, List[str]]: + """ + import cv2 + + image = cv2.imread(IMAGE_PATH) + + model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH) + detections, labels = model.predict_with_caption( + image=image, + caption=caption, + box_threshold=BOX_THRESHOLD, + text_threshold=TEXT_THRESHOLD + ) + + import supervision as sv + + box_annotator = sv.BoxAnnotator() + annotated_image = box_annotator.annotate(scene=image, detections=detections, labels=labels) + """ + processed_image = Model.preprocess_image(image_bgr=image).to(self.device) + boxes, logits, phrases = predict( + model=self.model, + image=processed_image, + caption=caption, + box_threshold=box_threshold, + text_threshold=text_threshold) + source_h, source_w, _ = image.shape + detections = Model.post_process_result( + source_h=source_h, + source_w=source_w, + boxes=boxes, + logits=logits) + return detections, phrases + + def predict_with_classes( + self, + image: np.ndarray, + classes: List[str], + box_threshold: float, + text_threshold: float + ) -> sv.Detections: + """ + import cv2 + + image = cv2.imread(IMAGE_PATH) + + model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH) + detections = model.predict_with_classes( + image=image, + classes=CLASSES, + box_threshold=BOX_THRESHOLD, + text_threshold=TEXT_THRESHOLD + ) + + + import supervision as sv + + box_annotator = sv.BoxAnnotator() + annotated_image = box_annotator.annotate(scene=image, detections=detections) + """ + caption = ", ".join(classes) + processed_image = Model.preprocess_image(image_bgr=image).to(self.device) + boxes, logits, phrases = predict( + model=self.model, + image=processed_image, + caption=caption, + box_threshold=box_threshold, + text_threshold=text_threshold) + source_h, source_w, _ = image.shape + detections = Model.post_process_result( + source_h=source_h, + source_w=source_w, + boxes=boxes, + logits=logits) + class_id = Model.phrases2classes(phrases=phrases, classes=classes) + detections.class_id = class_id + return detections + + @staticmethod + def preprocess_image(image_bgr: np.ndarray) -> torch.Tensor: + transform = T.Compose( + [ + T.RandomResize([800], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + image_pillow = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)) + image_transformed, _ = transform(image_pillow, None) + return image_transformed + + @staticmethod + def post_process_result( + source_h: int, + source_w: int, + boxes: torch.Tensor, + logits: torch.Tensor + ) -> sv.Detections: + boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h]) + xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() + confidence = logits.numpy() + return sv.Detections(xyxy=xyxy, confidence=confidence) + + @staticmethod + def phrases2classes(phrases: List[str], classes: List[str]) -> np.ndarray: + class_ids = [] + for phrase in phrases: + try: + class_ids.append(classes.index(phrase)) + except ValueError: + class_ids.append(None) + return np.array(class_ids) diff --git a/requirements.txt b/requirements.txt index f52ed0a..dc8575b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,5 +6,5 @@ yapf timm numpy opencv-python -supervision==0.3.2 +supervision==0.4.0 pycocotools \ No newline at end of file