30 lines
770 B
Python
30 lines
770 B
Python
import os
|
|
|
|
import torch
|
|
from dinov2.data.datasets import ImageDataset
|
|
from dinov2.data import DataAugmentationDINO
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
def test():
|
|
data_transform = DataAugmentationDINO(
|
|
(0.32, 1.0),
|
|
(0.05, 0.32),
|
|
8,
|
|
)
|
|
|
|
path_dataset = os.path.join(os.getcwd(), "dataset_test")
|
|
dataset = ImageDataset(root=path_dataset, transform=data_transform)
|
|
dataloader = DataLoader(dataset, batch_size=32)
|
|
|
|
for i in dataloader:
|
|
assert len(i["global_crops"]) == 2
|
|
assert i["global_crops"][0].shape == torch.Size([32, 3, 224, 224])
|
|
assert len(i["local_crops"]) == 8
|
|
assert i["local_crops"][0].shape == torch.Size([32, 3, 96, 96])
|
|
break
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test()
|