edit tests + README
parent
be41e7b2d3
commit
d2f9b593d4
35
README.md
35
README.md
|
@ -7,9 +7,40 @@ Developping a foundation model for blood white cells is interesting for several
|
||||||
- The categories of blood white cells are not unanimous, and hematologists / datasets make different classes.
|
- The categories of blood white cells are not unanimous, and hematologists / datasets make different classes.
|
||||||
- Some blood white cells present mutations that are visible on images, and those distinguishible features could be embedded by the model
|
- Some blood white cells present mutations that are visible on images, and those distinguishible features could be embedded by the model
|
||||||
|
|
||||||
## Installing
|
## Installing
|
||||||
|
|
||||||
|
To install the project and the required packages that are necessary for this project use `conda env install -f conda.yaml`,
|
||||||
|
then `conda activate cell_sim` and `pip install -e .`.
|
||||||
|
|
||||||
|
The dinov2 packages is then available from the conda env cell_sim for execution.
|
||||||
|
|
||||||
|
## Data
|
||||||
|
|
||||||
|
The dataset implementation is in the file `dinov2/data/datasets/custom_image_datasets.py`, and retrieves all the images in all the folders indicated in the dataset paths. There are options to ignore a fraction of the images in each dataset given, and to check if the images are not corrupted before using them.
|
||||||
|
|
||||||
## Training
|
## Training
|
||||||
|
|
||||||
Most of the code used for the training
|
**Disclaimer:**
|
||||||
|
|
||||||
|
Most of the code used for the training was used directly from the [dinov2 repo](https://github.com/facebookresearch/dinov2/tree/main),
|
||||||
|
and not every function has been checked. Though it should be functional and do as intended, some aspects might not work as expected.
|
||||||
|
|
||||||
|
### Config file
|
||||||
|
|
||||||
|
config files control every elements of the model, and of its training. The implementation made on your config files will merge with the one by default located at `dinov2/configs/ssl_default_config.yaml`.
|
||||||
|
In our case the minimal requirements for the train config files should be:
|
||||||
|
|
||||||
|
- `dataset_path` (List[str] or str) that indicates the path where the training data is located
|
||||||
|
- `output_dir` (str) that indicates the path where the logs, ckpts and models will be saved
|
||||||
|
|
||||||
|
### Submit training
|
||||||
|
|
||||||
|
The script used to submit the training process is located at `dinov2/run/train/train.py`.
|
||||||
|
An example of command would be `python dinov2/run/train/train.py --config-file dinov2/configs/train/vitl_cellsim_register.yaml --output-dir /home/guevel/OT4D/cell_similarity/vitl_register/ --partition hard --ngpus 2`
|
||||||
|
|
||||||
|
that would launch the training on 1 node with 2 GPUs on the partition named `hard`.
|
||||||
|
> This command makes a sh script with the required slurm constraints that is then executed.
|
||||||
|
|
||||||
|
**The conda env previously created should be activated before launching this command.**
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
|
@ -3,15 +3,27 @@ import pathlib
|
||||||
import random
|
import random
|
||||||
from typing import List
|
from typing import List
|
||||||
from omegaconf.listconfig import ListConfig
|
from omegaconf.listconfig import ListConfig
|
||||||
|
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from .decoders import ImageDataDecoder
|
from .decoders import ImageDataDecoder
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
class ImageDataset(Dataset):
|
class ImageDataset(Dataset):
|
||||||
def __init__(self, root, transform=None, path_preserved: List[str]=[], frac: float=0.1, is_valid=True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
root,
|
||||||
|
transform=None,
|
||||||
|
path_preserved: List[str] = [],
|
||||||
|
frac: float = 0.1,
|
||||||
|
is_valid=True,
|
||||||
|
):
|
||||||
self.root = root
|
self.root = root
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.path_preserved = path_preserved if isinstance(path_preserved, (list, ListConfig)) else [path_preserved]
|
self.path_preserved = (
|
||||||
|
path_preserved
|
||||||
|
if isinstance(path_preserved, (list, ListConfig))
|
||||||
|
else [path_preserved]
|
||||||
|
)
|
||||||
self.frac = frac
|
self.frac = frac
|
||||||
self.preserved_images = []
|
self.preserved_images = []
|
||||||
self.is_valid = is_valid
|
self.is_valid = is_valid
|
||||||
|
@ -24,7 +36,11 @@ class ImageDataset(Dataset):
|
||||||
try:
|
try:
|
||||||
p = self.root
|
p = self.root
|
||||||
preserve = p in self.path_preserved
|
preserve = p in self.path_preserved
|
||||||
images.extend(self._retrieve_images(p, preserve=preserve, frac=self.frac, is_valid=self.is_valid))
|
images.extend(
|
||||||
|
self._retrieve_images(
|
||||||
|
p, preserve=preserve, frac=self.frac, is_valid=self.is_valid
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
except OSError:
|
except OSError:
|
||||||
print("The root given is nor a list nor a path")
|
print("The root given is nor a list nor a path")
|
||||||
|
@ -33,33 +49,37 @@ class ImageDataset(Dataset):
|
||||||
for p in self.root:
|
for p in self.root:
|
||||||
try:
|
try:
|
||||||
preserve = p in self.path_preserved
|
preserve = p in self.path_preserved
|
||||||
images.extend(self._retrieve_images(p, preserve=preserve, frac=self.frac, is_valid=self.is_valid))
|
images.extend(
|
||||||
|
self._retrieve_images(
|
||||||
|
p, preserve=preserve, frac=self.frac, is_valid=self.is_valid
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
except OSError:
|
except OSError:
|
||||||
print(f"the path indicated at {p} cannot be found.")
|
print(f"the path indicated at {p} cannot be found.")
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
def _retrieve_images(self, path, is_valid=True, preserve=False, frac=1):
|
def _retrieve_images(self, path, is_valid=True, preserve=False, frac=1):
|
||||||
images_ini = len(self.preserved_images)
|
images_ini = len(self.preserved_images)
|
||||||
images = []
|
images = []
|
||||||
for root, _, files in os.walk(path):
|
for root, _, files in os.walk(path):
|
||||||
images_dir = []
|
images_dir = []
|
||||||
for file in files:
|
for file in files:
|
||||||
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff')):
|
if file.lower().endswith((".png", ".jpg", ".jpeg", ".tiff")):
|
||||||
im = os.path.join(root, file)
|
im = os.path.join(root, file)
|
||||||
if is_valid:
|
if is_valid:
|
||||||
try:
|
try:
|
||||||
with open(im, 'rb') as f:
|
with open(im, "rb") as f:
|
||||||
image_data = f.read()
|
image_data = f.read()
|
||||||
ImageDataDecoder(image_data).decode()
|
ImageDataDecoder(image_data).decode()
|
||||||
images_dir.append(im)
|
images_dir.append(im)
|
||||||
|
|
||||||
except OSError:
|
except OSError:
|
||||||
print(f"Image at path {im} could not be opened.")
|
print(f"Image at path {im} could not be opened.")
|
||||||
else:
|
else:
|
||||||
images_dir.append(im)
|
images_dir.append(im)
|
||||||
|
|
||||||
if preserve:
|
if preserve:
|
||||||
random.seed(24)
|
random.seed(24)
|
||||||
random.shuffle(images_dir)
|
random.shuffle(images_dir)
|
||||||
|
@ -69,23 +89,25 @@ class ImageDataset(Dataset):
|
||||||
|
|
||||||
else:
|
else:
|
||||||
images.extend(images_dir)
|
images.extend(images_dir)
|
||||||
|
|
||||||
images_end = len(self.preserved_images)
|
images_end = len(self.preserved_images)
|
||||||
if preserve:
|
if preserve:
|
||||||
print(f"{images_end - images_ini} images have been saved for the dataset at path {path}")
|
print(
|
||||||
|
f"{images_end - images_ini} images have been saved for the dataset at path {path}"
|
||||||
|
)
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
def get_image_data(self, index: int):
|
def get_image_data(self, index: int):
|
||||||
path = self.images_list[index]
|
path = self.images_list[index]
|
||||||
with open(path, 'rb') as f:
|
with open(path, "rb") as f:
|
||||||
image_data = f.read()
|
image_data = f.read()
|
||||||
|
|
||||||
return image_data
|
return image_data
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.images_list)
|
return len(self.images_list)
|
||||||
|
|
||||||
def __getitem__(self, index: int):
|
def __getitem__(self, index: int):
|
||||||
try:
|
try:
|
||||||
image_data = self.get_image_data(index)
|
image_data = self.get_image_data(index)
|
||||||
|
|
|
@ -1,19 +0,0 @@
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
path = Path('/home/manon/classification/data/Single_cells/vexas')
|
|
||||||
path_out = Path(os.getcwd()) / 'dataset_bis'
|
|
||||||
|
|
||||||
def main():
|
|
||||||
for d in os.listdir(path):
|
|
||||||
if os.path.isdir(path_out / d):
|
|
||||||
os.rmdir(path_out /d)
|
|
||||||
|
|
||||||
os.mkdir(path_out / d)
|
|
||||||
files = [f for f in os.listdir(path / d) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff'))]
|
|
||||||
for f in files[:64]:
|
|
||||||
shutil.copy(path / d / f, path_out / d / f)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
|
@ -1,23 +1,26 @@
|
||||||
import os
|
from functools import partial
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from functools import partial
|
from dinov2.data.datasets import ImageDataset
|
||||||
from pathlib import Path
|
from dinov2.data.collate import collate_data_and_cast
|
||||||
from cell_similarity.data.datasets import ImageDataset
|
from dinov2.data import (
|
||||||
from cell_similarity.data.collate import collate_data_and_cast
|
DataAugmentationDINO,
|
||||||
from dinov2.data import DataAugmentationDINO, MaskingGenerator, SamplerType, make_data_loader
|
MaskingGenerator,
|
||||||
from dinov2.train.ssl_meta_arch import SSLMetaArch
|
SamplerType,
|
||||||
from dinov2.utils.config import setup
|
make_data_loader,
|
||||||
from dinov2.train.train import get_args_parser
|
)
|
||||||
|
|
||||||
def test_single_path(cfg):
|
cfg = OmegaConf.load(Path(__file__).parent / "config.yaml")
|
||||||
|
|
||||||
|
def test_single_path():
|
||||||
img_size = cfg.crops.global_crops_size
|
img_size = cfg.crops.global_crops_size
|
||||||
patch_size = cfg.student.patch_size
|
patch_size = cfg.student.patch_size
|
||||||
n_tokens = (img_size // patch_size) ** 2
|
n_tokens = (img_size // patch_size) ** 2
|
||||||
mask_generator = MaskingGenerator(
|
mask_generator = MaskingGenerator(
|
||||||
input_size=(img_size // patch_size, img_size // patch_size),
|
input_size=(img_size // patch_size, img_size // patch_size),
|
||||||
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size
|
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
|
||||||
)
|
)
|
||||||
inputs_dtype = torch.half
|
inputs_dtype = torch.half
|
||||||
|
|
||||||
|
@ -37,8 +40,8 @@ def test_single_path(cfg):
|
||||||
mask_generator=mask_generator,
|
mask_generator=mask_generator,
|
||||||
dtype=inputs_dtype,
|
dtype=inputs_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
path_dataset_test = os.path.join(os.getcwd(), 'dataset_test')
|
path_dataset_test = Path(__file__).parent / "dataset_test"
|
||||||
dataset = ImageDataset(root=path_dataset_test, transform=data_transform)
|
dataset = ImageDataset(root=path_dataset_test, transform=data_transform)
|
||||||
|
|
||||||
sampler_type = SamplerType.SHARDED_INFINITE
|
sampler_type = SamplerType.SHARDED_INFINITE
|
||||||
|
@ -48,24 +51,27 @@ def test_single_path(cfg):
|
||||||
num_workers=cfg.train.num_workers,
|
num_workers=cfg.train.num_workers,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
sampler_type=sampler_type,
|
sampler_type=sampler_type,
|
||||||
sampler_advance=0,
|
sampler_advance=0,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in data_loader:
|
for i in data_loader:
|
||||||
assert i['collated_global_crops'].shape[0] == cfg.train.batch_size_per_gpu * 2
|
assert i["collated_global_crops"].shape[0] == cfg.train.batch_size_per_gpu * 2
|
||||||
assert i['collated_local_crops'].shape[0] == cfg.train.batch_size_per_gpu * cfg.crops.local_crops_number
|
assert (
|
||||||
|
i["collated_local_crops"].shape[0]
|
||||||
|
== cfg.train.batch_size_per_gpu * cfg.crops.local_crops_number
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
def test_several_paths(cfg):
|
|
||||||
|
def test_several_paths():
|
||||||
img_size = cfg.crops.global_crops_size
|
img_size = cfg.crops.global_crops_size
|
||||||
patch_size = cfg.student.patch_size
|
patch_size = cfg.student.patch_size
|
||||||
n_tokens = (img_size // patch_size) ** 2
|
n_tokens = (img_size // patch_size) ** 2
|
||||||
mask_generator = MaskingGenerator(
|
mask_generator = MaskingGenerator(
|
||||||
input_size=(img_size // patch_size, img_size // patch_size),
|
input_size=(img_size // patch_size, img_size // patch_size),
|
||||||
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size
|
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
|
||||||
)
|
)
|
||||||
inputs_dtype = torch.half
|
inputs_dtype = torch.half
|
||||||
|
|
||||||
|
@ -86,8 +92,7 @@ def test_several_paths(cfg):
|
||||||
dtype=inputs_dtype,
|
dtype=inputs_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
base_path = Path(os.getcwd())
|
dirs = [Path(__file__).parent / i for i in ["dataset_test", "dataset_bis"]]
|
||||||
dirs = ['dataset_test', 'dataset_bis']
|
|
||||||
dataset = ImageDataset(root=dirs, transform=data_transform)
|
dataset = ImageDataset(root=dirs, transform=data_transform)
|
||||||
sampler_type = SamplerType.SHARDED_INFINITE
|
sampler_type = SamplerType.SHARDED_INFINITE
|
||||||
data_loader = make_data_loader(
|
data_loader = make_data_loader(
|
||||||
|
@ -96,20 +101,20 @@ def test_several_paths(cfg):
|
||||||
num_workers=cfg.train.num_workers,
|
num_workers=cfg.train.num_workers,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
sampler_type=sampler_type,
|
sampler_type=sampler_type,
|
||||||
sampler_advance=0,
|
sampler_advance=0,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in data_loader:
|
for i in data_loader:
|
||||||
assert i['collated_global_crops'].shape[0] == cfg.train.batch_size_per_gpu * 2
|
assert i["collated_global_crops"].shape[0] == cfg.train.batch_size_per_gpu * 2
|
||||||
assert i['collated_local_crops'].shape[0] == cfg.train.batch_size_per_gpu * cfg.crops.local_crops_number
|
assert (
|
||||||
|
i["collated_local_crops"].shape[0]
|
||||||
|
== cfg.train.batch_size_per_gpu * cfg.crops.local_crops_number
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = get_args_parser(add_help=True).parse_args()
|
if __name__ == "__main__":
|
||||||
cfg = setup(args)
|
test_single_path()
|
||||||
test_single_path(cfg)
|
test_several_paths()
|
||||||
print("test_single_path succesfull")
|
|
||||||
test_several_paths(cfg)
|
|
||||||
print("test_several_paths successfull")
|
|
||||||
|
|
|
@ -1,47 +1,69 @@
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from cell_similarity.data.datasets import ImageDataset
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
from dinov2.data.datasets import ImageDataset
|
||||||
|
|
||||||
|
|
||||||
def test_single_path():
|
def test_single_path():
|
||||||
|
path_dataset_test = Path(__file__).parent / "dataset_test"
|
||||||
path_dataset_test = Path(os.getcwd()) / 'dataset_test'
|
transform = transforms.Compose(
|
||||||
transform = transforms.Compose([
|
[
|
||||||
transforms.Resize((256,256)),
|
transforms.Resize((256, 256)),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
dataset = ImageDataset(path_dataset_test, transform=transform)
|
dataset = ImageDataset(path_dataset_test, transform=transform)
|
||||||
|
|
||||||
assert dataset.__len__() == len([i for i in os.listdir(path_dataset_test) if i.endswith(('.png', '.jpg', '.jpeg', '.tiff'))])
|
assert dataset.__len__() == len(
|
||||||
|
[
|
||||||
|
i
|
||||||
|
for i in os.listdir(path_dataset_test)
|
||||||
|
if i.endswith((".png", ".jpg", ".jpeg", ".tiff"))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, batch_size=32)
|
dataloader = DataLoader(dataset, batch_size=32)
|
||||||
for i in dataloader:
|
for i in dataloader:
|
||||||
assert len(i) == 32
|
assert len(i) == 32
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
base_path = Path(os.getcwd())
|
base_path = Path(os.getcwd())
|
||||||
dirs = ['dataset_test', 'dataset_bis']
|
dirs = ["dataset_test", "dataset_bis"]
|
||||||
|
|
||||||
|
|
||||||
def test_several_paths():
|
def test_several_paths():
|
||||||
|
transform = transforms.Compose(
|
||||||
transform = transforms.Compose([
|
[
|
||||||
transforms.Resize((256, 256)),
|
transforms.Resize((256, 256)),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.226])
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.226]),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
dataset = ImageDataset(root=dirs, transform=transform)
|
dataset = ImageDataset(root=dirs, transform=transform)
|
||||||
|
|
||||||
expected_length = len([f for d in dirs for d, _, files in os.walk(d) for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff'))])
|
expected_length = len(
|
||||||
|
[
|
||||||
|
f
|
||||||
|
for d in dirs
|
||||||
|
for d, _, files in os.walk(d)
|
||||||
|
for f in files
|
||||||
|
if f.lower().endswith((".png", ".jpg", ".jpeg", ".tiff"))
|
||||||
|
]
|
||||||
|
)
|
||||||
assert dataset.__len__() == expected_length
|
assert dataset.__len__() == expected_length
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, batch_size=32)
|
dataloader = DataLoader(dataset, batch_size=32)
|
||||||
for i in dataloader:
|
for i in dataloader:
|
||||||
assert len(i)==32
|
assert len(i) == 32
|
||||||
break
|
break
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
test_single_path()
|
test_single_path()
|
||||||
test_several_paths()
|
test_several_paths()
|
||||||
|
|
|
@ -1,28 +1,29 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from cell_similarity.data.datasets import ImageDataset
|
from dinov2.data.datasets import ImageDataset
|
||||||
from dinov2.data import DataAugmentationDINO
|
from dinov2.data import DataAugmentationDINO
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
def test():
|
|
||||||
|
|
||||||
|
def test():
|
||||||
data_transform = DataAugmentationDINO(
|
data_transform = DataAugmentationDINO(
|
||||||
(0.32, 1.0),
|
(0.32, 1.0),
|
||||||
(0.05, 0.32),
|
(0.05, 0.32),
|
||||||
8,
|
8,
|
||||||
)
|
)
|
||||||
|
|
||||||
path_dataset = os.path.join(os.getcwd(), 'dataset_test')
|
path_dataset = os.path.join(os.getcwd(), "dataset_test")
|
||||||
dataset = ImageDataset(root=path_dataset, transform=data_transform)
|
dataset = ImageDataset(root=path_dataset, transform=data_transform)
|
||||||
dataloader = DataLoader(dataset, batch_size=32)
|
dataloader = DataLoader(dataset, batch_size=32)
|
||||||
|
|
||||||
for i in dataloader:
|
for i in dataloader:
|
||||||
assert len(i['global_crops']) == 2
|
assert len(i["global_crops"]) == 2
|
||||||
assert i['global_crops'][0].shape == torch.Size([32, 3, 224, 224])
|
assert i["global_crops"][0].shape == torch.Size([32, 3, 224, 224])
|
||||||
assert len(i['local_crops']) == 8
|
assert len(i["local_crops"]) == 8
|
||||||
assert i['local_crops'][0].shape == torch.Size([32, 3, 96, 96])
|
assert i["local_crops"][0].shape == torch.Size([32, 3, 96, 96])
|
||||||
break
|
break
|
||||||
|
|
||||||
if __name__ =='__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
test()
|
test()
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
import logging
|
import logging
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from dinov2.train.ssl_meta_arch import SSLMetaArch
|
from dinov2.train.ssl_meta_arch import SSLMetaArch
|
||||||
from dinov2.train.train import get_args_parser
|
|
||||||
from dinov2.utils.config import setup
|
|
||||||
|
|
||||||
logger = logging.getLogger("dinov2")
|
logger = logging.getLogger("dinov2")
|
||||||
|
cfg = OmegaConf.load(Path(__file__).parent / "config.yaml")
|
||||||
|
|
||||||
def test(args):
|
def test():
|
||||||
cfg = setup(args)
|
|
||||||
|
|
||||||
model = SSLMetaArch(cfg)
|
model = SSLMetaArch(cfg)
|
||||||
logger.info("Model: \n{}".format(model))
|
logger.info("Model: \n{}".format(model))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = get_args_parser(add_help=True).parse_args()
|
test()
|
||||||
test(args)
|
|
||||||
|
|
|
@ -1,22 +1,25 @@
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from dinov2.train.ssl_meta_arch import SSLMetaArch
|
from dinov2.train.ssl_meta_arch import SSLMetaArch
|
||||||
from dinov2.train.train import get_args_parser
|
from dinov2.train.train import do_train
|
||||||
from dinov2.utils.config import setup
|
|
||||||
from cell_similarity.training.train import do_train
|
|
||||||
|
|
||||||
logger = logging.getLogger("dinov2")
|
logger = logging.getLogger("dinov2")
|
||||||
|
cfg = OmegaConf.load(Path(__file__).parent / "config.yaml")
|
||||||
|
|
||||||
def test(args):
|
|
||||||
|
|
||||||
cfg = setup(args)
|
def test():
|
||||||
model = SSLMetaArch(cfg).to(torch.device("cuda"))
|
if torch.cuda.is_available():
|
||||||
model.prepare_for_distributed_training()
|
model = SSLMetaArch(cfg).to(torch.device("cuda"))
|
||||||
|
model.prepare_for_distributed_training()
|
||||||
|
|
||||||
|
logger.info("Model:\n {}".format(model))
|
||||||
|
do_train(cfg, model, resume=False)
|
||||||
|
else:
|
||||||
|
print("Unable to assess the training test, as no cuda devices were found")
|
||||||
|
|
||||||
logger.info("Model:\n {}".format(model))
|
|
||||||
do_train(cfg, model, resume=False)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = get_args_parser(add_help=True).parse_args()
|
test()
|
||||||
test(args)
|
|
||||||
|
|
Loading…
Reference in New Issue