diff --git a/dinov2/data/datasets/custom_image_dataset.py b/dinov2/data/datasets/custom_image_dataset.py index ed1e1a4..30caa40 100644 --- a/dinov2/data/datasets/custom_image_dataset.py +++ b/dinov2/data/datasets/custom_image_dataset.py @@ -2,6 +2,7 @@ import os import pathlib from torch.utils.data import Dataset +from .decoders import ImageDataDecoder from PIL import Image class ImageDataset(Dataset): @@ -37,7 +38,7 @@ class ImageDataset(Dataset): if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff')): if is_valid: try: - Image.open(os.path.join(root, file)).convert('RGB') + Image.open(os.path.join(root, file)) images.append(os.path.join(root, file)) except OSError: @@ -46,15 +47,25 @@ class ImageDataset(Dataset): images.append(os.path.join(root, file)) return images + + def get_image_data(self, index: int): + path = self.images_list[index] + with open(path) as f: + image_data = f.read() + + return image_data def __len__(self): return len(self.images_list) - def __getitem__(self, idx): - image_path = self.images_list[idx] - image = Image.open(image_path) + def __getitem__(self, index: int): + try: + image_data = self.get_image_data(index) + image = ImageDataDecoder(image_data).decode() + except Exception as e: + raise RuntimeError(f"can nor read image for sample {index}") from e - if self.transform: + if self.transform is not None: image = self.transform(image) return image