mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Fix] fix a bug of DBNet when text repr type is poly (#421)
* fix dbnet bug when text repr type is poly * add db_decode unit test
This commit is contained in:
parent
80a0536c7c
commit
9b5b25ef71
@ -223,7 +223,14 @@ def db_decode(preds,
|
|||||||
if len(poly) == 0 or isinstance(poly[0], list):
|
if len(poly) == 0 or isinstance(poly[0], list):
|
||||||
continue
|
continue
|
||||||
poly = poly.reshape(-1, 2)
|
poly = poly.reshape(-1, 2)
|
||||||
poly = points2boundary(poly, text_repr_type, score, min_text_width)
|
|
||||||
|
if text_repr_type == 'quad':
|
||||||
|
poly = points2boundary(poly, text_repr_type, score, min_text_width)
|
||||||
|
elif text_repr_type == 'poly':
|
||||||
|
poly = poly.flatten().tolist() + [score]
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Invalid text repr type {text_repr_type}')
|
||||||
|
|
||||||
if poly is not None:
|
if poly is not None:
|
||||||
boundaries.append(poly)
|
boundaries.append(poly)
|
||||||
return boundaries
|
return boundaries
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmocr.models.textdet.postprocess.wrapper import (comps2boundaries,
|
from mmocr.models.textdet.postprocess.wrapper import (comps2boundaries,
|
||||||
@ -93,3 +94,22 @@ def test_textsnake_decode():
|
|||||||
|
|
||||||
results = textsnake_decode(torch.squeeze(maps))
|
results = textsnake_decode(torch.squeeze(maps))
|
||||||
assert len(results) == 0
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_db_decode():
|
||||||
|
pred = torch.zeros((1, 8, 8))
|
||||||
|
pred[0, 2:7, 2:7] = 0.8
|
||||||
|
expect_result_quad = [[
|
||||||
|
1.0, 8.0, 1.0, 1.0, 8.0, 1.0, 8.0, 8.0, 0.800000011920929
|
||||||
|
]]
|
||||||
|
expect_result_poly = [[
|
||||||
|
8, 2, 8, 6, 6, 8, 2, 8, 1, 6, 1, 2, 2, 1, 6, 1, 0.800000011920929
|
||||||
|
]]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
db_decode(preds=pred, text_repr_type='dummpy')
|
||||||
|
result_quad = db_decode(
|
||||||
|
preds=pred, text_repr_type='quad', min_text_width=1)
|
||||||
|
result_poly = db_decode(
|
||||||
|
preds=pred, text_repr_type='poly', min_text_width=1)
|
||||||
|
assert result_quad == expect_result_quad
|
||||||
|
assert result_poly == expect_result_poly
|
||||||
|
Loading…
x
Reference in New Issue
Block a user