⚙️ more compact inference API - single class to load, process and infer (#16)
* ⚙️ more compact inference API - single class to load, process and infer * 👊 bump Supervision version to `0.4.0`add_grounding_dino_with_stable_diffsion
parent
f6b1145481
commit
e45c11c4c3
groundingdino/util
|
@ -13,6 +13,10 @@ from groundingdino.util.misc import clean_state_dict
|
|||
from groundingdino.util.slconfig import SLConfig
|
||||
from groundingdino.util.utils import get_phrases_from_posmap
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
# OLD API
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def preprocess_caption(caption: str) -> str:
|
||||
result = caption.lower().strip()
|
||||
|
@ -96,3 +100,143 @@ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor
|
|||
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
|
||||
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||||
return annotated_frame
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
# NEW API
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Model:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config_path: str,
|
||||
model_checkpoint_path: str,
|
||||
device: str = "cuda"
|
||||
):
|
||||
self.model = load_model(
|
||||
model_config_path=model_config_path,
|
||||
model_checkpoint_path=model_checkpoint_path,
|
||||
device=device
|
||||
).to(device)
|
||||
self.device = device
|
||||
|
||||
def predict_with_caption(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
caption: str,
|
||||
box_threshold: float = 0.35,
|
||||
text_threshold: float = 0.25
|
||||
) -> Tuple[sv.Detections, List[str]]:
|
||||
"""
|
||||
import cv2
|
||||
|
||||
image = cv2.imread(IMAGE_PATH)
|
||||
|
||||
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
|
||||
detections, labels = model.predict_with_caption(
|
||||
image=image,
|
||||
caption=caption,
|
||||
box_threshold=BOX_THRESHOLD,
|
||||
text_threshold=TEXT_THRESHOLD
|
||||
)
|
||||
|
||||
import supervision as sv
|
||||
|
||||
box_annotator = sv.BoxAnnotator()
|
||||
annotated_image = box_annotator.annotate(scene=image, detections=detections, labels=labels)
|
||||
"""
|
||||
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
|
||||
boxes, logits, phrases = predict(
|
||||
model=self.model,
|
||||
image=processed_image,
|
||||
caption=caption,
|
||||
box_threshold=box_threshold,
|
||||
text_threshold=text_threshold)
|
||||
source_h, source_w, _ = image.shape
|
||||
detections = Model.post_process_result(
|
||||
source_h=source_h,
|
||||
source_w=source_w,
|
||||
boxes=boxes,
|
||||
logits=logits)
|
||||
return detections, phrases
|
||||
|
||||
def predict_with_classes(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
classes: List[str],
|
||||
box_threshold: float,
|
||||
text_threshold: float
|
||||
) -> sv.Detections:
|
||||
"""
|
||||
import cv2
|
||||
|
||||
image = cv2.imread(IMAGE_PATH)
|
||||
|
||||
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
|
||||
detections = model.predict_with_classes(
|
||||
image=image,
|
||||
classes=CLASSES,
|
||||
box_threshold=BOX_THRESHOLD,
|
||||
text_threshold=TEXT_THRESHOLD
|
||||
)
|
||||
|
||||
|
||||
import supervision as sv
|
||||
|
||||
box_annotator = sv.BoxAnnotator()
|
||||
annotated_image = box_annotator.annotate(scene=image, detections=detections)
|
||||
"""
|
||||
caption = ", ".join(classes)
|
||||
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
|
||||
boxes, logits, phrases = predict(
|
||||
model=self.model,
|
||||
image=processed_image,
|
||||
caption=caption,
|
||||
box_threshold=box_threshold,
|
||||
text_threshold=text_threshold)
|
||||
source_h, source_w, _ = image.shape
|
||||
detections = Model.post_process_result(
|
||||
source_h=source_h,
|
||||
source_w=source_w,
|
||||
boxes=boxes,
|
||||
logits=logits)
|
||||
class_id = Model.phrases2classes(phrases=phrases, classes=classes)
|
||||
detections.class_id = class_id
|
||||
return detections
|
||||
|
||||
@staticmethod
|
||||
def preprocess_image(image_bgr: np.ndarray) -> torch.Tensor:
|
||||
transform = T.Compose(
|
||||
[
|
||||
T.RandomResize([800], max_size=1333),
|
||||
T.ToTensor(),
|
||||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
image_pillow = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
|
||||
image_transformed, _ = transform(image_pillow, None)
|
||||
return image_transformed
|
||||
|
||||
@staticmethod
|
||||
def post_process_result(
|
||||
source_h: int,
|
||||
source_w: int,
|
||||
boxes: torch.Tensor,
|
||||
logits: torch.Tensor
|
||||
) -> sv.Detections:
|
||||
boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
|
||||
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
||||
confidence = logits.numpy()
|
||||
return sv.Detections(xyxy=xyxy, confidence=confidence)
|
||||
|
||||
@staticmethod
|
||||
def phrases2classes(phrases: List[str], classes: List[str]) -> np.ndarray:
|
||||
class_ids = []
|
||||
for phrase in phrases:
|
||||
try:
|
||||
class_ids.append(classes.index(phrase))
|
||||
except ValueError:
|
||||
class_ids.append(None)
|
||||
return np.array(class_ids)
|
||||
|
|
|
@ -6,5 +6,5 @@ yapf
|
|||
timm
|
||||
numpy
|
||||
opencv-python
|
||||
supervision==0.3.2
|
||||
supervision==0.4.0
|
||||
pycocotools
|
Loading…
Reference in New Issue