dinov2/tests/data/test_datatransformation.py

30 lines
770 B
Python

import os
import torch
from dinov2.data.datasets import ImageDataset
from dinov2.data import DataAugmentationDINO
from torch.utils.data import DataLoader
def test():
data_transform = DataAugmentationDINO(
(0.32, 1.0),
(0.05, 0.32),
8,
)
path_dataset = os.path.join(os.getcwd(), "dataset_test")
dataset = ImageDataset(root=path_dataset, transform=data_transform)
dataloader = DataLoader(dataset, batch_size=32)
for i in dataloader:
assert len(i["global_crops"]) == 2
assert i["global_crops"][0].shape == torch.Size([32, 3, 224, 224])
assert len(i["local_crops"]) == 8
assert i["local_crops"][0].shape == torch.Size([32, 3, 96, 96])
break
if __name__ == "__main__":
test()