mirror of https://github.com/open-mmlab/mmocr.git
update encoders
parent
2df8cb89a4
commit
2cb55550cd
|
@ -1,13 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .abinet_vision_model import ABIVisionModel
|
||||
from .abi_encoder import ABIEncoder
|
||||
from .base_encoder import BaseEncoder
|
||||
from .channel_reduction_encoder import ChannelReductionEncoder
|
||||
from .nrtr_encoder import NRTREncoder
|
||||
from .sar_encoder import SAREncoder
|
||||
from .satrn_encoder import SATRNEncoder
|
||||
from .transformer import TransformerEncoder
|
||||
|
||||
__all__ = [
|
||||
'SAREncoder', 'NRTREncoder', 'BaseEncoder', 'ChannelReductionEncoder',
|
||||
'SATRNEncoder', 'TransformerEncoder', 'ABIVisionModel'
|
||||
'SATRNEncoder', 'ABIEncoder'
|
||||
]
|
||||
|
|
|
@ -1,36 +1,43 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
||||
from mmcv.runner import BaseModule, ModuleList
|
||||
|
||||
from mmocr.core import TextRecogDataSample
|
||||
from mmocr.models.common.modules import PositionalEncoding
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class TransformerEncoder(BaseModule):
|
||||
class ABIEncoder(BaseModule):
|
||||
"""Implement transformer encoder for text recognition, modified from
|
||||
`<https://github.com/FangShancheng/ABINet>`.
|
||||
|
||||
Args:
|
||||
n_layers (int): Number of attention layers.
|
||||
n_head (int): Number of parallel attention heads.
|
||||
n_layers (int): Number of attention layers. Defaults to 2.
|
||||
n_head (int): Number of parallel attention heads. Defaults to 8.
|
||||
d_model (int): Dimension :math:`D_m` of the input from previous model.
|
||||
d_inner (int): Hidden dimension of feedforward layers.
|
||||
dropout (float): Dropout rate.
|
||||
max_len (int): Maximum output sequence length :math:`T`.
|
||||
Defaults to 512.
|
||||
d_inner (int): Hidden dimension of feedforward layers. Defaults to
|
||||
2048.
|
||||
dropout (float): Dropout rate. Defaults to 0.1.
|
||||
max_len (int): Maximum output sequence length :math:`T`. Defaults to
|
||||
8 * 32.
|
||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_layers=2,
|
||||
n_head=8,
|
||||
d_model=512,
|
||||
d_inner=2048,
|
||||
dropout=0.1,
|
||||
max_len=8 * 32,
|
||||
init_cfg=None):
|
||||
n_layers: int = 2,
|
||||
n_head: int = 8,
|
||||
d_model: int = 512,
|
||||
d_inner: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
max_len: int = 8 * 32,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert d_model % n_head == 0, 'd_model must be divisible by n_head'
|
||||
|
@ -56,10 +63,12 @@ class TransformerEncoder(BaseModule):
|
|||
self.transformer = ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for _ in range(n_layers)])
|
||||
|
||||
def forward(self, feature):
|
||||
def forward(self, feature: torch.Tensor,
|
||||
data_samples: List[TextRecogDataSample]) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
feature (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
|
||||
data_samples (List[TextRecogDataSample]): List of data samples.
|
||||
|
||||
Returns:
|
||||
Tensor: Features of shape :math:`(N, D_m, H, W)`.
|
|
@ -1,45 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.registry import MODELS
|
||||
from .base_encoder import BaseEncoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ABIVisionModel(BaseEncoder):
|
||||
"""A wrapper of visual feature encoder and language token decoder that
|
||||
converts visual features into text tokens.
|
||||
|
||||
Implementation of VisionEncoder in
|
||||
`ABINet <https://arxiv.org/abs/1910.04396>`_.
|
||||
|
||||
Args:
|
||||
encoder (dict): Config for image feature encoder.
|
||||
decoder (dict): Config for language token decoder.
|
||||
init_cfg (dict): Specifies the initialization method for model layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
encoder=dict(type='TransformerEncoder'),
|
||||
decoder=dict(type='ABIVisionDecoder'),
|
||||
init_cfg=dict(type='Xavier', layer='Conv2d'),
|
||||
**kwargs):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.encoder = MODELS.build(encoder)
|
||||
self.decoder = MODELS.build(decoder)
|
||||
|
||||
def forward(self, feat, img_metas=None):
|
||||
"""
|
||||
Args:
|
||||
feat (Tensor): Images of shape (N, E, H, W).
|
||||
|
||||
Returns:
|
||||
dict: A dict with keys ``feature``, ``logits`` and ``attn_scores``.
|
||||
|
||||
- | feature (Tensor): Shape (N, T, E). Raw visual features for
|
||||
language decoder.
|
||||
- | logits (Tensor): Shape (N, T, C). The raw logits for
|
||||
characters. C is the number of characters.
|
||||
- | attn_scores (Tensor): Shape (N, T, H, W). Intermediate result
|
||||
for vision-language aligner.
|
||||
"""
|
||||
feat = self.encoder(feat)
|
||||
return self.decoder(feat=feat, out_enc=None)
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.encoders.abi_encoder import ABIEncoder
|
||||
|
||||
|
||||
class TestABIEncoder(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
ABIEncoder(d_model=512, n_head=10)
|
||||
|
||||
def test_forward(self):
|
||||
model = ABIEncoder()
|
||||
x = torch.randn(10, 512, 8, 32)
|
||||
self.assertEqual(model(x, None).shape, torch.Size([10, 512, 8, 32]))
|
Loading…
Reference in New Issue