mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
* add data_source imagenet * modify data_source imagenet and add unittest * modify data_source imagenet and modify unittest * modify voc data_source and modify voc unittest and download Part * modify coco data_source and modify coco unittest and add download Part , modify voc data_source * add dataset metadata Format specification * add pose download data , modify coco.py and modiy download file function ,add test coco download a part * modify download file function * modify download of cifar10 and cifar100 * modify dataset_name to target_dir * create download_util and modify function * modify function * modify function * add test case , modify * modify something * modify something * modify something * modify something * modify something * modify something * add wget * add wget * modify * add new modify * add new modify * modify something * modify something and add something * modify something and add something * modify something and add something * modify something and add something * modify something and add something * modify test case * modify test case * modify test case * modify test case * modify test case
53 lines
1.4 KiB
Python
53 lines
1.4 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
from PIL import Image
|
|
from torchvision.datasets import CIFAR10, CIFAR100
|
|
|
|
from easycv.datasets.registry import DATASOURCES
|
|
|
|
|
|
@DATASOURCES.register_module
|
|
class ClsSourceCifar10(object):
|
|
|
|
CLASSES = [
|
|
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
|
|
'horse', 'ship', 'truck'
|
|
]
|
|
|
|
def __init__(self, root, split, download=True):
|
|
assert split in ['train', 'test']
|
|
self.cifar = CIFAR10(
|
|
root=root, train=(split == 'train'), download=download)
|
|
self.labels = self.cifar.targets
|
|
|
|
def __len__(self):
|
|
return len(self.cifar)
|
|
|
|
def __getitem__(self, idx):
|
|
img = Image.fromarray(self.cifar.data[idx])
|
|
label = self.labels[idx] # img: HWC, RGB
|
|
result_dict = {'img': img, 'gt_labels': label}
|
|
return result_dict
|
|
|
|
|
|
@DATASOURCES.register_module
|
|
class ClsSourceCifar100(object):
|
|
|
|
CLASSES = None
|
|
|
|
def __init__(self, root, split, download=True):
|
|
assert split in ['train', 'test']
|
|
|
|
self.cifar = CIFAR100(
|
|
root=root, train=(split == 'train'), download=download)
|
|
|
|
self.labels = self.cifar.targets
|
|
|
|
def __len__(self):
|
|
return len(self.cifar)
|
|
|
|
def __getitem__(self, idx):
|
|
img = Image.fromarray(self.cifar.data[idx])
|
|
label = self.labels[idx] # img: HWC, RGB
|
|
result_dict = {'img': img, 'gt_labels': label}
|
|
return result_dict
|