73 lines
1.9 KiB
Python
Raw Normal View History

2020-12-19 17:55:32 +08:00
from abc import ABCMeta, abstractmethod
2020-06-16 00:05:18 +08:00
from PIL import Image
from torchvision.datasets import CIFAR10, CIFAR100
from ..registry import DATASOURCES
2020-12-19 17:55:32 +08:00
class Cifar(metaclass=ABCMeta):
CLASSES = None
def __init__(self, root, split, return_label=True):
2020-12-19 17:55:32 +08:00
assert split in ['train', 'test']
self.root = root
self.split = split
self.return_label = return_label
2020-12-19 17:55:32 +08:00
self.cifar = None
self.set_cifar()
self.labels = self.cifar.targets
@abstractmethod
def set_cifar(self):
pass
def get_length(self):
return len(self.cifar)
def get_sample(self, idx):
img = Image.fromarray(self.cifar.data[idx])
if self.return_label:
2020-12-19 17:55:32 +08:00
target = self.labels[idx] # img: HWC, RGB
return img, target
else:
return img
2020-06-16 00:05:18 +08:00
@DATASOURCES.register_module
2020-12-19 17:55:32 +08:00
class Cifar10(Cifar):
2020-06-16 00:05:18 +08:00
CLASSES = [
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
'horse', 'ship', 'truck'
]
def __init__(self, root, split, return_label=True):
super().__init__(root, split, return_label)
2020-12-19 17:55:32 +08:00
def set_cifar(self):
2020-06-16 00:05:18 +08:00
try:
self.cifar = CIFAR10(
2020-12-19 17:55:32 +08:00
root=self.root, train=self.split == 'train', download=False)
2020-06-16 00:05:18 +08:00
except:
raise Exception("Please download CIFAR10 manually, \
in case of downloading the dataset parallelly \
that may corrupt the dataset.")
@DATASOURCES.register_module
2020-12-19 17:55:32 +08:00
class Cifar100(Cifar):
2020-06-16 00:05:18 +08:00
def __init__(self, root, split, return_label=True):
super().__init__(root, split, return_label)
2020-06-16 00:05:18 +08:00
2020-12-19 17:55:32 +08:00
def set_cifar(self):
2020-06-16 00:05:18 +08:00
try:
self.cifar = CIFAR100(
2020-12-19 17:55:32 +08:00
root=self.root, train=self.split == 'train', download=False)
2020-06-16 00:05:18 +08:00
except:
raise Exception("Please download CIFAR10 manually, \
in case of downloading the dataset parallelly \
that may corrupt the dataset.")