preserve lab data changes
parent
6684955cef
commit
2e05e6b4ef
dinov2
configs/train
data/datasets
|
@ -0,0 +1,32 @@
|
|||
train:
|
||||
dataset_path:
|
||||
- /home/manon/classification/data/Single_cells/medhi
|
||||
- /home/manon/classification/data/Single_cells/vexas_original/Unlabeled
|
||||
- /home/manon/classification/data/Single_cells/matek
|
||||
centering: sinkhorn_knopp
|
||||
batch_size_per_gpu: 64
|
||||
output_dir: /home/guevel/OT4D/cell_similarity/vitl_register
|
||||
OFFICIAL_EPOCH_LENGTH: 1
|
||||
dino:
|
||||
head_n_prototypes: 131072
|
||||
head_bottleneck_dim: 384
|
||||
ibot:
|
||||
separate_head: true
|
||||
head_n_prototypes: 131072
|
||||
student:
|
||||
arch: vit_large
|
||||
patch_size: 14
|
||||
drop_path_rate: 0.4
|
||||
ffn_layer: swiglufused
|
||||
block_chunks: 4
|
||||
num_register_tokens: 4
|
||||
teacher:
|
||||
momentum_teacher: 0.994
|
||||
optim:
|
||||
epochs: 500
|
||||
weight_decay_end: 0.2
|
||||
base_lr: 2.0e-04 # learning rate for a batch size of 1024
|
||||
warmup_epochs: 80
|
||||
layerwise_decay: 1.0
|
||||
crops:
|
||||
local_crops_size: 98
|
|
@ -1,22 +1,27 @@
|
|||
import os
|
||||
import pathlib
|
||||
|
||||
import random
|
||||
from typing import List
|
||||
from torch.utils.data import Dataset
|
||||
from .decoders import ImageDataDecoder
|
||||
from PIL import Image
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(self, root, transform=None):
|
||||
def __init__(self, root, transform=None, path_preserved: List[str]=[], frac: float=0.1):
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.images_list = self._get_image_list()
|
||||
self.path_preserved = path_preserved if isinstance(path_preserved, list) else list(path_preserved)
|
||||
self.frac = frac
|
||||
self.preserved_images = []
|
||||
|
||||
def _get_image_list(self):
|
||||
images = []
|
||||
|
||||
if isinstance(self.root, (str, pathlib.PosixPath)):
|
||||
try:
|
||||
images.extend(self._retrieve_images(self.root))
|
||||
p = self.root
|
||||
images.extend(self._retrieve_images(p, preserve=p in self.path_preserved, frac=self.frac))
|
||||
|
||||
except OSError:
|
||||
print("The root given is nor a list nor a path")
|
||||
|
@ -24,28 +29,36 @@ class ImageDataset(Dataset):
|
|||
else:
|
||||
for p in self.root:
|
||||
try:
|
||||
images.extend(self._retrieve_images(p))
|
||||
images.extend(self._retrieve_images(p, preserve=p in self.path_preserved, frac=self.frac))
|
||||
|
||||
except OSError:
|
||||
print(f"the path indicated at {p} cannot be found.")
|
||||
|
||||
return images
|
||||
|
||||
def _retrieve_images(self, path, is_valid=False):
|
||||
def _retrieve_images(self, path, is_valid=False, preserve=False, frac=1):
|
||||
images = []
|
||||
for root, _, files in os.walk(path):
|
||||
images_dir = []
|
||||
for file in files:
|
||||
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff')):
|
||||
if is_valid:
|
||||
try:
|
||||
Image.open(os.path.join(root, file))
|
||||
images.append(os.path.join(root, file))
|
||||
images_dir.append(os.path.join(root, file))
|
||||
|
||||
except OSError:
|
||||
print(f"Image at path {os.path.join(root, file)} could not be opened.")
|
||||
else:
|
||||
images.append(os.path.join(root, file))
|
||||
|
||||
images_dir.append(os.path.join(root, file))
|
||||
|
||||
if preserve:
|
||||
random.seed(24)
|
||||
random.shuffle(images_dir)
|
||||
split_index = int(len(images_dir) * frac)
|
||||
self.preserved_images.extend(images_dir[:split_index])
|
||||
images.extend(images_dir[split_index:])
|
||||
|
||||
return images
|
||||
|
||||
def get_image_data(self, index: int):
|
||||
|
|
Loading…
Reference in New Issue