mirror of https://github.com/open-mmlab/mmocr.git
parent
b64f3906c0
commit
cbb4ec349b
|
@ -135,3 +135,78 @@ def test_dbnet_generate_targets():
|
|||
assert 'gt_shrink_mask' in results['mask_fields']
|
||||
assert 'gt_thr' in results['mask_fields']
|
||||
assert 'gt_thr_mask' in results['mask_fields']
|
||||
|
||||
|
||||
@mock.patch('%s.cf_bundle.show_feature' % __name__)
|
||||
def test_gen_textsnake_targets(mock_show_feature):
|
||||
|
||||
target_generator = textdet_targets.TextSnakeTargets()
|
||||
assert np.allclose(target_generator.orientation_thr, 2.0)
|
||||
assert np.allclose(target_generator.resample_step, 4.0)
|
||||
assert np.allclose(target_generator.center_region_shrink_ratio, 0.3)
|
||||
|
||||
# test find_head_tail
|
||||
polygon = np.array([[1.0, 1.0], [5.0, 1.0], [5.0, 3.0], [1.0, 3.0]])
|
||||
head_inds, tail_inds = target_generator.find_head_tail(polygon, 2.0)
|
||||
assert np.allclose(head_inds, [3, 0])
|
||||
assert np.allclose(tail_inds, [1, 2])
|
||||
|
||||
# test generate_text_region_mask
|
||||
img_size = (3, 10)
|
||||
text_polys = [[np.array([0, 0, 1, 0, 1, 1, 0, 1])],
|
||||
[np.array([2, 0, 3, 0, 3, 1, 2, 1])]]
|
||||
output = target_generator.generate_text_region_mask(img_size, text_polys)
|
||||
target = np.array([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
|
||||
assert np.allclose(output, target)
|
||||
|
||||
# test generate_center_region_mask
|
||||
target_generator.center_region_shrink_ratio = 1.0
|
||||
(center_region_mask, radius_map, sin_map,
|
||||
cos_map) = target_generator.generate_center_mask_attrib_maps(
|
||||
img_size, text_polys)
|
||||
target = np.array([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
|
||||
assert np.allclose(center_region_mask, target)
|
||||
assert np.allclose(sin_map, np.zeros(img_size))
|
||||
assert np.allclose(cos_map, target)
|
||||
|
||||
# test generate_effective_mask
|
||||
polys_ignore = text_polys
|
||||
output = target_generator.generate_effective_mask(img_size, polys_ignore)
|
||||
target = np.array([[0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
|
||||
assert np.allclose(output, target)
|
||||
|
||||
# test generate_targets
|
||||
results = {}
|
||||
results['img'] = np.zeros((3, 10, 3), np.uint8)
|
||||
results['gt_masks'] = PolygonMasks(text_polys, 3, 10)
|
||||
results['gt_masks_ignore'] = PolygonMasks([], 3, 10)
|
||||
results['img_shape'] = (3, 10, 3)
|
||||
results['mask_fields'] = []
|
||||
output = target_generator(results)
|
||||
assert len(output['gt_text_mask']) == 1
|
||||
assert len(output['gt_center_region_mask']) == 1
|
||||
assert len(output['gt_mask']) == 1
|
||||
assert len(output['gt_radius_map']) == 1
|
||||
assert len(output['gt_sin_map']) == 1
|
||||
assert len(output['gt_cos_map']) == 1
|
||||
|
||||
bundle = cf_bundle.CustomFormatBundle(
|
||||
keys=[
|
||||
'gt_text_mask', 'gt_center_region_mask', 'gt_mask',
|
||||
'gt_radius_map', 'gt_sin_map', 'gt_cos_map'
|
||||
],
|
||||
visualize=dict(flag=True, boundary_key='gt_text_mask'))
|
||||
bundle(output)
|
||||
assert 'gt_text_mask' in output.keys()
|
||||
assert 'gt_center_region_mask' in output.keys()
|
||||
assert 'gt_mask' in output.keys()
|
||||
assert 'gt_radius_map' in output.keys()
|
||||
assert 'gt_sin_map' in output.keys()
|
||||
assert 'gt_cos_map' in output.keys()
|
||||
mock_show_feature.assert_called_once()
|
||||
|
|
|
@ -81,7 +81,12 @@ def _demo_mm_inputs(num_kernels=0, input_shape=(1, 3, 300, 300),
|
|||
'gt_masks': gt_masks,
|
||||
'gt_kernels': gt_kernels,
|
||||
'gt_mask': gt_effective_mask,
|
||||
'gt_thr_mask': gt_effective_mask
|
||||
'gt_thr_mask': gt_effective_mask,
|
||||
'gt_text_mask': gt_effective_mask,
|
||||
'gt_center_region_mask': gt_effective_mask,
|
||||
'gt_radius_map': gt_kernels,
|
||||
'gt_sin_map': gt_kernels,
|
||||
'gt_cos_map': gt_kernels,
|
||||
}
|
||||
return mm_inputs
|
||||
|
||||
|
@ -309,3 +314,55 @@ def test_dbnet(cfg_file):
|
|||
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
|
||||
img = np.random.rand(5, 5)
|
||||
detector.show_result(img, results)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||
@pytest.mark.parametrize(
|
||||
'cfg_file', ['textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py'])
|
||||
def test_textsnake(cfg_file):
|
||||
model = _get_detector_cfg(cfg_file)
|
||||
model['pretrained'] = None
|
||||
model['backbone']['norm_cfg']['type'] = 'BN'
|
||||
|
||||
from mmocr.models import build_detector
|
||||
detector = build_detector(model)
|
||||
detector = detector.cuda()
|
||||
input_shape = (1, 3, 64, 64)
|
||||
num_kernels = 1
|
||||
mm_inputs = _demo_mm_inputs(num_kernels, input_shape)
|
||||
|
||||
imgs = mm_inputs.pop('imgs')
|
||||
imgs = imgs.cuda()
|
||||
img_metas = mm_inputs.pop('img_metas')
|
||||
gt_text_mask = mm_inputs.pop('gt_text_mask')
|
||||
gt_center_region_mask = mm_inputs.pop('gt_center_region_mask')
|
||||
gt_mask = mm_inputs.pop('gt_mask')
|
||||
gt_radius_map = mm_inputs.pop('gt_radius_map')
|
||||
gt_sin_map = mm_inputs.pop('gt_sin_map')
|
||||
gt_cos_map = mm_inputs.pop('gt_cos_map')
|
||||
|
||||
# Test forward train
|
||||
losses = detector.forward(
|
||||
imgs,
|
||||
img_metas,
|
||||
gt_text_mask=gt_text_mask,
|
||||
gt_center_region_mask=gt_center_region_mask,
|
||||
gt_mask=gt_mask,
|
||||
gt_radius_map=gt_radius_map,
|
||||
gt_sin_map=gt_sin_map,
|
||||
gt_cos_map=gt_cos_map)
|
||||
assert isinstance(losses, dict)
|
||||
|
||||
# Test forward test
|
||||
with torch.no_grad():
|
||||
img_list = [g[None, :] for g in imgs]
|
||||
batch_results = []
|
||||
for one_img, one_meta in zip(img_list, img_metas):
|
||||
result = detector.forward([one_img], [[one_meta]],
|
||||
return_loss=False)
|
||||
batch_results.append(result)
|
||||
|
||||
# Test show result
|
||||
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
|
||||
img = np.random.rand(5, 5)
|
||||
detector.show_result(img, results)
|
||||
|
|
|
@ -19,3 +19,15 @@ def test_panloss():
|
|||
assert len(results) == 1
|
||||
assert torch.sum(torch.abs(results[0].float() -
|
||||
torch.Tensor(target))).item() == 0
|
||||
|
||||
|
||||
def test_textsnakeloss():
|
||||
textsnakeloss = losses.TextSnakeLoss()
|
||||
|
||||
# test balanced_bce_loss
|
||||
pred = torch.tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=torch.float)
|
||||
target = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.long)
|
||||
mask = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.long)
|
||||
bce_loss = textsnakeloss.balanced_bce_loss(pred, target, mask).item()
|
||||
|
||||
assert np.allclose(bce_loss, 0)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textdet.necks import FPNC
|
||||
from mmocr.models.textdet.necks import FPN_UNET, FPNC
|
||||
|
||||
|
||||
def test_fpnc():
|
||||
|
@ -21,3 +22,29 @@ def test_fpnc():
|
|||
inputs.append(torch.rand(1, in_channels[i], size[i], size[i]))
|
||||
outputs = fpnc.forward(inputs)
|
||||
assert list(outputs.size()) == [1, 256, 112, 112]
|
||||
|
||||
|
||||
def test_fpn_unet_neck():
|
||||
s = 64
|
||||
feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
|
||||
in_channels = [8, 16, 32, 64]
|
||||
out_channels = 4
|
||||
|
||||
# len(in_channcels) is not equal to 4
|
||||
with pytest.raises(AssertionError):
|
||||
FPN_UNET(in_channels + [128], out_channels)
|
||||
|
||||
# `out_channels` is not int type
|
||||
with pytest.raises(AssertionError):
|
||||
FPN_UNET(in_channels, [2, 4])
|
||||
|
||||
feats = [
|
||||
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
|
||||
for i in range(len(in_channels))
|
||||
]
|
||||
|
||||
fpn_unet_neck = FPN_UNET(in_channels, out_channels)
|
||||
fpn_unet_neck.init_weights()
|
||||
|
||||
out_neck = fpn_unet_neck(feats)
|
||||
assert out_neck.shape == torch.Size([1, out_channels, s * 4, s * 4])
|
||||
|
|
Loading…
Reference in New Issue