Update outputs of DBHead and split_results in BaseTextDetPostprocessor

pull/1178/head
gaotongxiao 2022-07-13 12:12:40 +00:00
parent 68b0aaa2e9
commit bf517b63e8
8 changed files with 81 additions and 108 deletions

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmcv.runner import Sequential
from torch import Tensor
from mmocr.data import TextDetDataSample
from mmocr.models.textdet.heads import BaseTextDetHead
@ -54,8 +55,8 @@ class DBHead(BaseTextDetHead):
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid())
self.threshold = self._init_thr(in_channels)
def _diff_binarize(self, prob_map: torch.Tensor, thr_map: torch.Tensor,
k: int) -> torch.Tensor:
def _diff_binarize(self, prob_map: Tensor, thr_map: Tensor,
k: int) -> Tensor:
"""Differential binarization.
Args:
@ -64,30 +65,29 @@ class DBHead(BaseTextDetHead):
k (int): Amplification factor.
Returns:
torch.Tensor: Binary map.
Tensor: Binary map.
"""
return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map)))
def forward(self,
img: torch.Tensor,
data_samples: Optional[List[TextDetDataSample]] = None
) -> Dict:
def forward(
self,
img: Tensor,
data_samples: Optional[List[TextDetDataSample]] = None
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Args:
img (torch.Tensor): Shape :math:`(N, C, H, W)`.
img (Tensor): Shape :math:`(N, C, H, W)`.
data_samples (list[TextDetDataSample], optional): A list of data
samples. Defaults to None.
Returns:
dict: A dict with keys of ``prob_map``, ``thr_map`` and
``binary_map``, each of shape :math:`(N, 4H, 4W)`.
tuple(Tensor, Tensor, Tensor): A tuple of ``prob_map``, ``thr_map``
and ``binary_map``, each of shape :math:`(N, 4H, 4W)`.
"""
prob_map = self.binarize(img).squeeze(1)
thr_map = self.threshold(img).squeeze(1)
binary_map = self._diff_binarize(prob_map, thr_map, k=50).squeeze(1)
outputs = dict(
prob_map=prob_map, thr_map=thr_map, binary_map=binary_map)
return outputs
return (prob_map, thr_map, binary_map)
def _init_thr(self,
inner_channels: int,

View File

@ -7,7 +7,7 @@ import torch
from mmdet.core import multi_apply
from numpy.typing import ArrayLike
from shapely.geometry import Polygon
from torch import nn
from torch import Tensor, nn
from mmocr.data import TextDetDataSample
from mmocr.registry import MODELS
@ -57,23 +57,21 @@ class DBLoss(nn.Module, TextKernelMixin):
self.thr_max = thr_max
self.min_sidelength = min_sidelength
def forward(self, preds: Dict,
def forward(self, preds: Tuple[Tensor, Tensor, Tensor],
data_samples: Sequence[TextDetDataSample]) -> Dict:
"""Compute DBNet loss.
Args:
preds (dict): Raw predictions from model, containing ``prob_map``,
``thr_map`` and ``binary_map``. Each is a tensor of shape
:math:`(N, H, W)`.
preds (tuple(tensor)): Raw predictions from model, containing
``prob_map``, ``thr_map`` and ``binary_map``. Each is a tensor
of shape :math:`(N, H, W)`.
data_samples (list[TextDetDataSample]): The data samples.
Returns:
results(dict): The dict for dbnet losses with loss_prob, \
loss_db and loss_thr.
"""
prob_map = preds['prob_map']
thr_map = preds['thr_map']
binary_map = preds['binary_map']
prob_map, thr_map, binary_map = preds
gt_shrinks, gt_shrink_masks, gt_thrs, gt_thr_masks = self.get_targets(
data_samples)
gt_shrinks = gt_shrinks.to(prob_map.device)

View File

@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
from typing import Dict, List, Optional, Sequence, Tuple
from typing import Dict, List, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
from torch import Tensor
from mmocr.data import TextDetDataSample
from mmocr.utils import boundary_iou, is_type_list, rescale_polygons
from mmocr.utils import boundary_iou, rescale_polygons
class BaseTextDetPostProcessor:
@ -121,43 +123,39 @@ class BaseTextDetPostProcessor:
"""
raise NotImplementedError
def split_results(self,
pred_results: Dict,
fields: Optional[Sequence[str]] = None,
keep_unsplit_fields: bool = False) -> List[Dict]:
"""Split batched elements in pred_results along the first dimension
into ``batch_num`` sub-elements and regather them into a list of dicts.
def split_results(
self, pred_results: Union[Tensor, List[Tensor]]
) -> Union[List[Tensor], List[List[Tensor]]]:
"""Split batched tensor(s) along the first dimension pack split tensors
into a list.
Args:
pred_results (dict): Raw result dictionary from detection head.
Each item usually has the shape of (N, ...)
fields (list[str], optional): Fields to split. If not specified,
all fields in ``pred_results`` will be split.
keep_unsplit_fields (bool): Whether to keep unsplit fields in
result dicts. If True, the fields not specified in ``fields``
will be copied to each result dict. Defaults to False.
pred_results (tensor or list[tensor]): Raw result tensor(s) from
detection head. Each tensor usually has the shape of (N, ...)
Returns:
list[dict]: N dicts whose keys remains the same as that of
pred_results.
list[tensor] or list[list[tensor]]: N tensors if ``pred_results``
is a tensor, or a list of N lists of tensors if
``pred_results`` is a list of tensors.
"""
assert isinstance(pred_results, dict) and len(pred_results) > 0
assert fields is None or is_type_list(fields, str)
assert isinstance(keep_unsplit_fields, bool)
assert isinstance(pred_results, Tensor) or mmcv.is_seq_of(
pred_results, Tensor)
if fields is None:
fields = list(pred_results.keys())
batch_num = len(pred_results[fields[0]])
results = [{} for _ in range(batch_num)]
for field in fields:
for i in range(batch_num):
results[i][field] = pred_results[field][i]
if keep_unsplit_fields:
for k, v in pred_results.items():
if k in fields:
continue
for i in range(batch_num):
results[i][k] = v
if mmcv.is_seq_of(pred_results, Tensor):
for i in range(1, len(pred_results)):
assert pred_results[0].shape[0] == pred_results[i].shape[0], \
'The first dimension of all tensors should be the same'
batch_num = len(pred_results) if isinstance(pred_results, Tensor) else\
len(pred_results[0])
results = []
for i in range(batch_num):
if isinstance(pred_results, Tensor):
results.append(pred_results[i])
else:
results.append([])
for tensor in pred_results:
results[i].append(tensor[i])
return results
def poly_nms(self, polygons: List[np.ndarray], scores: List[float],

View File

@ -1,11 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
from typing import Sequence, Tuple
import cv2
import numpy as np
import torch
from mmengine import InstanceData
from shapely.geometry import Polygon
from torch import Tensor
from mmocr.data import TextDetDataSample
from mmocr.registry import MODELS
@ -54,14 +55,14 @@ class DBPostprocessor(BaseTextDetPostProcessor):
self.unclip_ratio = unclip_ratio
self.max_candidates = max_candidates
def get_text_instances(self, pred_results: dict,
def get_text_instances(self, pred_results: Tuple[Tensor, Tensor, Tensor],
data_sample: TextDetDataSample
) -> TextDetDataSample:
"""Get text instance predictions of one image.
Args:
pred_result (dict): Prediction results of an image containing the
``prob_map``, which is a tensor of shape :math:`(N, H, W)`.
pred_result (tuple(Tensor)): A tuple of 3 tensors where the first
tensor is ``prob_map`` of shape :math:`(N, H, W)`.
data_sample (TextDetDataSample): Datasample of an image.
Returns:
@ -75,7 +76,7 @@ class DBPostprocessor(BaseTextDetPostProcessor):
data_sample.pred_instances.polygons = []
data_sample.pred_instances.scores = []
prob_map = pred_results['prob_map']
prob_map = pred_results[0]
text_mask = prob_map > self.mask_thr
score_map = prob_map.data.cpu().numpy().astype(np.float32)

View File

@ -19,9 +19,6 @@ class TestDBHead(TestCase):
db_head = DBHead(in_channels=10)
data = torch.randn((2, 10, 40, 50))
results = db_head(data, None)
self.assertIn('prob_map', results)
self.assertIn('thr_map', results)
self.assertIn('binary_map', results)
self.assertEqual(results['prob_map'].shape, (2, 160, 200))
self.assertEqual(results['thr_map'].shape, (2, 160, 200))
self.assertEqual(results['binary_map'].shape, (2, 160, 200))
self.assertEqual(results[0].shape, (2, 160, 200))
self.assertEqual(results[1].shape, (2, 160, 200))
self.assertEqual(results[2].shape, (2, 160, 200))

View File

@ -26,10 +26,8 @@ class TestDBLoss(TestCase):
ignored=torch.BoolTensor([False, False, True])))
]
pred_size = (1, 40, 40)
self.preds = dict(
prob_map=torch.rand(pred_size),
thr_map=torch.rand(pred_size),
binary_map=torch.rand(pred_size))
self.preds = (torch.rand(pred_size), torch.rand(pred_size),
torch.rand(pred_size))
def test_is_poly_invalid(self):
# area < 1

View File

@ -3,6 +3,7 @@ import unittest
from unittest import mock
import numpy as np
import torch
from mmengine import InstanceData
from mmocr.data import TextDetDataSample
@ -29,9 +30,7 @@ class TestBaseTextDetPostProcessor(unittest.TestCase):
mock_get_text_instances.side_effect = mock_func
pred_results = {
'prob_map': np.array([[0.1, 0.2], [0.3, 0.4]]),
}
pred_results = torch.Tensor([[0.1, 0.2], [0.3, 0.4]])
data_samples = [
TextDetDataSample(
metainfo=dict(scale_factor=(0.5, 1)),
@ -79,49 +78,30 @@ class TestBaseTextDetPostProcessor(unittest.TestCase):
def test_split_results(self):
# some shorthands
lt = torch.LongTensor
ft = torch.FloatTensor
base_postprocessor = BaseTextDetPostProcessor()
# test invalid arguments
with self.assertRaises(AssertionError):
base_postprocessor.split_results(None)
results = [lt([0, 1, 5]), ft([0.2, 0.3])]
with self.assertRaises(AssertionError):
base_postprocessor.split_results({'test': [0, 1]}, 'fields')
with self.assertRaises(AssertionError):
base_postprocessor.split_results({'test': [0, 1]},
keep_unsplit_fields='true')
base_postprocessor.split_results(results)
# test split_results
results = {
'test': [0, 1],
'test2': np.array([2, 3], dtype=int),
'meta': 'str'
}
split_results = base_postprocessor.split_results(results, ['test'])
self.assertEqual(split_results, [{'test': 0}, {'test': 1}])
results = [lt([0, 1, 5]), ft([0.2, 0.3, 0.6])]
split_results = base_postprocessor.split_results(results)
self.assertEqual(split_results,
[[lt([0]), ft([0.2])], [lt([1]), ft([0.3])],
[lt([5]), ft([0.6])]])
split_results = base_postprocessor.split_results(
results, ['test', 'test2'])
self.assertEqual(split_results, [{
'test': 0,
'test2': 2
}, {
'test': 1,
'test2': 3
}])
split_results = base_postprocessor.split_results(
results, ['test', 'test2'], keep_unsplit_fields=True)
self.assertEqual(split_results, [{
'test': 0,
'test2': 2,
'meta': 'str'
}, {
'test': 1,
'test2': 3,
'meta': 'str'
}])
results = lt([0, 1, 5])
split_results = base_postprocessor.split_results(results)
self.assertEqual(split_results, [lt([0]), lt([1]), lt([5])])
def test_poly_nms(self):
base_postprocessor = BaseTextDetPostProcessor(text_repr_type='poly')

View File

@ -23,7 +23,7 @@ class TestDBPostProcessor(unittest.TestCase):
def test_get_text_instances(self, text_repr_type):
postprocessor = DBPostprocessor(text_repr_type=text_repr_type)
pred_result = dict(prob_map=torch.rand(4, 5))
pred_result = (torch.rand(4, 5), torch.rand(4, 5), torch.rand(4, 5))
data_sample = TextDetDataSample(
metainfo=dict(scale_factor=(0.5, 1)),
gt_instances=InstanceData(polygons=[
@ -38,7 +38,8 @@ class TestDBPostProcessor(unittest.TestCase):
postprocessor = DBPostprocessor(
min_text_score=1, text_repr_type=text_repr_type)
pred_result = dict(prob_map=torch.rand(4, 5) * 0.8)
pred_result = (torch.rand(4, 5) * 0.8, torch.rand(4, 5) * 0.8,
torch.rand(4, 5) * 0.8)
results = postprocessor.get_text_instances(pred_result, data_sample)
self.assertEqual(results.pred_instances.polygons, [])
self.assertTrue(