⚙️ more compact inference API - single class to load, process and infer (#16)
* ⚙️ more compact inference API - single class to load, process and infer * 👊 bump Supervision version to `0.4.0`add_grounding_dino_with_stable_diffsion
parent
f6b1145481
commit
e45c11c4c3
|
@ -13,6 +13,10 @@ from groundingdino.util.misc import clean_state_dict
|
||||||
from groundingdino.util.slconfig import SLConfig
|
from groundingdino.util.slconfig import SLConfig
|
||||||
from groundingdino.util.utils import get_phrases_from_posmap
|
from groundingdino.util.utils import get_phrases_from_posmap
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------------------------
|
||||||
|
# OLD API
|
||||||
|
# ----------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def preprocess_caption(caption: str) -> str:
|
def preprocess_caption(caption: str) -> str:
|
||||||
result = caption.lower().strip()
|
result = caption.lower().strip()
|
||||||
|
@ -96,3 +100,143 @@ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor
|
||||||
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
|
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
|
||||||
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||||||
return annotated_frame
|
return annotated_frame
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------------------------
|
||||||
|
# NEW API
|
||||||
|
# ----------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class Model:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config_path: str,
|
||||||
|
model_checkpoint_path: str,
|
||||||
|
device: str = "cuda"
|
||||||
|
):
|
||||||
|
self.model = load_model(
|
||||||
|
model_config_path=model_config_path,
|
||||||
|
model_checkpoint_path=model_checkpoint_path,
|
||||||
|
device=device
|
||||||
|
).to(device)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def predict_with_caption(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
caption: str,
|
||||||
|
box_threshold: float = 0.35,
|
||||||
|
text_threshold: float = 0.25
|
||||||
|
) -> Tuple[sv.Detections, List[str]]:
|
||||||
|
"""
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
image = cv2.imread(IMAGE_PATH)
|
||||||
|
|
||||||
|
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
|
||||||
|
detections, labels = model.predict_with_caption(
|
||||||
|
image=image,
|
||||||
|
caption=caption,
|
||||||
|
box_threshold=BOX_THRESHOLD,
|
||||||
|
text_threshold=TEXT_THRESHOLD
|
||||||
|
)
|
||||||
|
|
||||||
|
import supervision as sv
|
||||||
|
|
||||||
|
box_annotator = sv.BoxAnnotator()
|
||||||
|
annotated_image = box_annotator.annotate(scene=image, detections=detections, labels=labels)
|
||||||
|
"""
|
||||||
|
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
|
||||||
|
boxes, logits, phrases = predict(
|
||||||
|
model=self.model,
|
||||||
|
image=processed_image,
|
||||||
|
caption=caption,
|
||||||
|
box_threshold=box_threshold,
|
||||||
|
text_threshold=text_threshold)
|
||||||
|
source_h, source_w, _ = image.shape
|
||||||
|
detections = Model.post_process_result(
|
||||||
|
source_h=source_h,
|
||||||
|
source_w=source_w,
|
||||||
|
boxes=boxes,
|
||||||
|
logits=logits)
|
||||||
|
return detections, phrases
|
||||||
|
|
||||||
|
def predict_with_classes(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
classes: List[str],
|
||||||
|
box_threshold: float,
|
||||||
|
text_threshold: float
|
||||||
|
) -> sv.Detections:
|
||||||
|
"""
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
image = cv2.imread(IMAGE_PATH)
|
||||||
|
|
||||||
|
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
|
||||||
|
detections = model.predict_with_classes(
|
||||||
|
image=image,
|
||||||
|
classes=CLASSES,
|
||||||
|
box_threshold=BOX_THRESHOLD,
|
||||||
|
text_threshold=TEXT_THRESHOLD
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
import supervision as sv
|
||||||
|
|
||||||
|
box_annotator = sv.BoxAnnotator()
|
||||||
|
annotated_image = box_annotator.annotate(scene=image, detections=detections)
|
||||||
|
"""
|
||||||
|
caption = ", ".join(classes)
|
||||||
|
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
|
||||||
|
boxes, logits, phrases = predict(
|
||||||
|
model=self.model,
|
||||||
|
image=processed_image,
|
||||||
|
caption=caption,
|
||||||
|
box_threshold=box_threshold,
|
||||||
|
text_threshold=text_threshold)
|
||||||
|
source_h, source_w, _ = image.shape
|
||||||
|
detections = Model.post_process_result(
|
||||||
|
source_h=source_h,
|
||||||
|
source_w=source_w,
|
||||||
|
boxes=boxes,
|
||||||
|
logits=logits)
|
||||||
|
class_id = Model.phrases2classes(phrases=phrases, classes=classes)
|
||||||
|
detections.class_id = class_id
|
||||||
|
return detections
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def preprocess_image(image_bgr: np.ndarray) -> torch.Tensor:
|
||||||
|
transform = T.Compose(
|
||||||
|
[
|
||||||
|
T.RandomResize([800], max_size=1333),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
image_pillow = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
|
||||||
|
image_transformed, _ = transform(image_pillow, None)
|
||||||
|
return image_transformed
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def post_process_result(
|
||||||
|
source_h: int,
|
||||||
|
source_w: int,
|
||||||
|
boxes: torch.Tensor,
|
||||||
|
logits: torch.Tensor
|
||||||
|
) -> sv.Detections:
|
||||||
|
boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
|
||||||
|
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
||||||
|
confidence = logits.numpy()
|
||||||
|
return sv.Detections(xyxy=xyxy, confidence=confidence)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def phrases2classes(phrases: List[str], classes: List[str]) -> np.ndarray:
|
||||||
|
class_ids = []
|
||||||
|
for phrase in phrases:
|
||||||
|
try:
|
||||||
|
class_ids.append(classes.index(phrase))
|
||||||
|
except ValueError:
|
||||||
|
class_ids.append(None)
|
||||||
|
return np.array(class_ids)
|
||||||
|
|
|
@ -6,5 +6,5 @@ yapf
|
||||||
timm
|
timm
|
||||||
numpy
|
numpy
|
||||||
opencv-python
|
opencv-python
|
||||||
supervision==0.3.2
|
supervision==0.4.0
|
||||||
pycocotools
|
pycocotools
|
Loading…
Reference in New Issue