mmselfsup/openselfsup/datasets/byol.py

36 lines
1.1 KiB
Python

import torch
from torch.utils.data import Dataset
from openselfsup.utils import build_from_cfg
from torchvision.transforms import Compose
from .registry import DATASETS, PIPELINES
from .builder import build_datasource
@DATASETS.register_module
class BYOLDataset(Dataset):
"""Dataset for BYOL.
"""
def __init__(self, data_source, pipeline1, pipeline2):
self.data_source = build_datasource(data_source)
pipeline1 = [build_from_cfg(p, PIPELINES) for p in pipeline1]
self.pipeline1 = Compose(pipeline1)
pipeline2 = [build_from_cfg(p, PIPELINES) for p in pipeline2]
self.pipeline2 = Compose(pipeline2)
def __len__(self):
return self.data_source.get_length()
def __getitem__(self, idx):
img = self.data_source.get_sample(idx)
img1 = self.pipeline1(img)
img2 = self.pipeline2(img)
img_cat = torch.cat((img1.unsqueeze(0), img2.unsqueeze(0)), dim=0)
return dict(img=img_cat)
def evaluate(self, scores, keyword, logger=None, **kwargs):
raise NotImplemented