mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
* add caltech, flower, mnist data source * add det lvis data source * add pose crowdPose data source * add pose of OC Human data source * add pose of mpii data source * add Seg of voc data source * add Seg of coco data source * add Det of wider person datasource * add Det of african wildlife datasource * add Det of fruit datasource * add Det of pet datasource * add Det of artaxor and tiny person datasource * add Det of wider face datasource * add Det of crowd human datasource * add Det of object365 datasource * add Seg of coco stuff 10k and 164k datasource Co-authored-by: Cathy0908 <30484308+Cathy0908@users.noreply.github.com>
51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
from PIL import Image
|
|
from torchvision.datasets import MNIST, FashionMNIST
|
|
|
|
from easycv.datasets.registry import DATASOURCES
|
|
|
|
|
|
@DATASOURCES.register_module
|
|
class ClsSourceMnist(object):
|
|
|
|
def __init__(self, root, split, download=True):
|
|
assert split in ['train', 'test']
|
|
self.mnist = MNIST(
|
|
root=root, train=(split == 'train'), download=download)
|
|
self.labels = self.mnist.targets
|
|
# data label_classes
|
|
self.CLASSES = self.mnist.classes
|
|
|
|
def __len__(self):
|
|
return len(self.mnist)
|
|
|
|
def __getitem__(self, idx):
|
|
# img: HWC, RGB
|
|
img = Image.fromarray(self.mnist.data[idx].numpy()).convert('RGB')
|
|
label = self.labels[idx].item()
|
|
result_dict = {'img': img, 'gt_labels': label}
|
|
return result_dict
|
|
|
|
|
|
@DATASOURCES.register_module
|
|
class ClsSourceFashionMnist(object):
|
|
|
|
def __init__(self, root, split, download=True):
|
|
assert split in ['train', 'test']
|
|
self.fashion_mnist = FashionMNIST(
|
|
root=root, train=(split == 'train'), download=download)
|
|
self.labels = self.fashion_mnist.targets
|
|
# data label_classes
|
|
self.CLASSES = self.fashion_mnist.classes
|
|
|
|
def __len__(self):
|
|
return len(self.fashion_mnist)
|
|
|
|
def __getitem__(self, idx):
|
|
# img: HWC, RGB
|
|
img = Image.fromarray(
|
|
self.fashion_mnist.data[idx].numpy()).convert('RGB')
|
|
label = self.labels[idx].item()
|
|
result_dict = {'img': img, 'gt_labels': label}
|
|
return result_dict
|