mirror of https://github.com/open-mmlab/mmocr.git
Update outputs of DBHead and split_results in BaseTextDetPostprocessor
parent
68b0aaa2e9
commit
bf517b63e8
|
@ -1,9 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import Sequential
|
||||
from torch import Tensor
|
||||
|
||||
from mmocr.data import TextDetDataSample
|
||||
from mmocr.models.textdet.heads import BaseTextDetHead
|
||||
|
@ -54,8 +55,8 @@ class DBHead(BaseTextDetHead):
|
|||
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid())
|
||||
self.threshold = self._init_thr(in_channels)
|
||||
|
||||
def _diff_binarize(self, prob_map: torch.Tensor, thr_map: torch.Tensor,
|
||||
k: int) -> torch.Tensor:
|
||||
def _diff_binarize(self, prob_map: Tensor, thr_map: Tensor,
|
||||
k: int) -> Tensor:
|
||||
"""Differential binarization.
|
||||
|
||||
Args:
|
||||
|
@ -64,30 +65,29 @@ class DBHead(BaseTextDetHead):
|
|||
k (int): Amplification factor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Binary map.
|
||||
Tensor: Binary map.
|
||||
"""
|
||||
return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map)))
|
||||
|
||||
def forward(self,
|
||||
img: torch.Tensor,
|
||||
data_samples: Optional[List[TextDetDataSample]] = None
|
||||
) -> Dict:
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
data_samples: Optional[List[TextDetDataSample]] = None
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
img (torch.Tensor): Shape :math:`(N, C, H, W)`.
|
||||
img (Tensor): Shape :math:`(N, C, H, W)`.
|
||||
data_samples (list[TextDetDataSample], optional): A list of data
|
||||
samples. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: A dict with keys of ``prob_map``, ``thr_map`` and
|
||||
``binary_map``, each of shape :math:`(N, 4H, 4W)`.
|
||||
tuple(Tensor, Tensor, Tensor): A tuple of ``prob_map``, ``thr_map``
|
||||
and ``binary_map``, each of shape :math:`(N, 4H, 4W)`.
|
||||
"""
|
||||
prob_map = self.binarize(img).squeeze(1)
|
||||
thr_map = self.threshold(img).squeeze(1)
|
||||
binary_map = self._diff_binarize(prob_map, thr_map, k=50).squeeze(1)
|
||||
outputs = dict(
|
||||
prob_map=prob_map, thr_map=thr_map, binary_map=binary_map)
|
||||
return outputs
|
||||
return (prob_map, thr_map, binary_map)
|
||||
|
||||
def _init_thr(self,
|
||||
inner_channels: int,
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
|||
from mmdet.core import multi_apply
|
||||
from numpy.typing import ArrayLike
|
||||
from shapely.geometry import Polygon
|
||||
from torch import nn
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mmocr.data import TextDetDataSample
|
||||
from mmocr.registry import MODELS
|
||||
|
@ -57,23 +57,21 @@ class DBLoss(nn.Module, TextKernelMixin):
|
|||
self.thr_max = thr_max
|
||||
self.min_sidelength = min_sidelength
|
||||
|
||||
def forward(self, preds: Dict,
|
||||
def forward(self, preds: Tuple[Tensor, Tensor, Tensor],
|
||||
data_samples: Sequence[TextDetDataSample]) -> Dict:
|
||||
"""Compute DBNet loss.
|
||||
|
||||
Args:
|
||||
preds (dict): Raw predictions from model, containing ``prob_map``,
|
||||
``thr_map`` and ``binary_map``. Each is a tensor of shape
|
||||
:math:`(N, H, W)`.
|
||||
preds (tuple(tensor)): Raw predictions from model, containing
|
||||
``prob_map``, ``thr_map`` and ``binary_map``. Each is a tensor
|
||||
of shape :math:`(N, H, W)`.
|
||||
data_samples (list[TextDetDataSample]): The data samples.
|
||||
|
||||
Returns:
|
||||
results(dict): The dict for dbnet losses with loss_prob, \
|
||||
loss_db and loss_thr.
|
||||
"""
|
||||
prob_map = preds['prob_map']
|
||||
thr_map = preds['thr_map']
|
||||
binary_map = preds['binary_map']
|
||||
prob_map, thr_map, binary_map = preds
|
||||
gt_shrinks, gt_shrink_masks, gt_thrs, gt_thr_masks = self.get_targets(
|
||||
data_samples)
|
||||
gt_shrinks = gt_shrinks.to(prob_map.device)
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
|
||||
from mmocr.data import TextDetDataSample
|
||||
from mmocr.utils import boundary_iou, is_type_list, rescale_polygons
|
||||
from mmocr.utils import boundary_iou, rescale_polygons
|
||||
|
||||
|
||||
class BaseTextDetPostProcessor:
|
||||
|
@ -121,43 +123,39 @@ class BaseTextDetPostProcessor:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def split_results(self,
|
||||
pred_results: Dict,
|
||||
fields: Optional[Sequence[str]] = None,
|
||||
keep_unsplit_fields: bool = False) -> List[Dict]:
|
||||
"""Split batched elements in pred_results along the first dimension
|
||||
into ``batch_num`` sub-elements and regather them into a list of dicts.
|
||||
def split_results(
|
||||
self, pred_results: Union[Tensor, List[Tensor]]
|
||||
) -> Union[List[Tensor], List[List[Tensor]]]:
|
||||
"""Split batched tensor(s) along the first dimension pack split tensors
|
||||
into a list.
|
||||
|
||||
Args:
|
||||
pred_results (dict): Raw result dictionary from detection head.
|
||||
Each item usually has the shape of (N, ...)
|
||||
fields (list[str], optional): Fields to split. If not specified,
|
||||
all fields in ``pred_results`` will be split.
|
||||
keep_unsplit_fields (bool): Whether to keep unsplit fields in
|
||||
result dicts. If True, the fields not specified in ``fields``
|
||||
will be copied to each result dict. Defaults to False.
|
||||
pred_results (tensor or list[tensor]): Raw result tensor(s) from
|
||||
detection head. Each tensor usually has the shape of (N, ...)
|
||||
|
||||
Returns:
|
||||
list[dict]: N dicts whose keys remains the same as that of
|
||||
pred_results.
|
||||
list[tensor] or list[list[tensor]]: N tensors if ``pred_results``
|
||||
is a tensor, or a list of N lists of tensors if
|
||||
``pred_results`` is a list of tensors.
|
||||
"""
|
||||
assert isinstance(pred_results, dict) and len(pred_results) > 0
|
||||
assert fields is None or is_type_list(fields, str)
|
||||
assert isinstance(keep_unsplit_fields, bool)
|
||||
assert isinstance(pred_results, Tensor) or mmcv.is_seq_of(
|
||||
pred_results, Tensor)
|
||||
|
||||
if fields is None:
|
||||
fields = list(pred_results.keys())
|
||||
batch_num = len(pred_results[fields[0]])
|
||||
results = [{} for _ in range(batch_num)]
|
||||
for field in fields:
|
||||
for i in range(batch_num):
|
||||
results[i][field] = pred_results[field][i]
|
||||
if keep_unsplit_fields:
|
||||
for k, v in pred_results.items():
|
||||
if k in fields:
|
||||
continue
|
||||
for i in range(batch_num):
|
||||
results[i][k] = v
|
||||
if mmcv.is_seq_of(pred_results, Tensor):
|
||||
for i in range(1, len(pred_results)):
|
||||
assert pred_results[0].shape[0] == pred_results[i].shape[0], \
|
||||
'The first dimension of all tensors should be the same'
|
||||
|
||||
batch_num = len(pred_results) if isinstance(pred_results, Tensor) else\
|
||||
len(pred_results[0])
|
||||
results = []
|
||||
for i in range(batch_num):
|
||||
if isinstance(pred_results, Tensor):
|
||||
results.append(pred_results[i])
|
||||
else:
|
||||
results.append([])
|
||||
for tensor in pred_results:
|
||||
results[i].append(tensor[i])
|
||||
return results
|
||||
|
||||
def poly_nms(self, polygons: List[np.ndarray], scores: List[float],
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import InstanceData
|
||||
from shapely.geometry import Polygon
|
||||
from torch import Tensor
|
||||
|
||||
from mmocr.data import TextDetDataSample
|
||||
from mmocr.registry import MODELS
|
||||
|
@ -54,14 +55,14 @@ class DBPostprocessor(BaseTextDetPostProcessor):
|
|||
self.unclip_ratio = unclip_ratio
|
||||
self.max_candidates = max_candidates
|
||||
|
||||
def get_text_instances(self, pred_results: dict,
|
||||
def get_text_instances(self, pred_results: Tuple[Tensor, Tensor, Tensor],
|
||||
data_sample: TextDetDataSample
|
||||
) -> TextDetDataSample:
|
||||
"""Get text instance predictions of one image.
|
||||
|
||||
Args:
|
||||
pred_result (dict): Prediction results of an image containing the
|
||||
``prob_map``, which is a tensor of shape :math:`(N, H, W)`.
|
||||
pred_result (tuple(Tensor)): A tuple of 3 tensors where the first
|
||||
tensor is ``prob_map`` of shape :math:`(N, H, W)`.
|
||||
data_sample (TextDetDataSample): Datasample of an image.
|
||||
|
||||
Returns:
|
||||
|
@ -75,7 +76,7 @@ class DBPostprocessor(BaseTextDetPostProcessor):
|
|||
data_sample.pred_instances.polygons = []
|
||||
data_sample.pred_instances.scores = []
|
||||
|
||||
prob_map = pred_results['prob_map']
|
||||
prob_map = pred_results[0]
|
||||
text_mask = prob_map > self.mask_thr
|
||||
|
||||
score_map = prob_map.data.cpu().numpy().astype(np.float32)
|
||||
|
|
|
@ -19,9 +19,6 @@ class TestDBHead(TestCase):
|
|||
db_head = DBHead(in_channels=10)
|
||||
data = torch.randn((2, 10, 40, 50))
|
||||
results = db_head(data, None)
|
||||
self.assertIn('prob_map', results)
|
||||
self.assertIn('thr_map', results)
|
||||
self.assertIn('binary_map', results)
|
||||
self.assertEqual(results['prob_map'].shape, (2, 160, 200))
|
||||
self.assertEqual(results['thr_map'].shape, (2, 160, 200))
|
||||
self.assertEqual(results['binary_map'].shape, (2, 160, 200))
|
||||
self.assertEqual(results[0].shape, (2, 160, 200))
|
||||
self.assertEqual(results[1].shape, (2, 160, 200))
|
||||
self.assertEqual(results[2].shape, (2, 160, 200))
|
||||
|
|
|
@ -26,10 +26,8 @@ class TestDBLoss(TestCase):
|
|||
ignored=torch.BoolTensor([False, False, True])))
|
||||
]
|
||||
pred_size = (1, 40, 40)
|
||||
self.preds = dict(
|
||||
prob_map=torch.rand(pred_size),
|
||||
thr_map=torch.rand(pred_size),
|
||||
binary_map=torch.rand(pred_size))
|
||||
self.preds = (torch.rand(pred_size), torch.rand(pred_size),
|
||||
torch.rand(pred_size))
|
||||
|
||||
def test_is_poly_invalid(self):
|
||||
# area < 1
|
||||
|
|
|
@ -3,6 +3,7 @@ import unittest
|
|||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import InstanceData
|
||||
|
||||
from mmocr.data import TextDetDataSample
|
||||
|
@ -29,9 +30,7 @@ class TestBaseTextDetPostProcessor(unittest.TestCase):
|
|||
|
||||
mock_get_text_instances.side_effect = mock_func
|
||||
|
||||
pred_results = {
|
||||
'prob_map': np.array([[0.1, 0.2], [0.3, 0.4]]),
|
||||
}
|
||||
pred_results = torch.Tensor([[0.1, 0.2], [0.3, 0.4]])
|
||||
data_samples = [
|
||||
TextDetDataSample(
|
||||
metainfo=dict(scale_factor=(0.5, 1)),
|
||||
|
@ -79,49 +78,30 @@ class TestBaseTextDetPostProcessor(unittest.TestCase):
|
|||
|
||||
def test_split_results(self):
|
||||
|
||||
# some shorthands
|
||||
lt = torch.LongTensor
|
||||
ft = torch.FloatTensor
|
||||
|
||||
base_postprocessor = BaseTextDetPostProcessor()
|
||||
|
||||
# test invalid arguments
|
||||
with self.assertRaises(AssertionError):
|
||||
base_postprocessor.split_results(None)
|
||||
|
||||
results = [lt([0, 1, 5]), ft([0.2, 0.3])]
|
||||
with self.assertRaises(AssertionError):
|
||||
base_postprocessor.split_results({'test': [0, 1]}, 'fields')
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
base_postprocessor.split_results({'test': [0, 1]},
|
||||
keep_unsplit_fields='true')
|
||||
base_postprocessor.split_results(results)
|
||||
|
||||
# test split_results
|
||||
results = {
|
||||
'test': [0, 1],
|
||||
'test2': np.array([2, 3], dtype=int),
|
||||
'meta': 'str'
|
||||
}
|
||||
split_results = base_postprocessor.split_results(results, ['test'])
|
||||
self.assertEqual(split_results, [{'test': 0}, {'test': 1}])
|
||||
results = [lt([0, 1, 5]), ft([0.2, 0.3, 0.6])]
|
||||
split_results = base_postprocessor.split_results(results)
|
||||
self.assertEqual(split_results,
|
||||
[[lt([0]), ft([0.2])], [lt([1]), ft([0.3])],
|
||||
[lt([5]), ft([0.6])]])
|
||||
|
||||
split_results = base_postprocessor.split_results(
|
||||
results, ['test', 'test2'])
|
||||
self.assertEqual(split_results, [{
|
||||
'test': 0,
|
||||
'test2': 2
|
||||
}, {
|
||||
'test': 1,
|
||||
'test2': 3
|
||||
}])
|
||||
|
||||
split_results = base_postprocessor.split_results(
|
||||
results, ['test', 'test2'], keep_unsplit_fields=True)
|
||||
self.assertEqual(split_results, [{
|
||||
'test': 0,
|
||||
'test2': 2,
|
||||
'meta': 'str'
|
||||
}, {
|
||||
'test': 1,
|
||||
'test2': 3,
|
||||
'meta': 'str'
|
||||
}])
|
||||
results = lt([0, 1, 5])
|
||||
split_results = base_postprocessor.split_results(results)
|
||||
self.assertEqual(split_results, [lt([0]), lt([1]), lt([5])])
|
||||
|
||||
def test_poly_nms(self):
|
||||
base_postprocessor = BaseTextDetPostProcessor(text_repr_type='poly')
|
||||
|
|
|
@ -23,7 +23,7 @@ class TestDBPostProcessor(unittest.TestCase):
|
|||
def test_get_text_instances(self, text_repr_type):
|
||||
|
||||
postprocessor = DBPostprocessor(text_repr_type=text_repr_type)
|
||||
pred_result = dict(prob_map=torch.rand(4, 5))
|
||||
pred_result = (torch.rand(4, 5), torch.rand(4, 5), torch.rand(4, 5))
|
||||
data_sample = TextDetDataSample(
|
||||
metainfo=dict(scale_factor=(0.5, 1)),
|
||||
gt_instances=InstanceData(polygons=[
|
||||
|
@ -38,7 +38,8 @@ class TestDBPostProcessor(unittest.TestCase):
|
|||
|
||||
postprocessor = DBPostprocessor(
|
||||
min_text_score=1, text_repr_type=text_repr_type)
|
||||
pred_result = dict(prob_map=torch.rand(4, 5) * 0.8)
|
||||
pred_result = (torch.rand(4, 5) * 0.8, torch.rand(4, 5) * 0.8,
|
||||
torch.rand(4, 5) * 0.8)
|
||||
results = postprocessor.get_text_instances(pred_result, data_sample)
|
||||
self.assertEqual(results.pred_instances.polygons, [])
|
||||
self.assertTrue(
|
||||
|
|
Loading…
Reference in New Issue