[NRTR] NRTR Decoder

pull/1178/head
jiangqing.vendor 2022-06-10 10:14:36 +00:00 committed by gaotongxiao
parent d41921f03d
commit 8614070e36
2 changed files with 264 additions and 79 deletions

View File

@ -1,12 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, List, Optional, Sequence, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import ModuleList
from mmocr.core.data_structures import TextRecogDataSample
from mmocr.models.common import PositionalEncoding, TFDecoderLayer
from mmocr.models.textrecog.dictionary import Dictionary
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder
@ -16,20 +19,25 @@ class NRTRDecoder(BaseDecoder):
"""Transformer Decoder block with self attention mechanism.
Args:
n_layers (int): Number of attention layers.
d_embedding (int): Language embedding dimension.
n_head (int): Number of parallel attention heads.
d_k (int): Dimension of the key vector.
d_v (int): Dimension of the value vector.
n_layers (int): Number of attention layers. Defaults to 6.
d_embedding (int): Language embedding dimension. Defaults to 512.
n_head (int): Number of parallel attention heads. Defaults to 8.
d_k (int): Dimension of the key vector. Defaults to 64.
d_v (int): Dimension of the value vector. Defaults to 64
d_model (int): Dimension :math:`D_m` of the input from previous model.
d_inner (int): Hidden dimension of feedforward layers.
Defaults to 512.
d_inner (int): Hidden dimension of feedforward layers. Defaults to 256.
n_position (int): Length of the positional encoding vector. Must be
greater than ``max_seq_len``.
dropout (float): Dropout rate.
num_classes (int): Number of output classes :math:`C`.
max_seq_len (int): Maximum output sequence length :math:`T`.
start_idx (int): The index of `<SOS>`.
padding_idx (int): The index of `<PAD>`.
greater than ``max_seq_len``. Defaults to 200.
dropout (float): Dropout rate for text embedding, MHSA, FFN. Defaults
to 0.1.
loss (dict, optional): Config to build loss. Defaults to None.
postprocessor (dict, optional): Config to build postprocessor.
Defaults to None.
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
max_seq_len (int): Maximum output sequence length :math:`T`. Defaults
to 40.
init_cfg (dict or list[dict], optional): Initialization configs.
Warning:
@ -40,29 +48,34 @@ class NRTRDecoder(BaseDecoder):
"""
def __init__(self,
n_layers=6,
d_embedding=512,
n_head=8,
d_k=64,
d_v=64,
d_model=512,
d_inner=256,
n_position=200,
dropout=0.1,
num_classes=93,
max_seq_len=40,
start_idx=1,
padding_idx=92,
init_cfg=None,
**kwargs):
super().__init__(init_cfg=init_cfg)
n_layers: int = 6,
d_embedding: int = 512,
n_head: int = 8,
d_k: int = 64,
d_v: int = 64,
d_model: int = 512,
d_inner: int = 256,
n_position: int = 200,
dropout: float = 0.1,
loss: Optional[Dict] = None,
postprocessor: Optional[Dict] = None,
dictionary: Optional[Union[Dict, Dictionary]] = None,
max_seq_len: int = 40,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(
loss=loss,
postprocessor=postprocessor,
dictionary=dictionary,
init_cfg=init_cfg)
self.padding_idx = padding_idx
self.start_idx = start_idx
self.padding_idx = self.dictionary.padding_idx
self.start_idx = self.dictionary.start_idx
self.max_seq_len = max_seq_len
self.trg_word_emb = nn.Embedding(
num_classes, d_embedding, padding_idx=padding_idx)
self.dictionary.num_classes,
d_embedding,
padding_idx=self.padding_idx)
self.position_enc = PositionalEncoding(
d_embedding, n_position=n_position)
@ -70,38 +83,92 @@ class NRTRDecoder(BaseDecoder):
self.layer_stack = ModuleList([
TFDecoderLayer(
d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs)
d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)
])
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
pred_num_class = num_classes - 1 # ignore padding_idx
pred_num_class = self.dictionary.num_classes
self.classifier = nn.Linear(d_model, pred_num_class)
@staticmethod
def get_pad_mask(seq, pad_idx):
def _get_target_mask(self, trg_seq: torch.Tensor) -> torch.Tensor:
"""Generate mask for target sequence.
return (seq != pad_idx).unsqueeze(-2)
Args:
trg_seq (torch.Tensor): Input text sequence. Shape :math:`(N, T)`.
@staticmethod
def get_subsequent_mask(seq):
"""For masking out the subsequent info."""
len_s = seq.size(1)
Returns:
Tensor: Target mask. Shape :math:`(N, T, T)`.
E.g.:
seq = torch.Tensor([[1, 2, 0, 0]]), pad_idx = 0, then
target_mask =
torch.Tensor([[[True, False, False, False],
[True, True, False, False],
[True, True, False, False],
[True, True, False, False]]])
"""
pad_mask = (trg_seq != self.padding_idx).unsqueeze(-2)
len_s = trg_seq.size(1)
subsequent_mask = 1 - torch.triu(
torch.ones((len_s, len_s), device=seq.device), diagonal=1)
torch.ones((len_s, len_s), device=trg_seq.device), diagonal=1)
subsequent_mask = subsequent_mask.unsqueeze(0).bool()
return subsequent_mask
return pad_mask & subsequent_mask
def _get_source_mask(self, src_seq: torch.Tensor,
valid_ratios: Sequence[float]) -> torch.Tensor:
"""Generate mask for source sequence.
Args:
src_seq (torch.Tensor): Image sequence. Shape :math:`(N, T, C)`.
valid_ratios (list[float]): The valid ratio of input image. For
example, if the width of the original image is w1 and the width
after padding is w2, then valid_ratio = w1/w2. Source mask is
used to cover the area of the padding region.
Returns:
Tensor or None: Source mask. Shape :math:`(N, T)`. The region of
padding area are False, and the rest are True.
"""
N, T, _ = src_seq.size()
mask = None
if len(valid_ratios) > 0:
mask = src_seq.new_zeros((N, T), device=src_seq.device)
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(T, math.ceil(T * valid_ratio))
mask[i, :valid_width] = 1
return mask
def _attention(self,
trg_seq: torch.Tensor,
src: torch.Tensor,
src_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""A wrapped process for transformer based decoder including text
embedding, position embedding, N x transformer decoder and a LayerNorm
operation.
Args:
trg_seq (Tensor): Target sequence in. Shape :math:`(N, T)`.
src (Tensor): Source sequence from encoder in shape
Shape :math:`(N, T, D_m)` where :math:`D_m` is ``d_model``.
src_mask (Tensor, Optional): Mask for source sequence.
Shape :math:`(N, T)`. Defaults to None.
Returns:
Tensor: Output sequence from transformer decoder.
Shape :math:`(N, T, D_m)` where :math:`D_m` is ``d_model``.
"""
def _attention(self, trg_seq, src, src_mask=None):
trg_embedding = self.trg_word_emb(trg_seq)
trg_pos_encoded = self.position_enc(trg_embedding)
tgt = self.dropout(trg_pos_encoded)
trg_mask = self._get_target_mask(trg_seq)
tgt_seq = self.dropout(trg_pos_encoded)
trg_mask = self.get_pad_mask(
trg_seq,
pad_idx=self.padding_idx) & self.get_subsequent_mask(trg_seq)
output = tgt
output = tgt_seq
for dec_layer in self.layer_stack:
output = dec_layer(
output,
@ -112,46 +179,64 @@ class NRTRDecoder(BaseDecoder):
return output
def _get_mask(self, logit, img_metas):
valid_ratios = None
if img_metas is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
]
N, T, _ = logit.size()
mask = None
if valid_ratios is not None:
mask = logit.new_zeros((N, T))
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(T, math.ceil(T * valid_ratio))
mask[i, :valid_width] = 1
def forward_train(self,
feat: Optional[torch.Tensor] = None,
out_enc: torch.Tensor = None,
data_samples: Sequence[TextRecogDataSample] = None
) -> torch.Tensor:
"""Forward for training. Source mask will be used here.
return mask
def forward_train(self, feat, out_enc, targets_dict, img_metas):
r"""
Args:
feat (None): Unused.
out_enc (Tensor): Encoder output of shape :math:`(N, T, D_m)`
where :math:`D_m` is ``d_model``.
targets_dict (dict): A dict with the key ``padded_targets``, a
tensor of shape :math:`(N, T)`. Each element is the index of a
character.
img_metas (dict): A dict that contains meta information of input
images. Preferably with the key ``valid_ratio``.
feat (Tensor, optional): Unused.
out_enc (Tensor): Encoder output of shape : math:`(N, T, D_m)`
where :math:`D_m` is ``d_model``. Defaults to None.
data_samples (list[TextRecogDataSample]): Batch of
TextRecogDataSample, containing gt_text and valid_ratio
information. Defaults to None.
Returns:
Tensor: The raw logit tensor. Shape :math:`(N, T, C)`.
Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where
:math:`C` is ``num_classes``.
"""
src_mask = self._get_mask(out_enc, img_metas)
targets = targets_dict['padded_targets'].to(out_enc.device)
attn_output = self._attention(targets, out_enc, src_mask=src_mask)
valid_ratios = []
for data_sample in data_samples:
valid_ratios.append(data_sample.get('valid_ratio'))
src_mask = self._get_source_mask(out_enc, valid_ratios)
trg_seq = []
for data_sample in data_samples:
trg_seq.append(
data_sample.gt_text.padded_indexes.to(out_enc.device))
trg_seq = torch.stack(trg_seq, dim=0)
attn_output = self._attention(trg_seq, out_enc, src_mask=src_mask)
outputs = self.classifier(attn_output)
return outputs
def forward_test(self, feat, out_enc, img_metas):
src_mask = self._get_mask(out_enc, img_metas)
def forward_test(self,
feat: Optional[torch.Tensor] = None,
out_enc: torch.Tensor = None,
data_samples: Sequence[TextRecogDataSample] = None
) -> torch.Tensor:
"""Forward for testing.
Args:
feat (Tensor, optional): Unused.
out_enc (Tensor): Encoder output of shape:
math:`(N, T, D_m)` where :math:`D_m` is ``d_model``.
Defaults to None.
data_samples (list[TextRecogDataSample]): Batch of
TextRecogDataSample, containing gt_text and valid_ratio
information. Defaults to None.
Returns:
Tensor: The raw logit tensor.
Shape :math:`(N, self.max_seq_len, C)` where :math:`C` is
``num_classes``.
"""
valid_ratios = []
for data_sample in data_samples:
valid_ratios.append(data_sample.get('valid_ratio'))
src_mask = self._get_source_mask(out_enc, valid_ratios)
N = out_enc.size(0)
init_target_seq = torch.full((N, self.max_seq_len + 1),
self.padding_idx,

View File

@ -0,0 +1,100 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase
import torch
from mmengine.data import LabelData
from mmocr.core.data_structures import TextRecogDataSample
from mmocr.models.textrecog.decoders import NRTRDecoder
class TestNRTRDecoder(TestCase):
def setUp(self):
gt_text_sample1 = TextRecogDataSample()
gt_text = LabelData()
gt_text.item = 'Hello'
gt_text_sample1.gt_text = gt_text
gt_text_sample1.set_metainfo(dict(valid_ratio=0.9))
gt_text_sample2 = TextRecogDataSample()
gt_text = LabelData()
gt_text = LabelData()
gt_text.item = 'World'
gt_text_sample2.gt_text = gt_text
gt_text_sample2.set_metainfo(dict(valid_ratio=1.0))
self.data_info = [gt_text_sample1, gt_text_sample2]
def _create_dummy_dict_file(
self, dict_file,
chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA
with open(dict_file, 'w') as f:
for char in chars:
f.write(char + '\n')
def test_init(self):
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=True,
with_padding=True,
with_unknown=True)
loss_cfg = dict(type='CELoss')
NRTRDecoder(dictionary=dict_cfg, loss=loss_cfg)
tmp_dir.cleanup()
def test_forward_train(self):
encoder_out = torch.randn(2, 25, 512)
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
# test diction cfg
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=True,
with_padding=True,
with_unknown=True)
loss_cfg = dict(type='CELoss')
decoder = NRTRDecoder(dictionary=dict_cfg, loss=loss_cfg)
data_samples = decoder.loss.get_targets(self.data_info)
output = decoder(
out_enc=encoder_out, data_samples=data_samples, train_mode=True)
self.assertTupleEqual(tuple(output.shape), (2, 40, 39))
def test_forward_test(self):
encoder_out = torch.randn(2, 25, 512)
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
# test diction cfg
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=True,
with_padding=True,
with_unknown=True)
loss_cfg = dict(type='CELoss')
decoder = NRTRDecoder(
dictionary=dict_cfg, loss=loss_cfg, max_seq_len=40)
output = decoder(
out_enc=encoder_out, data_samples=self.data_info, train_mode=False)
self.assertTupleEqual(tuple(output.shape), (2, 40, 39))
if __name__ == '__main__':
t = TestNRTRDecoder()
t.setUp()
t.test_forward_test()