gulou 36a3c45efa
add more data source for auto download (#229)
* add caltech, flower, mnist data source

* add det lvis data source

* add pose crowdPose data source

* add pose of OC Human data source

* add pose of mpii data source

* add Seg of voc data source

* add Seg of coco data source

* add Det of wider person datasource

* add Det of african wildlife datasource

* add Det of fruit datasource

* add Det of pet datasource

* add Det of artaxor and tiny person datasource

* add Det of wider face datasource

* add Det of crowd human datasource

* add Det of object365 datasource

* add Seg of coco stuff 10k and 164k datasource

Co-authored-by: Cathy0908 <30484308+Cathy0908@users.noreply.github.com>
2022-12-02 10:57:23 +08:00

51 lines
1.6 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from PIL import Image
from torchvision.datasets import MNIST, FashionMNIST
from easycv.datasets.registry import DATASOURCES
@DATASOURCES.register_module
class ClsSourceMnist(object):
def __init__(self, root, split, download=True):
assert split in ['train', 'test']
self.mnist = MNIST(
root=root, train=(split == 'train'), download=download)
self.labels = self.mnist.targets
# data label_classes
self.CLASSES = self.mnist.classes
def __len__(self):
return len(self.mnist)
def __getitem__(self, idx):
# img: HWC, RGB
img = Image.fromarray(self.mnist.data[idx].numpy()).convert('RGB')
label = self.labels[idx].item()
result_dict = {'img': img, 'gt_labels': label}
return result_dict
@DATASOURCES.register_module
class ClsSourceFashionMnist(object):
def __init__(self, root, split, download=True):
assert split in ['train', 'test']
self.fashion_mnist = FashionMNIST(
root=root, train=(split == 'train'), download=download)
self.labels = self.fashion_mnist.targets
# data label_classes
self.CLASSES = self.fashion_mnist.classes
def __len__(self):
return len(self.fashion_mnist)
def __getitem__(self, idx):
# img: HWC, RGB
img = Image.fromarray(
self.fashion_mnist.data[idx].numpy()).convert('RGB')
label = self.labels[idx].item()
result_dict = {'img': img, 'gt_labels': label}
return result_dict