edit tests + README

pull/422/head
Etienne Guevel 2024-05-29 17:17:33 +02:00
parent be41e7b2d3
commit d2f9b593d4
8 changed files with 185 additions and 119 deletions

View File

@ -7,9 +7,40 @@ Developping a foundation model for blood white cells is interesting for several
- The categories of blood white cells are not unanimous, and hematologists / datasets make different classes.
- Some blood white cells present mutations that are visible on images, and those distinguishible features could be embedded by the model
## Installing
## Installing
To install the project and the required packages that are necessary for this project use `conda env install -f conda.yaml`,
then `conda activate cell_sim` and `pip install -e .`.
The dinov2 packages is then available from the conda env cell_sim for execution.
## Data
The dataset implementation is in the file `dinov2/data/datasets/custom_image_datasets.py`, and retrieves all the images in all the folders indicated in the dataset paths. There are options to ignore a fraction of the images in each dataset given, and to check if the images are not corrupted before using them.
## Training
Most of the code used for the training
**Disclaimer:**
Most of the code used for the training was used directly from the [dinov2 repo](https://github.com/facebookresearch/dinov2/tree/main),
and not every function has been checked. Though it should be functional and do as intended, some aspects might not work as expected.
### Config file
config files control every elements of the model, and of its training. The implementation made on your config files will merge with the one by default located at `dinov2/configs/ssl_default_config.yaml`.
In our case the minimal requirements for the train config files should be:
- `dataset_path` (List[str] or str) that indicates the path where the training data is located
- `output_dir` (str) that indicates the path where the logs, ckpts and models will be saved
### Submit training
The script used to submit the training process is located at `dinov2/run/train/train.py`.
An example of command would be `python dinov2/run/train/train.py --config-file dinov2/configs/train/vitl_cellsim_register.yaml --output-dir /home/guevel/OT4D/cell_similarity/vitl_register/ --partition hard --ngpus 2`
that would launch the training on 1 node with 2 GPUs on the partition named `hard`.
> This command makes a sh script with the required slurm constraints that is then executed.
**The conda env previously created should be activated before launching this command.**
## Results

View File

@ -3,15 +3,27 @@ import pathlib
import random
from typing import List
from omegaconf.listconfig import ListConfig
from torch.utils.data import Dataset
from .decoders import ImageDataDecoder
from PIL import Image
class ImageDataset(Dataset):
def __init__(self, root, transform=None, path_preserved: List[str]=[], frac: float=0.1, is_valid=True):
def __init__(
self,
root,
transform=None,
path_preserved: List[str] = [],
frac: float = 0.1,
is_valid=True,
):
self.root = root
self.transform = transform
self.path_preserved = path_preserved if isinstance(path_preserved, (list, ListConfig)) else [path_preserved]
self.path_preserved = (
path_preserved
if isinstance(path_preserved, (list, ListConfig))
else [path_preserved]
)
self.frac = frac
self.preserved_images = []
self.is_valid = is_valid
@ -24,7 +36,11 @@ class ImageDataset(Dataset):
try:
p = self.root
preserve = p in self.path_preserved
images.extend(self._retrieve_images(p, preserve=preserve, frac=self.frac, is_valid=self.is_valid))
images.extend(
self._retrieve_images(
p, preserve=preserve, frac=self.frac, is_valid=self.is_valid
)
)
except OSError:
print("The root given is nor a list nor a path")
@ -33,33 +49,37 @@ class ImageDataset(Dataset):
for p in self.root:
try:
preserve = p in self.path_preserved
images.extend(self._retrieve_images(p, preserve=preserve, frac=self.frac, is_valid=self.is_valid))
images.extend(
self._retrieve_images(
p, preserve=preserve, frac=self.frac, is_valid=self.is_valid
)
)
except OSError:
print(f"the path indicated at {p} cannot be found.")
return images
def _retrieve_images(self, path, is_valid=True, preserve=False, frac=1):
images_ini = len(self.preserved_images)
images = []
for root, _, files in os.walk(path):
images_dir = []
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff')):
if file.lower().endswith((".png", ".jpg", ".jpeg", ".tiff")):
im = os.path.join(root, file)
if is_valid:
try:
with open(im, 'rb') as f:
with open(im, "rb") as f:
image_data = f.read()
ImageDataDecoder(image_data).decode()
images_dir.append(im)
except OSError:
print(f"Image at path {im} could not be opened.")
else:
images_dir.append(im)
if preserve:
random.seed(24)
random.shuffle(images_dir)
@ -69,23 +89,25 @@ class ImageDataset(Dataset):
else:
images.extend(images_dir)
images_end = len(self.preserved_images)
if preserve:
print(f"{images_end - images_ini} images have been saved for the dataset at path {path}")
print(
f"{images_end - images_ini} images have been saved for the dataset at path {path}"
)
return images
def get_image_data(self, index: int):
path = self.images_list[index]
with open(path, 'rb') as f:
with open(path, "rb") as f:
image_data = f.read()
return image_data
def __len__(self):
return len(self.images_list)
def __getitem__(self, index: int):
try:
image_data = self.get_image_data(index)

View File

@ -1,19 +0,0 @@
import os
import shutil
from pathlib import Path
path = Path('/home/manon/classification/data/Single_cells/vexas')
path_out = Path(os.getcwd()) / 'dataset_bis'
def main():
for d in os.listdir(path):
if os.path.isdir(path_out / d):
os.rmdir(path_out /d)
os.mkdir(path_out / d)
files = [f for f in os.listdir(path / d) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff'))]
for f in files[:64]:
shutil.copy(path / d / f, path_out / d / f)
if __name__ == '__main__':
main()

View File

@ -1,23 +1,26 @@
import os
from functools import partial
from omegaconf import OmegaConf
from pathlib import Path
import torch
from functools import partial
from pathlib import Path
from cell_similarity.data.datasets import ImageDataset
from cell_similarity.data.collate import collate_data_and_cast
from dinov2.data import DataAugmentationDINO, MaskingGenerator, SamplerType, make_data_loader
from dinov2.train.ssl_meta_arch import SSLMetaArch
from dinov2.utils.config import setup
from dinov2.train.train import get_args_parser
from dinov2.data.datasets import ImageDataset
from dinov2.data.collate import collate_data_and_cast
from dinov2.data import (
DataAugmentationDINO,
MaskingGenerator,
SamplerType,
make_data_loader,
)
def test_single_path(cfg):
cfg = OmegaConf.load(Path(__file__).parent / "config.yaml")
def test_single_path():
img_size = cfg.crops.global_crops_size
patch_size = cfg.student.patch_size
n_tokens = (img_size // patch_size) ** 2
mask_generator = MaskingGenerator(
input_size=(img_size // patch_size, img_size // patch_size),
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
)
inputs_dtype = torch.half
@ -37,8 +40,8 @@ def test_single_path(cfg):
mask_generator=mask_generator,
dtype=inputs_dtype,
)
path_dataset_test = os.path.join(os.getcwd(), 'dataset_test')
path_dataset_test = Path(__file__).parent / "dataset_test"
dataset = ImageDataset(root=path_dataset_test, transform=data_transform)
sampler_type = SamplerType.SHARDED_INFINITE
@ -48,24 +51,27 @@ def test_single_path(cfg):
num_workers=cfg.train.num_workers,
shuffle=True,
sampler_type=sampler_type,
sampler_advance=0,
sampler_advance=0,
drop_last=True,
collate_fn=collate_fn,
)
for i in data_loader:
assert i['collated_global_crops'].shape[0] == cfg.train.batch_size_per_gpu * 2
assert i['collated_local_crops'].shape[0] == cfg.train.batch_size_per_gpu * cfg.crops.local_crops_number
assert i["collated_global_crops"].shape[0] == cfg.train.batch_size_per_gpu * 2
assert (
i["collated_local_crops"].shape[0]
== cfg.train.batch_size_per_gpu * cfg.crops.local_crops_number
)
break
def test_several_paths(cfg):
def test_several_paths():
img_size = cfg.crops.global_crops_size
patch_size = cfg.student.patch_size
n_tokens = (img_size // patch_size) ** 2
mask_generator = MaskingGenerator(
input_size=(img_size // patch_size, img_size // patch_size),
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
)
inputs_dtype = torch.half
@ -86,8 +92,7 @@ def test_several_paths(cfg):
dtype=inputs_dtype,
)
base_path = Path(os.getcwd())
dirs = ['dataset_test', 'dataset_bis']
dirs = [Path(__file__).parent / i for i in ["dataset_test", "dataset_bis"]]
dataset = ImageDataset(root=dirs, transform=data_transform)
sampler_type = SamplerType.SHARDED_INFINITE
data_loader = make_data_loader(
@ -96,20 +101,20 @@ def test_several_paths(cfg):
num_workers=cfg.train.num_workers,
shuffle=True,
sampler_type=sampler_type,
sampler_advance=0,
sampler_advance=0,
drop_last=True,
collate_fn=collate_fn,
)
for i in data_loader:
assert i['collated_global_crops'].shape[0] == cfg.train.batch_size_per_gpu * 2
assert i['collated_local_crops'].shape[0] == cfg.train.batch_size_per_gpu * cfg.crops.local_crops_number
assert i["collated_global_crops"].shape[0] == cfg.train.batch_size_per_gpu * 2
assert (
i["collated_local_crops"].shape[0]
== cfg.train.batch_size_per_gpu * cfg.crops.local_crops_number
)
break
if __name__ == '__main__':
args = get_args_parser(add_help=True).parse_args()
cfg = setup(args)
test_single_path(cfg)
print("test_single_path succesfull")
test_several_paths(cfg)
print("test_several_paths successfull")
if __name__ == "__main__":
test_single_path()
test_several_paths()

View File

@ -1,47 +1,69 @@
import os
from pathlib import Path
from cell_similarity.data.datasets import ImageDataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from dinov2.data.datasets import ImageDataset
def test_single_path():
path_dataset_test = Path(os.getcwd()) / 'dataset_test'
transform = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
path_dataset_test = Path(__file__).parent / "dataset_test"
transform = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
dataset = ImageDataset(path_dataset_test, transform=transform)
assert dataset.__len__() == len([i for i in os.listdir(path_dataset_test) if i.endswith(('.png', '.jpg', '.jpeg', '.tiff'))])
assert dataset.__len__() == len(
[
i
for i in os.listdir(path_dataset_test)
if i.endswith((".png", ".jpg", ".jpeg", ".tiff"))
]
)
dataloader = DataLoader(dataset, batch_size=32)
for i in dataloader:
assert len(i) == 32
break
base_path = Path(os.getcwd())
dirs = ['dataset_test', 'dataset_bis']
dirs = ["dataset_test", "dataset_bis"]
def test_several_paths():
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.226])
])
transform = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.226]),
]
)
dataset = ImageDataset(root=dirs, transform=transform)
expected_length = len([f for d in dirs for d, _, files in os.walk(d) for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff'))])
expected_length = len(
[
f
for d in dirs
for d, _, files in os.walk(d)
for f in files
if f.lower().endswith((".png", ".jpg", ".jpeg", ".tiff"))
]
)
assert dataset.__len__() == expected_length
dataloader = DataLoader(dataset, batch_size=32)
for i in dataloader:
assert len(i)==32
assert len(i) == 32
break
if __name__ == '__main__':
if __name__ == "__main__":
test_single_path()
test_several_paths()

View File

@ -1,28 +1,29 @@
import os
import torch
from cell_similarity.data.datasets import ImageDataset
from dinov2.data.datasets import ImageDataset
from dinov2.data import DataAugmentationDINO
from torch.utils.data import DataLoader
def test():
def test():
data_transform = DataAugmentationDINO(
(0.32, 1.0),
(0.05, 0.32),
8,
)
path_dataset = os.path.join(os.getcwd(), 'dataset_test')
path_dataset = os.path.join(os.getcwd(), "dataset_test")
dataset = ImageDataset(root=path_dataset, transform=data_transform)
dataloader = DataLoader(dataset, batch_size=32)
for i in dataloader:
assert len(i['global_crops']) == 2
assert i['global_crops'][0].shape == torch.Size([32, 3, 224, 224])
assert len(i['local_crops']) == 8
assert i['local_crops'][0].shape == torch.Size([32, 3, 96, 96])
assert len(i["global_crops"]) == 2
assert i["global_crops"][0].shape == torch.Size([32, 3, 224, 224])
assert len(i["local_crops"]) == 8
assert i["local_crops"][0].shape == torch.Size([32, 3, 96, 96])
break
if __name__ =='__main__':
if __name__ == "__main__":
test()

View File

@ -1,16 +1,17 @@
import logging
from omegaconf import OmegaConf
from pathlib import Path
from dinov2.train.ssl_meta_arch import SSLMetaArch
from dinov2.train.train import get_args_parser
from dinov2.utils.config import setup
logger = logging.getLogger("dinov2")
cfg = OmegaConf.load(Path(__file__).parent / "config.yaml")
def test(args):
cfg = setup(args)
def test():
model = SSLMetaArch(cfg)
logger.info("Model: \n{}".format(model))
if __name__ == "__main__":
args = get_args_parser(add_help=True).parse_args()
test(args)
test()

View File

@ -1,22 +1,25 @@
import logging
import torch
from omegaconf import OmegaConf
from pathlib import Path
from dinov2.train.ssl_meta_arch import SSLMetaArch
from dinov2.train.train import get_args_parser
from dinov2.utils.config import setup
from cell_similarity.training.train import do_train
from dinov2.train.train import do_train
logger = logging.getLogger("dinov2")
cfg = OmegaConf.load(Path(__file__).parent / "config.yaml")
def test(args):
cfg = setup(args)
model = SSLMetaArch(cfg).to(torch.device("cuda"))
model.prepare_for_distributed_training()
def test():
if torch.cuda.is_available():
model = SSLMetaArch(cfg).to(torch.device("cuda"))
model.prepare_for_distributed_training()
logger.info("Model:\n {}".format(model))
do_train(cfg, model, resume=False)
else:
print("Unable to assess the training test, as no cuda devices were found")
logger.info("Model:\n {}".format(model))
do_train(cfg, model, resume=False)
if __name__ == "__main__":
args = get_args_parser(add_help=True).parse_args()
test(args)
test()