diff --git a/dinov2/data/datasets/custom_image_dataset.py b/dinov2/data/datasets/custom_image_dataset.py index e35f049..be10f83 100644 --- a/dinov2/data/datasets/custom_image_dataset.py +++ b/dinov2/data/datasets/custom_image_dataset.py @@ -38,22 +38,25 @@ class ImageDataset(Dataset): return images - def _retrieve_images(self, path, is_valid=False, preserve=False, frac=1): + def _retrieve_images(self, path, is_valid=True, preserve=False, frac=1): images_ini = len(self.preserved_images) images = [] for root, _, files in os.walk(path): images_dir = [] for file in files: if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff')): + im = os.path.join(root, file) if is_valid: try: - Image.open(os.path.join(root, file)) - images_dir.append(os.path.join(root, file)) + with open(im, 'rb') as f: + image_data = f.read() + ImageDataDecoder(image_data).decode() + images_dir.append(im) except OSError: - print(f"Image at path {os.path.join(root, file)} could not be opened.") + print(f"Image at path {im} could not be opened.") else: - images_dir.append(os.path.join(root, file)) + images_dir.append(im) if preserve: random.seed(24) @@ -61,6 +64,9 @@ class ImageDataset(Dataset): split_index = int(len(images_dir) * frac) self.preserved_images.extend(images_dir[:split_index]) images.extend(images_dir[split_index:]) + + else: + images.extend(images_dir) images_end = len(self.preserved_images) if preserve: diff --git a/dinov2/train/train.py b/dinov2/train/train.py index 09f2f04..aec3ffc 100644 --- a/dinov2/train/train.py +++ b/dinov2/train/train.py @@ -200,7 +200,7 @@ def do_train(cfg, model, resume=False): # save the preserved images if dataset.preserved_images: - write_list(os.path.join(cfg.train.output_dir, 'preserved_images.pkl')) + write_list(os.path.join(cfg.train.output_dir, 'preserved_images.pkl'), dataset.preserved_images) # sampler_type = SamplerType.INFINITE sampler_type = SamplerType.SHARDED_INFINITE # define the sampler to use for fsdp