EasyCV/tests/datasets/shared/test_raw.py

72 lines
2.3 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
import torch
from tests.ut_config import IMG_NORM_CFG, SMALL_IMAGENET_RAW_LOCAL
from easycv.datasets.builder import build_dataset
class RawDatasetTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_raw_dataset_with_label(self):
data_train_list = os.path.join(SMALL_IMAGENET_RAW_LOCAL,
'meta/train_labeled_200.txt')
data_train_root = os.path.join(SMALL_IMAGENET_RAW_LOCAL, 'train')
train_data = dict(
type='RawDataset',
data_source=dict(
type='ClsSourceImageList',
list_file=data_train_list,
root=data_train_root),
pipeline=[
dict(type='Resize', size=256),
dict(type='CenterCrop', size=224),
dict(type='ToTensor'),
dict(type='Normalize', **IMG_NORM_CFG),
dict(type='Collect', keys=['img', 'gt_labels'])
])
dataset = build_dataset(train_data)
for _, batch in enumerate(dataset):
self.assertEqual(batch['img'].shape, torch.Size([3, 224, 224]))
self.assertIn(batch['gt_labels'], list(range(1000)))
break
self.assertEqual(len(dataset), 200)
def test_raw_dataset(self):
data_train_list = os.path.join(SMALL_IMAGENET_RAW_LOCAL,
'meta/train_200.txt')
data_train_root = SMALL_IMAGENET_RAW_LOCAL
train_data = dict(
type='RawDataset',
data_source=dict(
type='SSLSourceImageList',
list_file=data_train_list,
root=data_train_root),
pipeline=[
dict(type='Resize', size=256),
dict(type='Resize', size=(224, 224)),
dict(type='ToTensor'),
dict(type='Normalize', **IMG_NORM_CFG),
dict(type='Collect', keys=['img'])
])
dataset = build_dataset(train_data)
for _, batch in enumerate(dataset):
self.assertEqual(batch['img'].shape, torch.Size([3, 224, 224]))
break
self.assertEqual(len(dataset), 200)
if __name__ == '__main__':
unittest.main()