gulou 645fed2d1b
add data source and support for automatic download (#206)
* add data_source imagenet

* modify data_source imagenet and add unittest

* modify data_source imagenet and modify unittest

* modify voc data_source and modify voc unittest and download Part

* modify coco data_source and modify coco unittest and  add download Part , modify voc data_source

* add dataset metadata Format specification

* add pose download data , modify coco.py and modiy download file function ,add test coco download a part

* modify download file function

* modify download of cifar10 and cifar100

* modify dataset_name  to target_dir

* create download_util and modify function

* modify function

* modify function

* add test case , modify

* modify something

* modify something

* modify something

* modify something

* modify something

* modify something

* add wget

* add wget

* modify

* add new modify

* add new modify

* modify something

* modify something and add something

* modify something and add something

* modify something and add something

* modify something and add something

* modify something and add something

* modify test case

* modify test case

* modify test case

* modify test case

* modify test case
2022-11-04 19:36:37 +08:00

53 lines
1.4 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from PIL import Image
from torchvision.datasets import CIFAR10, CIFAR100
from easycv.datasets.registry import DATASOURCES
@DATASOURCES.register_module
class ClsSourceCifar10(object):
CLASSES = [
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
'horse', 'ship', 'truck'
]
def __init__(self, root, split, download=True):
assert split in ['train', 'test']
self.cifar = CIFAR10(
root=root, train=(split == 'train'), download=download)
self.labels = self.cifar.targets
def __len__(self):
return len(self.cifar)
def __getitem__(self, idx):
img = Image.fromarray(self.cifar.data[idx])
label = self.labels[idx] # img: HWC, RGB
result_dict = {'img': img, 'gt_labels': label}
return result_dict
@DATASOURCES.register_module
class ClsSourceCifar100(object):
CLASSES = None
def __init__(self, root, split, download=True):
assert split in ['train', 'test']
self.cifar = CIFAR100(
root=root, train=(split == 'train'), download=download)
self.labels = self.cifar.targets
def __len__(self):
return len(self.cifar)
def __getitem__(self, idx):
img = Image.fromarray(self.cifar.data[idx])
label = self.labels[idx] # img: HWC, RGB
result_dict = {'img': img, 'gt_labels': label}
return result_dict