2022-07-21 10:50:58 +08:00

140 lines
5.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest import TestCase
import numpy as np
import torch
from mmocr.datasets.pipelines import PackTextDetInputs, PackTextRecogInputs
class TestPackTextDetInputs(TestCase):
def test_packdetinput(self):
datainfo = dict(
img=np.random.random((10, 10)),
img_shape=(10, 10),
ori_shape=(10, 10),
pad_shape=(10, 10),
scale_factor=(1, 1),
img_path='tmp/tmp.jpg',
flip=True,
flip_direction='left',
gt_bboxes=np.array([[0, 0, 10, 10], [5, 5, 15, 15]],
dtype=np.float32),
gt_bboxes_labels=np.array([0, 0], np.int64),
gt_polygons=[
np.array([0, 0, 0, 10, 10, 10, 10, 0], dtype=np.float32),
np.array([5, 5, 5, 15, 15, 15, 15, 5], dtype=np.float32)
],
gt_texts=['mmocr', 'mmocr_ignore'],
gt_ignored=np.bool_([False, True]))
with self.assertRaises(KeyError):
transform = PackTextDetInputs(meta_keys=('tmp', ))
transform(copy.deepcopy(datainfo))
transform = PackTextDetInputs()
results = transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)
self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10))
self.assertIn('data_sample', results)
data_sample = results['data_sample']
self.assertIn('bboxes', data_sample.gt_instances)
self.assertIsInstance(data_sample.gt_instances.bboxes, torch.Tensor)
self.assertEqual(data_sample.gt_instances.bboxes.dtype, torch.float32)
self.assertIsInstance(data_sample.gt_instances.polygons[0], np.ndarray)
self.assertEqual(data_sample.gt_instances.polygons[0].dtype,
np.float32)
self.assertEqual(data_sample.gt_instances.ignored.dtype, torch.bool)
self.assertEqual(data_sample.gt_instances.labels.dtype, torch.int64)
self.assertIsInstance(data_sample.gt_instances.texts, list)
self.assertIn('img_path', data_sample)
self.assertIn('flip', data_sample)
transform = PackTextDetInputs(meta_keys=('img_path', ))
results = transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)
self.assertIn('data_sample', results)
data_sample = results['data_sample']
self.assertIn('bboxes', data_sample.gt_instances)
self.assertIn('img_path', data_sample)
self.assertNotIn('flip', data_sample)
datainfo.pop('gt_texts')
transform = PackTextDetInputs()
results = transform(copy.deepcopy(datainfo))
data_sample = results['data_sample']
self.assertNotIn('texts', data_sample.gt_instances)
datainfo = dict(img_shape=(10, 10))
transform = PackTextDetInputs(meta_keys=('img_shape', ))
results = transform(copy.deepcopy(datainfo))
self.assertNotIn('inputs', results)
data_sample = results['data_sample']
self.assertNotIn('texts', data_sample.gt_instances)
def test_repr(self):
transform = PackTextDetInputs()
self.assertEqual(
repr(transform),
("PackTextDetInputs(meta_keys=('img_path', 'ori_shape', "
"'img_shape', 'scale_factor', 'flip', 'flip_direction'))"))
class TestPackTextRecogInputs(TestCase):
def test_packrecogtinput(self):
datainfo = dict(
img=np.random.random((10, 10)),
img_shape=(10, 10),
ori_shape=(10, 10),
pad_shape=(10, 10),
scale_factor=(1, 1),
img_path='tmp/tmp.jpg',
flip=True,
flip_direction='left',
gt_bboxes=np.array([[0, 0, 10, 10]]),
gt_labels=np.array([0]),
gt_polygons=[[0, 0, 0, 10, 10, 10, 10, 0]],
gt_texts=['mmocr'],
)
with self.assertRaises(KeyError):
transform = PackTextRecogInputs(meta_keys=('tmp', ))
transform(copy.deepcopy(datainfo))
transform = PackTextRecogInputs()
results = transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)
self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10))
self.assertIn('data_sample', results)
data_sample = results['data_sample']
self.assertEqual(data_sample.gt_text.item, 'mmocr')
self.assertIn('img_path', data_sample)
self.assertIn('valid_ratio', data_sample)
self.assertIn('pad_shape', data_sample)
transform = PackTextRecogInputs(meta_keys=('img_path', ))
results = transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)
self.assertIn('data_sample', results)
data_sample = results['data_sample']
self.assertEqual(data_sample.gt_text.item, 'mmocr')
self.assertIn('img_path', data_sample)
self.assertNotIn('valid_ratio', data_sample)
self.assertNotIn('pad_shape', data_sample)
datainfo = dict(img_shape=(10, 10))
transform = PackTextRecogInputs(meta_keys=('img_shape', ))
results = transform(copy.deepcopy(datainfo))
self.assertNotIn('inputs', results)
data_sample = results['data_sample']
self.assertNotIn('item', data_sample.gt_text)
def test_repr(self):
transform = PackTextRecogInputs()
self.assertEqual(
repr(transform),
("PackTextRecogInputs(meta_keys=('img_path', 'ori_shape', "
"'img_shape', 'pad_shape', 'valid_ratio'))"))