commit
5d7a1b8ef5
|
@ -12,6 +12,7 @@ class ContrastiveDataset(BaseDataset):
|
|||
"""
|
||||
|
||||
def __init__(self, data_source, pipeline, prefetch=False):
|
||||
data_source['return_label'] = False
|
||||
super(ContrastiveDataset, self).__init__(data_source, pipeline, prefetch)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
from PIL import Image
|
||||
|
||||
from torchvision.datasets import CIFAR10, CIFAR100
|
||||
|
@ -5,51 +6,67 @@ from torchvision.datasets import CIFAR10, CIFAR100
|
|||
from ..registry import DATASOURCES
|
||||
|
||||
|
||||
@DATASOURCES.register_module
|
||||
class Cifar10(object):
|
||||
class Cifar(metaclass=ABCMeta):
|
||||
|
||||
CLASSES = [
|
||||
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
|
||||
'horse', 'ship', 'truck'
|
||||
]
|
||||
CLASSES = None
|
||||
|
||||
def __init__(self, root, split):
|
||||
def __init__(self, root, split, return_label=True):
|
||||
assert split in ['train', 'test']
|
||||
try:
|
||||
self.cifar = CIFAR10(
|
||||
root=root, train=split == 'train', download=False)
|
||||
except:
|
||||
raise Exception("Please download CIFAR10 manually, \
|
||||
in case of downloading the dataset parallelly \
|
||||
that may corrupt the dataset.")
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.return_label = return_label
|
||||
self.cifar = None
|
||||
self.set_cifar()
|
||||
self.labels = self.cifar.targets
|
||||
|
||||
@abstractmethod
|
||||
def set_cifar(self):
|
||||
pass
|
||||
|
||||
def get_length(self):
|
||||
return len(self.cifar)
|
||||
|
||||
def get_sample(self, idx):
|
||||
img = Image.fromarray(self.cifar.data[idx])
|
||||
target = self.labels[idx] # img: HWC, RGB
|
||||
return img, target
|
||||
if self.return_label:
|
||||
target = self.labels[idx] # img: HWC, RGB
|
||||
return img, target
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
@DATASOURCES.register_module
|
||||
class Cifar100(object):
|
||||
class Cifar10(Cifar):
|
||||
|
||||
CLASSES = None
|
||||
CLASSES = [
|
||||
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
|
||||
'horse', 'ship', 'truck'
|
||||
]
|
||||
|
||||
def __init__(self, root, split):
|
||||
assert split in ['train', 'test']
|
||||
def __init__(self, root, split, return_label=True):
|
||||
super().__init__(root, split, return_label)
|
||||
|
||||
def set_cifar(self):
|
||||
try:
|
||||
self.cifar = CIFAR100(
|
||||
root=root, train=split == 'train', download=False)
|
||||
self.cifar = CIFAR10(
|
||||
root=self.root, train=self.split == 'train', download=False)
|
||||
except:
|
||||
raise Exception("Please download CIFAR10 manually, \
|
||||
in case of downloading the dataset parallelly \
|
||||
that may corrupt the dataset.")
|
||||
self.labels = self.cifar.targets
|
||||
|
||||
def get_sample(self, idx):
|
||||
img = Image.fromarray(self.cifar.data[idx])
|
||||
target = self.labels[idx] # img: HWC, RGB
|
||||
return img, target
|
||||
|
||||
@DATASOURCES.register_module
|
||||
class Cifar100(Cifar):
|
||||
|
||||
def __init__(self, root, split, return_label=True):
|
||||
super().__init__(root, split, return_label)
|
||||
|
||||
def set_cifar(self):
|
||||
try:
|
||||
self.cifar = CIFAR100(
|
||||
root=self.root, train=self.split == 'train', download=False)
|
||||
except:
|
||||
raise Exception("Please download CIFAR10 manually, \
|
||||
in case of downloading the dataset parallelly \
|
||||
that may corrupt the dataset.")
|
||||
|
|
|
@ -8,14 +8,16 @@ from .utils import McLoader
|
|||
@DATASOURCES.register_module
|
||||
class ImageList(object):
|
||||
|
||||
def __init__(self, root, list_file, memcached=False, mclient_path=None):
|
||||
def __init__(self, root, list_file, memcached=False, mclient_path=None, return_label=True):
|
||||
with open(list_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
self.has_labels = len(lines[0].split()) == 2
|
||||
self.return_label = return_label
|
||||
if self.has_labels:
|
||||
self.fns, self.labels = zip(*[l.strip().split() for l in lines])
|
||||
self.labels = [int(l) for l in self.labels]
|
||||
else:
|
||||
assert self.return_label is False
|
||||
self.fns = [l.strip() for l in lines]
|
||||
self.fns = [os.path.join(root, fn) for fn in self.fns]
|
||||
self.memcached = memcached
|
||||
|
@ -39,7 +41,7 @@ class ImageList(object):
|
|||
else:
|
||||
img = Image.open(self.fns[idx])
|
||||
img = img.convert('RGB')
|
||||
if self.has_labels:
|
||||
if self.has_labels and self.return_label:
|
||||
target = self.labels[idx]
|
||||
return img, target
|
||||
else:
|
||||
|
|
|
@ -5,6 +5,6 @@ from .image_list import ImageList
|
|||
@DATASOURCES.register_module
|
||||
class ImageNet(ImageList):
|
||||
|
||||
def __init__(self, root, list_file, memcached, mclient_path):
|
||||
def __init__(self, root, list_file, memcached, mclient_path, return_label=True, *args, **kwargs):
|
||||
super(ImageNet, self).__init__(
|
||||
root, list_file, memcached, mclient_path)
|
||||
root, list_file, memcached, mclient_path, return_label)
|
||||
|
|
|
@ -5,6 +5,6 @@ from .image_list import ImageList
|
|||
@DATASOURCES.register_module
|
||||
class Places205(ImageList):
|
||||
|
||||
def __init__(self, root, list_file, memcached, mclient_path):
|
||||
def __init__(self, root, list_file, memcached, mclient_path, return_label=True, *args, **kwargs):
|
||||
super(Places205, self).__init__(
|
||||
root, list_file, memcached, mclient_path)
|
||||
root, list_file, memcached, mclient_path, return_label)
|
||||
|
|
Loading…
Reference in New Issue