dinov2/tests/data/test_dataloader.py

116 lines
3.9 KiB
Python

import os
import torch
from functools import partial
from pathlib import Path
from cell_similarity.data.datasets import ImageDataset
from cell_similarity.data.collate import collate_data_and_cast
from dinov2.data import DataAugmentationDINO, MaskingGenerator, SamplerType, make_data_loader
from dinov2.train.ssl_meta_arch import SSLMetaArch
from dinov2.utils.config import setup
from dinov2.train.train import get_args_parser
def test_single_path(cfg):
img_size = cfg.crops.global_crops_size
patch_size = cfg.student.patch_size
n_tokens = (img_size // patch_size) ** 2
mask_generator = MaskingGenerator(
input_size=(img_size // patch_size, img_size // patch_size),
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size
)
inputs_dtype = torch.half
data_transform = DataAugmentationDINO(
cfg.crops.global_crops_scale,
cfg.crops.local_crops_scale,
cfg.crops.local_crops_number,
global_crops_size=cfg.crops.global_crops_size,
local_crops_size=cfg.crops.local_crops_size,
)
collate_fn = partial(
collate_data_and_cast,
mask_ratio_tuple=cfg.ibot.mask_ratio_min_max,
mask_probability=cfg.ibot.mask_sample_probability,
n_tokens=n_tokens,
mask_generator=mask_generator,
dtype=inputs_dtype,
)
path_dataset_test = os.path.join(os.getcwd(), 'dataset_test')
dataset = ImageDataset(root=path_dataset_test, transform=data_transform)
sampler_type = SamplerType.SHARDED_INFINITE
data_loader = make_data_loader(
dataset=dataset,
batch_size=cfg.train.batch_size_per_gpu,
num_workers=cfg.train.num_workers,
shuffle=True,
sampler_type=sampler_type,
sampler_advance=0,
drop_last=True,
collate_fn=collate_fn,
)
for i in data_loader:
assert i['collated_global_crops'].shape[0] == cfg.train.batch_size_per_gpu * 2
assert i['collated_local_crops'].shape[0] == cfg.train.batch_size_per_gpu * cfg.crops.local_crops_number
break
def test_several_paths(cfg):
img_size = cfg.crops.global_crops_size
patch_size = cfg.student.patch_size
n_tokens = (img_size // patch_size) ** 2
mask_generator = MaskingGenerator(
input_size=(img_size // patch_size, img_size // patch_size),
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size
)
inputs_dtype = torch.half
data_transform = DataAugmentationDINO(
cfg.crops.global_crops_scale,
cfg.crops.local_crops_scale,
cfg.crops.local_crops_number,
global_crops_size=cfg.crops.global_crops_size,
local_crops_size=cfg.crops.local_crops_size,
)
collate_fn = partial(
collate_data_and_cast,
mask_ratio_tuple=cfg.ibot.mask_ratio_min_max,
mask_probability=cfg.ibot.mask_sample_probability,
n_tokens=n_tokens,
mask_generator=mask_generator,
dtype=inputs_dtype,
)
base_path = Path(os.getcwd())
dirs = ['dataset_test', 'dataset_bis']
dataset = ImageDataset(root=dirs, transform=data_transform)
sampler_type = SamplerType.SHARDED_INFINITE
data_loader = make_data_loader(
dataset=dataset,
batch_size=cfg.train.batch_size_per_gpu,
num_workers=cfg.train.num_workers,
shuffle=True,
sampler_type=sampler_type,
sampler_advance=0,
drop_last=True,
collate_fn=collate_fn,
)
for i in data_loader:
assert i['collated_global_crops'].shape[0] == cfg.train.batch_size_per_gpu * 2
assert i['collated_local_crops'].shape[0] == cfg.train.batch_size_per_gpu * cfg.crops.local_crops_number
break
if __name__ == '__main__':
args = get_args_parser(add_help=True).parse_args()
cfg = setup(args)
test_single_path(cfg)
print("test_single_path succesfull")
test_several_paths(cfg)
print("test_several_paths successfull")