[DBNet] Add DBHead

This commit is contained in:
gaotongxiao 2022-05-30 08:42:40 +00:00
parent 7a66a84b64
commit 32ef9cc3cf
3 changed files with 78 additions and 50 deletions

View File

@ -13,9 +13,6 @@ mmocr/datasets/pipelines/dbnet_transforms.py
# will be deleted
mmocr/models/textdet/heads/head_mixin.py
# Will be covered by det head tests
mmocr/models/textdet/heads/base_textdet_head.py
# They will be removed later all det models have been refactored
mmocr/models/common/detectors/single_stage.py
mmocr/models/textdet/detectors/text_detector_mixin.py

View File

@ -1,61 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from mmcv.runner import BaseModule, Sequential
from mmcv.runner import Sequential
from mmocr.core import TextDetDataSample
from mmocr.models.textdet.heads import BaseTextDetHead
from mmocr.registry import MODELS
from .head_mixin import HeadMixin
@MODELS.register_module()
class DBHead(HeadMixin, BaseModule):
class DBHead(BaseTextDetHead):
"""The class for DBNet head.
This was partially adapted from https://github.com/MhLiao/DB
Args:
in_channels (int): The number of input channels of the db head.
with_bias (bool): Whether add bias in Conv2d layer.
downsample_ratio (float): The downsample ratio of ground truths.
in_channels (int): The number of input channels.
with_bias (bool): Whether add bias in Conv2d layer. Defaults to False.
loss (dict): Config of loss for dbnet.
postprocessor (dict): Config of postprocessor for dbnet.
init_cfg (dict or list[dict], optional): Initialization configs.
"""
def __init__(
self,
in_channels,
with_bias=False,
downsample_ratio=1.0,
loss=dict(type='DBLoss'),
postprocessor=dict(type='DBPostprocessor', text_repr_type='quad'),
init_cfg=[
dict(type='Kaiming', layer='Conv'),
dict(type='Constant', layer='BatchNorm', val=1., bias=1e-4)
],
train_cfg=None,
test_cfg=None,
**kwargs):
old_keys = ['text_repr_type', 'decoding_type']
for key in old_keys:
if kwargs.get(key, None):
postprocessor[key] = kwargs.get(key)
warnings.warn(
f'{key} is deprecated, please specify '
'it in postprocessor config dict. See '
'https://github.com/open-mmlab/mmocr/pull/640'
' for details.', UserWarning)
BaseModule.__init__(self, init_cfg=init_cfg)
HeadMixin.__init__(self, loss, postprocessor)
self,
in_channels: int,
with_bias: bool = False,
loss: Dict = dict(type='DBLoss'),
postprocessor: Dict = dict(
type='DBPostprocessor', text_repr_type='quad'),
init_cfg: Optional[Union[Dict, List[Dict]]] = [
dict(type='Kaiming', layer='Conv'),
dict(type='Constant', layer='BatchNorm', val=1., bias=1e-4)
]
) -> None:
super().__init__(
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
assert isinstance(in_channels, int)
assert isinstance(with_bias, bool)
self.in_channels = in_channels
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.downsample_ratio = downsample_ratio
self.binarize = Sequential(
nn.Conv2d(
in_channels, in_channels // 4, 3, bias=with_bias, padding=1),
@ -63,27 +50,44 @@ class DBHead(HeadMixin, BaseModule):
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid())
self.threshold = self._init_thr(in_channels)
def diff_binarize(self, prob_map, thr_map, k):
return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map)))
def _diff_binarize(self, prob_map: torch.Tensor, thr_map: torch.Tensor,
k: int) -> torch.Tensor:
"""Differential binarization.
def forward(self, inputs):
"""
Args:
inputs (Tensor): Shape (batch_size, hidden_size, h, w).
prob_map (Tensor): Probability map.
thr_map (Tensor): Threshold map.
k (int): Amplification factor.
Returns:
Tensor: A tensor of the same shape as input.
torch.Tensor: Binary map.
"""
prob_map = self.binarize(inputs)
thr_map = self.threshold(inputs)
binary_map = self.diff_binarize(prob_map, thr_map, k=50)
outputs = torch.cat((prob_map, thr_map, binary_map), dim=1)
return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map)))
def forward(self, img: torch.Tensor,
data_samples: List[TextDetDataSample]) -> Dict:
"""
Args:
img (torch.Tensor): Shape :math:`(N, C, H, W)`.
data_samples (List[TextDetDataSample]): List of data samples.
Returns:
dict: A dict with keys of ``prob_map``, ``thr_map`` and
``binary_map``, each of shape :math:`(N, 4H, 4W)`.
"""
prob_map = self.binarize(img).squeeze(1)
thr_map = self.threshold(img).squeeze(1)
binary_map = self._diff_binarize(prob_map, thr_map, k=50).squeeze(1)
outputs = dict(
prob_map=prob_map, thr_map=thr_map, binary_map=binary_map)
return outputs
def _init_thr(self, inner_channels, bias=False):
def _init_thr(self,
inner_channels: int,
bias: bool = False) -> nn.ModuleList:
"""Initialize threshold branch."""
in_channels = inner_channels
seq = Sequential(
nn.Conv2d(

View File

@ -0,0 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmocr.models.textdet.heads import DBHead
class TestDBHead(TestCase):
def test_init(self):
with self.assertRaises(AssertionError):
DBHead(in_channels='test', with_bias=False)
with self.assertRaises(AssertionError):
DBHead(in_channels=1, with_bias='Text')
def test_forward(self):
db_head = DBHead(in_channels=10)
data = torch.randn((2, 10, 40, 50))
results = db_head(data, None)
self.assertIn('prob_map', results)
self.assertIn('thr_map', results)
self.assertIn('binary_map', results)
self.assertEqual(results['prob_map'].shape, (2, 160, 200))
self.assertEqual(results['thr_map'].shape, (2, 160, 200))
self.assertEqual(results['binary_map'].shape, (2, 160, 200))