add fce head

pull/1178/head
Mountchicken 2022-06-08 09:15:33 +08:00 committed by gaotongxiao
parent 200899b2a0
commit 17606c25fc
2 changed files with 94 additions and 97 deletions

View File

@ -1,17 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, List, Optional
import torch
import torch.nn as nn
from mmcv.runner import BaseModule
from mmdet.core import multi_apply
from mmocr.core import TextDetDataSample
from mmocr.models.textdet.heads import BaseTextDetHead
from mmocr.registry import MODELS
from ..postprocessors.utils import poly_nms
from .head_mixin import HeadMixin
@MODELS.register_module()
class FCEHead(HeadMixin, BaseModule):
class FCEHead(BaseTextDetHead):
"""The class for implementing FCENet head.
FCENet(CVPR2021): `Fourier Contour Embedding for Arbitrary-shaped Text
@ -19,72 +19,42 @@ class FCEHead(HeadMixin, BaseModule):
Args:
in_channels (int): The number of input channels.
scales (list[int]) : The scale of each layer.
fourier_degree (int) : The maximum Fourier transform degree k.
nms_thr (float) : The threshold of nms.
fourier_degree (int) : The maximum Fourier transform degree k. Defaults
to 5.
loss (dict): Config of loss for FCENet.
postprocessor (dict): Config of postprocessor for FCENet.
init_cfg (dict, optional): Initialization configs.
"""
def __init__(self,
in_channels,
scales,
fourier_degree=5,
nms_thr=0.1,
loss=dict(type='FCELoss', num_sample=50),
postprocessor=dict(
type='FCEPostprocessor',
text_repr_type='poly',
num_reconstr_points=50,
alpha=1.0,
beta=2.0,
score_thr=0.3),
train_cfg=None,
test_cfg=None,
init_cfg=dict(
type='Normal',
mean=0,
std=0.01,
override=[
dict(name='out_conv_cls'),
dict(name='out_conv_reg')
]),
**kwargs):
old_keys = [
'text_repr_type', 'decoding_type', 'num_reconstr_points', 'alpha',
'beta', 'score_thr'
]
for key in old_keys:
if kwargs.get(key, None):
postprocessor[key] = kwargs.get(key)
warnings.warn(
f'{key} is deprecated, please specify '
'it in postprocessor config dict. See '
'https://github.com/open-mmlab/mmocr/pull/640'
' for details.', UserWarning)
if kwargs.get('num_sample', None):
loss['num_sample'] = kwargs.get('num_sample')
warnings.warn(
'num_sample is deprecated, please specify '
'it in loss config dict. See '
'https://github.com/open-mmlab/mmocr/pull/640'
' for details.', UserWarning)
BaseModule.__init__(self, init_cfg=init_cfg)
def __init__(
self,
in_channels: int,
fourier_degree: int = 5,
loss: Dict = dict(type='FCELoss', num_sample=50),
postprocessor: Dict = dict(
type='FCEPostprocessor',
text_repr_type='poly',
num_reconstr_points=50,
alpha=1.0,
beta=2.0,
score_thr=0.3),
init_cfg: Optional[Dict] = dict(
type='Normal',
mean=0,
std=0.01,
override=[dict(name='out_conv_cls'),
dict(name='out_conv_reg')])
) -> None:
loss['fourier_degree'] = fourier_degree
postprocessor['fourier_degree'] = fourier_degree
postprocessor['nms_thr'] = nms_thr
HeadMixin.__init__(self, loss, postprocessor)
super().__init__(
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
assert isinstance(in_channels, int)
assert isinstance(fourier_degree, int)
self.downsample_ratio = 1.0
self.in_channels = in_channels
self.scales = scales
self.fourier_degree = fourier_degree
self.nms_thr = nms_thr
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.out_channels_cls = 4
self.out_channels_reg = (2 * self.fourier_degree + 1) * 2
@ -101,49 +71,43 @@ class FCEHead(HeadMixin, BaseModule):
stride=1,
padding=1)
def forward(self, feats):
def forward(self,
inputs: List[torch.Tensor],
data_samples: List[TextDetDataSample] = None) -> Dict:
"""
Args:
feats (list[Tensor]): Each tensor has the shape of :math:`(N, C_i,
inputs (List[Tensor]): Each tensor has the shape of :math:`(N, C_i,
H_i, W_i)`.
data_samples (List[TextDetDataSample]): List of data samples.
Default to None.
Returns:
list[dict]: A list of dict with keys of ``cls_res``, ``reg_res``
corresponds to the classification result and regression result
computed from the input tensor with the same index. They have
the shapes of :math:`(N, C_{cls,i}, H_i, W_i)` and :math:`(N,
C_{out,i}, H_i, W_i)`.
"""
cls_res, reg_res = multi_apply(self.forward_single, inputs)
level_num = len(cls_res)
preds = [
dict(cls_res=cls_res[i], reg_res=reg_res[i])
for i in range(level_num)
]
return preds
def forward_single(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function for a single feature level.
Args:
x (Tensor): The input tensor with the shape of :math:`(N, C_i,
H_i, W_i)`.
Returns:
list[[Tensor, Tensor]]: Each pair of tensors corresponds to the
classification result and regression result computed from the input
tensor with the same index. They have the shapes of :math:`(N,
C_{cls,i}, H_i, W_i)` and :math:`(N, C_{out,i}, H_i, W_i)`.
Tensor: The classification and regression result with the shape of
:math:`(N, C_{cls,i}, H_i, W_i)` and :math:`(N, C_{out,i}, H_i,
W_i)`.
"""
cls_res, reg_res = multi_apply(self.forward_single, feats)
level_num = len(cls_res)
preds = [[cls_res[i], reg_res[i]] for i in range(level_num)]
return preds
def forward_single(self, x):
cls_predict = self.out_conv_cls(x)
reg_predict = self.out_conv_reg(x)
return cls_predict, reg_predict
def get_boundary(self, score_maps, img_metas, rescale):
assert len(score_maps) == len(self.scales)
boundaries = []
for idx, score_map in enumerate(score_maps):
scale = self.scales[idx]
boundaries = boundaries + self._get_boundary_single(
score_map, scale)
# nms
boundaries = poly_nms(boundaries, self.nms_thr)
if rescale:
boundaries = self.resize_boundary(
boundaries, 1.0 / img_metas[0]['scale_factor'])
results = dict(boundary_result=boundaries)
return results
def _get_boundary_single(self, score_map, scale):
assert len(score_map) == 2
assert score_map[1].shape[1] == 4 * self.fourier_degree + 2
return self.postprocessor(score_map, scale)

View File

@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmocr.models.textdet.heads import FCEHead
class TestFCEHead(TestCase):
def test_init(self):
with self.assertRaises(AssertionError):
FCEHead(in_channels='test', fourier_degree=5)
with self.assertRaises(AssertionError):
FCEHead(in_channels=1, fourier_degree='Text')
def test_forward(self):
fce_head = FCEHead(in_channels=10, fourier_degree=5)
data = [
torch.randn(2, 10, 20, 20),
torch.randn(2, 10, 30, 30),
torch.randn(2, 10, 40, 40)
]
results = fce_head(data)
self.assertIn('cls_res', results[0])
self.assertIn('reg_res', results[0])
self.assertEqual(results[0]['cls_res'].shape, (2, 4, 20, 20))
self.assertEqual(results[0]['reg_res'].shape, (2, 22, 20, 20))
self.assertEqual(results[1]['cls_res'].shape, (2, 4, 30, 30))
self.assertEqual(results[1]['reg_res'].shape, (2, 22, 30, 30))
self.assertEqual(results[2]['cls_res'].shape, (2, 4, 40, 40))
self.assertEqual(results[2]['reg_res'].shape, (2, 22, 40, 40))