diff --git a/configs/_base_/recog_models/nrtr_modality_transform.py b/configs/_base_/recog_models/nrtr_modality_transform.py index 40657578..3c2e87f4 100644 --- a/configs/_base_/recog_models/nrtr_modality_transform.py +++ b/configs/_base_/recog_models/nrtr_modality_transform.py @@ -4,8 +4,8 @@ label_convertor = dict( model = dict( type='NRTR', backbone=dict(type='NRTRModalityTransform'), - encoder=dict(type='TFEncoder'), - decoder=dict(type='TFDecoder'), + encoder=dict(type='NRTREncoder', n_layers=12), + decoder=dict(type='NRTRDecoder'), loss=dict(type='TFLoss'), label_convertor=label_convertor, max_seq_len=40) diff --git a/configs/_base_/recog_pipelines/nrtr_pipeline.py b/configs/_base_/recog_pipelines/nrtr_pipeline.py index 4257af40..d476346e 100644 --- a/configs/_base_/recog_pipelines/nrtr_pipeline.py +++ b/configs/_base_/recog_pipelines/nrtr_pipeline.py @@ -17,26 +17,19 @@ train_pipeline = [ 'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio' ]), ] + test_pipeline = [ dict(type='LoadImageFromFile'), dict( - type='MultiRotateAugOCR', - rotate_degrees=[0, 90, 270], - transforms=[ - dict( - type='ResizeOCR', - height=32, - min_width=32, - max_width=160, - keep_aspect_ratio=True, - width_downsample_ratio=0.25), - dict(type='ToTensorOCR'), - dict(type='NormalizeOCR', **img_norm_cfg), - dict( - type='Collect', - keys=['img'], - meta_keys=[ - 'filename', 'ori_shape', 'resize_shape', 'valid_ratio' - ]), - ]) + type='ResizeOCR', + height=32, + min_width=32, + max_width=160, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=['filename', 'ori_shape', 'resize_shape', 'valid_ratio']) ] diff --git a/configs/textrecog/nrtr/README.md b/configs/textrecog/nrtr/README.md index 61384670..984f3cfe 100644 --- a/configs/textrecog/nrtr/README.md +++ b/configs/textrecog/nrtr/README.md @@ -66,10 +66,16 @@ Backbone | Methods | Backbone | | Regular Text | | | | Irregular Text | | download | | :-------------------------------------------------------------: | :----------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | | | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 | -| [NRTR](/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py) | R31-1/16-1/8 | 93.9 | 90.0 | 93.5 | | 74.5 | 78.5 | 86.5 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_academic_20210406-954db95e.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20210406_010150.log.json) | -| [NRTR](/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py) | R31-1/8-1/4 | 94.7 | 87.5 | 93.3 | | 75.1 | 78.9 | 87.9 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20210406_160845.log.json) | +| [NRTR](/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py) | R31-1/16-1/8 | 94.7 | 87.3 | 94.3 | | 73.5 | 78.9 | 85.1 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20211124_002420.log.json) | +| [NRTR](/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py) | R31-1/8-1/4 | 95.2 | 90.0 | 94.0 | | 74.1 | 79.4 | 88.2 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20211123_232151.log.json) | **Notes:** -- `R31-1/16-1/8` means the height of feature from backbone is 1/16 of input image, where 1/8 for width. -- `R31-1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width. +- For backbone `R31-1/16-1/8`: + - The output consists of 92 classes, including 26 lowercase letters, 26 uppercase letters, 28 symbols, 10 digital numbers, 1 unknown token and 1 end-of-sequence token. + - The encoder-block number is 6. + - `1/16-1/8` means the height of feature from backbone is 1/16 of input image, where 1/8 for width. +- For backbone `R31-1/8-1/4`: + - The output consists of 92 classes, including 26 lowercase letters, 26 uppercase letters, 28 symbols, 10 digital numbers, 1 unknown token and 1 end-of-sequence token. + - The encoder-block number is 6. + - `1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width. diff --git a/configs/textrecog/nrtr/metafile.yml b/configs/textrecog/nrtr/metafile.yml index b9794252..7d5ca150 100644 --- a/configs/textrecog/nrtr/metafile.yml +++ b/configs/textrecog/nrtr/metafile.yml @@ -4,13 +4,13 @@ Collections: Training Data: OCRDataset Training Techniques: - Adam - Epochs: 5 - Batch Size: 8192 - Training Resources: 64x GeForce GTX 1080 Ti + Epochs: 6 + Batch Size: 6144 + Training Resources: 48x GeForce GTX 1080 Ti Architecture: - - ResNet31OCR - - TFEncoder - - TFDecoder + - CNN + - NRTREncoder + - NRTRDecoder Paper: URL: https://arxiv.org/pdf/1806.00926.pdf Title: 'NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition' @@ -28,28 +28,28 @@ Models: - Task: Text Recognition Dataset: IIIT5K Metrics: - word_acc: 93.9 + word_acc: 94.7 - Task: Text Recognition Dataset: SVT Metrics: - word_acc: 80.0 + word_acc: 87.3 - Task: Text Recognition Dataset: ICDAR2013 Metrics: - word_acc: 93.5 + word_acc: 94.3 - Task: Text Recognition Dataset: ICDAR2015 Metrics: - word_acc: 74.5 + word_acc: 73.5 - Task: Text Recognition Dataset: SVTP Metrics: - word_acc: 78.5 + word_acc: 78.9 - Task: Text Recognition Dataset: CT80 Metrics: - word_acc: 86.5 - Weights: https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_academic_20210406-954db95e.pth + word_acc: 85.1 + Weights: https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth - Name: nrtr_r31_1by8_1by4_academic In Collection: NRTR @@ -62,25 +62,25 @@ Models: - Task: Text Recognition Dataset: IIIT5K Metrics: - word_acc: 94.7 + word_acc: 95.2 - Task: Text Recognition Dataset: SVT Metrics: - word_acc: 87.5 + word_acc: 90.0 - Task: Text Recognition Dataset: ICDAR2013 Metrics: - word_acc: 93.3 + word_acc: 94.0 - Task: Text Recognition Dataset: ICDAR2015 Metrics: - word_acc: 75.1 + word_acc: 74.1 - Task: Text Recognition Dataset: SVTP Metrics: - word_acc: 78.9 + word_acc: 79.4 - Task: Text Recognition Dataset: CT80 Metrics: - word_acc: 87.9 - Weights: https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth + word_acc: 88.2 + Weights: https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth diff --git a/configs/textrecog/nrtr/nrtr_modality_transform_academic.py b/configs/textrecog/nrtr/nrtr_modality_transform_academic.py new file mode 100644 index 00000000..471926ba --- /dev/null +++ b/configs/textrecog/nrtr/nrtr_modality_transform_academic.py @@ -0,0 +1,32 @@ +_base_ = [ + '../../_base_/default_runtime.py', + '../../_base_/recog_models/nrtr_modality_transform.py', + '../../_base_/schedules/schedule_adam_step_6e.py', + '../../_base_/recog_datasets/ST_MJ_train.py', + '../../_base_/recog_datasets/academic_test.py', + '../../_base_/recog_pipelines/nrtr_pipeline.py' +] + +train_list = {{_base_.train_list}} +test_list = {{_base_.test_list}} + +train_pipeline = {{_base_.train_pipeline}} +test_pipeline = {{_base_.test_pipeline}} + +data = dict( + samples_per_gpu=128, + workers_per_gpu=4, + train=dict( + type='UniformConcatDataset', + datasets=train_list, + pipeline=train_pipeline), + val=dict( + type='UniformConcatDataset', + datasets=test_list, + pipeline=test_pipeline), + test=dict( + type='UniformConcatDataset', + datasets=test_list, + pipeline=test_pipeline)) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/textrecog/nrtr/nrtr_modality_transform_toy_dataset.py b/configs/textrecog/nrtr/nrtr_modality_transform_toy_dataset.py new file mode 100644 index 00000000..1bb350fc --- /dev/null +++ b/configs/textrecog/nrtr/nrtr_modality_transform_toy_dataset.py @@ -0,0 +1,31 @@ +_base_ = [ + '../../_base_/default_runtime.py', + '../../_base_/recog_models/nrtr_modality_transform.py', + '../../_base_/schedules/schedule_adam_step_6e.py', + '../../_base_/recog_datasets/toy_data.py', + '../../_base_/recog_pipelines/nrtr_pipeline.py' +] + +train_list = {{_base_.train_list}} +test_list = {{_base_.test_list}} + +train_pipeline = {{_base_.train_pipeline}} +test_pipeline = {{_base_.test_pipeline}} + +data = dict( + samples_per_gpu=16, + workers_per_gpu=2, + train=dict( + type='UniformConcatDataset', + datasets=train_list, + pipeline=train_pipeline), + val=dict( + type='UniformConcatDataset', + datasets=test_list, + pipeline=test_pipeline), + test=dict( + type='UniformConcatDataset', + datasets=test_list, + pipeline=test_pipeline)) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py b/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py index 0b404aa7..b7adc0d3 100644 --- a/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py +++ b/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py @@ -23,8 +23,8 @@ model = dict( channels=[32, 64, 128, 256, 512, 512], stage4_pool_cfg=dict(kernel_size=(2, 1), stride=(2, 1)), last_stage_pool=True), - encoder=dict(type='TFEncoder'), - decoder=dict(type='TFDecoder'), + encoder=dict(type='NRTREncoder'), + decoder=dict(type='NRTRDecoder'), loss=dict(type='TFLoss'), label_convertor=label_convertor, max_seq_len=40) @@ -32,8 +32,6 @@ model = dict( data = dict( samples_per_gpu=128, workers_per_gpu=4, - val_dataloader=dict(samples_per_gpu=1), - test_dataloader=dict(samples_per_gpu=1), train=dict( type='UniformConcatDataset', datasets=train_list, diff --git a/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py b/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py index 72dfa12e..397122b5 100644 --- a/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py +++ b/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py @@ -23,8 +23,8 @@ model = dict( channels=[32, 64, 128, 256, 512, 512], stage4_pool_cfg=dict(kernel_size=(2, 1), stride=(2, 1)), last_stage_pool=False), - encoder=dict(type='TFEncoder'), - decoder=dict(type='TFDecoder'), + encoder=dict(type='NRTREncoder'), + decoder=dict(type='NRTRDecoder'), loss=dict(type='TFLoss'), label_convertor=label_convertor, max_seq_len=40) @@ -32,8 +32,6 @@ model = dict( data = dict( samples_per_gpu=64, workers_per_gpu=4, - val_dataloader=dict(samples_per_gpu=1), - test_dataloader=dict(samples_per_gpu=1), train=dict( type='UniformConcatDataset', datasets=train_list, diff --git a/configs/textrecog/satrn/satrn_academic.py b/configs/textrecog/satrn/satrn_academic.py index a42ae127..00a664e2 100644 --- a/configs/textrecog/satrn/satrn_academic.py +++ b/configs/textrecog/satrn/satrn_academic.py @@ -28,7 +28,7 @@ model = dict( d_inner=512 * 4, dropout=0.1), decoder=dict( - type='TFDecoder', + type='NRTRDecoder', n_layers=6, d_embedding=512, n_head=8, diff --git a/configs/textrecog/satrn/satrn_small.py b/configs/textrecog/satrn/satrn_small.py index 35b3d358..96f86797 100644 --- a/configs/textrecog/satrn/satrn_small.py +++ b/configs/textrecog/satrn/satrn_small.py @@ -28,7 +28,7 @@ model = dict( d_inner=256 * 4, dropout=0.1), decoder=dict( - type='TFDecoder', + type='NRTRDecoder', n_layers=6, d_embedding=256, n_head=8, diff --git a/mmocr/apis/inference.py b/mmocr/apis/inference.py index 7a362db9..d1899f40 100644 --- a/mmocr/apis/inference.py +++ b/mmocr/apis/inference.py @@ -68,12 +68,36 @@ def disable_text_recog_aug_test(cfg, set_types=None): assert set_types is None or isinstance(set_types, list) if set_types is None: set_types = ['val', 'test'] + warnings.simplefilter('once') + warning_msg = 'Remove "MultiRotateAugOCR" to support batch ' + \ + 'inference since samples_per_gpu > 1.' for set_type in set_types: - if cfg.data[set_type].pipeline[1].type == 'MultiRotateAugOCR': - cfg.data[set_type].pipeline = [ - cfg.data[set_type].pipeline[0], - *cfg.data[set_type].pipeline[1].transforms - ] + dataset_type = cfg.data[set_type].type + if dataset_type in ['OCRDataset', 'OCRSegDataset']: + if cfg.data[set_type].pipeline[1].type == 'MultiRotateAugOCR': + warnings.warn(warning_msg) + cfg.data[set_type].pipeline = [ + cfg.data[set_type].pipeline[0], + *cfg.data[set_type].pipeline[1].transforms + ] + elif dataset_type in ['ConcatDataset', 'UniformConcatDataset']: + if dataset_type == 'UniformConcatDataset': + uniform_pipeline = cfg.data[set_type].pipeline + if uniform_pipeline is not None: + if uniform_pipeline[1].type == 'MultiRotateAugOCR': + warnings.warn(warning_msg) + cfg.data[set_type].pipeline = [ + uniform_pipeline[0], + *uniform_pipeline[1].transforms + ] + for dataset in cfg.data[set_type].datasets: + if dataset.pipeline is not None: + if dataset.pipeline[1].type == 'MultiRotateAugOCR': + warnings.warn(warning_msg) + dataset.pipeline = [ + dataset.pipeline[0], + *dataset.pipeline[1].transforms + ] return cfg diff --git a/mmocr/models/common/__init__.py b/mmocr/models/common/__init__.py index e8a7f671..f9a578f1 100644 --- a/mmocr/models/common/__init__.py +++ b/mmocr/models/common/__init__.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import backbones, losses +from . import backbones, layers, losses, modules from .backbones import * # NOQA from .losses import * # NOQA +from .layers import * # NOQA +from .modules import * # NOQA -__all__ = backbones.__all__ + losses.__all__ +__all__ = backbones.__all__ + losses.__all__ + layers.__all__ + modules.__all__ diff --git a/mmocr/models/common/layers/__init__.py b/mmocr/models/common/layers/__init__.py new file mode 100644 index 00000000..87f605f1 --- /dev/null +++ b/mmocr/models/common/layers/__init__.py @@ -0,0 +1,3 @@ +from .transformer_layers import TFDecoderLayer, TFEncoderLayer + +__all__ = ['TFEncoderLayer', 'TFDecoderLayer'] diff --git a/mmocr/models/common/layers/transformer_layers.py b/mmocr/models/common/layers/transformer_layers.py new file mode 100644 index 00000000..a491ac67 --- /dev/null +++ b/mmocr/models/common/layers/transformer_layers.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.runner import BaseModule + +from mmocr.models.common.modules import (MultiHeadAttention, + PositionwiseFeedForward) + + +class TFEncoderLayer(BaseModule): + """Transformer Encoder Layer. + + Args: + d_model (int): The number of expected features + in the decoder inputs (default=512). + d_inner (int): The dimension of the feedforward + network model (default=256). + n_head (int): The number of heads in the + multiheadattention models (default=8). + d_k (int): Total number of features in key. + d_v (int): Total number of features in value. + dropout (float): Dropout layer on attn_output_weights. + qkv_bias (bool): Add bias in projection layer. Default: False. + act_cfg (dict): Activation cfg for feedforward module. + operation_order (tuple[str]): The execution order of operation + in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm') + or ('norm', 'self_attn', 'norm', 'ffn'). + Default:None. + """ + + def __init__(self, + d_model=512, + d_inner=256, + n_head=8, + d_k=64, + d_v=64, + dropout=0.1, + qkv_bias=False, + act_cfg=dict(type='mmcv.GELU'), + operation_order=None): + super().__init__() + self.attn = MultiHeadAttention( + n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout) + self.norm1 = nn.LayerNorm(d_model) + self.mlp = PositionwiseFeedForward( + d_model, d_inner, dropout=dropout, act_cfg=act_cfg) + self.norm2 = nn.LayerNorm(d_model) + + self.operation_order = operation_order + if self.operation_order is None: + self.operation_order = ('norm', 'self_attn', 'norm', 'ffn') + + assert self.operation_order in [('norm', 'self_attn', 'norm', 'ffn'), + ('self_attn', 'norm', 'ffn', 'norm')] + + def forward(self, x, mask=None): + if self.operation_order == ('self_attn', 'norm', 'ffn', 'norm'): + residual = x + x = residual + self.attn(x, x, x, mask) + x = self.norm1(x) + + residual = x + x = residual + self.mlp(x) + x = self.norm2(x) + elif self.operation_order == ('norm', 'self_attn', 'norm', 'ffn'): + residual = x + x = self.norm1(x) + x = residual + self.attn(x, x, x, mask) + + residual = x + x = self.norm2(x) + x = residual + self.mlp(x) + + return x + + +class TFDecoderLayer(nn.Module): + """Transformer Decoder Layer. + + Args: + d_model (int): The number of expected features + in the decoder inputs (default=512). + d_inner (int): The dimension of the feedforward + network model (default=256). + n_head (int): The number of heads in the + multiheadattention models (default=8). + d_k (int): Total number of features in key. + d_v (int): Total number of features in value. + dropout (float): Dropout layer on attn_output_weights. + qkv_bias (bool): Add bias in projection layer. Default: False. + act_cfg (dict): Activation cfg for feedforward module. + operation_order (tuple[str]): The execution order of operation + in transformer. Such as ('self_attn', 'norm', 'enc_dec_attn', + 'norm', 'ffn', 'norm') or ('norm', 'self_attn', 'norm', + 'enc_dec_attn', 'norm', 'ffn'). + Default:None. + """ + + def __init__(self, + d_model=512, + d_inner=256, + n_head=8, + d_k=64, + d_v=64, + dropout=0.1, + qkv_bias=False, + act_cfg=dict(type='mmcv.GELU'), + operation_order=None): + super().__init__() + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + + self.self_attn = MultiHeadAttention( + n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias) + + self.enc_attn = MultiHeadAttention( + n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias) + + self.mlp = PositionwiseFeedForward( + d_model, d_inner, dropout=dropout, act_cfg=act_cfg) + + self.operation_order = operation_order + if self.operation_order is None: + self.operation_order = ('norm', 'self_attn', 'norm', + 'enc_dec_attn', 'norm', 'ffn') + assert self.operation_order in [ + ('norm', 'self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn'), + ('self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn', 'norm') + ] + + def forward(self, + dec_input, + enc_output, + self_attn_mask=None, + dec_enc_attn_mask=None): + if self.operation_order == ('self_attn', 'norm', 'enc_dec_attn', + 'norm', 'ffn', 'norm'): + dec_attn_out = self.self_attn(dec_input, dec_input, dec_input, + self_attn_mask) + dec_attn_out += dec_input + dec_attn_out = self.norm1(dec_attn_out) + + enc_dec_attn_out = self.enc_attn(dec_attn_out, enc_output, + enc_output, dec_enc_attn_mask) + enc_dec_attn_out += dec_attn_out + enc_dec_attn_out = self.norm2(enc_dec_attn_out) + + mlp_out = self.mlp(enc_dec_attn_out) + mlp_out += enc_dec_attn_out + mlp_out = self.norm3(mlp_out) + elif self.operation_order == ('norm', 'self_attn', 'norm', + 'enc_dec_attn', 'norm', 'ffn'): + dec_input_norm = self.norm1(dec_input) + dec_attn_out = self.self_attn(dec_input_norm, dec_input_norm, + dec_input_norm, self_attn_mask) + dec_attn_out += dec_input + + enc_dec_attn_in = self.norm2(dec_attn_out) + enc_dec_attn_out = self.enc_attn(enc_dec_attn_in, enc_output, + enc_output, dec_enc_attn_mask) + enc_dec_attn_out += dec_attn_out + + mlp_out = self.mlp(self.norm3(enc_dec_attn_out)) + mlp_out += enc_dec_attn_out + + return mlp_out diff --git a/mmocr/models/common/modules/__init__.py b/mmocr/models/common/modules/__init__.py new file mode 100644 index 00000000..5ec6c38c --- /dev/null +++ b/mmocr/models/common/modules/__init__.py @@ -0,0 +1,8 @@ +from .transformer_module import (MultiHeadAttention, PositionalEncoding, + PositionwiseFeedForward, + ScaledDotProductAttention) + +__all__ = [ + 'ScaledDotProductAttention', 'MultiHeadAttention', + 'PositionwiseFeedForward', 'PositionalEncoding' +] diff --git a/mmocr/models/common/modules/transformer_module.py b/mmocr/models/common/modules/transformer_module.py new file mode 100644 index 00000000..6e23e5a7 --- /dev/null +++ b/mmocr/models/common/modules/transformer_module.py @@ -0,0 +1,155 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmocr.models.builder import build_activation_layer + + +class ScaledDotProductAttention(nn.Module): + """Scaled Dot-Product Attention Module. This code is adopted from + https://github.com/jadore801120/attention-is-all-you-need-pytorch. + + Args: + temperature (float): The scale factor for softmax input. + attn_dropout (float): Dropout layer on attn_output_weights. + """ + + def __init__(self, temperature, attn_dropout=0.1): + super().__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + + def forward(self, q, k, v, mask=None): + + attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) + + if mask is not None: + attn = attn.masked_fill(mask == 0, float('-inf')) + + attn = self.dropout(F.softmax(attn, dim=-1)) + output = torch.matmul(attn, v) + + return output, attn + + +class MultiHeadAttention(nn.Module): + """Multi-Head Attention module. + + Args: + n_head (int): The number of heads in the + multiheadattention models (default=8). + d_model (int): The number of expected features + in the decoder inputs (default=512). + d_k (int): Total number of features in key. + d_v (int): Total number of features in value. + dropout (float): Dropout layer on attn_output_weights. + qkv_bias (bool): Add bias in projection layer. Default: False. + """ + + def __init__(self, + n_head=8, + d_model=512, + d_k=64, + d_v=64, + dropout=0.1, + qkv_bias=False): + super().__init__() + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + self.dim_k = n_head * d_k + self.dim_v = n_head * d_v + + self.linear_q = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias) + self.linear_k = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias) + self.linear_v = nn.Linear(self.dim_v, self.dim_v, bias=qkv_bias) + + self.attention = ScaledDotProductAttention(d_k**0.5, dropout) + + self.fc = nn.Linear(self.dim_v, d_model, bias=qkv_bias) + self.proj_drop = nn.Dropout(dropout) + + def forward(self, q, k, v, mask=None): + batch_size, len_q, _ = q.size() + _, len_k, _ = k.size() + + q = self.linear_q(q).view(batch_size, len_q, self.n_head, self.d_k) + k = self.linear_k(k).view(batch_size, len_k, self.n_head, self.d_k) + v = self.linear_v(v).view(batch_size, len_k, self.n_head, self.d_v) + + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + + if mask is not None: + if mask.dim() == 3: + mask = mask.unsqueeze(1) + elif mask.dim() == 2: + mask = mask.unsqueeze(1).unsqueeze(1) + + attn_out, _ = self.attention(q, k, v, mask=mask) + + attn_out = attn_out.transpose(1, 2).contiguous().view( + batch_size, len_q, self.dim_v) + + attn_out = self.fc(attn_out) + attn_out = self.proj_drop(attn_out) + + return attn_out + + +class PositionwiseFeedForward(nn.Module): + """Two-layer feed-forward module. + + Args: + d_in (int): The dimension of the input for feedforward + network model. + d_hid (int): The dimension of the feedforward + network model. + dropout (float): Dropout layer on feedforward output. + act_cfg (dict): Activation cfg for feedforward module. + """ + + def __init__(self, d_in, d_hid, dropout=0.1, act_cfg=dict(type='Relu')): + super().__init__() + self.w_1 = nn.Linear(d_in, d_hid) + self.w_2 = nn.Linear(d_hid, d_in) + self.act = build_activation_layer(act_cfg) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.w_1(x) + x = self.act(x) + x = self.w_2(x) + x = self.dropout(x) + + return x + + +class PositionalEncoding(nn.Module): + """Fixed positional encoding with sine and cosine functions.""" + + def __init__(self, d_hid=512, n_position=200): + super().__init__() + # Not a parameter + self.register_buffer( + 'position_table', + self._get_sinusoid_encoding_table(n_position, d_hid)) + + def _get_sinusoid_encoding_table(self, n_position, d_hid): + """Sinusoid position encoding table.""" + denominator = torch.Tensor([ + 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ]) + denominator = denominator.view(1, -1) + pos_tensor = torch.arange(n_position).unsqueeze(-1).float() + sinusoid_table = pos_tensor * denominator + sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) + sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) + + return sinusoid_table.unsqueeze(0) + + def forward(self, x): + self.device = x.device + return x + self.position_table[:, :x.size(1)].clone().detach() diff --git a/mmocr/models/textrecog/backbones/nrtr_modality_transformer.py b/mmocr/models/textrecog/backbones/nrtr_modality_transformer.py index 6af55b0e..a514ffdf 100644 --- a/mmocr/models/textrecog/backbones/nrtr_modality_transformer.py +++ b/mmocr/models/textrecog/backbones/nrtr_modality_transformer.py @@ -10,7 +10,6 @@ class NRTRModalityTransform(BaseModule): def __init__(self, input_channels=3, - input_height=32, init_cfg=[ dict(type='Kaiming', layer='Conv2d'), dict(type='Uniform', layer='BatchNorm2d') @@ -35,9 +34,7 @@ class NRTRModalityTransform(BaseModule): self.relu_2 = nn.ReLU(True) self.bn_2 = nn.BatchNorm2d(64) - feat_height = input_height // 4 - - self.linear = nn.Linear(64 * feat_height, 512) + self.linear = nn.Linear(512, 512) def forward(self, x): x = self.conv_1(x) @@ -49,7 +46,11 @@ class NRTRModalityTransform(BaseModule): x = self.bn_2(x) n, c, h, w = x.size() + x = x.permute(0, 3, 2, 1).contiguous().view(n, w, h * c) + x = self.linear(x) + x = x.permute(0, 2, 1).contiguous().view(n, -1, 1, w) + return x diff --git a/mmocr/models/textrecog/decoders/__init__.py b/mmocr/models/textrecog/decoders/__init__.py index 7b6fad8c..d22dcf86 100755 --- a/mmocr/models/textrecog/decoders/__init__.py +++ b/mmocr/models/textrecog/decoders/__init__.py @@ -1,16 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base_decoder import BaseDecoder from .crnn_decoder import CRNNDecoder +from .nrtr_decoder import NRTRDecoder from .position_attention_decoder import PositionAttentionDecoder from .robust_scanner_decoder import RobustScannerDecoder from .sar_decoder import ParallelSARDecoder, SequentialSARDecoder from .sar_decoder_with_bs import ParallelSARDecoderWithBS from .sequence_attention_decoder import SequenceAttentionDecoder -from .transformer_decoder import TFDecoder __all__ = [ 'CRNNDecoder', 'ParallelSARDecoder', 'SequentialSARDecoder', - 'ParallelSARDecoderWithBS', 'TFDecoder', 'BaseDecoder', + 'ParallelSARDecoderWithBS', 'NRTRDecoder', 'BaseDecoder', 'SequenceAttentionDecoder', 'PositionAttentionDecoder', 'RobustScannerDecoder' ] diff --git a/mmocr/models/textrecog/decoders/transformer_decoder.py b/mmocr/models/textrecog/decoders/nrtr_decoder.py similarity index 78% rename from mmocr/models/textrecog/decoders/transformer_decoder.py rename to mmocr/models/textrecog/decoders/nrtr_decoder.py index 464bef99..c21c0248 100644 --- a/mmocr/models/textrecog/decoders/transformer_decoder.py +++ b/mmocr/models/textrecog/decoders/nrtr_decoder.py @@ -7,14 +7,12 @@ import torch.nn.functional as F from mmcv.runner import ModuleList from mmocr.models.builder import DECODERS -from mmocr.models.textrecog.layers.transformer_layer import ( - PositionalEncoding, TransformerDecoderLayer, get_pad_mask, - get_subsequent_mask) +from mmocr.models.common import PositionalEncoding, TFDecoderLayer from .base_decoder import BaseDecoder @DECODERS.register_module() -class TFDecoder(BaseDecoder): +class NRTRDecoder(BaseDecoder): """Transformer Decoder block with self attention mechanism. Args: @@ -71,8 +69,8 @@ class TFDecoder(BaseDecoder): self.dropout = nn.Dropout(p=dropout) self.layer_stack = ModuleList([ - TransformerDecoderLayer( - d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + TFDecoderLayer( + d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs) for _ in range(n_layers) ]) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) @@ -80,13 +78,29 @@ class TFDecoder(BaseDecoder): pred_num_class = num_classes - 1 # ignore padding_idx self.classifier = nn.Linear(d_model, pred_num_class) + @staticmethod + def get_pad_mask(seq, pad_idx): + + return (seq != pad_idx).unsqueeze(-2) + + @staticmethod + def get_subsequent_mask(seq): + """For masking out the subsequent info.""" + len_s = seq.size(1) + subsequent_mask = 1 - torch.triu( + torch.ones((len_s, len_s), device=seq.device), diagonal=1) + subsequent_mask = subsequent_mask.unsqueeze(0).bool() + + return subsequent_mask + def _attention(self, trg_seq, src, src_mask=None): trg_embedding = self.trg_word_emb(trg_seq) trg_pos_encoded = self.position_enc(trg_embedding) tgt = self.dropout(trg_pos_encoded) - trg_mask = get_pad_mask( - trg_seq, pad_idx=self.padding_idx) & get_subsequent_mask(trg_seq) + trg_mask = self.get_pad_mask( + trg_seq, + pad_idx=self.padding_idx) & self.get_subsequent_mask(trg_seq) output = tgt for dec_layer in self.layer_stack: output = dec_layer( @@ -98,11 +112,27 @@ class TFDecoder(BaseDecoder): return output + def _get_mask(self, logit, img_metas): + valid_ratios = None + if img_metas is not None: + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] + N, T, _ = logit.size() + mask = None + if valid_ratios is not None: + mask = logit.new_zeros((N, T)) + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(T, math.ceil(T * valid_ratio)) + mask[i, :valid_width] = 1 + + return mask + def forward_train(self, feat, out_enc, targets_dict, img_metas): r""" Args: feat (None): Unused. - out_enc (Tensor): Encoder output of shape :math:`(N, D_m, H, W)` + out_enc (Tensor): Encoder output of shape :math:`(N, T, D_m)` where :math:`D_m` is ``d_model``. targets_dict (dict): A dict with the key ``padded_targets``, a tensor of shape :math:`(N, T)`. Each element is the index of a @@ -113,44 +143,17 @@ class TFDecoder(BaseDecoder): Returns: Tensor: The raw logit tensor. Shape :math:`(N, T, C)`. """ - valid_ratios = None - if img_metas is not None: - valid_ratios = [ - img_meta.get('valid_ratio', 1.0) for img_meta in img_metas - ] - n, c, h, w = out_enc.size() - src_mask = None - if valid_ratios is not None: - src_mask = out_enc.new_zeros((n, h, w)) - for i, valid_ratio in enumerate(valid_ratios): - valid_width = min(w, math.ceil(w * valid_ratio)) - src_mask[i, :, :valid_width] = 1 - src_mask = src_mask.view(n, h * w) - out_enc = out_enc.view(n, c, h * w).permute(0, 2, 1) - out_enc = out_enc.contiguous() + src_mask = self._get_mask(out_enc, img_metas) targets = targets_dict['padded_targets'].to(out_enc.device) attn_output = self._attention(targets, out_enc, src_mask=src_mask) outputs = self.classifier(attn_output) + return outputs def forward_test(self, feat, out_enc, img_metas): - valid_ratios = None - if img_metas is not None: - valid_ratios = [ - img_meta.get('valid_ratio', 1.0) for img_meta in img_metas - ] - n, c, h, w = out_enc.size() - src_mask = None - if valid_ratios is not None: - src_mask = out_enc.new_zeros((n, h, w)) - for i, valid_ratio in enumerate(valid_ratios): - valid_width = min(w, math.ceil(w * valid_ratio)) - src_mask[i, :, :valid_width] = 1 - src_mask = src_mask.view(n, h * w) - out_enc = out_enc.view(n, c, h * w).permute(0, 2, 1) - out_enc = out_enc.contiguous() - - init_target_seq = torch.full((n, self.max_seq_len + 1), + src_mask = self._get_mask(out_enc, img_metas) + N = out_enc.size(0) + init_target_seq = torch.full((N, self.max_seq_len + 1), self.padding_idx, device=out_enc.device, dtype=torch.long) @@ -161,7 +164,7 @@ class TFDecoder(BaseDecoder): for step in range(0, self.max_seq_len): decoder_output = self._attention( init_target_seq, out_enc, src_mask=src_mask) - # bsz * seq_len * 512 + # bsz * seq_len * C step_result = F.softmax( self.classifier(decoder_output[:, step, :]), dim=-1) # bsz * num_classes diff --git a/mmocr/models/textrecog/encoders/__init__.py b/mmocr/models/textrecog/encoders/__init__.py index 604f5439..391b5163 100644 --- a/mmocr/models/textrecog/encoders/__init__.py +++ b/mmocr/models/textrecog/encoders/__init__.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base_encoder import BaseEncoder from .channel_reduction_encoder import ChannelReductionEncoder +from .nrtr_encoder import NRTREncoder from .sar_encoder import SAREncoder from .satrn_encoder import SatrnEncoder -from .transformer_encoder import TFEncoder __all__ = [ - 'SAREncoder', 'TFEncoder', 'BaseEncoder', 'ChannelReductionEncoder', + 'SAREncoder', 'NRTREncoder', 'BaseEncoder', 'ChannelReductionEncoder', 'SatrnEncoder' ] diff --git a/mmocr/models/textrecog/encoders/nrtr_encoder.py b/mmocr/models/textrecog/encoders/nrtr_encoder.py new file mode 100644 index 00000000..72b229f0 --- /dev/null +++ b/mmocr/models/textrecog/encoders/nrtr_encoder.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch.nn as nn +from mmcv.runner import ModuleList + +from mmocr.models.builder import ENCODERS +from mmocr.models.common import TFEncoderLayer +from .base_encoder import BaseEncoder + + +@ENCODERS.register_module() +class NRTREncoder(BaseEncoder): + """Transformer Encoder block with self attention mechanism. + + Args: + n_layers (int): The number of sub-encoder-layers + in the encoder (default=6). + n_head (int): The number of heads in the + multiheadattention models (default=8). + d_k (int): Total number of features in key. + d_v (int): Total number of features in value. + d_model (int): The number of expected features + in the decoder inputs (default=512). + d_inner (int): The dimension of the feedforward + network model (default=256). + dropout (float): Dropout layer on attn_output_weights. + init_cfg (dict or list[dict], optional): Initialization configs. + """ + + def __init__(self, + n_layers=6, + n_head=8, + d_k=64, + d_v=64, + d_model=512, + d_inner=256, + dropout=0.1, + init_cfg=None, + **kwargs): + super().__init__(init_cfg=init_cfg) + self.d_model = d_model + self.layer_stack = ModuleList([ + TFEncoderLayer( + d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs) + for _ in range(n_layers) + ]) + self.layer_norm = nn.LayerNorm(d_model) + + def _get_mask(self, logit, img_metas): + valid_ratios = None + if img_metas is not None: + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] + N, T, _ = logit.size() + mask = None + if valid_ratios is not None: + mask = logit.new_zeros((N, T)) + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(T, math.ceil(T * valid_ratio)) + mask[i, :valid_width] = 1 + + return mask + + def forward(self, feat, img_metas=None): + r""" + Args: + feat (Tensor): Backbone output of shape :math:`(N, C, H, W)`. + img_metas (dict): A dict that contains meta information of input + images. Preferably with the key ``valid_ratio``. + + Returns: + Tensor: The encoder output tensor. Shape :math:`(N, T, C)`. + """ + n, c, h, w = feat.size() + + feat = feat.view(n, c, h * w).permute(0, 2, 1).contiguous() + + mask = self._get_mask(feat, img_metas) + + output = feat + for enc_layer in self.layer_stack: + output = enc_layer(output, mask) + output = self.layer_norm(output) + + return output diff --git a/mmocr/models/textrecog/encoders/satrn_encoder.py b/mmocr/models/textrecog/encoders/satrn_encoder.py index 915ca570..7056a905 100644 --- a/mmocr/models/textrecog/encoders/satrn_encoder.py +++ b/mmocr/models/textrecog/encoders/satrn_encoder.py @@ -61,7 +61,7 @@ class SatrnEncoder(BaseEncoder): images. Preferably with the key ``valid_ratio``. Returns: - Tensor: A tensor of shape :math:`(N, D_m, H, W)`. + Tensor: A tensor of shape :math:`(N, T, D_m)`. """ valid_ratios = [1.0 for _ in range(feat.size(0))] if img_metas is not None: @@ -82,7 +82,4 @@ class SatrnEncoder(BaseEncoder): output = enc_layer(output, h, w, mask) output = self.layer_norm(output) - output = output.permute(0, 2, 1).contiguous() - output = output.view(n, self.d_model, h, w) - return output diff --git a/mmocr/models/textrecog/encoders/transformer_encoder.py b/mmocr/models/textrecog/encoders/transformer_encoder.py deleted file mode 100644 index dd799e39..00000000 --- a/mmocr/models/textrecog/encoders/transformer_encoder.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import math - -import torch.nn as nn -from mmcv.runner import ModuleList - -from mmocr.models.builder import ENCODERS -from mmocr.models.textrecog.layers import TransformerEncoderLayer -from .base_encoder import BaseEncoder - - -@ENCODERS.register_module() -class TFEncoder(BaseEncoder): - """Encode 2d feature map to 1d sequence.""" - - def __init__(self, - n_layers=6, - n_head=8, - d_k=64, - d_v=64, - d_model=512, - d_inner=256, - dropout=0.1, - init_cfg=None, - **kwargs): - super().__init__(init_cfg=init_cfg) - self.d_model = d_model - self.layer_stack = ModuleList([ - TransformerEncoderLayer( - d_model, d_inner, n_head, d_k, d_v, dropout=dropout) - for _ in range(n_layers) - ]) - self.layer_norm = nn.LayerNorm(d_model) - - def forward(self, feat, img_metas=None): - valid_ratios = [1.0 for _ in range(feat.size(0))] - if img_metas is not None: - valid_ratios = [ - img_meta.get('valid_ratio', 1.0) for img_meta in img_metas - ] - n, c, h, w = feat.size() - mask = feat.new_zeros((n, h, w)) - for i, valid_ratio in enumerate(valid_ratios): - valid_width = min(w, math.ceil(w * valid_ratio)) - mask[i, :, :valid_width] = 1 - mask = mask.view(n, h * w) - feat = feat.view(n, c, h * w) - - output = feat.permute(0, 2, 1).contiguous() - for enc_layer in self.layer_stack: - output = enc_layer(output, mask) - output = self.layer_norm(output) - - output = output.permute(0, 2, 1).contiguous() - output = output.view(n, self.d_model, h, w) - - return output diff --git a/mmocr/models/textrecog/layers/__init__.py b/mmocr/models/textrecog/layers/__init__.py index 69232b3c..c92fef54 100644 --- a/mmocr/models/textrecog/layers/__init__.py +++ b/mmocr/models/textrecog/layers/__init__.py @@ -4,17 +4,10 @@ from .dot_product_attention_layer import DotProductAttentionLayer from .lstm_layer import BidirectionalLSTM from .position_aware_layer import PositionAwareLayer from .robust_scanner_fusion_layer import RobustScannerFusionLayer -from .transformer_layer import (Adaptive2DPositionalEncoding, - MultiHeadAttention, PositionalEncoding, - PositionwiseFeedForward, SatrnEncoderLayer, - TransformerDecoderLayer, - TransformerEncoderLayer, get_pad_mask, - get_subsequent_mask) +from .satrn_layers import Adaptive2DPositionalEncoding, SatrnEncoderLayer __all__ = [ - 'BidirectionalLSTM', 'MultiHeadAttention', 'PositionalEncoding', - 'Adaptive2DPositionalEncoding', 'PositionwiseFeedForward', 'BasicBlock', + 'BidirectionalLSTM', 'Adaptive2DPositionalEncoding', 'BasicBlock', 'Bottleneck', 'RobustScannerFusionLayer', 'DotProductAttentionLayer', - 'PositionAwareLayer', 'get_pad_mask', 'get_subsequent_mask', - 'TransformerDecoderLayer', 'TransformerEncoderLayer', 'SatrnEncoderLayer' + 'PositionAwareLayer', 'SatrnEncoderLayer' ] diff --git a/mmocr/models/textrecog/layers/satrn_layers.py b/mmocr/models/textrecog/layers/satrn_layers.py new file mode 100644 index 00000000..d75b6dac --- /dev/null +++ b/mmocr/models/textrecog/layers/satrn_layers.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmcv.runner import BaseModule + +from mmocr.models.common import MultiHeadAttention + + +class SatrnEncoderLayer(BaseModule): + """""" + + def __init__(self, + d_model=512, + d_inner=512, + n_head=8, + d_k=64, + d_v=64, + dropout=0.1, + qkv_bias=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.norm1 = nn.LayerNorm(d_model) + self.attn = MultiHeadAttention( + n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout) + self.norm2 = nn.LayerNorm(d_model) + self.feed_forward = LocalityAwareFeedforward( + d_model, d_inner, dropout=dropout) + + def forward(self, x, h, w, mask=None): + n, hw, c = x.size() + residual = x + x = self.norm1(x) + x = residual + self.attn(x, x, x, mask) + residual = x + x = self.norm2(x) + x = x.transpose(1, 2).contiguous().view(n, c, h, w) + x = self.feed_forward(x) + x = x.view(n, c, hw).transpose(1, 2) + x = residual + x + return x + + +class LocalityAwareFeedforward(BaseModule): + """Locality-aware feedforward layer in SATRN, see `SATRN. + + `_ + """ + + def __init__(self, + d_in, + d_hid, + dropout=0.1, + init_cfg=[ + dict(type='Xavier', layer='Conv2d'), + dict(type='Constant', layer='BatchNorm2d', val=1, bias=0) + ]): + super().__init__(init_cfg=init_cfg) + self.conv1 = ConvModule( + d_in, + d_hid, + kernel_size=1, + padding=0, + bias=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU')) + + self.depthwise_conv = ConvModule( + d_hid, + d_hid, + kernel_size=3, + padding=1, + bias=False, + groups=d_hid, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU')) + + self.conv2 = ConvModule( + d_hid, + d_in, + kernel_size=1, + padding=0, + bias=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU')) + + def forward(self, x): + x = self.conv1(x) + x = self.depthwise_conv(x) + x = self.conv2(x) + + return x + + +class Adaptive2DPositionalEncoding(BaseModule): + """Implement Adaptive 2D positional encoder for SATRN, see + `SATRN `_ + Modified from https://github.com/Media-Smart/vedastr + Licensed under the Apache License, Version 2.0 (the "License"); + Args: + d_hid (int): Dimensions of hidden layer. + n_height (int): Max height of the 2D feature output. + n_width (int): Max width of the 2D feature output. + dropout (int): Size of hidden layers of the model. + """ + + def __init__(self, + d_hid=512, + n_height=100, + n_width=100, + dropout=0.1, + init_cfg=[dict(type='Xavier', layer='Conv2d')]): + super().__init__(init_cfg=init_cfg) + + h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid) + h_position_encoder = h_position_encoder.transpose(0, 1) + h_position_encoder = h_position_encoder.view(1, d_hid, n_height, 1) + + w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid) + w_position_encoder = w_position_encoder.transpose(0, 1) + w_position_encoder = w_position_encoder.view(1, d_hid, 1, n_width) + + self.register_buffer('h_position_encoder', h_position_encoder) + self.register_buffer('w_position_encoder', w_position_encoder) + + self.h_scale = self.scale_factor_generate(d_hid) + self.w_scale = self.scale_factor_generate(d_hid) + self.pool = nn.AdaptiveAvgPool2d(1) + self.dropout = nn.Dropout(p=dropout) + + def _get_sinusoid_encoding_table(self, n_position, d_hid): + """Sinusoid position encoding table.""" + denominator = torch.Tensor([ + 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ]) + denominator = denominator.view(1, -1) + pos_tensor = torch.arange(n_position).unsqueeze(-1).float() + sinusoid_table = pos_tensor * denominator + sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) + sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) + + return sinusoid_table + + def scale_factor_generate(self, d_hid): + scale_factor = nn.Sequential( + nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.ReLU(inplace=True), + nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.Sigmoid()) + + return scale_factor + + def forward(self, x): + b, c, h, w = x.size() + + avg_pool = self.pool(x) + + h_pos_encoding = \ + self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :] + w_pos_encoding = \ + self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w] + + out = x + h_pos_encoding + w_pos_encoding + + out = self.dropout(out) + + return out diff --git a/mmocr/models/textrecog/layers/transformer_layer.py b/mmocr/models/textrecog/layers/transformer_layer.py deleted file mode 100644 index 8c933398..00000000 --- a/mmocr/models/textrecog/layers/transformer_layer.py +++ /dev/null @@ -1,399 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -"""This code is from https://github.com/jadore801120/attention-is-all-you-need- -pytorch.""" -import numpy as np -import torch -import torch.nn as nn -from mmcv.cnn import ConvModule -from mmcv.runner import BaseModule - - -class TransformerEncoderLayer(nn.Module): - """""" - - def __init__(self, - d_model=512, - d_inner=256, - n_head=8, - d_k=64, - d_v=64, - dropout=0.1, - qkv_bias=False, - mask_value=0, - act_layer=nn.GELU): - super().__init__() - self.norm1 = nn.LayerNorm(d_model) - self.attn = MultiHeadAttention( - n_head, - d_model, - d_k, - d_v, - qkv_bias=qkv_bias, - dropout=dropout, - mask_value=mask_value) - self.norm2 = nn.LayerNorm(d_model) - self.mlp = PositionwiseFeedForward( - d_model, d_inner, dropout=dropout, act_layer=act_layer) - - def forward(self, x, mask=None): - residual = x - x = self.norm1(x) - x = residual + self.attn(x, x, x, mask) - residual = x - x = self.norm2(x) - x = residual + self.mlp(x) - - return x - - -class SatrnEncoderLayer(BaseModule): - """""" - - def __init__(self, - d_model=512, - d_inner=512, - n_head=8, - d_k=64, - d_v=64, - dropout=0.1, - qkv_bias=False, - mask_value=0, - init_cfg=None): - super().__init__(init_cfg=init_cfg) - self.norm1 = nn.LayerNorm(d_model) - self.attn = MultiHeadAttention( - n_head, - d_model, - d_k, - d_v, - qkv_bias=qkv_bias, - dropout=dropout, - mask_value=mask_value) - self.norm2 = nn.LayerNorm(d_model) - self.feed_forward = LocalityAwareFeedforward( - d_model, d_inner, dropout=dropout) - - def forward(self, x, h, w, mask=None): - n, hw, c = x.size() - residual = x - x = self.norm1(x) - x = residual + self.attn(x, x, x, mask) - residual = x - x = self.norm2(x) - x = x.transpose(1, 2).contiguous().view(n, c, h, w) - x = self.feed_forward(x) - x = x.view(n, c, hw).transpose(1, 2) - x = residual + x - return x - - -class TransformerDecoderLayer(nn.Module): - - def __init__(self, - d_model=512, - d_inner=256, - n_head=8, - d_k=64, - d_v=64, - dropout=0.1, - qkv_bias=False, - mask_value=0, - act_layer=nn.GELU): - super().__init__() - self.self_attn = MultiHeadAttention() - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - - self.self_attn = MultiHeadAttention( - n_head, - d_model, - d_k, - d_v, - dropout=dropout, - qkv_bias=qkv_bias, - mask_value=mask_value) - self.enc_attn = MultiHeadAttention( - n_head, - d_model, - d_k, - d_v, - dropout=dropout, - qkv_bias=qkv_bias, - mask_value=mask_value) - self.mlp = PositionwiseFeedForward( - d_model, d_inner, dropout=dropout, act_layer=act_layer) - - def forward(self, - dec_input, - enc_output, - self_attn_mask=None, - dec_enc_attn_mask=None): - self_attn_in = self.norm1(dec_input) - self_attn_out = self.self_attn(self_attn_in, self_attn_in, - self_attn_in, self_attn_mask) - enc_attn_in = dec_input + self_attn_out - - enc_attn_q = self.norm2(enc_attn_in) - enc_attn_out = self.enc_attn(enc_attn_q, enc_output, enc_output, - dec_enc_attn_mask) - - mlp_in = enc_attn_in + enc_attn_out - mlp_out = self.mlp(self.norm3(mlp_in)) - out = mlp_in + mlp_out - - return out - - -class MultiHeadAttention(nn.Module): - """Multi-Head Attention module.""" - - def __init__(self, - n_head=8, - d_model=512, - d_k=64, - d_v=64, - dropout=0.1, - qkv_bias=False, - mask_value=0): - super().__init__() - - self.mask_value = mask_value - - self.n_head = n_head - self.d_k = d_k - self.d_v = d_v - - self.scale = d_k**-0.5 - - self.dim_k = n_head * d_k - self.dim_v = n_head * d_v - - self.linear_q = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias) - - self.linear_k = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias) - - self.linear_v = nn.Linear(self.dim_v, self.dim_v, bias=qkv_bias) - - self.fc = nn.Linear(self.dim_v, d_model, bias=qkv_bias) - - self.attn_drop = nn.Dropout(dropout) - self.proj_drop = nn.Dropout(dropout) - - def forward(self, q, k, v, mask=None): - batch_size, len_q, _ = q.size() - _, len_k, _ = k.size() - - q = self.linear_q(q).view(batch_size, len_q, self.n_head, self.d_k) - k = self.linear_k(k).view(batch_size, len_k, self.n_head, self.d_k) - v = self.linear_v(v).view(batch_size, len_k, self.n_head, self.d_v) - - q = q.permute(0, 2, 1, 3) - k = k.permute(0, 2, 3, 1) - v = v.permute(0, 2, 1, 3) - - logits = torch.matmul(q, k) * self.scale - - if mask is not None: - if mask.dim() == 3: - mask = mask.unsqueeze(1) - elif mask.dim() == 2: - mask = mask.unsqueeze(1).unsqueeze(1) - logits = logits.masked_fill(mask == self.mask_value, float('-inf')) - weights = logits.softmax(dim=-1) - weights = self.attn_drop(weights) - - attn_out = torch.matmul(weights, v).transpose(1, 2) - attn_out = attn_out.reshape(batch_size, len_q, self.dim_v) - attn_out = self.fc(attn_out) - attn_out = self.proj_drop(attn_out) - - return attn_out - - -class PositionwiseFeedForward(nn.Module): - """A two-feed-forward-layer module.""" - - def __init__(self, d_in, d_hid, dropout=0.1, act_layer=nn.GELU): - super().__init__() - self.w_1 = nn.Linear(d_in, d_hid) - self.w_2 = nn.Linear(d_hid, d_in) - self.act = act_layer() - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - x = self.w_1(x) - x = self.act(x) - x = self.dropout(x) - x = self.w_2(x) - x = self.dropout(x) - - return x - - -class LocalityAwareFeedforward(BaseModule): - """Locality-aware feedforward layer in SATRN, see `SATRN. - - `_ - """ - - def __init__(self, - d_in, - d_hid, - dropout=0.1, - init_cfg=[ - dict(type='Xavier', layer='Conv2d'), - dict(type='Constant', layer='BatchNorm2d', val=1, bias=0) - ]): - super().__init__(init_cfg=init_cfg) - self.conv1 = ConvModule( - d_in, - d_hid, - kernel_size=1, - padding=0, - bias=False, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU')) - - self.depthwise_conv = ConvModule( - d_hid, - d_hid, - kernel_size=3, - padding=1, - bias=False, - groups=d_hid, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU')) - - self.conv2 = ConvModule( - d_hid, - d_in, - kernel_size=1, - padding=0, - bias=False, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU')) - - def forward(self, x): - x = self.conv1(x) - x = self.depthwise_conv(x) - x = self.conv2(x) - - return x - - -class PositionalEncoding(nn.Module): - - def __init__(self, d_hid=512, n_position=200): - super().__init__() - - # Not a parameter - self.register_buffer( - 'position_table', - self._get_sinusoid_encoding_table(n_position, d_hid)) - - def _get_sinusoid_encoding_table(self, n_position, d_hid): - """Sinusoid position encoding table.""" - denominator = torch.Tensor([ - 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) - for hid_j in range(d_hid) - ]) - denominator = denominator.view(1, -1) - pos_tensor = torch.arange(n_position).unsqueeze(-1).float() - sinusoid_table = pos_tensor * denominator - sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) - sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) - - return sinusoid_table.unsqueeze(0) - - def forward(self, x): - self.device = x.device - return x + self.position_table[:, :x.size(1)].clone().detach() - - -class Adaptive2DPositionalEncoding(BaseModule): - """Implement Adaptive 2D positional encoder for SATRN, see - `SATRN `_ - Modified from https://github.com/Media-Smart/vedastr - Licensed under the Apache License, Version 2.0 (the "License"); - Args: - d_hid (int): Dimensions of hidden layer. - n_height (int): Max height of the 2D feature output. - n_width (int): Max width of the 2D feature output. - dropout (int): Size of hidden layers of the model. - """ - - def __init__(self, - d_hid=512, - n_height=100, - n_width=100, - dropout=0.1, - init_cfg=[dict(type='Xavier', layer='Conv2d')]): - super().__init__(init_cfg=init_cfg) - - h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid) - h_position_encoder = h_position_encoder.transpose(0, 1) - h_position_encoder = h_position_encoder.view(1, d_hid, n_height, 1) - - w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid) - w_position_encoder = w_position_encoder.transpose(0, 1) - w_position_encoder = w_position_encoder.view(1, d_hid, 1, n_width) - - self.register_buffer('h_position_encoder', h_position_encoder) - self.register_buffer('w_position_encoder', w_position_encoder) - - self.h_scale = self.scale_factor_generate(d_hid) - self.w_scale = self.scale_factor_generate(d_hid) - self.pool = nn.AdaptiveAvgPool2d(1) - self.dropout = nn.Dropout(p=dropout) - - def _get_sinusoid_encoding_table(self, n_position, d_hid): - """Sinusoid position encoding table.""" - denominator = torch.Tensor([ - 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) - for hid_j in range(d_hid) - ]) - denominator = denominator.view(1, -1) - pos_tensor = torch.arange(n_position).unsqueeze(-1).float() - sinusoid_table = pos_tensor * denominator - sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) - sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) - - return sinusoid_table - - def scale_factor_generate(self, d_hid): - scale_factor = nn.Sequential( - nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.ReLU(inplace=True), - nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.Sigmoid()) - - return scale_factor - - def forward(self, x): - b, c, h, w = x.size() - - avg_pool = self.pool(x) - - h_pos_encoding = \ - self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :] - w_pos_encoding = \ - self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w] - - out = x + h_pos_encoding + w_pos_encoding - - out = self.dropout(out) - - return out - - -def get_pad_mask(seq, pad_idx): - return (seq != pad_idx).unsqueeze(-2) - - -def get_subsequent_mask(seq): - """For masking out the subsequent info.""" - len_s = seq.size(1) - subsequent_mask = 1 - torch.triu( - torch.ones((len_s, len_s), device=seq.device), diagonal=1) - subsequent_mask = subsequent_mask.unsqueeze(0).bool() - return subsequent_mask diff --git a/mmocr/utils/ocr.py b/mmocr/utils/ocr.py index de369f3f..1c5f342e 100755 --- a/mmocr/utils/ocr.py +++ b/mmocr/utils/ocr.py @@ -282,12 +282,13 @@ class MMOCR: }, 'NRTR_1/16-1/8': { 'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py', - 'ckpt': 'nrtr/nrtr_r31_academic_20210406-954db95e.pth' + 'ckpt': + 'nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth' }, 'NRTR_1/8-1/4': { 'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py', 'ckpt': - 'nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth' + 'nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth' }, 'RobustScanner': { 'config': 'robust_scanner/robustscanner_r31_academic.py', diff --git a/tests/test_apis/test_model_inference.py b/tests/test_apis/test_model_inference.py index 4476b35e..54c8eab1 100644 --- a/tests/test_apis/test_model_inference.py +++ b/tests/test_apis/test_model_inference.py @@ -1,10 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import os import pytest +from mmcv import Config from mmcv.image import imread -from mmocr.apis.inference import init_detector, model_inference +from mmocr.apis.inference import (disable_text_recog_aug_test, init_detector, + model_inference) from mmocr.datasets import build_dataset # noqa: F401 from mmocr.models import build_detector # noqa: F401 from mmocr.utils import revert_sync_batchnorm @@ -121,3 +124,26 @@ def test_model_batch_inference_empty_detection(cfg_file): match='empty imgs provided, please check and try again'): model_inference(model, empty_detection, batch_mode=True) + + +@pytest.mark.parametrize('cfg_file', [ + '../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py', +]) +def test_disable_text_recog_aug_test(cfg_file): + tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + config_file = os.path.join(tmp_dir, cfg_file) + + cfg = Config.fromfile(config_file) + cfg1 = copy.deepcopy(cfg) + test = cfg1.data.test.datasets[0] + test.pipeline = cfg1.data.test.pipeline + cfg1.data.test = test + disable_text_recog_aug_test(cfg1, set_types=['test']) + + cfg2 = copy.deepcopy(cfg) + cfg2.data.test.pipeline = None + disable_text_recog_aug_test(cfg2, set_types=['test']) + + cfg2 = copy.deepcopy(cfg) + cfg2.data.test = Config(dict(type='ConcatDataset', datasets=[test])) + disable_text_recog_aug_test(cfg2, set_types=['test']) diff --git a/tests/test_models/test_ocr_decoder.py b/tests/test_models/test_ocr_decoder.py index 5d41a2fc..f17ccc58 100644 --- a/tests/test_models/test_ocr_decoder.py +++ b/tests/test_models/test_ocr_decoder.py @@ -4,9 +4,10 @@ import math import pytest import torch -from mmocr.models.textrecog.decoders import (BaseDecoder, ParallelSARDecoder, +from mmocr.models.textrecog.decoders import (BaseDecoder, NRTRDecoder, + ParallelSARDecoder, ParallelSARDecoderWithBS, - SequentialSARDecoder, TFDecoder) + SequentialSARDecoder) from mmocr.models.textrecog.decoders.sar_decoder_with_bs import DecodeNode @@ -97,11 +98,11 @@ def test_parallel_sar_decoder_with_beam_search(): def test_transformer_decoder(): - decoder = TFDecoder(num_classes=37, padding_idx=36, max_seq_len=5) + decoder = NRTRDecoder(num_classes=37, padding_idx=36, max_seq_len=5) decoder.init_weights() decoder.train() - out_enc = torch.rand(1, 512, 1, 25) + out_enc = torch.rand(1, 25, 512) tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])} img_metas = [{'valid_ratio': 1.0}] tgt_dict['padded_targets'] = tgt_dict['padded_targets'] diff --git a/tests/test_models/test_ocr_encoder.py b/tests/test_models/test_ocr_encoder.py index 81eb01fc..a0c21c18 100644 --- a/tests/test_models/test_ocr_encoder.py +++ b/tests/test_models/test_ocr_encoder.py @@ -2,8 +2,8 @@ import pytest import torch -from mmocr.models.textrecog.encoders import (BaseEncoder, SAREncoder, - SatrnEncoder, TFEncoder) +from mmocr.models.textrecog.encoders import (BaseEncoder, NRTREncoder, + SAREncoder, SatrnEncoder) def test_sar_encoder(): @@ -34,14 +34,14 @@ def test_sar_encoder(): def test_transformer_encoder(): - tf_encoder = TFEncoder() + tf_encoder = NRTREncoder() tf_encoder.init_weights() tf_encoder.train() feat = torch.randn(1, 512, 1, 25) out_enc = tf_encoder(feat) print('hello', out_enc.size()) - assert out_enc.shape == torch.Size([1, 512, 1, 25]) + assert out_enc.shape == torch.Size([1, 25, 512]) def test_satrn_encoder(): @@ -51,7 +51,7 @@ def test_satrn_encoder(): feat = torch.randn(1, 512, 8, 25) out_enc = satrn_encoder(feat) - assert out_enc.shape == torch.Size([1, 512, 8, 25]) + assert out_enc.shape == torch.Size([1, 200, 512]) def test_base_encoder(): diff --git a/tests/test_models/test_ocr_layer.py b/tests/test_models/test_ocr_layer.py index 78cd60c2..e4b4a39b 100644 --- a/tests/test_models/test_ocr_layer.py +++ b/tests/test_models/test_ocr_layer.py @@ -1,10 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from mmocr.models.textrecog.layers import (BasicBlock, Bottleneck, - PositionalEncoding, - TransformerDecoderLayer, - get_pad_mask, get_subsequent_mask) +from mmocr.models.common import (PositionalEncoding, TFDecoderLayer, + TFEncoderLayer) +from mmocr.models.textrecog.layers import BasicBlock, Bottleneck from mmocr.models.textrecog.layers.conv_layer import conv3x3 @@ -34,24 +33,31 @@ def test_conv_layer(): def test_transformer_layer(): # test decoder_layer - decoder_layer = TransformerDecoderLayer() + decoder_layer = TFDecoderLayer() in_dec = torch.rand(1, 30, 512) out_enc = torch.rand(1, 128, 512) out_dec = decoder_layer(in_dec, out_enc) assert out_dec.shape == torch.Size([1, 30, 512]) + decoder_layer = TFDecoderLayer( + operation_order=('self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn', + 'norm')) + out_dec = decoder_layer(in_dec, out_enc) + assert out_dec.shape == torch.Size([1, 30, 512]) + # test positional_encoding pos_encoder = PositionalEncoding() x = torch.rand(1, 30, 512) out = pos_encoder(x) assert out.size() == x.size() - # test get pad mask - seq = torch.rand(1, 30) - pad_idx = 0 - out = get_pad_mask(seq, pad_idx) - assert out.shape == torch.Size([1, 1, 30]) + # test encoder_layer + encoder_layer = TFEncoderLayer() + in_enc = torch.rand(1, 20, 512) + out_enc = encoder_layer(in_enc) + assert out_dec.shape == torch.Size([1, 30, 512]) - # test get_subsequent_mask - out_mask = get_subsequent_mask(seq) - assert out_mask.shape == torch.Size([1, 30, 30]) + encoder_layer = TFEncoderLayer( + operation_order=('self_attn', 'norm', 'ffn', 'norm')) + out_enc = encoder_layer(in_enc) + assert out_dec.shape == torch.Size([1, 30, 512]) diff --git a/tests/test_models/test_recog_config.py b/tests/test_models/test_recog_config.py index e885b7ab..5084f4ad 100644 --- a/tests/test_models/test_recog_config.py +++ b/tests/test_models/test_recog_config.py @@ -102,11 +102,19 @@ def _get_detector_cfg(fname): @pytest.mark.parametrize('cfg_file', [ 'textrecog/sar/sar_r31_parallel_decoder_academic.py', + 'textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py', + 'textrecog/sar/sar_r31_sequential_decoder_academic.py', + 'textrecog/crnn/crnn_toy_dataset.py', 'textrecog/crnn/crnn_academic_dataset.py', 'textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py', + 'textrecog/nrtr/nrtr_modality_transform_academic.py', + 'textrecog/nrtr/nrtr_modality_transform_toy_dataset.py', + 'textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py', 'textrecog/robust_scanner/robustscanner_r31_academic.py', 'textrecog/seg/seg_r31_1by16_fpnocr_academic.py', - 'textrecog/satrn/satrn_academic.py' + 'textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py', + 'textrecog/satrn/satrn_academic.py', 'textrecog/satrn/satrn_small.py', + 'textrecog/tps/crnn_tps_academic_dataset.py' ]) def test_recognizer_pipeline(cfg_file): model = _get_detector_cfg(cfg_file)