EasyCV/tests/test_datasets/ocr/test_ocr_cls_dataset.py
Cathy0908 785d8d97db
rename test dir name, fix import datasets errors (#310)
* rename test dir name, fix import datasets errors
2023-04-12 15:35:27 +08:00

51 lines
1.5 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
import torch
from tests.ut_config import SMALL_OCR_CLS_DATA
from easycv.datasets.builder import build_dataset
class OCRClsDatasetTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def _get_dataset(self):
data_root = SMALL_OCR_CLS_DATA
data_train_list = os.path.join(data_root, 'label.txt')
pipeline = [
dict(type='RecAug', use_tia=False),
dict(type='ClsResizeImg', img_shape=(3, 48, 192)),
dict(type='MMToTensor'),
dict(
type='Collect', keys=['img', 'label'], meta_keys=['img_path'])
]
data = dict(
train=dict(
type='OCRClsDataset',
data_source=dict(
type='OCRClsSource',
label_file=data_train_list,
data_dir=SMALL_OCR_CLS_DATA + '/img',
label_list=['0', '180'],
),
pipeline=pipeline))
dataset = build_dataset(data['train'])
return dataset
def test_default(self):
dataset = self._get_dataset()
for _, batch in enumerate(dataset):
img, target = batch['img'], batch['label']
self.assertEqual(img.shape, torch.Size([3, 48, 192]))
self.assertIn(target, list(range(2)))
break
if __name__ == '__main__':
unittest.main()