diff --git a/mmocr/models/textdet/heads/fce_head.py b/mmocr/models/textdet/heads/fce_head.py index 6635d531..24788a8b 100644 --- a/mmocr/models/textdet/heads/fce_head.py +++ b/mmocr/models/textdet/heads/fce_head.py @@ -1,17 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings +from typing import Dict, List, Optional +import torch import torch.nn as nn -from mmcv.runner import BaseModule from mmdet.core import multi_apply +from mmocr.core import TextDetDataSample +from mmocr.models.textdet.heads import BaseTextDetHead from mmocr.registry import MODELS -from ..postprocessors.utils import poly_nms -from .head_mixin import HeadMixin @MODELS.register_module() -class FCEHead(HeadMixin, BaseModule): +class FCEHead(BaseTextDetHead): """The class for implementing FCENet head. FCENet(CVPR2021): `Fourier Contour Embedding for Arbitrary-shaped Text @@ -19,72 +19,42 @@ class FCEHead(HeadMixin, BaseModule): Args: in_channels (int): The number of input channels. - scales (list[int]) : The scale of each layer. - fourier_degree (int) : The maximum Fourier transform degree k. - nms_thr (float) : The threshold of nms. + fourier_degree (int) : The maximum Fourier transform degree k. Defaults + to 5. loss (dict): Config of loss for FCENet. postprocessor (dict): Config of postprocessor for FCENet. + init_cfg (dict, optional): Initialization configs. """ - def __init__(self, - in_channels, - scales, - fourier_degree=5, - nms_thr=0.1, - loss=dict(type='FCELoss', num_sample=50), - postprocessor=dict( - type='FCEPostprocessor', - text_repr_type='poly', - num_reconstr_points=50, - alpha=1.0, - beta=2.0, - score_thr=0.3), - train_cfg=None, - test_cfg=None, - init_cfg=dict( - type='Normal', - mean=0, - std=0.01, - override=[ - dict(name='out_conv_cls'), - dict(name='out_conv_reg') - ]), - **kwargs): - old_keys = [ - 'text_repr_type', 'decoding_type', 'num_reconstr_points', 'alpha', - 'beta', 'score_thr' - ] - for key in old_keys: - if kwargs.get(key, None): - postprocessor[key] = kwargs.get(key) - warnings.warn( - f'{key} is deprecated, please specify ' - 'it in postprocessor config dict. See ' - 'https://github.com/open-mmlab/mmocr/pull/640' - ' for details.', UserWarning) - if kwargs.get('num_sample', None): - loss['num_sample'] = kwargs.get('num_sample') - warnings.warn( - 'num_sample is deprecated, please specify ' - 'it in loss config dict. See ' - 'https://github.com/open-mmlab/mmocr/pull/640' - ' for details.', UserWarning) - BaseModule.__init__(self, init_cfg=init_cfg) + def __init__( + self, + in_channels: int, + fourier_degree: int = 5, + loss: Dict = dict(type='FCELoss', num_sample=50), + postprocessor: Dict = dict( + type='FCEPostprocessor', + text_repr_type='poly', + num_reconstr_points=50, + alpha=1.0, + beta=2.0, + score_thr=0.3), + init_cfg: Optional[Dict] = dict( + type='Normal', + mean=0, + std=0.01, + override=[dict(name='out_conv_cls'), + dict(name='out_conv_reg')]) + ) -> None: loss['fourier_degree'] = fourier_degree postprocessor['fourier_degree'] = fourier_degree - postprocessor['nms_thr'] = nms_thr - HeadMixin.__init__(self, loss, postprocessor) + super().__init__( + loss=loss, postprocessor=postprocessor, init_cfg=init_cfg) assert isinstance(in_channels, int) + assert isinstance(fourier_degree, int) - self.downsample_ratio = 1.0 self.in_channels = in_channels - self.scales = scales self.fourier_degree = fourier_degree - - self.nms_thr = nms_thr - self.train_cfg = train_cfg - self.test_cfg = test_cfg self.out_channels_cls = 4 self.out_channels_reg = (2 * self.fourier_degree + 1) * 2 @@ -101,49 +71,43 @@ class FCEHead(HeadMixin, BaseModule): stride=1, padding=1) - def forward(self, feats): + def forward(self, + inputs: List[torch.Tensor], + data_samples: List[TextDetDataSample] = None) -> Dict: """ Args: - feats (list[Tensor]): Each tensor has the shape of :math:`(N, C_i, + inputs (List[Tensor]): Each tensor has the shape of :math:`(N, C_i, + H_i, W_i)`. + data_samples (List[TextDetDataSample]): List of data samples. + Default to None. + + Returns: + list[dict]: A list of dict with keys of ``cls_res``, ``reg_res`` + corresponds to the classification result and regression result + computed from the input tensor with the same index. They have + the shapes of :math:`(N, C_{cls,i}, H_i, W_i)` and :math:`(N, + C_{out,i}, H_i, W_i)`. + """ + cls_res, reg_res = multi_apply(self.forward_single, inputs) + level_num = len(cls_res) + preds = [ + dict(cls_res=cls_res[i], reg_res=reg_res[i]) + for i in range(level_num) + ] + return preds + + def forward_single(self, x: torch.Tensor) -> torch.Tensor: + """Forward function for a single feature level. + + Args: + x (Tensor): The input tensor with the shape of :math:`(N, C_i, H_i, W_i)`. Returns: - list[[Tensor, Tensor]]: Each pair of tensors corresponds to the - classification result and regression result computed from the input - tensor with the same index. They have the shapes of :math:`(N, - C_{cls,i}, H_i, W_i)` and :math:`(N, C_{out,i}, H_i, W_i)`. + Tensor: The classification and regression result with the shape of + :math:`(N, C_{cls,i}, H_i, W_i)` and :math:`(N, C_{out,i}, H_i, + W_i)`. """ - cls_res, reg_res = multi_apply(self.forward_single, feats) - level_num = len(cls_res) - preds = [[cls_res[i], reg_res[i]] for i in range(level_num)] - return preds - - def forward_single(self, x): cls_predict = self.out_conv_cls(x) reg_predict = self.out_conv_reg(x) return cls_predict, reg_predict - - def get_boundary(self, score_maps, img_metas, rescale): - assert len(score_maps) == len(self.scales) - - boundaries = [] - for idx, score_map in enumerate(score_maps): - scale = self.scales[idx] - boundaries = boundaries + self._get_boundary_single( - score_map, scale) - - # nms - boundaries = poly_nms(boundaries, self.nms_thr) - - if rescale: - boundaries = self.resize_boundary( - boundaries, 1.0 / img_metas[0]['scale_factor']) - - results = dict(boundary_result=boundaries) - return results - - def _get_boundary_single(self, score_map, scale): - assert len(score_map) == 2 - assert score_map[1].shape[1] == 4 * self.fourier_degree + 2 - - return self.postprocessor(score_map, scale) diff --git a/tests/test_models/test_textdet/test_heads/test_fce_head.py b/tests/test_models/test_textdet/test_heads/test_fce_head.py new file mode 100644 index 00000000..8cf11e79 --- /dev/null +++ b/tests/test_models/test_textdet/test_heads/test_fce_head.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmocr.models.textdet.heads import FCEHead + + +class TestFCEHead(TestCase): + + def test_init(self): + with self.assertRaises(AssertionError): + FCEHead(in_channels='test', fourier_degree=5) + + with self.assertRaises(AssertionError): + FCEHead(in_channels=1, fourier_degree='Text') + + def test_forward(self): + fce_head = FCEHead(in_channels=10, fourier_degree=5) + data = [ + torch.randn(2, 10, 20, 20), + torch.randn(2, 10, 30, 30), + torch.randn(2, 10, 40, 40) + ] + results = fce_head(data) + self.assertIn('cls_res', results[0]) + self.assertIn('reg_res', results[0]) + self.assertEqual(results[0]['cls_res'].shape, (2, 4, 20, 20)) + self.assertEqual(results[0]['reg_res'].shape, (2, 22, 20, 20)) + self.assertEqual(results[1]['cls_res'].shape, (2, 4, 30, 30)) + self.assertEqual(results[1]['reg_res'].shape, (2, 22, 30, 30)) + self.assertEqual(results[2]['cls_res'].shape, (2, 4, 40, 40)) + self.assertEqual(results[2]['reg_res'].shape, (2, 22, 40, 40))