Update inference.py (#298)

pull/279/head^2
ASHWIN UNNIKRISHNAN 2024-02-23 02:10:00 -05:00 committed by GitHub
parent 2b62f419c2
commit d13643262e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 12 additions and 0 deletions

View File

@ -98,6 +98,18 @@ def predict(
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
"""
This function annotates an image with bounding boxes and labels.
Parameters:
image_source (np.ndarray): The source image to be annotated.
boxes (torch.Tensor): A tensor containing bounding box coordinates.
logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
phrases (List[str]): A list of labels for each bounding box.
Returns:
np.ndarray: The annotated image.
"""
h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()