mirror of https://github.com/open-mmlab/mmocr.git
[ASTER] Add ASTER decoder (#1625)
* add aster decoder * aster decoder * decoder Co-authored-by: gaotongxiao <gaotongxiao@gmail.com>pull/1623/head
parent
0bd62d67c8
commit
419f98d8a4
|
@ -2,6 +2,7 @@
|
|||
from .abi_fuser import ABIFuser
|
||||
from .abi_language_decoder import ABILanguageDecoder
|
||||
from .abi_vision_decoder import ABIVisionDecoder
|
||||
from .aster_decoder import ASTERDecoder
|
||||
from .base import BaseDecoder
|
||||
from .crnn_decoder import CRNNDecoder
|
||||
from .master_decoder import MasterDecoder
|
||||
|
@ -17,5 +18,5 @@ __all__ = [
|
|||
'ParallelSARDecoderWithBS', 'NRTRDecoder', 'BaseDecoder',
|
||||
'SequenceAttentionDecoder', 'PositionAttentionDecoder',
|
||||
'ABILanguageDecoder', 'ABIVisionDecoder', 'MasterDecoder',
|
||||
'RobustScannerFuser', 'ABIFuser'
|
||||
'RobustScannerFuser', 'ABIFuser', 'ASTERDecoder'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,180 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmocr.models.common.dictionary import Dictionary
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base import BaseDecoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ASTERDecoder(BaseDecoder):
|
||||
"""Implement attention decoder.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
emb_dims (int): Dims of char embedding. Defaults to 512.
|
||||
attn_dims (int): Dims of attention. Both hidden states and features
|
||||
will be projected to this dims. Defaults to 512.
|
||||
hidden_size (int): Dims of hidden state for GRU. Defaults to 512.
|
||||
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
|
||||
the instance of `Dictionary`. Defaults to None.
|
||||
max_seq_len (int): Maximum output sequence length :math:`T`. Defaults
|
||||
to 25.
|
||||
module_loss (dict, optional): Config to build loss. Defaults to None.
|
||||
postprocessor (dict, optional): Config to build postprocessor.
|
||||
Defaults to None.
|
||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
emb_dims: int = 512,
|
||||
attn_dims: int = 512,
|
||||
hidden_size: int = 512,
|
||||
dictionary: Union[Dictionary, Dict] = None,
|
||||
max_seq_len: int = 25,
|
||||
module_loss: Dict = None,
|
||||
postprocessor: Dict = None,
|
||||
init_cfg=dict(type='Xavier', layer='Conv2d')):
|
||||
super().__init__(
|
||||
init_cfg=init_cfg,
|
||||
dictionary=dictionary,
|
||||
module_loss=module_loss,
|
||||
postprocessor=postprocessor,
|
||||
max_seq_len=max_seq_len)
|
||||
|
||||
self.start_idx = self.dictionary.start_idx
|
||||
self.num_classes = self.dictionary.num_classes
|
||||
self.in_channels = in_channels
|
||||
self.embedding_dim = emb_dims
|
||||
self.att_dims = attn_dims
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# Projection layers
|
||||
self.proj_feat = nn.Linear(in_channels, attn_dims)
|
||||
self.proj_hidden = nn.Linear(hidden_size, attn_dims)
|
||||
self.proj_sum = nn.Linear(attn_dims, 1)
|
||||
|
||||
# Decoder input embedding
|
||||
self.embedding = nn.Embedding(self.num_classes, self.att_dims)
|
||||
|
||||
# GRU
|
||||
self.gru = nn.GRU(
|
||||
input_size=self.in_channels + self.embedding_dim,
|
||||
hidden_size=self.hidden_size,
|
||||
batch_first=True)
|
||||
|
||||
# Prediction layer
|
||||
self.fc = nn.Linear(hidden_size, self.dictionary.num_classes)
|
||||
|
||||
def _attention(self, feat: torch.Tensor, prev_hidden: torch.Tensor,
|
||||
prev_char: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Implement the attention mechanism.
|
||||
|
||||
Args:
|
||||
feat (Tensor): Feature map from encoder of shape :math:`(N, T, C)`.
|
||||
prev_hidden (Tensor): Previous hidden state from GRU of shape
|
||||
:math:`(1, N, self.hidden_size)`.
|
||||
prev_char (Tensor): Previous predicted character of shape
|
||||
:math:`(N, )`.
|
||||
|
||||
Returns:
|
||||
tuple(Tensor, Tensor):
|
||||
- output (Tensor): Predicted character of current time step of
|
||||
shape :math:`(N, 1)`.
|
||||
- state (Tensor): Hidden state from GRU of current time step of
|
||||
shape :math:`(N, self.hidden_size)`.
|
||||
"""
|
||||
# Calculate the attention weights
|
||||
B, T, _ = feat.size()
|
||||
feat_proj = self.proj_feat(feat) # [N, T, attn_dims]
|
||||
hidden_proj = self.proj_hidden(prev_hidden) # [1, N, attn_dims]
|
||||
hidden_proj = hidden_proj.squeeze(0).unsqueeze(1) # [N, 1, attn_dims]
|
||||
hidden_proj = hidden_proj.expand(B, T,
|
||||
self.att_dims) # [N, T, attn_dims]
|
||||
|
||||
sum_tanh = torch.tanh(feat_proj + hidden_proj) # [N, T, attn_dims]
|
||||
sum_proj = self.proj_sum(sum_tanh).squeeze(-1) # [N, T]
|
||||
attn_weights = torch.softmax(sum_proj, dim=1) # [N, T]
|
||||
|
||||
# GRU forward
|
||||
context = torch.bmm(attn_weights.unsqueeze(1), feat).squeeze(1)
|
||||
char_embed = self.embedding(prev_char.long()) # [N, emb_dims]
|
||||
output, state = self.gru(
|
||||
torch.cat([char_embed, context], 1).unsqueeze(1), prev_hidden)
|
||||
output = output.squeeze(1)
|
||||
output = self.fc(output)
|
||||
return output, state
|
||||
|
||||
def forward_train(
|
||||
self,
|
||||
feat: torch.Tensor = None,
|
||||
out_enc: Optional[torch.Tensor] = None,
|
||||
data_samples: Optional[Sequence[TextRecogDataSample]] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
feat (Tensor): Feature from backbone. Unused in this decoder.
|
||||
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
|
||||
data_samples (list[TextRecogDataSample], optional): Batch of
|
||||
TextRecogDataSample, containing gt_text information. Defaults
|
||||
to None.
|
||||
|
||||
Returns:
|
||||
Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where
|
||||
:math:`C` is ``num_classes``.
|
||||
"""
|
||||
B = out_enc.shape[0]
|
||||
state = torch.zeros(1, B, self.hidden_size).to(out_enc.device)
|
||||
padded_targets = [
|
||||
data_sample.gt_text.padded_indexes for data_sample in data_samples
|
||||
]
|
||||
padded_targets = torch.stack(padded_targets, dim=0).to(out_enc.device)
|
||||
outputs = []
|
||||
for i in range(self.max_seq_len):
|
||||
prev_char = padded_targets[:, i].to(out_enc.device)
|
||||
output, state = self._attention(out_enc, state, prev_char)
|
||||
outputs.append(output)
|
||||
outputs = torch.cat([_.unsqueeze(1) for _ in outputs], 1)
|
||||
return outputs
|
||||
|
||||
def forward_test(
|
||||
self,
|
||||
feat: Optional[torch.Tensor] = None,
|
||||
out_enc: Optional[torch.Tensor] = None,
|
||||
data_samples: Optional[Sequence[TextRecogDataSample]] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
feat (Tensor): Feature from backbone. Unused in this decoder.
|
||||
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
|
||||
data_samples (list[TextRecogDataSample], optional): Batch of
|
||||
TextRecogDataSample, containing gt_text information. Defaults
|
||||
to None. Unused in this decoder.
|
||||
|
||||
Returns:
|
||||
Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where
|
||||
:math:`C` is ``num_classes``.
|
||||
"""
|
||||
B = out_enc.shape[0]
|
||||
predicted = []
|
||||
state = torch.zeros(1, B, self.hidden_size).to(out_enc.device)
|
||||
outputs = []
|
||||
for i in range(self.max_seq_len):
|
||||
if i == 0:
|
||||
prev_char = torch.zeros(B).fill_(self.start_idx).to(
|
||||
out_enc.device)
|
||||
else:
|
||||
prev_char = predicted
|
||||
|
||||
output, state = self._attention(out_enc, state, prev_char)
|
||||
outputs.append(output)
|
||||
_, predicted = output.max(-1)
|
||||
outputs = torch.cat([_.unsqueeze(1) for _ in outputs], 1)
|
||||
return outputs
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmocr.models.textrecog.decoders import ASTERDecoder
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
|
||||
|
||||
class TestASTERDecoder(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
gt_text_sample1 = TextRecogDataSample()
|
||||
gt_text = LabelData()
|
||||
gt_text.item = 'Hello'
|
||||
gt_text_sample1.gt_text = gt_text
|
||||
|
||||
gt_text_sample2 = TextRecogDataSample()
|
||||
gt_text = LabelData()
|
||||
gt_text = LabelData()
|
||||
gt_text.item = 'World1'
|
||||
gt_text_sample2.gt_text = gt_text
|
||||
|
||||
self.data_info = [gt_text_sample1, gt_text_sample2]
|
||||
|
||||
def _create_dummy_dict_file(
|
||||
self, dict_file,
|
||||
chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA
|
||||
with open(dict_file, 'w') as f:
|
||||
for char in chars:
|
||||
f.write(char + '\n')
|
||||
|
||||
def test_init(self):
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=True,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
loss_cfg = dict(type='CEModuleLoss')
|
||||
ASTERDecoder(
|
||||
in_channels=512, dictionary=dict_cfg, module_loss=loss_cfg)
|
||||
tmp_dir.cleanup()
|
||||
|
||||
def test_forward_train(self):
|
||||
encoder_out = torch.randn(2, 25, 512)
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
# test diction cfg
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=True,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
loss_cfg = dict(type='CEModuleLoss')
|
||||
decoder = ASTERDecoder(
|
||||
in_channels=512,
|
||||
dictionary=dict_cfg,
|
||||
module_loss=loss_cfg,
|
||||
max_seq_len=25)
|
||||
data_samples = decoder.module_loss.get_targets(self.data_info)
|
||||
output = decoder.forward_train(
|
||||
out_enc=encoder_out, data_samples=data_samples)
|
||||
self.assertTupleEqual(tuple(output.shape), (2, 25, 39))
|
||||
|
||||
def test_forward_test(self):
|
||||
encoder_out = torch.randn(2, 25, 512)
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
# test diction cfg
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=True,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
loss_cfg = dict(type='CEModuleLoss')
|
||||
decoder = ASTERDecoder(
|
||||
in_channels=512,
|
||||
dictionary=dict_cfg,
|
||||
module_loss=loss_cfg,
|
||||
max_seq_len=25)
|
||||
output = decoder.forward_test(
|
||||
out_enc=encoder_out, data_samples=self.data_info)
|
||||
self.assertTupleEqual(tuple(output.shape), (2, 25, 39))
|
Loading…
Reference in New Issue