mirror of https://github.com/open-mmlab/mmocr.git
76 lines
2.0 KiB
Python
76 lines
2.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
import tempfile
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from mmocr.datasets.base_dataset import BaseDataset
|
|
|
|
|
|
def _create_dummy_ann_file(ann_file):
|
|
ann_info1 = 'sample1.jpg hello'
|
|
ann_info2 = 'sample2.jpg world'
|
|
|
|
with open(ann_file, 'w') as fw:
|
|
for ann_info in [ann_info1, ann_info2]:
|
|
fw.write(ann_info + '\n')
|
|
|
|
|
|
def _create_dummy_loader():
|
|
loader = dict(
|
|
type='HardDiskLoader',
|
|
repeat=1,
|
|
parser=dict(type='LineStrParser', keys=['file_name', 'text']))
|
|
return loader
|
|
|
|
|
|
def test_custom_dataset():
|
|
tmp_dir = tempfile.TemporaryDirectory()
|
|
# create dummy data
|
|
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
|
|
_create_dummy_ann_file(ann_file)
|
|
loader = _create_dummy_loader()
|
|
|
|
for mode in [True, False]:
|
|
dataset = BaseDataset(ann_file, loader, pipeline=[], test_mode=mode)
|
|
|
|
# test len
|
|
assert len(dataset) == len(dataset.data_infos)
|
|
|
|
# test set group flag
|
|
assert np.allclose(dataset.flag, [0, 0])
|
|
|
|
# test prepare_train_img
|
|
expect_results = {
|
|
'img_info': {
|
|
'file_name': 'sample1.jpg',
|
|
'text': 'hello'
|
|
},
|
|
'img_prefix': ''
|
|
}
|
|
assert dataset.prepare_train_img(0) == expect_results
|
|
|
|
# test prepare_test_img
|
|
assert dataset.prepare_test_img(0) == expect_results
|
|
|
|
# test __getitem__
|
|
assert dataset[0] == expect_results
|
|
|
|
# test get_next_index
|
|
assert dataset._get_next_index(0) == 1
|
|
|
|
# test format_resuls
|
|
expect_results_copy = {
|
|
key: value
|
|
for key, value in expect_results.items()
|
|
}
|
|
dataset.format_results(expect_results)
|
|
assert expect_results_copy == expect_results
|
|
|
|
# test evaluate
|
|
with pytest.raises(NotImplementedError):
|
|
dataset.evaluate(expect_results)
|
|
|
|
tmp_dir.cleanup()
|