EasyCV/tests/test_datasets/ocr/test_ocr_rec_dataset.py
Cathy0908 785d8d97db
rename test dir name, fix import datasets errors (#310)
* rename test dir name, fix import datasets errors
2023-04-12 15:35:27 +08:00

67 lines
2.1 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
import torch
from tests.ut_config import SMALL_OCR_REC_DATA
from easycv.datasets.builder import build_dataset
class OCRRecsDatasetTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def _get_dataset(self):
data_root = SMALL_OCR_REC_DATA
data_train_list = os.path.join(data_root, 'label.txt')
pipeline = [
dict(type='RecConAug', prob=0.5, image_shape=(48, 320, 3)),
dict(type='RecAug'),
dict(
type='MultiLabelEncode',
max_text_length=25,
use_space_char=True,
character_dict_path=
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/ocr/dict/ppocr_keys_v1.txt',
),
dict(type='RecResizeImg', image_shape=(3, 48, 320)),
dict(type='MMToTensor'),
dict(
type='Collect',
keys=[
'img', 'label_ctc', 'label_sar', 'length', 'valid_ratio'
],
meta_keys=['img_path'])
]
data = dict(
train=dict(
type='OCRRecDataset',
data_source=dict(
type='OCRRecSource',
label_file=data_train_list,
data_dir=SMALL_OCR_REC_DATA + '/img',
ext_data_num=0,
test_mode=True,
),
pipeline=pipeline))
dataset = build_dataset(data['train'])
return dataset
def test_default(self):
dataset = self._get_dataset()
for _, batch in enumerate(dataset):
img, label_ctc, label_sar = batch['img'], batch[
'label_ctc'], batch['label_sar']
self.assertEqual(img.shape, torch.Size([3, 48, 320]))
self.assertEqual(label_ctc.shape, (25, ))
self.assertEqual(label_sar.shape, (25, ))
break
if __name__ == '__main__':
unittest.main()