diff --git a/mmocr/models/textdet/postprocess/wrapper.py b/mmocr/models/textdet/postprocess/wrapper.py index 5a2d1519..ff4c7ae8 100644 --- a/mmocr/models/textdet/postprocess/wrapper.py +++ b/mmocr/models/textdet/postprocess/wrapper.py @@ -223,7 +223,14 @@ def db_decode(preds, if len(poly) == 0 or isinstance(poly[0], list): continue poly = poly.reshape(-1, 2) - poly = points2boundary(poly, text_repr_type, score, min_text_width) + + if text_repr_type == 'quad': + poly = points2boundary(poly, text_repr_type, score, min_text_width) + elif text_repr_type == 'poly': + poly = poly.flatten().tolist() + [score] + else: + raise ValueError(f'Invalid text repr type {text_repr_type}') + if poly is not None: boundaries.append(poly) return boundaries diff --git a/tests/test_utils/test_wrapper.py b/tests/test_utils/test_wrapper.py index 2cd75c2d..d92ec175 100644 --- a/tests/test_utils/test_wrapper.py +++ b/tests/test_utils/test_wrapper.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import torch from mmocr.models.textdet.postprocess.wrapper import (comps2boundaries, @@ -93,3 +94,22 @@ def test_textsnake_decode(): results = textsnake_decode(torch.squeeze(maps)) assert len(results) == 0 + + +def test_db_decode(): + pred = torch.zeros((1, 8, 8)) + pred[0, 2:7, 2:7] = 0.8 + expect_result_quad = [[ + 1.0, 8.0, 1.0, 1.0, 8.0, 1.0, 8.0, 8.0, 0.800000011920929 + ]] + expect_result_poly = [[ + 8, 2, 8, 6, 6, 8, 2, 8, 1, 6, 1, 2, 2, 1, 6, 1, 0.800000011920929 + ]] + with pytest.raises(ValueError): + db_decode(preds=pred, text_repr_type='dummpy') + result_quad = db_decode( + preds=pred, text_repr_type='quad', min_text_width=1) + result_poly = db_decode( + preds=pred, text_repr_type='poly', min_text_width=1) + assert result_quad == expect_result_quad + assert result_poly == expect_result_poly