mmocr/tests/test_dataset/test_textdet_targets.py

138 lines
5.0 KiB
Python
Raw Normal View History

2021-04-03 01:03:52 +08:00
from unittest import mock
import numpy as np
import mmocr.datasets.pipelines.custom_format_bundle as cf_bundle
import mmocr.datasets.pipelines.textdet_targets as textdet_targets
from mmdet.core import PolygonMasks
@mock.patch('%s.cf_bundle.show_feature' % __name__)
def test_gen_pannet_targets(mock_show_feature):
target_generator = textdet_targets.PANetTargets()
assert target_generator.max_shrink == 20
# test generate_kernels
img_size = (3, 10)
text_polys = [[np.array([0, 0, 1, 0, 1, 1, 0, 1])],
[np.array([2, 0, 3, 0, 3, 1, 2, 1])]]
shrink_ratio = 1.0
kernel = np.array([[1, 1, 2, 2, 0, 0, 0, 0, 0, 0],
[1, 1, 2, 2, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
output, _ = target_generator.generate_kernels(img_size, text_polys,
shrink_ratio)
print(output)
assert np.allclose(output, kernel)
# test generate_effective_mask
polys_ignore = text_polys
output = target_generator.generate_effective_mask((3, 10), polys_ignore)
target = np.array([[0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
assert np.allclose(output, target)
# test generate_targets
results = {}
results['img'] = np.zeros((3, 10, 3), np.uint8)
results['gt_masks'] = PolygonMasks(text_polys, 3, 10)
results['gt_masks_ignore'] = PolygonMasks([], 3, 10)
results['img_shape'] = (3, 10, 3)
results['mask_fields'] = []
output = target_generator(results)
assert len(output['gt_kernels']) == 2
assert len(output['gt_mask']) == 1
bundle = cf_bundle.CustomFormatBundle(
keys=['gt_kernels', 'gt_mask'],
visualize=dict(flag=True, boundary_key='gt_kernels'))
bundle(output)
assert 'gt_kernels' in output.keys()
assert 'gt_mask' in output.keys()
mock_show_feature.assert_called_once()
def test_gen_psenet_targets():
target_generator = textdet_targets.PSENetTargets()
assert target_generator.max_shrink == 20
assert target_generator.shrink_ratio == (1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4)
# Test DBNetTargets
def test_dbnet_targets_find_invalid():
target_generator = textdet_targets.DBNetTargets()
assert target_generator.shrink_ratio == 0.4
assert target_generator.thr_min == 0.3
assert target_generator.thr_max == 0.7
results = {}
text_polys = [[np.array([0, 0, 10, 0, 10, 10, 0, 10])],
[np.array([20, 0, 30, 0, 30, 10, 20, 10])]]
results['gt_masks'] = PolygonMasks(text_polys, 40, 40)
ignore_tags = target_generator.find_invalid(results)
assert np.allclose(ignore_tags, [False, False])
def test_dbnet_targets():
target_generator = textdet_targets.DBNetTargets()
assert target_generator.shrink_ratio == 0.4
assert target_generator.thr_min == 0.3
assert target_generator.thr_max == 0.7
def test_dbnet_ignore_texts():
target_generator = textdet_targets.DBNetTargets()
ignore_tags = [True, False]
results = {}
text_polys = [[np.array([0, 0, 10, 0, 10, 10, 0, 10])],
[np.array([20, 0, 30, 0, 30, 10, 20, 10])]]
text_polys_ignore = [[np.array([0, 0, 15, 0, 15, 10, 0, 10])]]
results['gt_masks_ignore'] = PolygonMasks(text_polys_ignore, 40, 40)
results['gt_masks'] = PolygonMasks(text_polys, 40, 40)
results['gt_bboxes'] = np.array([[0, 0, 10, 10], [20, 0, 30, 10]])
results['gt_labels'] = np.array([0, 1])
target_generator.ignore_texts(results, ignore_tags)
assert np.allclose(results['gt_labels'], np.array([1]))
assert len(results['gt_masks_ignore'].masks) == 2
assert np.allclose(results['gt_masks_ignore'].masks[1][0],
text_polys[0][0])
assert len(results['gt_masks'].masks) == 1
def test_dbnet_generate_thr_map():
target_generator = textdet_targets.DBNetTargets()
text_polys = [[np.array([0, 0, 10, 0, 10, 10, 0, 10])],
[np.array([20, 0, 30, 0, 30, 10, 20, 10])]]
thr_map, thr_mask = target_generator.generate_thr_map((40, 40), text_polys)
assert np.all((thr_map >= 0.29) * (thr_map <= 0.71))
def test_dbnet_generate_targets():
target_generator = textdet_targets.DBNetTargets()
text_polys = [[np.array([0, 0, 10, 0, 10, 10, 0, 10])],
[np.array([20, 0, 30, 0, 30, 10, 20, 10])]]
text_polys_ignore = [[np.array([0, 0, 15, 0, 15, 10, 0, 10])]]
results = {}
results['mask_fields'] = []
results['img_shape'] = (40, 40, 3)
results['gt_masks_ignore'] = PolygonMasks(text_polys_ignore, 40, 40)
results['gt_masks'] = PolygonMasks(text_polys, 40, 40)
results['gt_bboxes'] = np.array([[0, 0, 10, 10], [20, 0, 30, 10]])
results['gt_labels'] = np.array([0, 1])
target_generator.generate_targets(results)
assert 'gt_shrink' in results['mask_fields']
assert 'gt_shrink_mask' in results['mask_fields']
assert 'gt_thr' in results['mask_fields']
assert 'gt_thr_mask' in results['mask_fields']