mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Fix] fix a bug of DBNet when text repr type is poly (#421)
* fix dbnet bug when text repr type is poly * add db_decode unit test
This commit is contained in:
parent
80a0536c7c
commit
9b5b25ef71
@ -223,7 +223,14 @@ def db_decode(preds,
|
||||
if len(poly) == 0 or isinstance(poly[0], list):
|
||||
continue
|
||||
poly = poly.reshape(-1, 2)
|
||||
poly = points2boundary(poly, text_repr_type, score, min_text_width)
|
||||
|
||||
if text_repr_type == 'quad':
|
||||
poly = points2boundary(poly, text_repr_type, score, min_text_width)
|
||||
elif text_repr_type == 'poly':
|
||||
poly = poly.flatten().tolist() + [score]
|
||||
else:
|
||||
raise ValueError(f'Invalid text repr type {text_repr_type}')
|
||||
|
||||
if poly is not None:
|
||||
boundaries.append(poly)
|
||||
return boundaries
|
||||
|
@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textdet.postprocess.wrapper import (comps2boundaries,
|
||||
@ -93,3 +94,22 @@ def test_textsnake_decode():
|
||||
|
||||
results = textsnake_decode(torch.squeeze(maps))
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_db_decode():
|
||||
pred = torch.zeros((1, 8, 8))
|
||||
pred[0, 2:7, 2:7] = 0.8
|
||||
expect_result_quad = [[
|
||||
1.0, 8.0, 1.0, 1.0, 8.0, 1.0, 8.0, 8.0, 0.800000011920929
|
||||
]]
|
||||
expect_result_poly = [[
|
||||
8, 2, 8, 6, 6, 8, 2, 8, 1, 6, 1, 2, 2, 1, 6, 1, 0.800000011920929
|
||||
]]
|
||||
with pytest.raises(ValueError):
|
||||
db_decode(preds=pred, text_repr_type='dummpy')
|
||||
result_quad = db_decode(
|
||||
preds=pred, text_repr_type='quad', min_text_width=1)
|
||||
result_poly = db_decode(
|
||||
preds=pred, text_repr_type='poly', min_text_width=1)
|
||||
assert result_quad == expect_result_quad
|
||||
assert result_poly == expect_result_poly
|
||||
|
Loading…
x
Reference in New Issue
Block a user