EasyCV/tests/datasets/classification/test_cls_raw_dataset.py

69 lines
2.2 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
import torch
from tests.ut_config import IMG_NORM_CFG, SMALL_IMAGENET_RAW_LOCAL
from easycv.datasets.builder import build_dataset
class ClsDatasetTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def _get_dataset(self):
data_root = SMALL_IMAGENET_RAW_LOCAL
data_train_list = os.path.join(data_root, 'meta/train_labeled_200.txt')
pipeline = [
dict(type='RandomResizedCrop', size=224),
dict(type='RandomHorizontalFlip'),
dict(type='ToTensor'),
dict(type='Normalize', **IMG_NORM_CFG),
dict(type='Collect', keys=['img', 'gt_labels'])
]
data = dict(
train=dict(
type='ClsDataset',
data_source=dict(
list_file=data_train_list,
root=os.path.join(data_root, 'train'),
type='ClsSourceImageList'),
pipeline=pipeline))
dataset = build_dataset(data['train'])
return dataset
def test_default(self):
dataset = self._get_dataset()
for _, batch in enumerate(dataset):
img, target = batch['img'], batch['gt_labels']
self.assertEqual(img.shape, torch.Size([3, 224, 224]))
self.assertIn(target, list(range(1000)))
break
def test_visualize(self):
# TODO: add img_metas for classification
return
dataset = self._get_dataset()
count = 5
classes = []
img_metas = []
for i, data in enumerate(dataset):
classes.append(data['gt_labels'])
img_metas.append(data['img_metas'].data)
if i > count:
break
results = {'class': classes, 'img_metas': img_metas}
output = dataset.visualize(results, vis_num=2)
self.assertEqual(len(output['images']), 2)
self.assertEqual(len(output['img_metas']), 2)
self.assertEqual(len(output['images'][0].shape), 3)
self.assertIn('filename', output['img_metas'][0])
if __name__ == '__main__':
unittest.main()