diff --git a/.dev_scripts/covignore.cfg b/.dev_scripts/covignore.cfg index 82368a9a..2f31e8a4 100644 --- a/.dev_scripts/covignore.cfg +++ b/.dev_scripts/covignore.cfg @@ -13,9 +13,6 @@ mmocr/datasets/pipelines/dbnet_transforms.py # will be deleted mmocr/models/textdet/heads/head_mixin.py -# Will be covered by det head tests -mmocr/models/textdet/heads/base_textdet_head.py - # They will be removed later all det models have been refactored mmocr/models/common/detectors/single_stage.py mmocr/models/textdet/detectors/text_detector_mixin.py diff --git a/mmocr/models/textdet/heads/db_head.py b/mmocr/models/textdet/heads/db_head.py index 2aa2990c..20b4fddf 100644 --- a/mmocr/models/textdet/heads/db_head.py +++ b/mmocr/models/textdet/heads/db_head.py @@ -1,61 +1,48 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings +from typing import Dict, List, Optional, Union import torch import torch.nn as nn -from mmcv.runner import BaseModule, Sequential +from mmcv.runner import Sequential +from mmocr.core import TextDetDataSample +from mmocr.models.textdet.heads import BaseTextDetHead from mmocr.registry import MODELS -from .head_mixin import HeadMixin @MODELS.register_module() -class DBHead(HeadMixin, BaseModule): +class DBHead(BaseTextDetHead): """The class for DBNet head. This was partially adapted from https://github.com/MhLiao/DB Args: - in_channels (int): The number of input channels of the db head. - with_bias (bool): Whether add bias in Conv2d layer. - downsample_ratio (float): The downsample ratio of ground truths. + in_channels (int): The number of input channels. + with_bias (bool): Whether add bias in Conv2d layer. Defaults to False. loss (dict): Config of loss for dbnet. postprocessor (dict): Config of postprocessor for dbnet. + init_cfg (dict or list[dict], optional): Initialization configs. """ def __init__( - self, - in_channels, - with_bias=False, - downsample_ratio=1.0, - loss=dict(type='DBLoss'), - postprocessor=dict(type='DBPostprocessor', text_repr_type='quad'), - init_cfg=[ - dict(type='Kaiming', layer='Conv'), - dict(type='Constant', layer='BatchNorm', val=1., bias=1e-4) - ], - train_cfg=None, - test_cfg=None, - **kwargs): - old_keys = ['text_repr_type', 'decoding_type'] - for key in old_keys: - if kwargs.get(key, None): - postprocessor[key] = kwargs.get(key) - warnings.warn( - f'{key} is deprecated, please specify ' - 'it in postprocessor config dict. See ' - 'https://github.com/open-mmlab/mmocr/pull/640' - ' for details.', UserWarning) - BaseModule.__init__(self, init_cfg=init_cfg) - HeadMixin.__init__(self, loss, postprocessor) + self, + in_channels: int, + with_bias: bool = False, + loss: Dict = dict(type='DBLoss'), + postprocessor: Dict = dict( + type='DBPostprocessor', text_repr_type='quad'), + init_cfg: Optional[Union[Dict, List[Dict]]] = [ + dict(type='Kaiming', layer='Conv'), + dict(type='Constant', layer='BatchNorm', val=1., bias=1e-4) + ] + ) -> None: + super().__init__( + loss=loss, postprocessor=postprocessor, init_cfg=init_cfg) assert isinstance(in_channels, int) + assert isinstance(with_bias, bool) self.in_channels = in_channels - self.train_cfg = train_cfg - self.test_cfg = test_cfg - self.downsample_ratio = downsample_ratio - self.binarize = Sequential( nn.Conv2d( in_channels, in_channels // 4, 3, bias=with_bias, padding=1), @@ -63,27 +50,44 @@ class DBHead(HeadMixin, BaseModule): nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2), nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid()) - self.threshold = self._init_thr(in_channels) - def diff_binarize(self, prob_map, thr_map, k): - return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map))) + def _diff_binarize(self, prob_map: torch.Tensor, thr_map: torch.Tensor, + k: int) -> torch.Tensor: + """Differential binarization. - def forward(self, inputs): - """ Args: - inputs (Tensor): Shape (batch_size, hidden_size, h, w). + prob_map (Tensor): Probability map. + thr_map (Tensor): Threshold map. + k (int): Amplification factor. Returns: - Tensor: A tensor of the same shape as input. + torch.Tensor: Binary map. """ - prob_map = self.binarize(inputs) - thr_map = self.threshold(inputs) - binary_map = self.diff_binarize(prob_map, thr_map, k=50) - outputs = torch.cat((prob_map, thr_map, binary_map), dim=1) + return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map))) + + def forward(self, img: torch.Tensor, + data_samples: List[TextDetDataSample]) -> Dict: + """ + Args: + img (torch.Tensor): Shape :math:`(N, C, H, W)`. + data_samples (List[TextDetDataSample]): List of data samples. + + Returns: + dict: A dict with keys of ``prob_map``, ``thr_map`` and + ``binary_map``, each of shape :math:`(N, 4H, 4W)`. + """ + prob_map = self.binarize(img).squeeze(1) + thr_map = self.threshold(img).squeeze(1) + binary_map = self._diff_binarize(prob_map, thr_map, k=50).squeeze(1) + outputs = dict( + prob_map=prob_map, thr_map=thr_map, binary_map=binary_map) return outputs - def _init_thr(self, inner_channels, bias=False): + def _init_thr(self, + inner_channels: int, + bias: bool = False) -> nn.ModuleList: + """Initialize threshold branch.""" in_channels = inner_channels seq = Sequential( nn.Conv2d( diff --git a/tests/test_models/test_textdet/test_heads/test_db_head.py b/tests/test_models/test_textdet/test_heads/test_db_head.py new file mode 100644 index 00000000..bfcb43ae --- /dev/null +++ b/tests/test_models/test_textdet/test_heads/test_db_head.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmocr.models.textdet.heads import DBHead + + +class TestDBHead(TestCase): + + def test_init(self): + with self.assertRaises(AssertionError): + DBHead(in_channels='test', with_bias=False) + + with self.assertRaises(AssertionError): + DBHead(in_channels=1, with_bias='Text') + + def test_forward(self): + db_head = DBHead(in_channels=10) + data = torch.randn((2, 10, 40, 50)) + results = db_head(data, None) + self.assertIn('prob_map', results) + self.assertIn('thr_map', results) + self.assertIn('binary_map', results) + self.assertEqual(results['prob_map'].shape, (2, 160, 200)) + self.assertEqual(results['thr_map'].shape, (2, 160, 200)) + self.assertEqual(results['binary_map'].shape, (2, 160, 200))