Update inference.py

Adding additional function transform_image.. so that users can use a loaded RGB image and transform it to the expected format of grounding dino prediction.
pull/299/head
ASHWIN UNNIKRISHNAN 2024-02-25 22:31:30 -05:00 committed by GitHub
parent d13643262e
commit 98a10ca0e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 34 additions and 0 deletions

View File

@ -37,6 +37,17 @@ def load_model(model_config_path: str, model_checkpoint_path: str, device: str =
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
"""
Load an image and apply transformations.
This function takes the path to an image file, loads the image, and applies a series of transformations to it.
The transformations include resizing the image, converting it to a tensor, and normalizing its pixel values.
Parameters:
image_path (str): The path to the image file.
Returns:
Tuple[np.array, torch.Tensor]: A tuple containing the original image as a NumPy array and the transformed image as a PyTorch tensor.
"""
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
@ -49,6 +60,29 @@ def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
image_transformed, _ = transform(image_source, None)
return image, image_transformed
def transform_image(t: PIL.Image.Image) -> Tuple[np.array, torch.Tensor]:
"""
Transform an RGB image and convert it to a tensor.
This function takes a PIL Image, applies a series of transformations to it, and returns the original and transformed images.
The transformations include resizing the image, converting it to a tensor, and normalizing its pixel values.
Parameters:
img (PIL.Image.Image): The input image.
Returns:
Tuple[np.array, torch.Tensor]: A tuple containing the original image as a NumPy array and the transformed image as a PyTorch tensor.
"""
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image = np.asarray(t)
image_transformed, _ = transform(t, None)
return image, image_transformed
def predict(
model,