dinov2/distillation/datasets/GTA5.py

83 lines
2.4 KiB
Python

import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from typing import Optional
from torchvision import transforms
class GTA5Dataset(Dataset):
def __init__(self, img_dir: str = "/home/arda/.cache/kagglehub/datasets/ardaerendoru/gtagta/versions/1/GTA5/GTA5/images", transform: Optional[transforms.Compose] = None):
self.img_dir = img_dir
self.transform = transform
# Get all image files
self.images = []
for img_name in os.listdir(self.img_dir):
if img_name.endswith(('.jpg', '.png')):
self.images.append(os.path.join(self.img_dir, img_name))
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from datasets.GTA5 import GTA5Dataset
from datasets.collate_fn import collate_data_and_cast
class CustomDataModule(pl.LightningDataModule):
def __init__(
self,
data_dir: str,
transform,
batch_size: int = 32,
num_workers: int = 4,
train_val_split: float = 0.9
):
super().__init__()
self.data_dir = data_dir
self.transform = transform
self.batch_size = batch_size
self.num_workers = num_workers
self.train_val_split = train_val_split
def setup(self, stage=None):
dataset = GTA5Dataset(img_dir=self.data_dir, transform=self.transform)
train_size = int(self.train_val_split * len(dataset))
val_size = len(dataset) - train_size
self.train_dataset, self.val_dataset = random_split(
dataset,
[train_size, val_size]
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
collate_fn=collate_data_and_cast
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
collate_fn=collate_data_and_cast
)