mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[DBNet] Add DBHead
This commit is contained in:
parent
7a66a84b64
commit
32ef9cc3cf
@ -13,9 +13,6 @@ mmocr/datasets/pipelines/dbnet_transforms.py
|
||||
# will be deleted
|
||||
mmocr/models/textdet/heads/head_mixin.py
|
||||
|
||||
# Will be covered by det head tests
|
||||
mmocr/models/textdet/heads/base_textdet_head.py
|
||||
|
||||
# They will be removed later all det models have been refactored
|
||||
mmocr/models/common/detectors/single_stage.py
|
||||
mmocr/models/textdet/detectors/text_detector_mixin.py
|
||||
|
@ -1,61 +1,48 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule, Sequential
|
||||
from mmcv.runner import Sequential
|
||||
|
||||
from mmocr.core import TextDetDataSample
|
||||
from mmocr.models.textdet.heads import BaseTextDetHead
|
||||
from mmocr.registry import MODELS
|
||||
from .head_mixin import HeadMixin
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DBHead(HeadMixin, BaseModule):
|
||||
class DBHead(BaseTextDetHead):
|
||||
"""The class for DBNet head.
|
||||
|
||||
This was partially adapted from https://github.com/MhLiao/DB
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels of the db head.
|
||||
with_bias (bool): Whether add bias in Conv2d layer.
|
||||
downsample_ratio (float): The downsample ratio of ground truths.
|
||||
in_channels (int): The number of input channels.
|
||||
with_bias (bool): Whether add bias in Conv2d layer. Defaults to False.
|
||||
loss (dict): Config of loss for dbnet.
|
||||
postprocessor (dict): Config of postprocessor for dbnet.
|
||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
with_bias=False,
|
||||
downsample_ratio=1.0,
|
||||
loss=dict(type='DBLoss'),
|
||||
postprocessor=dict(type='DBPostprocessor', text_repr_type='quad'),
|
||||
init_cfg=[
|
||||
dict(type='Kaiming', layer='Conv'),
|
||||
dict(type='Constant', layer='BatchNorm', val=1., bias=1e-4)
|
||||
],
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
**kwargs):
|
||||
old_keys = ['text_repr_type', 'decoding_type']
|
||||
for key in old_keys:
|
||||
if kwargs.get(key, None):
|
||||
postprocessor[key] = kwargs.get(key)
|
||||
warnings.warn(
|
||||
f'{key} is deprecated, please specify '
|
||||
'it in postprocessor config dict. See '
|
||||
'https://github.com/open-mmlab/mmocr/pull/640'
|
||||
' for details.', UserWarning)
|
||||
BaseModule.__init__(self, init_cfg=init_cfg)
|
||||
HeadMixin.__init__(self, loss, postprocessor)
|
||||
self,
|
||||
in_channels: int,
|
||||
with_bias: bool = False,
|
||||
loss: Dict = dict(type='DBLoss'),
|
||||
postprocessor: Dict = dict(
|
||||
type='DBPostprocessor', text_repr_type='quad'),
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = [
|
||||
dict(type='Kaiming', layer='Conv'),
|
||||
dict(type='Constant', layer='BatchNorm', val=1., bias=1e-4)
|
||||
]
|
||||
) -> None:
|
||||
super().__init__(
|
||||
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
|
||||
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(with_bias, bool)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
self.downsample_ratio = downsample_ratio
|
||||
|
||||
self.binarize = Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels, in_channels // 4, 3, bias=with_bias, padding=1),
|
||||
@ -63,27 +50,44 @@ class DBHead(HeadMixin, BaseModule):
|
||||
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
|
||||
nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
|
||||
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid())
|
||||
|
||||
self.threshold = self._init_thr(in_channels)
|
||||
|
||||
def diff_binarize(self, prob_map, thr_map, k):
|
||||
return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map)))
|
||||
def _diff_binarize(self, prob_map: torch.Tensor, thr_map: torch.Tensor,
|
||||
k: int) -> torch.Tensor:
|
||||
"""Differential binarization.
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Args:
|
||||
inputs (Tensor): Shape (batch_size, hidden_size, h, w).
|
||||
prob_map (Tensor): Probability map.
|
||||
thr_map (Tensor): Threshold map.
|
||||
k (int): Amplification factor.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of the same shape as input.
|
||||
torch.Tensor: Binary map.
|
||||
"""
|
||||
prob_map = self.binarize(inputs)
|
||||
thr_map = self.threshold(inputs)
|
||||
binary_map = self.diff_binarize(prob_map, thr_map, k=50)
|
||||
outputs = torch.cat((prob_map, thr_map, binary_map), dim=1)
|
||||
return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map)))
|
||||
|
||||
def forward(self, img: torch.Tensor,
|
||||
data_samples: List[TextDetDataSample]) -> Dict:
|
||||
"""
|
||||
Args:
|
||||
img (torch.Tensor): Shape :math:`(N, C, H, W)`.
|
||||
data_samples (List[TextDetDataSample]): List of data samples.
|
||||
|
||||
Returns:
|
||||
dict: A dict with keys of ``prob_map``, ``thr_map`` and
|
||||
``binary_map``, each of shape :math:`(N, 4H, 4W)`.
|
||||
"""
|
||||
prob_map = self.binarize(img).squeeze(1)
|
||||
thr_map = self.threshold(img).squeeze(1)
|
||||
binary_map = self._diff_binarize(prob_map, thr_map, k=50).squeeze(1)
|
||||
outputs = dict(
|
||||
prob_map=prob_map, thr_map=thr_map, binary_map=binary_map)
|
||||
return outputs
|
||||
|
||||
def _init_thr(self, inner_channels, bias=False):
|
||||
def _init_thr(self,
|
||||
inner_channels: int,
|
||||
bias: bool = False) -> nn.ModuleList:
|
||||
"""Initialize threshold branch."""
|
||||
in_channels = inner_channels
|
||||
seq = Sequential(
|
||||
nn.Conv2d(
|
||||
|
27
tests/test_models/test_textdet/test_heads/test_db_head.py
Normal file
27
tests/test_models/test_textdet/test_heads/test_db_head.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.models.textdet.heads import DBHead
|
||||
|
||||
|
||||
class TestDBHead(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
DBHead(in_channels='test', with_bias=False)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
DBHead(in_channels=1, with_bias='Text')
|
||||
|
||||
def test_forward(self):
|
||||
db_head = DBHead(in_channels=10)
|
||||
data = torch.randn((2, 10, 40, 50))
|
||||
results = db_head(data, None)
|
||||
self.assertIn('prob_map', results)
|
||||
self.assertIn('thr_map', results)
|
||||
self.assertIn('binary_map', results)
|
||||
self.assertEqual(results['prob_map'].shape, (2, 160, 200))
|
||||
self.assertEqual(results['thr_map'].shape, (2, 160, 200))
|
||||
self.assertEqual(results['binary_map'].shape, (2, 160, 200))
|
Loading…
x
Reference in New Issue
Block a user