mirror of https://github.com/open-mmlab/mmocr.git
add fce head
parent
200899b2a0
commit
17606c25fc
|
@ -1,17 +1,17 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule
|
||||
from mmdet.core import multi_apply
|
||||
|
||||
from mmocr.core import TextDetDataSample
|
||||
from mmocr.models.textdet.heads import BaseTextDetHead
|
||||
from mmocr.registry import MODELS
|
||||
from ..postprocessors.utils import poly_nms
|
||||
from .head_mixin import HeadMixin
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FCEHead(HeadMixin, BaseModule):
|
||||
class FCEHead(BaseTextDetHead):
|
||||
"""The class for implementing FCENet head.
|
||||
|
||||
FCENet(CVPR2021): `Fourier Contour Embedding for Arbitrary-shaped Text
|
||||
|
@ -19,72 +19,42 @@ class FCEHead(HeadMixin, BaseModule):
|
|||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
scales (list[int]) : The scale of each layer.
|
||||
fourier_degree (int) : The maximum Fourier transform degree k.
|
||||
nms_thr (float) : The threshold of nms.
|
||||
fourier_degree (int) : The maximum Fourier transform degree k. Defaults
|
||||
to 5.
|
||||
loss (dict): Config of loss for FCENet.
|
||||
postprocessor (dict): Config of postprocessor for FCENet.
|
||||
init_cfg (dict, optional): Initialization configs.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
scales,
|
||||
fourier_degree=5,
|
||||
nms_thr=0.1,
|
||||
loss=dict(type='FCELoss', num_sample=50),
|
||||
postprocessor=dict(
|
||||
type='FCEPostprocessor',
|
||||
text_repr_type='poly',
|
||||
num_reconstr_points=50,
|
||||
alpha=1.0,
|
||||
beta=2.0,
|
||||
score_thr=0.3),
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
init_cfg=dict(
|
||||
type='Normal',
|
||||
mean=0,
|
||||
std=0.01,
|
||||
override=[
|
||||
dict(name='out_conv_cls'),
|
||||
dict(name='out_conv_reg')
|
||||
]),
|
||||
**kwargs):
|
||||
old_keys = [
|
||||
'text_repr_type', 'decoding_type', 'num_reconstr_points', 'alpha',
|
||||
'beta', 'score_thr'
|
||||
]
|
||||
for key in old_keys:
|
||||
if kwargs.get(key, None):
|
||||
postprocessor[key] = kwargs.get(key)
|
||||
warnings.warn(
|
||||
f'{key} is deprecated, please specify '
|
||||
'it in postprocessor config dict. See '
|
||||
'https://github.com/open-mmlab/mmocr/pull/640'
|
||||
' for details.', UserWarning)
|
||||
if kwargs.get('num_sample', None):
|
||||
loss['num_sample'] = kwargs.get('num_sample')
|
||||
warnings.warn(
|
||||
'num_sample is deprecated, please specify '
|
||||
'it in loss config dict. See '
|
||||
'https://github.com/open-mmlab/mmocr/pull/640'
|
||||
' for details.', UserWarning)
|
||||
BaseModule.__init__(self, init_cfg=init_cfg)
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
fourier_degree: int = 5,
|
||||
loss: Dict = dict(type='FCELoss', num_sample=50),
|
||||
postprocessor: Dict = dict(
|
||||
type='FCEPostprocessor',
|
||||
text_repr_type='poly',
|
||||
num_reconstr_points=50,
|
||||
alpha=1.0,
|
||||
beta=2.0,
|
||||
score_thr=0.3),
|
||||
init_cfg: Optional[Dict] = dict(
|
||||
type='Normal',
|
||||
mean=0,
|
||||
std=0.01,
|
||||
override=[dict(name='out_conv_cls'),
|
||||
dict(name='out_conv_reg')])
|
||||
) -> None:
|
||||
loss['fourier_degree'] = fourier_degree
|
||||
postprocessor['fourier_degree'] = fourier_degree
|
||||
postprocessor['nms_thr'] = nms_thr
|
||||
HeadMixin.__init__(self, loss, postprocessor)
|
||||
super().__init__(
|
||||
loss=loss, postprocessor=postprocessor, init_cfg=init_cfg)
|
||||
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(fourier_degree, int)
|
||||
|
||||
self.downsample_ratio = 1.0
|
||||
self.in_channels = in_channels
|
||||
self.scales = scales
|
||||
self.fourier_degree = fourier_degree
|
||||
|
||||
self.nms_thr = nms_thr
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
self.out_channels_cls = 4
|
||||
self.out_channels_reg = (2 * self.fourier_degree + 1) * 2
|
||||
|
||||
|
@ -101,49 +71,43 @@ class FCEHead(HeadMixin, BaseModule):
|
|||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, feats):
|
||||
def forward(self,
|
||||
inputs: List[torch.Tensor],
|
||||
data_samples: List[TextDetDataSample] = None) -> Dict:
|
||||
"""
|
||||
Args:
|
||||
feats (list[Tensor]): Each tensor has the shape of :math:`(N, C_i,
|
||||
inputs (List[Tensor]): Each tensor has the shape of :math:`(N, C_i,
|
||||
H_i, W_i)`.
|
||||
data_samples (List[TextDetDataSample]): List of data samples.
|
||||
Default to None.
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of dict with keys of ``cls_res``, ``reg_res``
|
||||
corresponds to the classification result and regression result
|
||||
computed from the input tensor with the same index. They have
|
||||
the shapes of :math:`(N, C_{cls,i}, H_i, W_i)` and :math:`(N,
|
||||
C_{out,i}, H_i, W_i)`.
|
||||
"""
|
||||
cls_res, reg_res = multi_apply(self.forward_single, inputs)
|
||||
level_num = len(cls_res)
|
||||
preds = [
|
||||
dict(cls_res=cls_res[i], reg_res=reg_res[i])
|
||||
for i in range(level_num)
|
||||
]
|
||||
return preds
|
||||
|
||||
def forward_single(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function for a single feature level.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input tensor with the shape of :math:`(N, C_i,
|
||||
H_i, W_i)`.
|
||||
|
||||
Returns:
|
||||
list[[Tensor, Tensor]]: Each pair of tensors corresponds to the
|
||||
classification result and regression result computed from the input
|
||||
tensor with the same index. They have the shapes of :math:`(N,
|
||||
C_{cls,i}, H_i, W_i)` and :math:`(N, C_{out,i}, H_i, W_i)`.
|
||||
Tensor: The classification and regression result with the shape of
|
||||
:math:`(N, C_{cls,i}, H_i, W_i)` and :math:`(N, C_{out,i}, H_i,
|
||||
W_i)`.
|
||||
"""
|
||||
cls_res, reg_res = multi_apply(self.forward_single, feats)
|
||||
level_num = len(cls_res)
|
||||
preds = [[cls_res[i], reg_res[i]] for i in range(level_num)]
|
||||
return preds
|
||||
|
||||
def forward_single(self, x):
|
||||
cls_predict = self.out_conv_cls(x)
|
||||
reg_predict = self.out_conv_reg(x)
|
||||
return cls_predict, reg_predict
|
||||
|
||||
def get_boundary(self, score_maps, img_metas, rescale):
|
||||
assert len(score_maps) == len(self.scales)
|
||||
|
||||
boundaries = []
|
||||
for idx, score_map in enumerate(score_maps):
|
||||
scale = self.scales[idx]
|
||||
boundaries = boundaries + self._get_boundary_single(
|
||||
score_map, scale)
|
||||
|
||||
# nms
|
||||
boundaries = poly_nms(boundaries, self.nms_thr)
|
||||
|
||||
if rescale:
|
||||
boundaries = self.resize_boundary(
|
||||
boundaries, 1.0 / img_metas[0]['scale_factor'])
|
||||
|
||||
results = dict(boundary_result=boundaries)
|
||||
return results
|
||||
|
||||
def _get_boundary_single(self, score_map, scale):
|
||||
assert len(score_map) == 2
|
||||
assert score_map[1].shape[1] == 4 * self.fourier_degree + 2
|
||||
|
||||
return self.postprocessor(score_map, scale)
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.models.textdet.heads import FCEHead
|
||||
|
||||
|
||||
class TestFCEHead(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
FCEHead(in_channels='test', fourier_degree=5)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
FCEHead(in_channels=1, fourier_degree='Text')
|
||||
|
||||
def test_forward(self):
|
||||
fce_head = FCEHead(in_channels=10, fourier_degree=5)
|
||||
data = [
|
||||
torch.randn(2, 10, 20, 20),
|
||||
torch.randn(2, 10, 30, 30),
|
||||
torch.randn(2, 10, 40, 40)
|
||||
]
|
||||
results = fce_head(data)
|
||||
self.assertIn('cls_res', results[0])
|
||||
self.assertIn('reg_res', results[0])
|
||||
self.assertEqual(results[0]['cls_res'].shape, (2, 4, 20, 20))
|
||||
self.assertEqual(results[0]['reg_res'].shape, (2, 22, 20, 20))
|
||||
self.assertEqual(results[1]['cls_res'].shape, (2, 4, 30, 30))
|
||||
self.assertEqual(results[1]['reg_res'].shape, (2, 22, 30, 30))
|
||||
self.assertEqual(results[2]['cls_res'].shape, (2, 4, 40, 40))
|
||||
self.assertEqual(results[2]['reg_res'].shape, (2, 22, 40, 40))
|
Loading…
Reference in New Issue