pull/299/merge
ASHWIN UNNIKRISHNAN 2024-08-13 21:17:37 +08:00 committed by GitHub
commit 896cd3a5f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 34 additions and 0 deletions

View File

@ -37,6 +37,17 @@ def load_model(model_config_path: str, model_checkpoint_path: str, device: str =
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
"""
Load an image and apply transformations.
This function takes the path to an image file, loads the image, and applies a series of transformations to it.
The transformations include resizing the image, converting it to a tensor, and normalizing its pixel values.
Parameters:
image_path (str): The path to the image file.
Returns:
Tuple[np.array, torch.Tensor]: A tuple containing the original image as a NumPy array and the transformed image as a PyTorch tensor.
"""
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
@ -49,6 +60,29 @@ def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
image_transformed, _ = transform(image_source, None)
return image, image_transformed
def transform_image(PIL_image: PIL.Image.Image) -> Tuple[np.array, torch.Tensor]:
"""
Transform an RGB image and convert it to a tensor.
This function takes a PIL Image, applies a series of transformations to it, and returns the original and transformed images.
The transformations include resizing the image, converting it to a tensor, and normalizing its pixel values.
Parameters:
PIL_image (PIL.Image.Image): The input image.
Returns:
Tuple[np.array, torch.Tensor]: A tuple containing the original image as a NumPy array and the transformed image as a PyTorch tensor.
"""
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image = np.asarray(PIL_image)
image_transformed, _ = transform(PIL_image, None)
return image, image_transformed
def predict(
model,