diff --git a/mmocr/models/textdet/heads/db_head.py b/mmocr/models/textdet/heads/db_head.py index 12588607..ef4f3ccd 100644 --- a/mmocr/models/textdet/heads/db_head.py +++ b/mmocr/models/textdet/heads/db_head.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from mmcv.runner import Sequential +from torch import Tensor from mmocr.data import TextDetDataSample from mmocr.models.textdet.heads import BaseTextDetHead @@ -54,8 +55,8 @@ class DBHead(BaseTextDetHead): nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid()) self.threshold = self._init_thr(in_channels) - def _diff_binarize(self, prob_map: torch.Tensor, thr_map: torch.Tensor, - k: int) -> torch.Tensor: + def _diff_binarize(self, prob_map: Tensor, thr_map: Tensor, + k: int) -> Tensor: """Differential binarization. Args: @@ -64,30 +65,29 @@ class DBHead(BaseTextDetHead): k (int): Amplification factor. Returns: - torch.Tensor: Binary map. + Tensor: Binary map. """ return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map))) - def forward(self, - img: torch.Tensor, - data_samples: Optional[List[TextDetDataSample]] = None - ) -> Dict: + def forward( + self, + img: Tensor, + data_samples: Optional[List[TextDetDataSample]] = None + ) -> Tuple[Tensor, Tensor, Tensor]: """ Args: - img (torch.Tensor): Shape :math:`(N, C, H, W)`. + img (Tensor): Shape :math:`(N, C, H, W)`. data_samples (list[TextDetDataSample], optional): A list of data samples. Defaults to None. Returns: - dict: A dict with keys of ``prob_map``, ``thr_map`` and - ``binary_map``, each of shape :math:`(N, 4H, 4W)`. + tuple(Tensor, Tensor, Tensor): A tuple of ``prob_map``, ``thr_map`` + and ``binary_map``, each of shape :math:`(N, 4H, 4W)`. """ prob_map = self.binarize(img).squeeze(1) thr_map = self.threshold(img).squeeze(1) binary_map = self._diff_binarize(prob_map, thr_map, k=50).squeeze(1) - outputs = dict( - prob_map=prob_map, thr_map=thr_map, binary_map=binary_map) - return outputs + return (prob_map, thr_map, binary_map) def _init_thr(self, inner_channels: int, diff --git a/mmocr/models/textdet/losses/db_loss.py b/mmocr/models/textdet/losses/db_loss.py index df0dec23..88f753f3 100644 --- a/mmocr/models/textdet/losses/db_loss.py +++ b/mmocr/models/textdet/losses/db_loss.py @@ -7,7 +7,7 @@ import torch from mmdet.core import multi_apply from numpy.typing import ArrayLike from shapely.geometry import Polygon -from torch import nn +from torch import Tensor, nn from mmocr.data import TextDetDataSample from mmocr.registry import MODELS @@ -57,23 +57,21 @@ class DBLoss(nn.Module, TextKernelMixin): self.thr_max = thr_max self.min_sidelength = min_sidelength - def forward(self, preds: Dict, + def forward(self, preds: Tuple[Tensor, Tensor, Tensor], data_samples: Sequence[TextDetDataSample]) -> Dict: """Compute DBNet loss. Args: - preds (dict): Raw predictions from model, containing ``prob_map``, - ``thr_map`` and ``binary_map``. Each is a tensor of shape - :math:`(N, H, W)`. + preds (tuple(tensor)): Raw predictions from model, containing + ``prob_map``, ``thr_map`` and ``binary_map``. Each is a tensor + of shape :math:`(N, H, W)`. data_samples (list[TextDetDataSample]): The data samples. Returns: results(dict): The dict for dbnet losses with loss_prob, \ loss_db and loss_thr. """ - prob_map = preds['prob_map'] - thr_map = preds['thr_map'] - binary_map = preds['binary_map'] + prob_map, thr_map, binary_map = preds gt_shrinks, gt_shrink_masks, gt_thrs, gt_thr_masks = self.get_targets( data_samples) gt_shrinks = gt_shrinks.to(prob_map.device) diff --git a/mmocr/models/textdet/postprocessors/base_postprocessor.py b/mmocr/models/textdet/postprocessors/base_postprocessor.py index d86feaba..1e96eea2 100644 --- a/mmocr/models/textdet/postprocessors/base_postprocessor.py +++ b/mmocr/models/textdet/postprocessors/base_postprocessor.py @@ -1,11 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. from functools import partial -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence, Tuple, Union +import mmcv import numpy as np +from torch import Tensor from mmocr.data import TextDetDataSample -from mmocr.utils import boundary_iou, is_type_list, rescale_polygons +from mmocr.utils import boundary_iou, rescale_polygons class BaseTextDetPostProcessor: @@ -121,43 +123,39 @@ class BaseTextDetPostProcessor: """ raise NotImplementedError - def split_results(self, - pred_results: Dict, - fields: Optional[Sequence[str]] = None, - keep_unsplit_fields: bool = False) -> List[Dict]: - """Split batched elements in pred_results along the first dimension - into ``batch_num`` sub-elements and regather them into a list of dicts. + def split_results( + self, pred_results: Union[Tensor, List[Tensor]] + ) -> Union[List[Tensor], List[List[Tensor]]]: + """Split batched tensor(s) along the first dimension pack split tensors + into a list. Args: - pred_results (dict): Raw result dictionary from detection head. - Each item usually has the shape of (N, ...) - fields (list[str], optional): Fields to split. If not specified, - all fields in ``pred_results`` will be split. - keep_unsplit_fields (bool): Whether to keep unsplit fields in - result dicts. If True, the fields not specified in ``fields`` - will be copied to each result dict. Defaults to False. + pred_results (tensor or list[tensor]): Raw result tensor(s) from + detection head. Each tensor usually has the shape of (N, ...) Returns: - list[dict]: N dicts whose keys remains the same as that of - pred_results. + list[tensor] or list[list[tensor]]: N tensors if ``pred_results`` + is a tensor, or a list of N lists of tensors if + ``pred_results`` is a list of tensors. """ - assert isinstance(pred_results, dict) and len(pred_results) > 0 - assert fields is None or is_type_list(fields, str) - assert isinstance(keep_unsplit_fields, bool) + assert isinstance(pred_results, Tensor) or mmcv.is_seq_of( + pred_results, Tensor) - if fields is None: - fields = list(pred_results.keys()) - batch_num = len(pred_results[fields[0]]) - results = [{} for _ in range(batch_num)] - for field in fields: - for i in range(batch_num): - results[i][field] = pred_results[field][i] - if keep_unsplit_fields: - for k, v in pred_results.items(): - if k in fields: - continue - for i in range(batch_num): - results[i][k] = v + if mmcv.is_seq_of(pred_results, Tensor): + for i in range(1, len(pred_results)): + assert pred_results[0].shape[0] == pred_results[i].shape[0], \ + 'The first dimension of all tensors should be the same' + + batch_num = len(pred_results) if isinstance(pred_results, Tensor) else\ + len(pred_results[0]) + results = [] + for i in range(batch_num): + if isinstance(pred_results, Tensor): + results.append(pred_results[i]) + else: + results.append([]) + for tensor in pred_results: + results[i].append(tensor[i]) return results def poly_nms(self, polygons: List[np.ndarray], scores: List[float], diff --git a/mmocr/models/textdet/postprocessors/db_postprocessor.py b/mmocr/models/textdet/postprocessors/db_postprocessor.py index 148ff556..bff40437 100644 --- a/mmocr/models/textdet/postprocessors/db_postprocessor.py +++ b/mmocr/models/textdet/postprocessors/db_postprocessor.py @@ -1,11 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Sequence +from typing import Sequence, Tuple import cv2 import numpy as np import torch from mmengine import InstanceData from shapely.geometry import Polygon +from torch import Tensor from mmocr.data import TextDetDataSample from mmocr.registry import MODELS @@ -54,14 +55,14 @@ class DBPostprocessor(BaseTextDetPostProcessor): self.unclip_ratio = unclip_ratio self.max_candidates = max_candidates - def get_text_instances(self, pred_results: dict, + def get_text_instances(self, pred_results: Tuple[Tensor, Tensor, Tensor], data_sample: TextDetDataSample ) -> TextDetDataSample: """Get text instance predictions of one image. Args: - pred_result (dict): Prediction results of an image containing the - ``prob_map``, which is a tensor of shape :math:`(N, H, W)`. + pred_result (tuple(Tensor)): A tuple of 3 tensors where the first + tensor is ``prob_map`` of shape :math:`(N, H, W)`. data_sample (TextDetDataSample): Datasample of an image. Returns: @@ -75,7 +76,7 @@ class DBPostprocessor(BaseTextDetPostProcessor): data_sample.pred_instances.polygons = [] data_sample.pred_instances.scores = [] - prob_map = pred_results['prob_map'] + prob_map = pred_results[0] text_mask = prob_map > self.mask_thr score_map = prob_map.data.cpu().numpy().astype(np.float32) diff --git a/tests/test_models/test_textdet/test_heads/test_db_head.py b/tests/test_models/test_textdet/test_heads/test_db_head.py index bfcb43ae..60d1dd2e 100644 --- a/tests/test_models/test_textdet/test_heads/test_db_head.py +++ b/tests/test_models/test_textdet/test_heads/test_db_head.py @@ -19,9 +19,6 @@ class TestDBHead(TestCase): db_head = DBHead(in_channels=10) data = torch.randn((2, 10, 40, 50)) results = db_head(data, None) - self.assertIn('prob_map', results) - self.assertIn('thr_map', results) - self.assertIn('binary_map', results) - self.assertEqual(results['prob_map'].shape, (2, 160, 200)) - self.assertEqual(results['thr_map'].shape, (2, 160, 200)) - self.assertEqual(results['binary_map'].shape, (2, 160, 200)) + self.assertEqual(results[0].shape, (2, 160, 200)) + self.assertEqual(results[1].shape, (2, 160, 200)) + self.assertEqual(results[2].shape, (2, 160, 200)) diff --git a/tests/test_models/test_textdet/test_losses/test_db_loss.py b/tests/test_models/test_textdet/test_losses/test_db_loss.py index 4e91bab8..4f255a18 100644 --- a/tests/test_models/test_textdet/test_losses/test_db_loss.py +++ b/tests/test_models/test_textdet/test_losses/test_db_loss.py @@ -26,10 +26,8 @@ class TestDBLoss(TestCase): ignored=torch.BoolTensor([False, False, True]))) ] pred_size = (1, 40, 40) - self.preds = dict( - prob_map=torch.rand(pred_size), - thr_map=torch.rand(pred_size), - binary_map=torch.rand(pred_size)) + self.preds = (torch.rand(pred_size), torch.rand(pred_size), + torch.rand(pred_size)) def test_is_poly_invalid(self): # area < 1 diff --git a/tests/test_models/test_textdet/test_postprocessors/test_base_postprocessor.py b/tests/test_models/test_textdet/test_postprocessors/test_base_postprocessor.py index 13f9ff9e..e5004f62 100644 --- a/tests/test_models/test_textdet/test_postprocessors/test_base_postprocessor.py +++ b/tests/test_models/test_textdet/test_postprocessors/test_base_postprocessor.py @@ -3,6 +3,7 @@ import unittest from unittest import mock import numpy as np +import torch from mmengine import InstanceData from mmocr.data import TextDetDataSample @@ -29,9 +30,7 @@ class TestBaseTextDetPostProcessor(unittest.TestCase): mock_get_text_instances.side_effect = mock_func - pred_results = { - 'prob_map': np.array([[0.1, 0.2], [0.3, 0.4]]), - } + pred_results = torch.Tensor([[0.1, 0.2], [0.3, 0.4]]) data_samples = [ TextDetDataSample( metainfo=dict(scale_factor=(0.5, 1)), @@ -79,49 +78,30 @@ class TestBaseTextDetPostProcessor(unittest.TestCase): def test_split_results(self): + # some shorthands + lt = torch.LongTensor + ft = torch.FloatTensor + base_postprocessor = BaseTextDetPostProcessor() # test invalid arguments with self.assertRaises(AssertionError): base_postprocessor.split_results(None) + results = [lt([0, 1, 5]), ft([0.2, 0.3])] with self.assertRaises(AssertionError): - base_postprocessor.split_results({'test': [0, 1]}, 'fields') - - with self.assertRaises(AssertionError): - base_postprocessor.split_results({'test': [0, 1]}, - keep_unsplit_fields='true') + base_postprocessor.split_results(results) # test split_results - results = { - 'test': [0, 1], - 'test2': np.array([2, 3], dtype=int), - 'meta': 'str' - } - split_results = base_postprocessor.split_results(results, ['test']) - self.assertEqual(split_results, [{'test': 0}, {'test': 1}]) + results = [lt([0, 1, 5]), ft([0.2, 0.3, 0.6])] + split_results = base_postprocessor.split_results(results) + self.assertEqual(split_results, + [[lt([0]), ft([0.2])], [lt([1]), ft([0.3])], + [lt([5]), ft([0.6])]]) - split_results = base_postprocessor.split_results( - results, ['test', 'test2']) - self.assertEqual(split_results, [{ - 'test': 0, - 'test2': 2 - }, { - 'test': 1, - 'test2': 3 - }]) - - split_results = base_postprocessor.split_results( - results, ['test', 'test2'], keep_unsplit_fields=True) - self.assertEqual(split_results, [{ - 'test': 0, - 'test2': 2, - 'meta': 'str' - }, { - 'test': 1, - 'test2': 3, - 'meta': 'str' - }]) + results = lt([0, 1, 5]) + split_results = base_postprocessor.split_results(results) + self.assertEqual(split_results, [lt([0]), lt([1]), lt([5])]) def test_poly_nms(self): base_postprocessor = BaseTextDetPostProcessor(text_repr_type='poly') diff --git a/tests/test_models/test_textdet/test_postprocessors/test_db_postprocessor.py b/tests/test_models/test_textdet/test_postprocessors/test_db_postprocessor.py index c6c1fe34..49a1e6b9 100644 --- a/tests/test_models/test_textdet/test_postprocessors/test_db_postprocessor.py +++ b/tests/test_models/test_textdet/test_postprocessors/test_db_postprocessor.py @@ -23,7 +23,7 @@ class TestDBPostProcessor(unittest.TestCase): def test_get_text_instances(self, text_repr_type): postprocessor = DBPostprocessor(text_repr_type=text_repr_type) - pred_result = dict(prob_map=torch.rand(4, 5)) + pred_result = (torch.rand(4, 5), torch.rand(4, 5), torch.rand(4, 5)) data_sample = TextDetDataSample( metainfo=dict(scale_factor=(0.5, 1)), gt_instances=InstanceData(polygons=[ @@ -38,7 +38,8 @@ class TestDBPostProcessor(unittest.TestCase): postprocessor = DBPostprocessor( min_text_score=1, text_repr_type=text_repr_type) - pred_result = dict(prob_map=torch.rand(4, 5) * 0.8) + pred_result = (torch.rand(4, 5) * 0.8, torch.rand(4, 5) * 0.8, + torch.rand(4, 5) * 0.8) results = postprocessor.get_text_instances(pred_result, data_sample) self.assertEqual(results.pred_instances.polygons, []) self.assertTrue(