Textsnake tests (#51)

* add textsnake unit tests
pull/2/head
Theo Chan 2021-04-06 05:20:43 -05:00 committed by GitHub
parent b64f3906c0
commit cbb4ec349b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 173 additions and 2 deletions

View File

@ -135,3 +135,78 @@ def test_dbnet_generate_targets():
assert 'gt_shrink_mask' in results['mask_fields']
assert 'gt_thr' in results['mask_fields']
assert 'gt_thr_mask' in results['mask_fields']
@mock.patch('%s.cf_bundle.show_feature' % __name__)
def test_gen_textsnake_targets(mock_show_feature):
target_generator = textdet_targets.TextSnakeTargets()
assert np.allclose(target_generator.orientation_thr, 2.0)
assert np.allclose(target_generator.resample_step, 4.0)
assert np.allclose(target_generator.center_region_shrink_ratio, 0.3)
# test find_head_tail
polygon = np.array([[1.0, 1.0], [5.0, 1.0], [5.0, 3.0], [1.0, 3.0]])
head_inds, tail_inds = target_generator.find_head_tail(polygon, 2.0)
assert np.allclose(head_inds, [3, 0])
assert np.allclose(tail_inds, [1, 2])
# test generate_text_region_mask
img_size = (3, 10)
text_polys = [[np.array([0, 0, 1, 0, 1, 1, 0, 1])],
[np.array([2, 0, 3, 0, 3, 1, 2, 1])]]
output = target_generator.generate_text_region_mask(img_size, text_polys)
target = np.array([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
assert np.allclose(output, target)
# test generate_center_region_mask
target_generator.center_region_shrink_ratio = 1.0
(center_region_mask, radius_map, sin_map,
cos_map) = target_generator.generate_center_mask_attrib_maps(
img_size, text_polys)
target = np.array([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
assert np.allclose(center_region_mask, target)
assert np.allclose(sin_map, np.zeros(img_size))
assert np.allclose(cos_map, target)
# test generate_effective_mask
polys_ignore = text_polys
output = target_generator.generate_effective_mask(img_size, polys_ignore)
target = np.array([[0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
assert np.allclose(output, target)
# test generate_targets
results = {}
results['img'] = np.zeros((3, 10, 3), np.uint8)
results['gt_masks'] = PolygonMasks(text_polys, 3, 10)
results['gt_masks_ignore'] = PolygonMasks([], 3, 10)
results['img_shape'] = (3, 10, 3)
results['mask_fields'] = []
output = target_generator(results)
assert len(output['gt_text_mask']) == 1
assert len(output['gt_center_region_mask']) == 1
assert len(output['gt_mask']) == 1
assert len(output['gt_radius_map']) == 1
assert len(output['gt_sin_map']) == 1
assert len(output['gt_cos_map']) == 1
bundle = cf_bundle.CustomFormatBundle(
keys=[
'gt_text_mask', 'gt_center_region_mask', 'gt_mask',
'gt_radius_map', 'gt_sin_map', 'gt_cos_map'
],
visualize=dict(flag=True, boundary_key='gt_text_mask'))
bundle(output)
assert 'gt_text_mask' in output.keys()
assert 'gt_center_region_mask' in output.keys()
assert 'gt_mask' in output.keys()
assert 'gt_radius_map' in output.keys()
assert 'gt_sin_map' in output.keys()
assert 'gt_cos_map' in output.keys()
mock_show_feature.assert_called_once()

View File

@ -81,7 +81,12 @@ def _demo_mm_inputs(num_kernels=0, input_shape=(1, 3, 300, 300),
'gt_masks': gt_masks,
'gt_kernels': gt_kernels,
'gt_mask': gt_effective_mask,
'gt_thr_mask': gt_effective_mask
'gt_thr_mask': gt_effective_mask,
'gt_text_mask': gt_effective_mask,
'gt_center_region_mask': gt_effective_mask,
'gt_radius_map': gt_kernels,
'gt_sin_map': gt_kernels,
'gt_cos_map': gt_kernels,
}
return mm_inputs
@ -309,3 +314,55 @@ def test_dbnet(cfg_file):
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
img = np.random.rand(5, 5)
detector.show_result(img, results)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.parametrize(
'cfg_file', ['textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py'])
def test_textsnake(cfg_file):
model = _get_detector_cfg(cfg_file)
model['pretrained'] = None
model['backbone']['norm_cfg']['type'] = 'BN'
from mmocr.models import build_detector
detector = build_detector(model)
detector = detector.cuda()
input_shape = (1, 3, 64, 64)
num_kernels = 1
mm_inputs = _demo_mm_inputs(num_kernels, input_shape)
imgs = mm_inputs.pop('imgs')
imgs = imgs.cuda()
img_metas = mm_inputs.pop('img_metas')
gt_text_mask = mm_inputs.pop('gt_text_mask')
gt_center_region_mask = mm_inputs.pop('gt_center_region_mask')
gt_mask = mm_inputs.pop('gt_mask')
gt_radius_map = mm_inputs.pop('gt_radius_map')
gt_sin_map = mm_inputs.pop('gt_sin_map')
gt_cos_map = mm_inputs.pop('gt_cos_map')
# Test forward train
losses = detector.forward(
imgs,
img_metas,
gt_text_mask=gt_text_mask,
gt_center_region_mask=gt_center_region_mask,
gt_mask=gt_mask,
gt_radius_map=gt_radius_map,
gt_sin_map=gt_sin_map,
gt_cos_map=gt_cos_map)
assert isinstance(losses, dict)
# Test forward test
with torch.no_grad():
img_list = [g[None, :] for g in imgs]
batch_results = []
for one_img, one_meta in zip(img_list, img_metas):
result = detector.forward([one_img], [[one_meta]],
return_loss=False)
batch_results.append(result)
# Test show result
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
img = np.random.rand(5, 5)
detector.show_result(img, results)

View File

@ -19,3 +19,15 @@ def test_panloss():
assert len(results) == 1
assert torch.sum(torch.abs(results[0].float() -
torch.Tensor(target))).item() == 0
def test_textsnakeloss():
textsnakeloss = losses.TextSnakeLoss()
# test balanced_bce_loss
pred = torch.tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=torch.float)
target = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.long)
mask = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.long)
bce_loss = textsnakeloss.balanced_bce_loss(pred, target, mask).item()
assert np.allclose(bce_loss, 0)

View File

@ -1,6 +1,7 @@
import pytest
import torch
from mmocr.models.textdet.necks import FPNC
from mmocr.models.textdet.necks import FPN_UNET, FPNC
def test_fpnc():
@ -21,3 +22,29 @@ def test_fpnc():
inputs.append(torch.rand(1, in_channels[i], size[i], size[i]))
outputs = fpnc.forward(inputs)
assert list(outputs.size()) == [1, 256, 112, 112]
def test_fpn_unet_neck():
s = 64
feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
in_channels = [8, 16, 32, 64]
out_channels = 4
# len(in_channcels) is not equal to 4
with pytest.raises(AssertionError):
FPN_UNET(in_channels + [128], out_channels)
# `out_channels` is not int type
with pytest.raises(AssertionError):
FPN_UNET(in_channels, [2, 4])
feats = [
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels))
]
fpn_unet_neck = FPN_UNET(in_channels, out_channels)
fpn_unet_neck.init_weights()
out_neck = fpn_unet_neck(feats)
assert out_neck.shape == torch.Size([1, out_channels, s * 4, s * 4])