[Fix] fix a bug of DBNet when text repr type is poly (#421)

* fix dbnet bug when text repr type is poly

* add db_decode unit test
This commit is contained in:
liukuikun 2021-08-13 21:39:53 +08:00 committed by GitHub
parent 80a0536c7c
commit 9b5b25ef71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 1 deletions

View File

@ -223,7 +223,14 @@ def db_decode(preds,
if len(poly) == 0 or isinstance(poly[0], list):
continue
poly = poly.reshape(-1, 2)
poly = points2boundary(poly, text_repr_type, score, min_text_width)
if text_repr_type == 'quad':
poly = points2boundary(poly, text_repr_type, score, min_text_width)
elif text_repr_type == 'poly':
poly = poly.flatten().tolist() + [score]
else:
raise ValueError(f'Invalid text repr type {text_repr_type}')
if poly is not None:
boundaries.append(poly)
return boundaries

View File

@ -1,4 +1,5 @@
import numpy as np
import pytest
import torch
from mmocr.models.textdet.postprocess.wrapper import (comps2boundaries,
@ -93,3 +94,22 @@ def test_textsnake_decode():
results = textsnake_decode(torch.squeeze(maps))
assert len(results) == 0
def test_db_decode():
pred = torch.zeros((1, 8, 8))
pred[0, 2:7, 2:7] = 0.8
expect_result_quad = [[
1.0, 8.0, 1.0, 1.0, 8.0, 1.0, 8.0, 8.0, 0.800000011920929
]]
expect_result_poly = [[
8, 2, 8, 6, 6, 8, 2, 8, 1, 6, 1, 2, 2, 1, 6, 1, 0.800000011920929
]]
with pytest.raises(ValueError):
db_decode(preds=pred, text_repr_type='dummpy')
result_quad = db_decode(
preds=pred, text_repr_type='quad', min_text_width=1)
result_poly = db_decode(
preds=pred, text_repr_type='poly', min_text_width=1)
assert result_quad == expect_result_quad
assert result_poly == expect_result_poly