mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
140 lines
5.5 KiB
Python
140 lines
5.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
from unittest import TestCase
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from mmocr.datasets.pipelines import PackTextDetInputs, PackTextRecogInputs
|
|
|
|
|
|
class TestPackTextDetInputs(TestCase):
|
|
|
|
def test_packdetinput(self):
|
|
datainfo = dict(
|
|
img=np.random.random((10, 10)),
|
|
img_shape=(10, 10),
|
|
ori_shape=(10, 10),
|
|
pad_shape=(10, 10),
|
|
scale_factor=(1, 1),
|
|
img_path='tmp/tmp.jpg',
|
|
flip=True,
|
|
flip_direction='left',
|
|
gt_bboxes=np.array([[0, 0, 10, 10], [5, 5, 15, 15]],
|
|
dtype=np.float32),
|
|
gt_bboxes_labels=np.array([0, 0], np.int64),
|
|
gt_polygons=[
|
|
np.array([0, 0, 0, 10, 10, 10, 10, 0], dtype=np.float32),
|
|
np.array([5, 5, 5, 15, 15, 15, 15, 5], dtype=np.float32)
|
|
],
|
|
gt_texts=['mmocr', 'mmocr_ignore'],
|
|
gt_ignored=np.bool_([False, True]))
|
|
with self.assertRaises(KeyError):
|
|
transform = PackTextDetInputs(meta_keys=('tmp', ))
|
|
transform(copy.deepcopy(datainfo))
|
|
transform = PackTextDetInputs()
|
|
results = transform(copy.deepcopy(datainfo))
|
|
self.assertIn('inputs', results)
|
|
self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10))
|
|
self.assertIn('data_sample', results)
|
|
|
|
data_sample = results['data_sample']
|
|
self.assertIn('bboxes', data_sample.gt_instances)
|
|
self.assertIsInstance(data_sample.gt_instances.bboxes, torch.Tensor)
|
|
self.assertEqual(data_sample.gt_instances.bboxes.dtype, torch.float32)
|
|
self.assertIsInstance(data_sample.gt_instances.polygons[0], np.ndarray)
|
|
self.assertEqual(data_sample.gt_instances.polygons[0].dtype,
|
|
np.float32)
|
|
self.assertEqual(data_sample.gt_instances.ignored.dtype, torch.bool)
|
|
self.assertEqual(data_sample.gt_instances.labels.dtype, torch.int64)
|
|
self.assertIsInstance(data_sample.gt_instances.texts, list)
|
|
|
|
self.assertIn('img_path', data_sample)
|
|
self.assertIn('flip', data_sample)
|
|
|
|
transform = PackTextDetInputs(meta_keys=('img_path', ))
|
|
results = transform(copy.deepcopy(datainfo))
|
|
self.assertIn('inputs', results)
|
|
self.assertIn('data_sample', results)
|
|
|
|
data_sample = results['data_sample']
|
|
self.assertIn('bboxes', data_sample.gt_instances)
|
|
self.assertIn('img_path', data_sample)
|
|
self.assertNotIn('flip', data_sample)
|
|
|
|
datainfo.pop('gt_texts')
|
|
transform = PackTextDetInputs()
|
|
results = transform(copy.deepcopy(datainfo))
|
|
data_sample = results['data_sample']
|
|
self.assertNotIn('texts', data_sample.gt_instances)
|
|
|
|
datainfo = dict(img_shape=(10, 10))
|
|
transform = PackTextDetInputs(meta_keys=('img_shape', ))
|
|
results = transform(copy.deepcopy(datainfo))
|
|
self.assertNotIn('inputs', results)
|
|
data_sample = results['data_sample']
|
|
self.assertNotIn('texts', data_sample.gt_instances)
|
|
|
|
def test_repr(self):
|
|
transform = PackTextDetInputs()
|
|
self.assertEqual(
|
|
repr(transform),
|
|
("PackTextDetInputs(meta_keys=('img_path', 'ori_shape', "
|
|
"'img_shape', 'scale_factor', 'flip', 'flip_direction'))"))
|
|
|
|
|
|
class TestPackTextRecogInputs(TestCase):
|
|
|
|
def test_packrecogtinput(self):
|
|
datainfo = dict(
|
|
img=np.random.random((10, 10)),
|
|
img_shape=(10, 10),
|
|
ori_shape=(10, 10),
|
|
pad_shape=(10, 10),
|
|
scale_factor=(1, 1),
|
|
img_path='tmp/tmp.jpg',
|
|
flip=True,
|
|
flip_direction='left',
|
|
gt_bboxes=np.array([[0, 0, 10, 10]]),
|
|
gt_labels=np.array([0]),
|
|
gt_polygons=[[0, 0, 0, 10, 10, 10, 10, 0]],
|
|
gt_texts=['mmocr'],
|
|
)
|
|
with self.assertRaises(KeyError):
|
|
transform = PackTextRecogInputs(meta_keys=('tmp', ))
|
|
transform(copy.deepcopy(datainfo))
|
|
transform = PackTextRecogInputs()
|
|
results = transform(copy.deepcopy(datainfo))
|
|
self.assertIn('inputs', results)
|
|
self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10))
|
|
self.assertIn('data_sample', results)
|
|
data_sample = results['data_sample']
|
|
self.assertEqual(data_sample.gt_text.item, 'mmocr')
|
|
self.assertIn('img_path', data_sample)
|
|
self.assertIn('valid_ratio', data_sample)
|
|
self.assertIn('pad_shape', data_sample)
|
|
|
|
transform = PackTextRecogInputs(meta_keys=('img_path', ))
|
|
results = transform(copy.deepcopy(datainfo))
|
|
self.assertIn('inputs', results)
|
|
self.assertIn('data_sample', results)
|
|
data_sample = results['data_sample']
|
|
self.assertEqual(data_sample.gt_text.item, 'mmocr')
|
|
self.assertIn('img_path', data_sample)
|
|
self.assertNotIn('valid_ratio', data_sample)
|
|
self.assertNotIn('pad_shape', data_sample)
|
|
|
|
datainfo = dict(img_shape=(10, 10))
|
|
transform = PackTextRecogInputs(meta_keys=('img_shape', ))
|
|
results = transform(copy.deepcopy(datainfo))
|
|
self.assertNotIn('inputs', results)
|
|
data_sample = results['data_sample']
|
|
self.assertNotIn('item', data_sample.gt_text)
|
|
|
|
def test_repr(self):
|
|
transform = PackTextRecogInputs()
|
|
self.assertEqual(
|
|
repr(transform),
|
|
("PackTextRecogInputs(meta_keys=('img_path', 'ori_shape', "
|
|
"'img_shape', 'pad_shape', 'valid_ratio'))"))
|