update encoders

pull/1178/head
gaotongxiao 2022-07-08 14:42:06 +08:00
parent 2df8cb89a4
commit 2cb55550cd
4 changed files with 43 additions and 62 deletions

View File

@ -1,13 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .abinet_vision_model import ABIVisionModel from .abi_encoder import ABIEncoder
from .base_encoder import BaseEncoder from .base_encoder import BaseEncoder
from .channel_reduction_encoder import ChannelReductionEncoder from .channel_reduction_encoder import ChannelReductionEncoder
from .nrtr_encoder import NRTREncoder from .nrtr_encoder import NRTREncoder
from .sar_encoder import SAREncoder from .sar_encoder import SAREncoder
from .satrn_encoder import SATRNEncoder from .satrn_encoder import SATRNEncoder
from .transformer import TransformerEncoder
__all__ = [ __all__ = [
'SAREncoder', 'NRTREncoder', 'BaseEncoder', 'ChannelReductionEncoder', 'SAREncoder', 'NRTREncoder', 'BaseEncoder', 'ChannelReductionEncoder',
'SATRNEncoder', 'TransformerEncoder', 'ABIVisionModel' 'SATRNEncoder', 'ABIEncoder'
] ]

View File

@ -1,36 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
from typing import Dict, List, Optional, Union
import torch
from mmcv.cnn.bricks.transformer import BaseTransformerLayer from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from mmcv.runner import BaseModule, ModuleList from mmcv.runner import BaseModule, ModuleList
from mmocr.core import TextRecogDataSample
from mmocr.models.common.modules import PositionalEncoding from mmocr.models.common.modules import PositionalEncoding
from mmocr.registry import MODELS from mmocr.registry import MODELS
@MODELS.register_module() @MODELS.register_module()
class TransformerEncoder(BaseModule): class ABIEncoder(BaseModule):
"""Implement transformer encoder for text recognition, modified from """Implement transformer encoder for text recognition, modified from
`<https://github.com/FangShancheng/ABINet>`. `<https://github.com/FangShancheng/ABINet>`.
Args: Args:
n_layers (int): Number of attention layers. n_layers (int): Number of attention layers. Defaults to 2.
n_head (int): Number of parallel attention heads. n_head (int): Number of parallel attention heads. Defaults to 8.
d_model (int): Dimension :math:`D_m` of the input from previous model. d_model (int): Dimension :math:`D_m` of the input from previous model.
d_inner (int): Hidden dimension of feedforward layers. Defaults to 512.
dropout (float): Dropout rate. d_inner (int): Hidden dimension of feedforward layers. Defaults to
max_len (int): Maximum output sequence length :math:`T`. 2048.
dropout (float): Dropout rate. Defaults to 0.1.
max_len (int): Maximum output sequence length :math:`T`. Defaults to
8 * 32.
init_cfg (dict or list[dict], optional): Initialization configs. init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to None.
""" """
def __init__(self, def __init__(self,
n_layers=2, n_layers: int = 2,
n_head=8, n_head: int = 8,
d_model=512, d_model: int = 512,
d_inner=2048, d_inner: int = 2048,
dropout=0.1, dropout: float = 0.1,
max_len=8 * 32, max_len: int = 8 * 32,
init_cfg=None): init_cfg: Optional[Union[Dict, List[Dict]]] = None):
super().__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
assert d_model % n_head == 0, 'd_model must be divisible by n_head' assert d_model % n_head == 0, 'd_model must be divisible by n_head'
@ -56,10 +63,12 @@ class TransformerEncoder(BaseModule):
self.transformer = ModuleList( self.transformer = ModuleList(
[copy.deepcopy(encoder_layer) for _ in range(n_layers)]) [copy.deepcopy(encoder_layer) for _ in range(n_layers)])
def forward(self, feature): def forward(self, feature: torch.Tensor,
data_samples: List[TextRecogDataSample]) -> torch.Tensor:
""" """
Args: Args:
feature (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`. feature (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
data_samples (List[TextRecogDataSample]): List of data samples.
Returns: Returns:
Tensor: Features of shape :math:`(N, D_m, H, W)`. Tensor: Features of shape :math:`(N, D_m, H, W)`.

View File

@ -1,45 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.registry import MODELS
from .base_encoder import BaseEncoder
@MODELS.register_module()
class ABIVisionModel(BaseEncoder):
"""A wrapper of visual feature encoder and language token decoder that
converts visual features into text tokens.
Implementation of VisionEncoder in
`ABINet <https://arxiv.org/abs/1910.04396>`_.
Args:
encoder (dict): Config for image feature encoder.
decoder (dict): Config for language token decoder.
init_cfg (dict): Specifies the initialization method for model layers.
"""
def __init__(self,
encoder=dict(type='TransformerEncoder'),
decoder=dict(type='ABIVisionDecoder'),
init_cfg=dict(type='Xavier', layer='Conv2d'),
**kwargs):
super().__init__(init_cfg=init_cfg)
self.encoder = MODELS.build(encoder)
self.decoder = MODELS.build(decoder)
def forward(self, feat, img_metas=None):
"""
Args:
feat (Tensor): Images of shape (N, E, H, W).
Returns:
dict: A dict with keys ``feature``, ``logits`` and ``attn_scores``.
- | feature (Tensor): Shape (N, T, E). Raw visual features for
language decoder.
- | logits (Tensor): Shape (N, T, C). The raw logits for
characters. C is the number of characters.
- | attn_scores (Tensor): Shape (N, T, H, W). Intermediate result
for vision-language aligner.
"""
feat = self.encoder(feat)
return self.decoder(feat=feat, out_enc=None)

View File

@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmocr.models.textrecog.encoders.abi_encoder import ABIEncoder
class TestABIEncoder(TestCase):
def test_init(self):
with self.assertRaises(AssertionError):
ABIEncoder(d_model=512, n_head=10)
def test_forward(self):
model = ABIEncoder()
x = torch.randn(10, 512, 8, 32)
self.assertEqual(model(x, None).shape, torch.Size([10, 512, 8, 32]))