From 2e05e6b4efe65fe1b106ca2de8ea335d52a76c9b Mon Sep 17 00:00:00 2001 From: Etienne Guevel Date: Fri, 24 May 2024 16:55:10 +0200 Subject: [PATCH] preserve lab data changes --- .../configs/train/vitl_cellsim_register.yaml | 32 +++++++++++++++++++ dinov2/data/datasets/custom_image_dataset.py | 29 ++++++++++++----- 2 files changed, 53 insertions(+), 8 deletions(-) create mode 100644 dinov2/configs/train/vitl_cellsim_register.yaml diff --git a/dinov2/configs/train/vitl_cellsim_register.yaml b/dinov2/configs/train/vitl_cellsim_register.yaml new file mode 100644 index 0000000..85ccf70 --- /dev/null +++ b/dinov2/configs/train/vitl_cellsim_register.yaml @@ -0,0 +1,32 @@ +train: + dataset_path: + - /home/manon/classification/data/Single_cells/medhi + - /home/manon/classification/data/Single_cells/vexas_original/Unlabeled + - /home/manon/classification/data/Single_cells/matek + centering: sinkhorn_knopp + batch_size_per_gpu: 64 + output_dir: /home/guevel/OT4D/cell_similarity/vitl_register + OFFICIAL_EPOCH_LENGTH: 1 +dino: + head_n_prototypes: 131072 + head_bottleneck_dim: 384 +ibot: + separate_head: true + head_n_prototypes: 131072 +student: + arch: vit_large + patch_size: 14 + drop_path_rate: 0.4 + ffn_layer: swiglufused + block_chunks: 4 + num_register_tokens: 4 +teacher: + momentum_teacher: 0.994 +optim: + epochs: 500 + weight_decay_end: 0.2 + base_lr: 2.0e-04 # learning rate for a batch size of 1024 + warmup_epochs: 80 + layerwise_decay: 1.0 +crops: + local_crops_size: 98 diff --git a/dinov2/data/datasets/custom_image_dataset.py b/dinov2/data/datasets/custom_image_dataset.py index ed7e4db..0bd2704 100644 --- a/dinov2/data/datasets/custom_image_dataset.py +++ b/dinov2/data/datasets/custom_image_dataset.py @@ -1,22 +1,27 @@ import os import pathlib - +import random +from typing import List from torch.utils.data import Dataset from .decoders import ImageDataDecoder from PIL import Image class ImageDataset(Dataset): - def __init__(self, root, transform=None): + def __init__(self, root, transform=None, path_preserved: List[str]=[], frac: float=0.1): self.root = root self.transform = transform self.images_list = self._get_image_list() + self.path_preserved = path_preserved if isinstance(path_preserved, list) else list(path_preserved) + self.frac = frac + self.preserved_images = [] def _get_image_list(self): images = [] if isinstance(self.root, (str, pathlib.PosixPath)): try: - images.extend(self._retrieve_images(self.root)) + p = self.root + images.extend(self._retrieve_images(p, preserve=p in self.path_preserved, frac=self.frac)) except OSError: print("The root given is nor a list nor a path") @@ -24,28 +29,36 @@ class ImageDataset(Dataset): else: for p in self.root: try: - images.extend(self._retrieve_images(p)) + images.extend(self._retrieve_images(p, preserve=p in self.path_preserved, frac=self.frac)) except OSError: print(f"the path indicated at {p} cannot be found.") return images - def _retrieve_images(self, path, is_valid=False): + def _retrieve_images(self, path, is_valid=False, preserve=False, frac=1): images = [] for root, _, files in os.walk(path): + images_dir = [] for file in files: if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff')): if is_valid: try: Image.open(os.path.join(root, file)) - images.append(os.path.join(root, file)) + images_dir.append(os.path.join(root, file)) except OSError: print(f"Image at path {os.path.join(root, file)} could not be opened.") else: - images.append(os.path.join(root, file)) - + images_dir.append(os.path.join(root, file)) + + if preserve: + random.seed(24) + random.shuffle(images_dir) + split_index = int(len(images_dir) * frac) + self.preserved_images.extend(images_dir[:split_index]) + images.extend(images_dir[split_index:]) + return images def get_image_data(self, index: int):