diff --git a/openselfsup/datasets/contrastive.py b/openselfsup/datasets/contrastive.py index f323dd7f..86179d75 100644 --- a/openselfsup/datasets/contrastive.py +++ b/openselfsup/datasets/contrastive.py @@ -12,7 +12,7 @@ class ContrastiveDataset(BaseDataset): """ def __init__(self, data_source, pipeline, prefetch=False): - data_source['has_labels'] = False + data_source['return_label'] = False super(ContrastiveDataset, self).__init__(data_source, pipeline, prefetch) def __getitem__(self, idx): diff --git a/openselfsup/datasets/data_sources/cifar.py b/openselfsup/datasets/data_sources/cifar.py index 5665e45c..0be150fb 100644 --- a/openselfsup/datasets/data_sources/cifar.py +++ b/openselfsup/datasets/data_sources/cifar.py @@ -10,11 +10,11 @@ class Cifar(metaclass=ABCMeta): CLASSES = None - def __init__(self, root, split, has_labels=True): + def __init__(self, root, split, return_label=True): assert split in ['train', 'test'] self.root = root self.split = split - self.has_labels = has_labels + self.return_label = return_label self.cifar = None self.set_cifar() self.labels = self.cifar.targets @@ -28,7 +28,7 @@ class Cifar(metaclass=ABCMeta): def get_sample(self, idx): img = Image.fromarray(self.cifar.data[idx]) - if self.has_labels: + if self.return_label: target = self.labels[idx] # img: HWC, RGB return img, target else: @@ -43,8 +43,8 @@ class Cifar10(Cifar): 'horse', 'ship', 'truck' ] - def __init__(self, root, split, has_labels=True): - super().__init__(root, split, has_labels) + def __init__(self, root, split, return_label=True): + super().__init__(root, split, return_label) def set_cifar(self): try: @@ -59,8 +59,8 @@ class Cifar10(Cifar): @DATASOURCES.register_module class Cifar100(Cifar): - def __init__(self, root, split, has_labels=True): - super().__init__(root, split, has_labels) + def __init__(self, root, split, return_label=True): + super().__init__(root, split, return_label) def set_cifar(self): try: diff --git a/openselfsup/datasets/data_sources/image_list.py b/openselfsup/datasets/data_sources/image_list.py index fb7dec02..ad64b711 100644 --- a/openselfsup/datasets/data_sources/image_list.py +++ b/openselfsup/datasets/data_sources/image_list.py @@ -8,14 +8,16 @@ from .utils import McLoader @DATASOURCES.register_module class ImageList(object): - def __init__(self, root, list_file, memcached=False, mclient_path=None): + def __init__(self, root, list_file, memcached=False, mclient_path=None, return_label=True): with open(list_file, 'r') as f: lines = f.readlines() self.has_labels = len(lines[0].split()) == 2 + self.return_label = return_label if self.has_labels: self.fns, self.labels = zip(*[l.strip().split() for l in lines]) self.labels = [int(l) for l in self.labels] else: + assert self.return_label is False self.fns = [l.strip() for l in lines] self.fns = [os.path.join(root, fn) for fn in self.fns] self.memcached = memcached @@ -39,7 +41,7 @@ class ImageList(object): else: img = Image.open(self.fns[idx]) img = img.convert('RGB') - if self.has_labels: + if self.has_labels and self.return_label: target = self.labels[idx] return img, target else: diff --git a/openselfsup/datasets/data_sources/imagenet.py b/openselfsup/datasets/data_sources/imagenet.py index e42e2d9b..4d7617ee 100644 --- a/openselfsup/datasets/data_sources/imagenet.py +++ b/openselfsup/datasets/data_sources/imagenet.py @@ -5,6 +5,6 @@ from .image_list import ImageList @DATASOURCES.register_module class ImageNet(ImageList): - def __init__(self, root, list_file, memcached, mclient_path): + def __init__(self, root, list_file, memcached, mclient_path, return_label=True, *args, **kwargs): super(ImageNet, self).__init__( - root, list_file, memcached, mclient_path) + root, list_file, memcached, mclient_path, return_label) diff --git a/openselfsup/datasets/data_sources/places205.py b/openselfsup/datasets/data_sources/places205.py index 28815ccc..e2e49154 100644 --- a/openselfsup/datasets/data_sources/places205.py +++ b/openselfsup/datasets/data_sources/places205.py @@ -5,6 +5,6 @@ from .image_list import ImageList @DATASOURCES.register_module class Places205(ImageList): - def __init__(self, root, list_file, memcached, mclient_path): + def __init__(self, root, list_file, memcached, mclient_path, return_label=True, *args, **kwargs): super(Places205, self).__init__( - root, list_file, memcached, mclient_path) + root, list_file, memcached, mclient_path, return_label)