diff --git a/dinov2/data/datasets/custom_image_dataset.py b/dinov2/data/datasets/custom_image_dataset.py index 6b414a6..d17e4a0 100644 --- a/dinov2/data/datasets/custom_image_dataset.py +++ b/dinov2/data/datasets/custom_image_dataset.py @@ -14,6 +14,9 @@ class ImageDataset(Dataset): self.frac = frac self.preserved_images = [] self.images_list = self._get_image_list() + self.path_preserved = path_preserved if isinstance(path_preserved, list) else list(path_preserved) + self.frac = frac + self.preserved_images = [] def _get_image_list(self): images = [] @@ -52,12 +55,12 @@ class ImageDataset(Dataset): else: images_dir.append(os.path.join(root, file)) - if preserve: - random.seed(24) - random.shuffle(images_dir) - split_index = int(len(images_dir) * frac) - self.preserved_images.extend(images_dir[:split_index]) - images.extend(images_dir[split_index:]) + if preserve: + random.seed(24) + random.shuffle(images_dir) + split_index = int(len(images_dir) * frac) + self.preserved_images.extend(images_dir[:split_index]) + images.extend(images_dir[split_index:]) return images