From 2cb55550cdca03b54d934a108827ade0aec9763e Mon Sep 17 00:00:00 2001 From: gaotongxiao Date: Fri, 8 Jul 2022 14:42:06 +0800 Subject: [PATCH] update encoders --- mmocr/models/textrecog/encoders/__init__.py | 5 +-- .../{transformer.py => abi_encoder.py} | 37 +++++++++------ .../textrecog/encoders/abinet_vision_model.py | 45 ------------------- .../test_encoders/test_abi_encoder.py | 18 ++++++++ 4 files changed, 43 insertions(+), 62 deletions(-) rename mmocr/models/textrecog/encoders/{transformer.py => abi_encoder.py} (68%) delete mode 100644 mmocr/models/textrecog/encoders/abinet_vision_model.py create mode 100644 tests/test_models/test_textrecog/test_encoders/test_abi_encoder.py diff --git a/mmocr/models/textrecog/encoders/__init__.py b/mmocr/models/textrecog/encoders/__init__.py index 570bbe7b..661775cf 100644 --- a/mmocr/models/textrecog/encoders/__init__.py +++ b/mmocr/models/textrecog/encoders/__init__.py @@ -1,13 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .abinet_vision_model import ABIVisionModel +from .abi_encoder import ABIEncoder from .base_encoder import BaseEncoder from .channel_reduction_encoder import ChannelReductionEncoder from .nrtr_encoder import NRTREncoder from .sar_encoder import SAREncoder from .satrn_encoder import SATRNEncoder -from .transformer import TransformerEncoder __all__ = [ 'SAREncoder', 'NRTREncoder', 'BaseEncoder', 'ChannelReductionEncoder', - 'SATRNEncoder', 'TransformerEncoder', 'ABIVisionModel' + 'SATRNEncoder', 'ABIEncoder' ] diff --git a/mmocr/models/textrecog/encoders/transformer.py b/mmocr/models/textrecog/encoders/abi_encoder.py similarity index 68% rename from mmocr/models/textrecog/encoders/transformer.py rename to mmocr/models/textrecog/encoders/abi_encoder.py index 08b29771..478ec680 100644 --- a/mmocr/models/textrecog/encoders/transformer.py +++ b/mmocr/models/textrecog/encoders/abi_encoder.py @@ -1,36 +1,43 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +from typing import Dict, List, Optional, Union +import torch from mmcv.cnn.bricks.transformer import BaseTransformerLayer from mmcv.runner import BaseModule, ModuleList +from mmocr.core import TextRecogDataSample from mmocr.models.common.modules import PositionalEncoding from mmocr.registry import MODELS @MODELS.register_module() -class TransformerEncoder(BaseModule): +class ABIEncoder(BaseModule): """Implement transformer encoder for text recognition, modified from ``. Args: - n_layers (int): Number of attention layers. - n_head (int): Number of parallel attention heads. + n_layers (int): Number of attention layers. Defaults to 2. + n_head (int): Number of parallel attention heads. Defaults to 8. d_model (int): Dimension :math:`D_m` of the input from previous model. - d_inner (int): Hidden dimension of feedforward layers. - dropout (float): Dropout rate. - max_len (int): Maximum output sequence length :math:`T`. + Defaults to 512. + d_inner (int): Hidden dimension of feedforward layers. Defaults to + 2048. + dropout (float): Dropout rate. Defaults to 0.1. + max_len (int): Maximum output sequence length :math:`T`. Defaults to + 8 * 32. init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to None. """ def __init__(self, - n_layers=2, - n_head=8, - d_model=512, - d_inner=2048, - dropout=0.1, - max_len=8 * 32, - init_cfg=None): + n_layers: int = 2, + n_head: int = 8, + d_model: int = 512, + d_inner: int = 2048, + dropout: float = 0.1, + max_len: int = 8 * 32, + init_cfg: Optional[Union[Dict, List[Dict]]] = None): super().__init__(init_cfg=init_cfg) assert d_model % n_head == 0, 'd_model must be divisible by n_head' @@ -56,10 +63,12 @@ class TransformerEncoder(BaseModule): self.transformer = ModuleList( [copy.deepcopy(encoder_layer) for _ in range(n_layers)]) - def forward(self, feature): + def forward(self, feature: torch.Tensor, + data_samples: List[TextRecogDataSample]) -> torch.Tensor: """ Args: feature (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`. + data_samples (List[TextRecogDataSample]): List of data samples. Returns: Tensor: Features of shape :math:`(N, D_m, H, W)`. diff --git a/mmocr/models/textrecog/encoders/abinet_vision_model.py b/mmocr/models/textrecog/encoders/abinet_vision_model.py deleted file mode 100644 index 188063d0..00000000 --- a/mmocr/models/textrecog/encoders/abinet_vision_model.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmocr.registry import MODELS -from .base_encoder import BaseEncoder - - -@MODELS.register_module() -class ABIVisionModel(BaseEncoder): - """A wrapper of visual feature encoder and language token decoder that - converts visual features into text tokens. - - Implementation of VisionEncoder in - `ABINet `_. - - Args: - encoder (dict): Config for image feature encoder. - decoder (dict): Config for language token decoder. - init_cfg (dict): Specifies the initialization method for model layers. - """ - - def __init__(self, - encoder=dict(type='TransformerEncoder'), - decoder=dict(type='ABIVisionDecoder'), - init_cfg=dict(type='Xavier', layer='Conv2d'), - **kwargs): - super().__init__(init_cfg=init_cfg) - self.encoder = MODELS.build(encoder) - self.decoder = MODELS.build(decoder) - - def forward(self, feat, img_metas=None): - """ - Args: - feat (Tensor): Images of shape (N, E, H, W). - - Returns: - dict: A dict with keys ``feature``, ``logits`` and ``attn_scores``. - - - | feature (Tensor): Shape (N, T, E). Raw visual features for - language decoder. - - | logits (Tensor): Shape (N, T, C). The raw logits for - characters. C is the number of characters. - - | attn_scores (Tensor): Shape (N, T, H, W). Intermediate result - for vision-language aligner. - """ - feat = self.encoder(feat) - return self.decoder(feat=feat, out_enc=None) diff --git a/tests/test_models/test_textrecog/test_encoders/test_abi_encoder.py b/tests/test_models/test_textrecog/test_encoders/test_abi_encoder.py new file mode 100644 index 00000000..7b108856 --- /dev/null +++ b/tests/test_models/test_textrecog/test_encoders/test_abi_encoder.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmocr.models.textrecog.encoders.abi_encoder import ABIEncoder + + +class TestABIEncoder(TestCase): + + def test_init(self): + with self.assertRaises(AssertionError): + ABIEncoder(d_model=512, n_head=10) + + def test_forward(self): + model = ABIEncoder() + x = torch.randn(10, 512, 8, 32) + self.assertEqual(model(x, None).shape, torch.Size([10, 512, 8, 32]))