from abc import ABCMeta, abstractmethod from PIL import Image from torchvision.datasets import CIFAR10, CIFAR100 from ..registry import DATASOURCES class Cifar(metaclass=ABCMeta): CLASSES = None def __init__(self, root, split, has_labels=True): assert split in ['train', 'test'] self.root = root self.split = split self.has_labels = has_labels self.cifar = None self.set_cifar() self.labels = self.cifar.targets @abstractmethod def set_cifar(self): pass def get_length(self): return len(self.cifar) def get_sample(self, idx): img = Image.fromarray(self.cifar.data[idx]) if self.has_labels: target = self.labels[idx] # img: HWC, RGB return img, target else: return img @DATASOURCES.register_module class Cifar10(Cifar): CLASSES = [ 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' ] def __init__(self, root, split, has_labels=True): super().__init__(root, split, has_labels) def set_cifar(self): try: self.cifar = CIFAR10( root=self.root, train=self.split == 'train', download=False) except: raise Exception("Please download CIFAR10 manually, \ in case of downloading the dataset parallelly \ that may corrupt the dataset.") @DATASOURCES.register_module class Cifar100(Cifar): def __init__(self, root, split, has_labels=True): super().__init__(root, split, has_labels) def set_cifar(self): try: self.cifar = CIFAR100( root=self.root, train=self.split == 'train', download=False) except: raise Exception("Please download CIFAR10 manually, \ in case of downloading the dataset parallelly \ that may corrupt the dataset.")