edit tests + README

pull/422/head
Etienne Guevel 2024-05-29 17:17:33 +02:00
parent be41e7b2d3
commit d2f9b593d4
8 changed files with 185 additions and 119 deletions

View File

@ -7,9 +7,40 @@ Developping a foundation model for blood white cells is interesting for several
- The categories of blood white cells are not unanimous, and hematologists / datasets make different classes. - The categories of blood white cells are not unanimous, and hematologists / datasets make different classes.
- Some blood white cells present mutations that are visible on images, and those distinguishible features could be embedded by the model - Some blood white cells present mutations that are visible on images, and those distinguishible features could be embedded by the model
## Installing ## Installing
To install the project and the required packages that are necessary for this project use `conda env install -f conda.yaml`,
then `conda activate cell_sim` and `pip install -e .`.
The dinov2 packages is then available from the conda env cell_sim for execution.
## Data
The dataset implementation is in the file `dinov2/data/datasets/custom_image_datasets.py`, and retrieves all the images in all the folders indicated in the dataset paths. There are options to ignore a fraction of the images in each dataset given, and to check if the images are not corrupted before using them.
## Training ## Training
Most of the code used for the training **Disclaimer:**
Most of the code used for the training was used directly from the [dinov2 repo](https://github.com/facebookresearch/dinov2/tree/main),
and not every function has been checked. Though it should be functional and do as intended, some aspects might not work as expected.
### Config file
config files control every elements of the model, and of its training. The implementation made on your config files will merge with the one by default located at `dinov2/configs/ssl_default_config.yaml`.
In our case the minimal requirements for the train config files should be:
- `dataset_path` (List[str] or str) that indicates the path where the training data is located
- `output_dir` (str) that indicates the path where the logs, ckpts and models will be saved
### Submit training
The script used to submit the training process is located at `dinov2/run/train/train.py`.
An example of command would be `python dinov2/run/train/train.py --config-file dinov2/configs/train/vitl_cellsim_register.yaml --output-dir /home/guevel/OT4D/cell_similarity/vitl_register/ --partition hard --ngpus 2`
that would launch the training on 1 node with 2 GPUs on the partition named `hard`.
> This command makes a sh script with the required slurm constraints that is then executed.
**The conda env previously created should be activated before launching this command.**
## Results

View File

@ -3,15 +3,27 @@ import pathlib
import random import random
from typing import List from typing import List
from omegaconf.listconfig import ListConfig from omegaconf.listconfig import ListConfig
from torch.utils.data import Dataset from torch.utils.data import Dataset
from .decoders import ImageDataDecoder from .decoders import ImageDataDecoder
from PIL import Image
class ImageDataset(Dataset): class ImageDataset(Dataset):
def __init__(self, root, transform=None, path_preserved: List[str]=[], frac: float=0.1, is_valid=True): def __init__(
self,
root,
transform=None,
path_preserved: List[str] = [],
frac: float = 0.1,
is_valid=True,
):
self.root = root self.root = root
self.transform = transform self.transform = transform
self.path_preserved = path_preserved if isinstance(path_preserved, (list, ListConfig)) else [path_preserved] self.path_preserved = (
path_preserved
if isinstance(path_preserved, (list, ListConfig))
else [path_preserved]
)
self.frac = frac self.frac = frac
self.preserved_images = [] self.preserved_images = []
self.is_valid = is_valid self.is_valid = is_valid
@ -24,7 +36,11 @@ class ImageDataset(Dataset):
try: try:
p = self.root p = self.root
preserve = p in self.path_preserved preserve = p in self.path_preserved
images.extend(self._retrieve_images(p, preserve=preserve, frac=self.frac, is_valid=self.is_valid)) images.extend(
self._retrieve_images(
p, preserve=preserve, frac=self.frac, is_valid=self.is_valid
)
)
except OSError: except OSError:
print("The root given is nor a list nor a path") print("The root given is nor a list nor a path")
@ -33,33 +49,37 @@ class ImageDataset(Dataset):
for p in self.root: for p in self.root:
try: try:
preserve = p in self.path_preserved preserve = p in self.path_preserved
images.extend(self._retrieve_images(p, preserve=preserve, frac=self.frac, is_valid=self.is_valid)) images.extend(
self._retrieve_images(
p, preserve=preserve, frac=self.frac, is_valid=self.is_valid
)
)
except OSError: except OSError:
print(f"the path indicated at {p} cannot be found.") print(f"the path indicated at {p} cannot be found.")
return images return images
def _retrieve_images(self, path, is_valid=True, preserve=False, frac=1): def _retrieve_images(self, path, is_valid=True, preserve=False, frac=1):
images_ini = len(self.preserved_images) images_ini = len(self.preserved_images)
images = [] images = []
for root, _, files in os.walk(path): for root, _, files in os.walk(path):
images_dir = [] images_dir = []
for file in files: for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff')): if file.lower().endswith((".png", ".jpg", ".jpeg", ".tiff")):
im = os.path.join(root, file) im = os.path.join(root, file)
if is_valid: if is_valid:
try: try:
with open(im, 'rb') as f: with open(im, "rb") as f:
image_data = f.read() image_data = f.read()
ImageDataDecoder(image_data).decode() ImageDataDecoder(image_data).decode()
images_dir.append(im) images_dir.append(im)
except OSError: except OSError:
print(f"Image at path {im} could not be opened.") print(f"Image at path {im} could not be opened.")
else: else:
images_dir.append(im) images_dir.append(im)
if preserve: if preserve:
random.seed(24) random.seed(24)
random.shuffle(images_dir) random.shuffle(images_dir)
@ -69,23 +89,25 @@ class ImageDataset(Dataset):
else: else:
images.extend(images_dir) images.extend(images_dir)
images_end = len(self.preserved_images) images_end = len(self.preserved_images)
if preserve: if preserve:
print(f"{images_end - images_ini} images have been saved for the dataset at path {path}") print(
f"{images_end - images_ini} images have been saved for the dataset at path {path}"
)
return images return images
def get_image_data(self, index: int): def get_image_data(self, index: int):
path = self.images_list[index] path = self.images_list[index]
with open(path, 'rb') as f: with open(path, "rb") as f:
image_data = f.read() image_data = f.read()
return image_data return image_data
def __len__(self): def __len__(self):
return len(self.images_list) return len(self.images_list)
def __getitem__(self, index: int): def __getitem__(self, index: int):
try: try:
image_data = self.get_image_data(index) image_data = self.get_image_data(index)

View File

@ -1,19 +0,0 @@
import os
import shutil
from pathlib import Path
path = Path('/home/manon/classification/data/Single_cells/vexas')
path_out = Path(os.getcwd()) / 'dataset_bis'
def main():
for d in os.listdir(path):
if os.path.isdir(path_out / d):
os.rmdir(path_out /d)
os.mkdir(path_out / d)
files = [f for f in os.listdir(path / d) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff'))]
for f in files[:64]:
shutil.copy(path / d / f, path_out / d / f)
if __name__ == '__main__':
main()

View File

@ -1,23 +1,26 @@
import os from functools import partial
from omegaconf import OmegaConf
from pathlib import Path
import torch import torch
from functools import partial from dinov2.data.datasets import ImageDataset
from pathlib import Path from dinov2.data.collate import collate_data_and_cast
from cell_similarity.data.datasets import ImageDataset from dinov2.data import (
from cell_similarity.data.collate import collate_data_and_cast DataAugmentationDINO,
from dinov2.data import DataAugmentationDINO, MaskingGenerator, SamplerType, make_data_loader MaskingGenerator,
from dinov2.train.ssl_meta_arch import SSLMetaArch SamplerType,
from dinov2.utils.config import setup make_data_loader,
from dinov2.train.train import get_args_parser )
def test_single_path(cfg): cfg = OmegaConf.load(Path(__file__).parent / "config.yaml")
def test_single_path():
img_size = cfg.crops.global_crops_size img_size = cfg.crops.global_crops_size
patch_size = cfg.student.patch_size patch_size = cfg.student.patch_size
n_tokens = (img_size // patch_size) ** 2 n_tokens = (img_size // patch_size) ** 2
mask_generator = MaskingGenerator( mask_generator = MaskingGenerator(
input_size=(img_size // patch_size, img_size // patch_size), input_size=(img_size // patch_size, img_size // patch_size),
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
) )
inputs_dtype = torch.half inputs_dtype = torch.half
@ -37,8 +40,8 @@ def test_single_path(cfg):
mask_generator=mask_generator, mask_generator=mask_generator,
dtype=inputs_dtype, dtype=inputs_dtype,
) )
path_dataset_test = os.path.join(os.getcwd(), 'dataset_test') path_dataset_test = Path(__file__).parent / "dataset_test"
dataset = ImageDataset(root=path_dataset_test, transform=data_transform) dataset = ImageDataset(root=path_dataset_test, transform=data_transform)
sampler_type = SamplerType.SHARDED_INFINITE sampler_type = SamplerType.SHARDED_INFINITE
@ -48,24 +51,27 @@ def test_single_path(cfg):
num_workers=cfg.train.num_workers, num_workers=cfg.train.num_workers,
shuffle=True, shuffle=True,
sampler_type=sampler_type, sampler_type=sampler_type,
sampler_advance=0, sampler_advance=0,
drop_last=True, drop_last=True,
collate_fn=collate_fn, collate_fn=collate_fn,
) )
for i in data_loader: for i in data_loader:
assert i['collated_global_crops'].shape[0] == cfg.train.batch_size_per_gpu * 2 assert i["collated_global_crops"].shape[0] == cfg.train.batch_size_per_gpu * 2
assert i['collated_local_crops'].shape[0] == cfg.train.batch_size_per_gpu * cfg.crops.local_crops_number assert (
i["collated_local_crops"].shape[0]
== cfg.train.batch_size_per_gpu * cfg.crops.local_crops_number
)
break break
def test_several_paths(cfg):
def test_several_paths():
img_size = cfg.crops.global_crops_size img_size = cfg.crops.global_crops_size
patch_size = cfg.student.patch_size patch_size = cfg.student.patch_size
n_tokens = (img_size // patch_size) ** 2 n_tokens = (img_size // patch_size) ** 2
mask_generator = MaskingGenerator( mask_generator = MaskingGenerator(
input_size=(img_size // patch_size, img_size // patch_size), input_size=(img_size // patch_size, img_size // patch_size),
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
) )
inputs_dtype = torch.half inputs_dtype = torch.half
@ -86,8 +92,7 @@ def test_several_paths(cfg):
dtype=inputs_dtype, dtype=inputs_dtype,
) )
base_path = Path(os.getcwd()) dirs = [Path(__file__).parent / i for i in ["dataset_test", "dataset_bis"]]
dirs = ['dataset_test', 'dataset_bis']
dataset = ImageDataset(root=dirs, transform=data_transform) dataset = ImageDataset(root=dirs, transform=data_transform)
sampler_type = SamplerType.SHARDED_INFINITE sampler_type = SamplerType.SHARDED_INFINITE
data_loader = make_data_loader( data_loader = make_data_loader(
@ -96,20 +101,20 @@ def test_several_paths(cfg):
num_workers=cfg.train.num_workers, num_workers=cfg.train.num_workers,
shuffle=True, shuffle=True,
sampler_type=sampler_type, sampler_type=sampler_type,
sampler_advance=0, sampler_advance=0,
drop_last=True, drop_last=True,
collate_fn=collate_fn, collate_fn=collate_fn,
) )
for i in data_loader: for i in data_loader:
assert i['collated_global_crops'].shape[0] == cfg.train.batch_size_per_gpu * 2 assert i["collated_global_crops"].shape[0] == cfg.train.batch_size_per_gpu * 2
assert i['collated_local_crops'].shape[0] == cfg.train.batch_size_per_gpu * cfg.crops.local_crops_number assert (
i["collated_local_crops"].shape[0]
== cfg.train.batch_size_per_gpu * cfg.crops.local_crops_number
)
break break
if __name__ == '__main__':
args = get_args_parser(add_help=True).parse_args() if __name__ == "__main__":
cfg = setup(args) test_single_path()
test_single_path(cfg) test_several_paths()
print("test_single_path succesfull")
test_several_paths(cfg)
print("test_several_paths successfull")

View File

@ -1,47 +1,69 @@
import os import os
from pathlib import Path from pathlib import Path
from cell_similarity.data.datasets import ImageDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torchvision.transforms as transforms import torchvision.transforms as transforms
from dinov2.data.datasets import ImageDataset
def test_single_path(): def test_single_path():
path_dataset_test = Path(__file__).parent / "dataset_test"
path_dataset_test = Path(os.getcwd()) / 'dataset_test' transform = transforms.Compose(
transform = transforms.Compose([ [
transforms.Resize((256,256)), transforms.Resize((256, 256)),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]) ]
)
dataset = ImageDataset(path_dataset_test, transform=transform) dataset = ImageDataset(path_dataset_test, transform=transform)
assert dataset.__len__() == len([i for i in os.listdir(path_dataset_test) if i.endswith(('.png', '.jpg', '.jpeg', '.tiff'))]) assert dataset.__len__() == len(
[
i
for i in os.listdir(path_dataset_test)
if i.endswith((".png", ".jpg", ".jpeg", ".tiff"))
]
)
dataloader = DataLoader(dataset, batch_size=32) dataloader = DataLoader(dataset, batch_size=32)
for i in dataloader: for i in dataloader:
assert len(i) == 32 assert len(i) == 32
break break
base_path = Path(os.getcwd()) base_path = Path(os.getcwd())
dirs = ['dataset_test', 'dataset_bis'] dirs = ["dataset_test", "dataset_bis"]
def test_several_paths(): def test_several_paths():
transform = transforms.Compose(
transform = transforms.Compose([ [
transforms.Resize((256, 256)), transforms.Resize((256, 256)),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.226]) transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.226]),
]) ]
)
dataset = ImageDataset(root=dirs, transform=transform) dataset = ImageDataset(root=dirs, transform=transform)
expected_length = len([f for d in dirs for d, _, files in os.walk(d) for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff'))]) expected_length = len(
[
f
for d in dirs
for d, _, files in os.walk(d)
for f in files
if f.lower().endswith((".png", ".jpg", ".jpeg", ".tiff"))
]
)
assert dataset.__len__() == expected_length assert dataset.__len__() == expected_length
dataloader = DataLoader(dataset, batch_size=32) dataloader = DataLoader(dataset, batch_size=32)
for i in dataloader: for i in dataloader:
assert len(i)==32 assert len(i) == 32
break break
if __name__ == '__main__':
if __name__ == "__main__":
test_single_path() test_single_path()
test_several_paths() test_several_paths()

View File

@ -1,28 +1,29 @@
import os import os
import torch import torch
from cell_similarity.data.datasets import ImageDataset from dinov2.data.datasets import ImageDataset
from dinov2.data import DataAugmentationDINO from dinov2.data import DataAugmentationDINO
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
def test():
def test():
data_transform = DataAugmentationDINO( data_transform = DataAugmentationDINO(
(0.32, 1.0), (0.32, 1.0),
(0.05, 0.32), (0.05, 0.32),
8, 8,
) )
path_dataset = os.path.join(os.getcwd(), 'dataset_test') path_dataset = os.path.join(os.getcwd(), "dataset_test")
dataset = ImageDataset(root=path_dataset, transform=data_transform) dataset = ImageDataset(root=path_dataset, transform=data_transform)
dataloader = DataLoader(dataset, batch_size=32) dataloader = DataLoader(dataset, batch_size=32)
for i in dataloader: for i in dataloader:
assert len(i['global_crops']) == 2 assert len(i["global_crops"]) == 2
assert i['global_crops'][0].shape == torch.Size([32, 3, 224, 224]) assert i["global_crops"][0].shape == torch.Size([32, 3, 224, 224])
assert len(i['local_crops']) == 8 assert len(i["local_crops"]) == 8
assert i['local_crops'][0].shape == torch.Size([32, 3, 96, 96]) assert i["local_crops"][0].shape == torch.Size([32, 3, 96, 96])
break break
if __name__ =='__main__':
if __name__ == "__main__":
test() test()

View File

@ -1,16 +1,17 @@
import logging import logging
from omegaconf import OmegaConf
from pathlib import Path
from dinov2.train.ssl_meta_arch import SSLMetaArch from dinov2.train.ssl_meta_arch import SSLMetaArch
from dinov2.train.train import get_args_parser
from dinov2.utils.config import setup
logger = logging.getLogger("dinov2") logger = logging.getLogger("dinov2")
cfg = OmegaConf.load(Path(__file__).parent / "config.yaml")
def test(args): def test():
cfg = setup(args)
model = SSLMetaArch(cfg) model = SSLMetaArch(cfg)
logger.info("Model: \n{}".format(model)) logger.info("Model: \n{}".format(model))
if __name__ == "__main__": if __name__ == "__main__":
args = get_args_parser(add_help=True).parse_args() test()
test(args)

View File

@ -1,22 +1,25 @@
import logging import logging
import torch import torch
from omegaconf import OmegaConf
from pathlib import Path
from dinov2.train.ssl_meta_arch import SSLMetaArch from dinov2.train.ssl_meta_arch import SSLMetaArch
from dinov2.train.train import get_args_parser from dinov2.train.train import do_train
from dinov2.utils.config import setup
from cell_similarity.training.train import do_train
logger = logging.getLogger("dinov2") logger = logging.getLogger("dinov2")
cfg = OmegaConf.load(Path(__file__).parent / "config.yaml")
def test(args):
cfg = setup(args) def test():
model = SSLMetaArch(cfg).to(torch.device("cuda")) if torch.cuda.is_available():
model.prepare_for_distributed_training() model = SSLMetaArch(cfg).to(torch.device("cuda"))
model.prepare_for_distributed_training()
logger.info("Model:\n {}".format(model))
do_train(cfg, model, resume=False)
else:
print("Unable to assess the training test, as no cuda devices were found")
logger.info("Model:\n {}".format(model))
do_train(cfg, model, resume=False)
if __name__ == "__main__": if __name__ == "__main__":
args = get_args_parser(add_help=True).parse_args() test()
test(args)