diff --git a/dinov2/data/datasets/custom_image_dataset.py b/dinov2/data/datasets/custom_image_dataset.py index e35f049..fefdac7 100644 --- a/dinov2/data/datasets/custom_image_dataset.py +++ b/dinov2/data/datasets/custom_image_dataset.py @@ -2,17 +2,19 @@ import os import pathlib import random from typing import List +from omegaconf.listconfig import ListConfig from torch.utils.data import Dataset from .decoders import ImageDataDecoder from PIL import Image class ImageDataset(Dataset): - def __init__(self, root, transform=None, path_preserved: List[str]=[], frac: float=0.1): + def __init__(self, root, transform=None, path_preserved: List[str]=[], frac: float=0.1, is_valid=True): self.root = root self.transform = transform - self.path_preserved = path_preserved if isinstance(path_preserved, list) else [path_preserved] + self.path_preserved = path_preserved if isinstance(path_preserved, (list, ListConfig)) else [path_preserved] self.frac = frac self.preserved_images = [] + self.is_valid = is_valid self.images_list = self._get_image_list() def _get_image_list(self): @@ -22,7 +24,7 @@ class ImageDataset(Dataset): try: p = self.root preserve = p in self.path_preserved - images.extend(self._retrieve_images(p, preserve=preserve, frac=self.frac)) + images.extend(self._retrieve_images(p, preserve=preserve, frac=self.frac, is_valid=self.is_valid)) except OSError: print("The root given is nor a list nor a path") @@ -31,29 +33,32 @@ class ImageDataset(Dataset): for p in self.root: try: preserve = p in self.path_preserved - images.extend(self._retrieve_images(p, preserve=preserve, frac=self.frac)) + images.extend(self._retrieve_images(p, preserve=preserve, frac=self.frac, is_valid=self.is_valid)) except OSError: print(f"the path indicated at {p} cannot be found.") return images - def _retrieve_images(self, path, is_valid=False, preserve=False, frac=1): + def _retrieve_images(self, path, is_valid=True, preserve=False, frac=1): images_ini = len(self.preserved_images) images = [] for root, _, files in os.walk(path): images_dir = [] for file in files: if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff')): + im = os.path.join(root, file) if is_valid: try: - Image.open(os.path.join(root, file)) - images_dir.append(os.path.join(root, file)) + with open(im, 'rb') as f: + image_data = f.read() + ImageDataDecoder(image_data).decode() + images_dir.append(im) except OSError: - print(f"Image at path {os.path.join(root, file)} could not be opened.") + print(f"Image at path {im} could not be opened.") else: - images_dir.append(os.path.join(root, file)) + images_dir.append(im) if preserve: random.seed(24) @@ -61,6 +66,9 @@ class ImageDataset(Dataset): split_index = int(len(images_dir) * frac) self.preserved_images.extend(images_dir[:split_index]) images.extend(images_dir[split_index:]) + + else: + images.extend(images_dir) images_end = len(self.preserved_images) if preserve: diff --git a/dinov2/train/train.py b/dinov2/train/train.py index 09f2f04..aec3ffc 100644 --- a/dinov2/train/train.py +++ b/dinov2/train/train.py @@ -200,7 +200,7 @@ def do_train(cfg, model, resume=False): # save the preserved images if dataset.preserved_images: - write_list(os.path.join(cfg.train.output_dir, 'preserved_images.pkl')) + write_list(os.path.join(cfg.train.output_dir, 'preserved_images.pkl'), dataset.preserved_images) # sampler_type = SamplerType.INFINITE sampler_type = SamplerType.SHARDED_INFINITE # define the sampler to use for fsdp