mirror of https://github.com/open-mmlab/mmocr.git
update encoders
parent
2df8cb89a4
commit
2cb55550cd
|
@ -1,13 +1,12 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .abinet_vision_model import ABIVisionModel
|
from .abi_encoder import ABIEncoder
|
||||||
from .base_encoder import BaseEncoder
|
from .base_encoder import BaseEncoder
|
||||||
from .channel_reduction_encoder import ChannelReductionEncoder
|
from .channel_reduction_encoder import ChannelReductionEncoder
|
||||||
from .nrtr_encoder import NRTREncoder
|
from .nrtr_encoder import NRTREncoder
|
||||||
from .sar_encoder import SAREncoder
|
from .sar_encoder import SAREncoder
|
||||||
from .satrn_encoder import SATRNEncoder
|
from .satrn_encoder import SATRNEncoder
|
||||||
from .transformer import TransformerEncoder
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'SAREncoder', 'NRTREncoder', 'BaseEncoder', 'ChannelReductionEncoder',
|
'SAREncoder', 'NRTREncoder', 'BaseEncoder', 'ChannelReductionEncoder',
|
||||||
'SATRNEncoder', 'TransformerEncoder', 'ABIVisionModel'
|
'SATRNEncoder', 'ABIEncoder'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,36 +1,43 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import copy
|
import copy
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
||||||
from mmcv.runner import BaseModule, ModuleList
|
from mmcv.runner import BaseModule, ModuleList
|
||||||
|
|
||||||
|
from mmocr.core import TextRecogDataSample
|
||||||
from mmocr.models.common.modules import PositionalEncoding
|
from mmocr.models.common.modules import PositionalEncoding
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class TransformerEncoder(BaseModule):
|
class ABIEncoder(BaseModule):
|
||||||
"""Implement transformer encoder for text recognition, modified from
|
"""Implement transformer encoder for text recognition, modified from
|
||||||
`<https://github.com/FangShancheng/ABINet>`.
|
`<https://github.com/FangShancheng/ABINet>`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_layers (int): Number of attention layers.
|
n_layers (int): Number of attention layers. Defaults to 2.
|
||||||
n_head (int): Number of parallel attention heads.
|
n_head (int): Number of parallel attention heads. Defaults to 8.
|
||||||
d_model (int): Dimension :math:`D_m` of the input from previous model.
|
d_model (int): Dimension :math:`D_m` of the input from previous model.
|
||||||
d_inner (int): Hidden dimension of feedforward layers.
|
Defaults to 512.
|
||||||
dropout (float): Dropout rate.
|
d_inner (int): Hidden dimension of feedforward layers. Defaults to
|
||||||
max_len (int): Maximum output sequence length :math:`T`.
|
2048.
|
||||||
|
dropout (float): Dropout rate. Defaults to 0.1.
|
||||||
|
max_len (int): Maximum output sequence length :math:`T`. Defaults to
|
||||||
|
8 * 32.
|
||||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||||
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
n_layers=2,
|
n_layers: int = 2,
|
||||||
n_head=8,
|
n_head: int = 8,
|
||||||
d_model=512,
|
d_model: int = 512,
|
||||||
d_inner=2048,
|
d_inner: int = 2048,
|
||||||
dropout=0.1,
|
dropout: float = 0.1,
|
||||||
max_len=8 * 32,
|
max_len: int = 8 * 32,
|
||||||
init_cfg=None):
|
init_cfg: Optional[Union[Dict, List[Dict]]] = None):
|
||||||
super().__init__(init_cfg=init_cfg)
|
super().__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
assert d_model % n_head == 0, 'd_model must be divisible by n_head'
|
assert d_model % n_head == 0, 'd_model must be divisible by n_head'
|
||||||
|
@ -56,10 +63,12 @@ class TransformerEncoder(BaseModule):
|
||||||
self.transformer = ModuleList(
|
self.transformer = ModuleList(
|
||||||
[copy.deepcopy(encoder_layer) for _ in range(n_layers)])
|
[copy.deepcopy(encoder_layer) for _ in range(n_layers)])
|
||||||
|
|
||||||
def forward(self, feature):
|
def forward(self, feature: torch.Tensor,
|
||||||
|
data_samples: List[TextRecogDataSample]) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
feature (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
|
feature (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
|
||||||
|
data_samples (List[TextRecogDataSample]): List of data samples.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: Features of shape :math:`(N, D_m, H, W)`.
|
Tensor: Features of shape :math:`(N, D_m, H, W)`.
|
|
@ -1,45 +0,0 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
from mmocr.registry import MODELS
|
|
||||||
from .base_encoder import BaseEncoder
|
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
|
||||||
class ABIVisionModel(BaseEncoder):
|
|
||||||
"""A wrapper of visual feature encoder and language token decoder that
|
|
||||||
converts visual features into text tokens.
|
|
||||||
|
|
||||||
Implementation of VisionEncoder in
|
|
||||||
`ABINet <https://arxiv.org/abs/1910.04396>`_.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
encoder (dict): Config for image feature encoder.
|
|
||||||
decoder (dict): Config for language token decoder.
|
|
||||||
init_cfg (dict): Specifies the initialization method for model layers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
encoder=dict(type='TransformerEncoder'),
|
|
||||||
decoder=dict(type='ABIVisionDecoder'),
|
|
||||||
init_cfg=dict(type='Xavier', layer='Conv2d'),
|
|
||||||
**kwargs):
|
|
||||||
super().__init__(init_cfg=init_cfg)
|
|
||||||
self.encoder = MODELS.build(encoder)
|
|
||||||
self.decoder = MODELS.build(decoder)
|
|
||||||
|
|
||||||
def forward(self, feat, img_metas=None):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
feat (Tensor): Images of shape (N, E, H, W).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: A dict with keys ``feature``, ``logits`` and ``attn_scores``.
|
|
||||||
|
|
||||||
- | feature (Tensor): Shape (N, T, E). Raw visual features for
|
|
||||||
language decoder.
|
|
||||||
- | logits (Tensor): Shape (N, T, C). The raw logits for
|
|
||||||
characters. C is the number of characters.
|
|
||||||
- | attn_scores (Tensor): Shape (N, T, H, W). Intermediate result
|
|
||||||
for vision-language aligner.
|
|
||||||
"""
|
|
||||||
feat = self.encoder(feat)
|
|
||||||
return self.decoder(feat=feat, out_enc=None)
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmocr.models.textrecog.encoders.abi_encoder import ABIEncoder
|
||||||
|
|
||||||
|
|
||||||
|
class TestABIEncoder(TestCase):
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
ABIEncoder(d_model=512, n_head=10)
|
||||||
|
|
||||||
|
def test_forward(self):
|
||||||
|
model = ABIEncoder()
|
||||||
|
x = torch.randn(10, 512, 8, 32)
|
||||||
|
self.assertEqual(model(x, None).shape, torch.Size([10, 512, 8, 32]))
|
Loading…
Reference in New Issue