update encoders

pull/1178/head
gaotongxiao 2022-07-08 14:42:06 +08:00
parent 2df8cb89a4
commit 2cb55550cd
4 changed files with 43 additions and 62 deletions

View File

@ -1,13 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .abinet_vision_model import ABIVisionModel
from .abi_encoder import ABIEncoder
from .base_encoder import BaseEncoder
from .channel_reduction_encoder import ChannelReductionEncoder
from .nrtr_encoder import NRTREncoder
from .sar_encoder import SAREncoder
from .satrn_encoder import SATRNEncoder
from .transformer import TransformerEncoder
__all__ = [
'SAREncoder', 'NRTREncoder', 'BaseEncoder', 'ChannelReductionEncoder',
'SATRNEncoder', 'TransformerEncoder', 'ABIVisionModel'
'SATRNEncoder', 'ABIEncoder'
]

View File

@ -1,36 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, List, Optional, Union
import torch
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from mmcv.runner import BaseModule, ModuleList
from mmocr.core import TextRecogDataSample
from mmocr.models.common.modules import PositionalEncoding
from mmocr.registry import MODELS
@MODELS.register_module()
class TransformerEncoder(BaseModule):
class ABIEncoder(BaseModule):
"""Implement transformer encoder for text recognition, modified from
`<https://github.com/FangShancheng/ABINet>`.
Args:
n_layers (int): Number of attention layers.
n_head (int): Number of parallel attention heads.
n_layers (int): Number of attention layers. Defaults to 2.
n_head (int): Number of parallel attention heads. Defaults to 8.
d_model (int): Dimension :math:`D_m` of the input from previous model.
d_inner (int): Hidden dimension of feedforward layers.
dropout (float): Dropout rate.
max_len (int): Maximum output sequence length :math:`T`.
Defaults to 512.
d_inner (int): Hidden dimension of feedforward layers. Defaults to
2048.
dropout (float): Dropout rate. Defaults to 0.1.
max_len (int): Maximum output sequence length :math:`T`. Defaults to
8 * 32.
init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to None.
"""
def __init__(self,
n_layers=2,
n_head=8,
d_model=512,
d_inner=2048,
dropout=0.1,
max_len=8 * 32,
init_cfg=None):
n_layers: int = 2,
n_head: int = 8,
d_model: int = 512,
d_inner: int = 2048,
dropout: float = 0.1,
max_len: int = 8 * 32,
init_cfg: Optional[Union[Dict, List[Dict]]] = None):
super().__init__(init_cfg=init_cfg)
assert d_model % n_head == 0, 'd_model must be divisible by n_head'
@ -56,10 +63,12 @@ class TransformerEncoder(BaseModule):
self.transformer = ModuleList(
[copy.deepcopy(encoder_layer) for _ in range(n_layers)])
def forward(self, feature):
def forward(self, feature: torch.Tensor,
data_samples: List[TextRecogDataSample]) -> torch.Tensor:
"""
Args:
feature (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
data_samples (List[TextRecogDataSample]): List of data samples.
Returns:
Tensor: Features of shape :math:`(N, D_m, H, W)`.

View File

@ -1,45 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.registry import MODELS
from .base_encoder import BaseEncoder
@MODELS.register_module()
class ABIVisionModel(BaseEncoder):
"""A wrapper of visual feature encoder and language token decoder that
converts visual features into text tokens.
Implementation of VisionEncoder in
`ABINet <https://arxiv.org/abs/1910.04396>`_.
Args:
encoder (dict): Config for image feature encoder.
decoder (dict): Config for language token decoder.
init_cfg (dict): Specifies the initialization method for model layers.
"""
def __init__(self,
encoder=dict(type='TransformerEncoder'),
decoder=dict(type='ABIVisionDecoder'),
init_cfg=dict(type='Xavier', layer='Conv2d'),
**kwargs):
super().__init__(init_cfg=init_cfg)
self.encoder = MODELS.build(encoder)
self.decoder = MODELS.build(decoder)
def forward(self, feat, img_metas=None):
"""
Args:
feat (Tensor): Images of shape (N, E, H, W).
Returns:
dict: A dict with keys ``feature``, ``logits`` and ``attn_scores``.
- | feature (Tensor): Shape (N, T, E). Raw visual features for
language decoder.
- | logits (Tensor): Shape (N, T, C). The raw logits for
characters. C is the number of characters.
- | attn_scores (Tensor): Shape (N, T, H, W). Intermediate result
for vision-language aligner.
"""
feat = self.encoder(feat)
return self.decoder(feat=feat, out_enc=None)

View File

@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmocr.models.textrecog.encoders.abi_encoder import ABIEncoder
class TestABIEncoder(TestCase):
def test_init(self):
with self.assertRaises(AssertionError):
ABIEncoder(d_model=512, n_head=10)
def test_forward(self):
model = ABIEncoder()
x = torch.randn(10, 512, 8, 32)
self.assertEqual(model(x, None).shape, torch.Size([10, 512, 8, 32]))