[DBNet] Add DBHead

This commit is contained in:
gaotongxiao 2022-05-30 08:42:40 +00:00
parent 7a66a84b64
commit 32ef9cc3cf
3 changed files with 78 additions and 50 deletions

View File

@ -13,9 +13,6 @@ mmocr/datasets/pipelines/dbnet_transforms.py
# will be deleted # will be deleted
mmocr/models/textdet/heads/head_mixin.py mmocr/models/textdet/heads/head_mixin.py
# Will be covered by det head tests
mmocr/models/textdet/heads/base_textdet_head.py
# They will be removed later all det models have been refactored # They will be removed later all det models have been refactored
mmocr/models/common/detectors/single_stage.py mmocr/models/common/detectors/single_stage.py
mmocr/models/textdet/detectors/text_detector_mixin.py mmocr/models/textdet/detectors/text_detector_mixin.py

View File

@ -1,61 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings from typing import Dict, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.runner import BaseModule, Sequential from mmcv.runner import Sequential
from mmocr.core import TextDetDataSample
from mmocr.models.textdet.heads import BaseTextDetHead
from mmocr.registry import MODELS from mmocr.registry import MODELS
from .head_mixin import HeadMixin
@MODELS.register_module() @MODELS.register_module()
class DBHead(HeadMixin, BaseModule): class DBHead(BaseTextDetHead):
"""The class for DBNet head. """The class for DBNet head.
This was partially adapted from https://github.com/MhLiao/DB This was partially adapted from https://github.com/MhLiao/DB
Args: Args:
in_channels (int): The number of input channels of the db head. in_channels (int): The number of input channels.
with_bias (bool): Whether add bias in Conv2d layer. with_bias (bool): Whether add bias in Conv2d layer. Defaults to False.
downsample_ratio (float): The downsample ratio of ground truths.
loss (dict): Config of loss for dbnet. loss (dict): Config of loss for dbnet.
postprocessor (dict): Config of postprocessor for dbnet. postprocessor (dict): Config of postprocessor for dbnet.
init_cfg (dict or list[dict], optional): Initialization configs.
""" """
def __init__( def __init__(
self, self,
in_channels, in_channels: int,
with_bias=False, with_bias: bool = False,
downsample_ratio=1.0, loss: Dict = dict(type='DBLoss'),
loss=dict(type='DBLoss'), postprocessor: Dict = dict(
postprocessor=dict(type='DBPostprocessor', text_repr_type='quad'), type='DBPostprocessor', text_repr_type='quad'),
init_cfg=[ init_cfg: Optional[Union[Dict, List[Dict]]] = [
dict(type='Kaiming', layer='Conv'), dict(type='Kaiming', layer='Conv'),
dict(type='Constant', layer='BatchNorm', val=1., bias=1e-4) dict(type='Constant', layer='BatchNorm', val=1., bias=1e-4)
], ]
train_cfg=None, ) -> None:
test_cfg=None, super().__init__(
**kwargs): loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
old_keys = ['text_repr_type', 'decoding_type']
for key in old_keys:
if kwargs.get(key, None):
postprocessor[key] = kwargs.get(key)
warnings.warn(
f'{key} is deprecated, please specify '
'it in postprocessor config dict. See '
'https://github.com/open-mmlab/mmocr/pull/640'
' for details.', UserWarning)
BaseModule.__init__(self, init_cfg=init_cfg)
HeadMixin.__init__(self, loss, postprocessor)
assert isinstance(in_channels, int) assert isinstance(in_channels, int)
assert isinstance(with_bias, bool)
self.in_channels = in_channels self.in_channels = in_channels
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.downsample_ratio = downsample_ratio
self.binarize = Sequential( self.binarize = Sequential(
nn.Conv2d( nn.Conv2d(
in_channels, in_channels // 4, 3, bias=with_bias, padding=1), in_channels, in_channels // 4, 3, bias=with_bias, padding=1),
@ -63,27 +50,44 @@ class DBHead(HeadMixin, BaseModule):
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2), nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid()) nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid())
self.threshold = self._init_thr(in_channels) self.threshold = self._init_thr(in_channels)
def diff_binarize(self, prob_map, thr_map, k): def _diff_binarize(self, prob_map: torch.Tensor, thr_map: torch.Tensor,
return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map))) k: int) -> torch.Tensor:
"""Differential binarization.
def forward(self, inputs):
"""
Args: Args:
inputs (Tensor): Shape (batch_size, hidden_size, h, w). prob_map (Tensor): Probability map.
thr_map (Tensor): Threshold map.
k (int): Amplification factor.
Returns: Returns:
Tensor: A tensor of the same shape as input. torch.Tensor: Binary map.
""" """
prob_map = self.binarize(inputs) return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map)))
thr_map = self.threshold(inputs)
binary_map = self.diff_binarize(prob_map, thr_map, k=50) def forward(self, img: torch.Tensor,
outputs = torch.cat((prob_map, thr_map, binary_map), dim=1) data_samples: List[TextDetDataSample]) -> Dict:
"""
Args:
img (torch.Tensor): Shape :math:`(N, C, H, W)`.
data_samples (List[TextDetDataSample]): List of data samples.
Returns:
dict: A dict with keys of ``prob_map``, ``thr_map`` and
``binary_map``, each of shape :math:`(N, 4H, 4W)`.
"""
prob_map = self.binarize(img).squeeze(1)
thr_map = self.threshold(img).squeeze(1)
binary_map = self._diff_binarize(prob_map, thr_map, k=50).squeeze(1)
outputs = dict(
prob_map=prob_map, thr_map=thr_map, binary_map=binary_map)
return outputs return outputs
def _init_thr(self, inner_channels, bias=False): def _init_thr(self,
inner_channels: int,
bias: bool = False) -> nn.ModuleList:
"""Initialize threshold branch."""
in_channels = inner_channels in_channels = inner_channels
seq = Sequential( seq = Sequential(
nn.Conv2d( nn.Conv2d(

View File

@ -0,0 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmocr.models.textdet.heads import DBHead
class TestDBHead(TestCase):
def test_init(self):
with self.assertRaises(AssertionError):
DBHead(in_channels='test', with_bias=False)
with self.assertRaises(AssertionError):
DBHead(in_channels=1, with_bias='Text')
def test_forward(self):
db_head = DBHead(in_channels=10)
data = torch.randn((2, 10, 40, 50))
results = db_head(data, None)
self.assertIn('prob_map', results)
self.assertIn('thr_map', results)
self.assertIn('binary_map', results)
self.assertEqual(results['prob_map'].shape, (2, 160, 200))
self.assertEqual(results['thr_map'].shape, (2, 160, 200))
self.assertEqual(results['binary_map'].shape, (2, 160, 200))