mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[DBNet] Add DBHead
This commit is contained in:
parent
7a66a84b64
commit
32ef9cc3cf
@ -13,9 +13,6 @@ mmocr/datasets/pipelines/dbnet_transforms.py
|
|||||||
# will be deleted
|
# will be deleted
|
||||||
mmocr/models/textdet/heads/head_mixin.py
|
mmocr/models/textdet/heads/head_mixin.py
|
||||||
|
|
||||||
# Will be covered by det head tests
|
|
||||||
mmocr/models/textdet/heads/base_textdet_head.py
|
|
||||||
|
|
||||||
# They will be removed later all det models have been refactored
|
# They will be removed later all det models have been refactored
|
||||||
mmocr/models/common/detectors/single_stage.py
|
mmocr/models/common/detectors/single_stage.py
|
||||||
mmocr/models/textdet/detectors/text_detector_mixin.py
|
mmocr/models/textdet/detectors/text_detector_mixin.py
|
||||||
|
@ -1,61 +1,48 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import warnings
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.runner import BaseModule, Sequential
|
from mmcv.runner import Sequential
|
||||||
|
|
||||||
|
from mmocr.core import TextDetDataSample
|
||||||
|
from mmocr.models.textdet.heads import BaseTextDetHead
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from .head_mixin import HeadMixin
|
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class DBHead(HeadMixin, BaseModule):
|
class DBHead(BaseTextDetHead):
|
||||||
"""The class for DBNet head.
|
"""The class for DBNet head.
|
||||||
|
|
||||||
This was partially adapted from https://github.com/MhLiao/DB
|
This was partially adapted from https://github.com/MhLiao/DB
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_channels (int): The number of input channels of the db head.
|
in_channels (int): The number of input channels.
|
||||||
with_bias (bool): Whether add bias in Conv2d layer.
|
with_bias (bool): Whether add bias in Conv2d layer. Defaults to False.
|
||||||
downsample_ratio (float): The downsample ratio of ground truths.
|
|
||||||
loss (dict): Config of loss for dbnet.
|
loss (dict): Config of loss for dbnet.
|
||||||
postprocessor (dict): Config of postprocessor for dbnet.
|
postprocessor (dict): Config of postprocessor for dbnet.
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels,
|
in_channels: int,
|
||||||
with_bias=False,
|
with_bias: bool = False,
|
||||||
downsample_ratio=1.0,
|
loss: Dict = dict(type='DBLoss'),
|
||||||
loss=dict(type='DBLoss'),
|
postprocessor: Dict = dict(
|
||||||
postprocessor=dict(type='DBPostprocessor', text_repr_type='quad'),
|
type='DBPostprocessor', text_repr_type='quad'),
|
||||||
init_cfg=[
|
init_cfg: Optional[Union[Dict, List[Dict]]] = [
|
||||||
dict(type='Kaiming', layer='Conv'),
|
dict(type='Kaiming', layer='Conv'),
|
||||||
dict(type='Constant', layer='BatchNorm', val=1., bias=1e-4)
|
dict(type='Constant', layer='BatchNorm', val=1., bias=1e-4)
|
||||||
],
|
]
|
||||||
train_cfg=None,
|
) -> None:
|
||||||
test_cfg=None,
|
super().__init__(
|
||||||
**kwargs):
|
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
|
||||||
old_keys = ['text_repr_type', 'decoding_type']
|
|
||||||
for key in old_keys:
|
|
||||||
if kwargs.get(key, None):
|
|
||||||
postprocessor[key] = kwargs.get(key)
|
|
||||||
warnings.warn(
|
|
||||||
f'{key} is deprecated, please specify '
|
|
||||||
'it in postprocessor config dict. See '
|
|
||||||
'https://github.com/open-mmlab/mmocr/pull/640'
|
|
||||||
' for details.', UserWarning)
|
|
||||||
BaseModule.__init__(self, init_cfg=init_cfg)
|
|
||||||
HeadMixin.__init__(self, loss, postprocessor)
|
|
||||||
|
|
||||||
assert isinstance(in_channels, int)
|
assert isinstance(in_channels, int)
|
||||||
|
assert isinstance(with_bias, bool)
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.train_cfg = train_cfg
|
|
||||||
self.test_cfg = test_cfg
|
|
||||||
self.downsample_ratio = downsample_ratio
|
|
||||||
|
|
||||||
self.binarize = Sequential(
|
self.binarize = Sequential(
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
in_channels, in_channels // 4, 3, bias=with_bias, padding=1),
|
in_channels, in_channels // 4, 3, bias=with_bias, padding=1),
|
||||||
@ -63,27 +50,44 @@ class DBHead(HeadMixin, BaseModule):
|
|||||||
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
|
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
|
||||||
nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
|
nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
|
||||||
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid())
|
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid())
|
||||||
|
|
||||||
self.threshold = self._init_thr(in_channels)
|
self.threshold = self._init_thr(in_channels)
|
||||||
|
|
||||||
def diff_binarize(self, prob_map, thr_map, k):
|
def _diff_binarize(self, prob_map: torch.Tensor, thr_map: torch.Tensor,
|
||||||
return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map)))
|
k: int) -> torch.Tensor:
|
||||||
|
"""Differential binarization.
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
"""
|
|
||||||
Args:
|
Args:
|
||||||
inputs (Tensor): Shape (batch_size, hidden_size, h, w).
|
prob_map (Tensor): Probability map.
|
||||||
|
thr_map (Tensor): Threshold map.
|
||||||
|
k (int): Amplification factor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: A tensor of the same shape as input.
|
torch.Tensor: Binary map.
|
||||||
"""
|
"""
|
||||||
prob_map = self.binarize(inputs)
|
return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map)))
|
||||||
thr_map = self.threshold(inputs)
|
|
||||||
binary_map = self.diff_binarize(prob_map, thr_map, k=50)
|
def forward(self, img: torch.Tensor,
|
||||||
outputs = torch.cat((prob_map, thr_map, binary_map), dim=1)
|
data_samples: List[TextDetDataSample]) -> Dict:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (torch.Tensor): Shape :math:`(N, C, H, W)`.
|
||||||
|
data_samples (List[TextDetDataSample]): List of data samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dict with keys of ``prob_map``, ``thr_map`` and
|
||||||
|
``binary_map``, each of shape :math:`(N, 4H, 4W)`.
|
||||||
|
"""
|
||||||
|
prob_map = self.binarize(img).squeeze(1)
|
||||||
|
thr_map = self.threshold(img).squeeze(1)
|
||||||
|
binary_map = self._diff_binarize(prob_map, thr_map, k=50).squeeze(1)
|
||||||
|
outputs = dict(
|
||||||
|
prob_map=prob_map, thr_map=thr_map, binary_map=binary_map)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def _init_thr(self, inner_channels, bias=False):
|
def _init_thr(self,
|
||||||
|
inner_channels: int,
|
||||||
|
bias: bool = False) -> nn.ModuleList:
|
||||||
|
"""Initialize threshold branch."""
|
||||||
in_channels = inner_channels
|
in_channels = inner_channels
|
||||||
seq = Sequential(
|
seq = Sequential(
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
|
27
tests/test_models/test_textdet/test_heads/test_db_head.py
Normal file
27
tests/test_models/test_textdet/test_heads/test_db_head.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmocr.models.textdet.heads import DBHead
|
||||||
|
|
||||||
|
|
||||||
|
class TestDBHead(TestCase):
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
DBHead(in_channels='test', with_bias=False)
|
||||||
|
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
DBHead(in_channels=1, with_bias='Text')
|
||||||
|
|
||||||
|
def test_forward(self):
|
||||||
|
db_head = DBHead(in_channels=10)
|
||||||
|
data = torch.randn((2, 10, 40, 50))
|
||||||
|
results = db_head(data, None)
|
||||||
|
self.assertIn('prob_map', results)
|
||||||
|
self.assertIn('thr_map', results)
|
||||||
|
self.assertIn('binary_map', results)
|
||||||
|
self.assertEqual(results['prob_map'].shape, (2, 160, 200))
|
||||||
|
self.assertEqual(results['thr_map'].shape, (2, 160, 200))
|
||||||
|
self.assertEqual(results['binary_map'].shape, (2, 160, 200))
|
Loading…
x
Reference in New Issue
Block a user