mirror of https://github.com/open-mmlab/mmocr.git
[NRTR] NRTR Decoder
parent
d41921f03d
commit
8614070e36
|
@ -1,12 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.runner import ModuleList
|
||||
|
||||
from mmocr.core.data_structures import TextRecogDataSample
|
||||
from mmocr.models.common import PositionalEncoding, TFDecoderLayer
|
||||
from mmocr.models.textrecog.dictionary import Dictionary
|
||||
from mmocr.registry import MODELS
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
@ -16,20 +19,25 @@ class NRTRDecoder(BaseDecoder):
|
|||
"""Transformer Decoder block with self attention mechanism.
|
||||
|
||||
Args:
|
||||
n_layers (int): Number of attention layers.
|
||||
d_embedding (int): Language embedding dimension.
|
||||
n_head (int): Number of parallel attention heads.
|
||||
d_k (int): Dimension of the key vector.
|
||||
d_v (int): Dimension of the value vector.
|
||||
n_layers (int): Number of attention layers. Defaults to 6.
|
||||
d_embedding (int): Language embedding dimension. Defaults to 512.
|
||||
n_head (int): Number of parallel attention heads. Defaults to 8.
|
||||
d_k (int): Dimension of the key vector. Defaults to 64.
|
||||
d_v (int): Dimension of the value vector. Defaults to 64
|
||||
d_model (int): Dimension :math:`D_m` of the input from previous model.
|
||||
d_inner (int): Hidden dimension of feedforward layers.
|
||||
Defaults to 512.
|
||||
d_inner (int): Hidden dimension of feedforward layers. Defaults to 256.
|
||||
n_position (int): Length of the positional encoding vector. Must be
|
||||
greater than ``max_seq_len``.
|
||||
dropout (float): Dropout rate.
|
||||
num_classes (int): Number of output classes :math:`C`.
|
||||
max_seq_len (int): Maximum output sequence length :math:`T`.
|
||||
start_idx (int): The index of `<SOS>`.
|
||||
padding_idx (int): The index of `<PAD>`.
|
||||
greater than ``max_seq_len``. Defaults to 200.
|
||||
dropout (float): Dropout rate for text embedding, MHSA, FFN. Defaults
|
||||
to 0.1.
|
||||
loss (dict, optional): Config to build loss. Defaults to None.
|
||||
postprocessor (dict, optional): Config to build postprocessor.
|
||||
Defaults to None.
|
||||
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
|
||||
the instance of `Dictionary`.
|
||||
max_seq_len (int): Maximum output sequence length :math:`T`. Defaults
|
||||
to 40.
|
||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||
|
||||
Warning:
|
||||
|
@ -40,29 +48,34 @@ class NRTRDecoder(BaseDecoder):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_layers=6,
|
||||
d_embedding=512,
|
||||
n_head=8,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
d_model=512,
|
||||
d_inner=256,
|
||||
n_position=200,
|
||||
dropout=0.1,
|
||||
num_classes=93,
|
||||
max_seq_len=40,
|
||||
start_idx=1,
|
||||
padding_idx=92,
|
||||
init_cfg=None,
|
||||
**kwargs):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
n_layers: int = 6,
|
||||
d_embedding: int = 512,
|
||||
n_head: int = 8,
|
||||
d_k: int = 64,
|
||||
d_v: int = 64,
|
||||
d_model: int = 512,
|
||||
d_inner: int = 256,
|
||||
n_position: int = 200,
|
||||
dropout: float = 0.1,
|
||||
loss: Optional[Dict] = None,
|
||||
postprocessor: Optional[Dict] = None,
|
||||
dictionary: Optional[Union[Dict, Dictionary]] = None,
|
||||
max_seq_len: int = 40,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||
super().__init__(
|
||||
loss=loss,
|
||||
postprocessor=postprocessor,
|
||||
dictionary=dictionary,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.padding_idx = padding_idx
|
||||
self.start_idx = start_idx
|
||||
self.padding_idx = self.dictionary.padding_idx
|
||||
self.start_idx = self.dictionary.start_idx
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
self.trg_word_emb = nn.Embedding(
|
||||
num_classes, d_embedding, padding_idx=padding_idx)
|
||||
self.dictionary.num_classes,
|
||||
d_embedding,
|
||||
padding_idx=self.padding_idx)
|
||||
|
||||
self.position_enc = PositionalEncoding(
|
||||
d_embedding, n_position=n_position)
|
||||
|
@ -70,38 +83,92 @@ class NRTRDecoder(BaseDecoder):
|
|||
|
||||
self.layer_stack = ModuleList([
|
||||
TFDecoderLayer(
|
||||
d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs)
|
||||
d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||
|
||||
pred_num_class = num_classes - 1 # ignore padding_idx
|
||||
pred_num_class = self.dictionary.num_classes
|
||||
self.classifier = nn.Linear(d_model, pred_num_class)
|
||||
|
||||
@staticmethod
|
||||
def get_pad_mask(seq, pad_idx):
|
||||
def _get_target_mask(self, trg_seq: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate mask for target sequence.
|
||||
|
||||
return (seq != pad_idx).unsqueeze(-2)
|
||||
Args:
|
||||
trg_seq (torch.Tensor): Input text sequence. Shape :math:`(N, T)`.
|
||||
|
||||
@staticmethod
|
||||
def get_subsequent_mask(seq):
|
||||
"""For masking out the subsequent info."""
|
||||
len_s = seq.size(1)
|
||||
Returns:
|
||||
Tensor: Target mask. Shape :math:`(N, T, T)`.
|
||||
E.g.:
|
||||
seq = torch.Tensor([[1, 2, 0, 0]]), pad_idx = 0, then
|
||||
target_mask =
|
||||
torch.Tensor([[[True, False, False, False],
|
||||
[True, True, False, False],
|
||||
[True, True, False, False],
|
||||
[True, True, False, False]]])
|
||||
"""
|
||||
|
||||
pad_mask = (trg_seq != self.padding_idx).unsqueeze(-2)
|
||||
|
||||
len_s = trg_seq.size(1)
|
||||
subsequent_mask = 1 - torch.triu(
|
||||
torch.ones((len_s, len_s), device=seq.device), diagonal=1)
|
||||
torch.ones((len_s, len_s), device=trg_seq.device), diagonal=1)
|
||||
subsequent_mask = subsequent_mask.unsqueeze(0).bool()
|
||||
|
||||
return subsequent_mask
|
||||
return pad_mask & subsequent_mask
|
||||
|
||||
def _get_source_mask(self, src_seq: torch.Tensor,
|
||||
valid_ratios: Sequence[float]) -> torch.Tensor:
|
||||
"""Generate mask for source sequence.
|
||||
|
||||
Args:
|
||||
src_seq (torch.Tensor): Image sequence. Shape :math:`(N, T, C)`.
|
||||
valid_ratios (list[float]): The valid ratio of input image. For
|
||||
example, if the width of the original image is w1 and the width
|
||||
after padding is w2, then valid_ratio = w1/w2. Source mask is
|
||||
used to cover the area of the padding region.
|
||||
|
||||
Returns:
|
||||
Tensor or None: Source mask. Shape :math:`(N, T)`. The region of
|
||||
padding area are False, and the rest are True.
|
||||
"""
|
||||
|
||||
N, T, _ = src_seq.size()
|
||||
mask = None
|
||||
if len(valid_ratios) > 0:
|
||||
mask = src_seq.new_zeros((N, T), device=src_seq.device)
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(T, math.ceil(T * valid_ratio))
|
||||
mask[i, :valid_width] = 1
|
||||
|
||||
return mask
|
||||
|
||||
def _attention(self,
|
||||
trg_seq: torch.Tensor,
|
||||
src: torch.Tensor,
|
||||
src_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""A wrapped process for transformer based decoder including text
|
||||
embedding, position embedding, N x transformer decoder and a LayerNorm
|
||||
operation.
|
||||
|
||||
Args:
|
||||
trg_seq (Tensor): Target sequence in. Shape :math:`(N, T)`.
|
||||
src (Tensor): Source sequence from encoder in shape
|
||||
Shape :math:`(N, T, D_m)` where :math:`D_m` is ``d_model``.
|
||||
src_mask (Tensor, Optional): Mask for source sequence.
|
||||
Shape :math:`(N, T)`. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: Output sequence from transformer decoder.
|
||||
Shape :math:`(N, T, D_m)` where :math:`D_m` is ``d_model``.
|
||||
"""
|
||||
|
||||
def _attention(self, trg_seq, src, src_mask=None):
|
||||
trg_embedding = self.trg_word_emb(trg_seq)
|
||||
trg_pos_encoded = self.position_enc(trg_embedding)
|
||||
tgt = self.dropout(trg_pos_encoded)
|
||||
trg_mask = self._get_target_mask(trg_seq)
|
||||
tgt_seq = self.dropout(trg_pos_encoded)
|
||||
|
||||
trg_mask = self.get_pad_mask(
|
||||
trg_seq,
|
||||
pad_idx=self.padding_idx) & self.get_subsequent_mask(trg_seq)
|
||||
output = tgt
|
||||
output = tgt_seq
|
||||
for dec_layer in self.layer_stack:
|
||||
output = dec_layer(
|
||||
output,
|
||||
|
@ -112,46 +179,64 @@ class NRTRDecoder(BaseDecoder):
|
|||
|
||||
return output
|
||||
|
||||
def _get_mask(self, logit, img_metas):
|
||||
valid_ratios = None
|
||||
if img_metas is not None:
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
]
|
||||
N, T, _ = logit.size()
|
||||
mask = None
|
||||
if valid_ratios is not None:
|
||||
mask = logit.new_zeros((N, T))
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(T, math.ceil(T * valid_ratio))
|
||||
mask[i, :valid_width] = 1
|
||||
def forward_train(self,
|
||||
feat: Optional[torch.Tensor] = None,
|
||||
out_enc: torch.Tensor = None,
|
||||
data_samples: Sequence[TextRecogDataSample] = None
|
||||
) -> torch.Tensor:
|
||||
"""Forward for training. Source mask will be used here.
|
||||
|
||||
return mask
|
||||
|
||||
def forward_train(self, feat, out_enc, targets_dict, img_metas):
|
||||
r"""
|
||||
Args:
|
||||
feat (None): Unused.
|
||||
out_enc (Tensor): Encoder output of shape :math:`(N, T, D_m)`
|
||||
where :math:`D_m` is ``d_model``.
|
||||
targets_dict (dict): A dict with the key ``padded_targets``, a
|
||||
tensor of shape :math:`(N, T)`. Each element is the index of a
|
||||
character.
|
||||
img_metas (dict): A dict that contains meta information of input
|
||||
images. Preferably with the key ``valid_ratio``.
|
||||
feat (Tensor, optional): Unused.
|
||||
out_enc (Tensor): Encoder output of shape : math:`(N, T, D_m)`
|
||||
where :math:`D_m` is ``d_model``. Defaults to None.
|
||||
data_samples (list[TextRecogDataSample]): Batch of
|
||||
TextRecogDataSample, containing gt_text and valid_ratio
|
||||
information. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: The raw logit tensor. Shape :math:`(N, T, C)`.
|
||||
Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where
|
||||
:math:`C` is ``num_classes``.
|
||||
"""
|
||||
src_mask = self._get_mask(out_enc, img_metas)
|
||||
targets = targets_dict['padded_targets'].to(out_enc.device)
|
||||
attn_output = self._attention(targets, out_enc, src_mask=src_mask)
|
||||
valid_ratios = []
|
||||
for data_sample in data_samples:
|
||||
valid_ratios.append(data_sample.get('valid_ratio'))
|
||||
src_mask = self._get_source_mask(out_enc, valid_ratios)
|
||||
trg_seq = []
|
||||
for data_sample in data_samples:
|
||||
trg_seq.append(
|
||||
data_sample.gt_text.padded_indexes.to(out_enc.device))
|
||||
trg_seq = torch.stack(trg_seq, dim=0)
|
||||
attn_output = self._attention(trg_seq, out_enc, src_mask=src_mask)
|
||||
outputs = self.classifier(attn_output)
|
||||
|
||||
return outputs
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
src_mask = self._get_mask(out_enc, img_metas)
|
||||
def forward_test(self,
|
||||
feat: Optional[torch.Tensor] = None,
|
||||
out_enc: torch.Tensor = None,
|
||||
data_samples: Sequence[TextRecogDataSample] = None
|
||||
) -> torch.Tensor:
|
||||
"""Forward for testing.
|
||||
|
||||
Args:
|
||||
feat (Tensor, optional): Unused.
|
||||
out_enc (Tensor): Encoder output of shape:
|
||||
math:`(N, T, D_m)` where :math:`D_m` is ``d_model``.
|
||||
Defaults to None.
|
||||
data_samples (list[TextRecogDataSample]): Batch of
|
||||
TextRecogDataSample, containing gt_text and valid_ratio
|
||||
information. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: The raw logit tensor.
|
||||
Shape :math:`(N, self.max_seq_len, C)` where :math:`C` is
|
||||
``num_classes``.
|
||||
"""
|
||||
valid_ratios = []
|
||||
for data_sample in data_samples:
|
||||
valid_ratios.append(data_sample.get('valid_ratio'))
|
||||
src_mask = self._get_source_mask(out_enc, valid_ratios)
|
||||
N = out_enc.size(0)
|
||||
init_target_seq = torch.full((N, self.max_seq_len + 1),
|
||||
self.padding_idx,
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine.data import LabelData
|
||||
|
||||
from mmocr.core.data_structures import TextRecogDataSample
|
||||
from mmocr.models.textrecog.decoders import NRTRDecoder
|
||||
|
||||
|
||||
class TestNRTRDecoder(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
gt_text_sample1 = TextRecogDataSample()
|
||||
gt_text = LabelData()
|
||||
gt_text.item = 'Hello'
|
||||
gt_text_sample1.gt_text = gt_text
|
||||
gt_text_sample1.set_metainfo(dict(valid_ratio=0.9))
|
||||
|
||||
gt_text_sample2 = TextRecogDataSample()
|
||||
gt_text = LabelData()
|
||||
gt_text = LabelData()
|
||||
gt_text.item = 'World'
|
||||
gt_text_sample2.gt_text = gt_text
|
||||
gt_text_sample2.set_metainfo(dict(valid_ratio=1.0))
|
||||
|
||||
self.data_info = [gt_text_sample1, gt_text_sample2]
|
||||
|
||||
def _create_dummy_dict_file(
|
||||
self, dict_file,
|
||||
chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA
|
||||
with open(dict_file, 'w') as f:
|
||||
for char in chars:
|
||||
f.write(char + '\n')
|
||||
|
||||
def test_init(self):
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=True,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
loss_cfg = dict(type='CELoss')
|
||||
NRTRDecoder(dictionary=dict_cfg, loss=loss_cfg)
|
||||
tmp_dir.cleanup()
|
||||
|
||||
def test_forward_train(self):
|
||||
encoder_out = torch.randn(2, 25, 512)
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
# test diction cfg
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=True,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
loss_cfg = dict(type='CELoss')
|
||||
decoder = NRTRDecoder(dictionary=dict_cfg, loss=loss_cfg)
|
||||
data_samples = decoder.loss.get_targets(self.data_info)
|
||||
output = decoder(
|
||||
out_enc=encoder_out, data_samples=data_samples, train_mode=True)
|
||||
self.assertTupleEqual(tuple(output.shape), (2, 40, 39))
|
||||
|
||||
def test_forward_test(self):
|
||||
encoder_out = torch.randn(2, 25, 512)
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
# test diction cfg
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=True,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
loss_cfg = dict(type='CELoss')
|
||||
decoder = NRTRDecoder(
|
||||
dictionary=dict_cfg, loss=loss_cfg, max_seq_len=40)
|
||||
output = decoder(
|
||||
out_enc=encoder_out, data_samples=self.data_info, train_mode=False)
|
||||
self.assertTupleEqual(tuple(output.shape), (2, 40, 39))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
t = TestNRTRDecoder()
|
||||
t.setUp()
|
||||
t.test_forward_test()
|
Loading…
Reference in New Issue