mirror of https://github.com/open-mmlab/mmocr.git
[Refactor] refactor transformer modules (#618)
* base refactor * update config * modify implementation of nrtr * add config file * add mask * add operation order * fix contiguous bug * update config * fix pytest * fix pytest * update readme * update readme and metafile * update docstring * fix norm cfg and dict size * rm useless * use mmocr builder instead * update pytest * update * remove useless * fix ckpt name * fix path * include all config file into pytest * update inference * Update test_recog_config.pypull/521/head^2
parent
5a8859fe66
commit
0a1787d6bc
|
@ -4,8 +4,8 @@ label_convertor = dict(
|
|||
model = dict(
|
||||
type='NRTR',
|
||||
backbone=dict(type='NRTRModalityTransform'),
|
||||
encoder=dict(type='TFEncoder'),
|
||||
decoder=dict(type='TFDecoder'),
|
||||
encoder=dict(type='NRTREncoder', n_layers=12),
|
||||
decoder=dict(type='NRTRDecoder'),
|
||||
loss=dict(type='TFLoss'),
|
||||
label_convertor=label_convertor,
|
||||
max_seq_len=40)
|
||||
|
|
|
@ -17,26 +17,19 @@ train_pipeline = [
|
|||
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
|
||||
]),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiRotateAugOCR',
|
||||
rotate_degrees=[0, 90, 270],
|
||||
transforms=[
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=32,
|
||||
min_width=32,
|
||||
max_width=160,
|
||||
keep_aspect_ratio=True,
|
||||
width_downsample_ratio=0.25),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
|
||||
]),
|
||||
])
|
||||
type='ResizeOCR',
|
||||
height=32,
|
||||
min_width=32,
|
||||
max_width=160,
|
||||
keep_aspect_ratio=True),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=['filename', 'ori_shape', 'resize_shape', 'valid_ratio'])
|
||||
]
|
||||
|
|
|
@ -66,10 +66,16 @@ Backbone
|
|||
| Methods | Backbone | | Regular Text | | | | Irregular Text | | download |
|
||||
| :-------------------------------------------------------------: | :----------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| | | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
|
||||
| [NRTR](/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py) | R31-1/16-1/8 | 93.9 | 90.0 | 93.5 | | 74.5 | 78.5 | 86.5 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_academic_20210406-954db95e.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20210406_010150.log.json) |
|
||||
| [NRTR](/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py) | R31-1/8-1/4 | 94.7 | 87.5 | 93.3 | | 75.1 | 78.9 | 87.9 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20210406_160845.log.json) |
|
||||
| [NRTR](/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py) | R31-1/16-1/8 | 94.7 | 87.3 | 94.3 | | 73.5 | 78.9 | 85.1 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20211124_002420.log.json) |
|
||||
| [NRTR](/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py) | R31-1/8-1/4 | 95.2 | 90.0 | 94.0 | | 74.1 | 79.4 | 88.2 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20211123_232151.log.json) |
|
||||
|
||||
**Notes:**
|
||||
|
||||
- `R31-1/16-1/8` means the height of feature from backbone is 1/16 of input image, where 1/8 for width.
|
||||
- `R31-1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width.
|
||||
- For backbone `R31-1/16-1/8`:
|
||||
- The output consists of 92 classes, including 26 lowercase letters, 26 uppercase letters, 28 symbols, 10 digital numbers, 1 unknown token and 1 end-of-sequence token.
|
||||
- The encoder-block number is 6.
|
||||
- `1/16-1/8` means the height of feature from backbone is 1/16 of input image, where 1/8 for width.
|
||||
- For backbone `R31-1/8-1/4`:
|
||||
- The output consists of 92 classes, including 26 lowercase letters, 26 uppercase letters, 28 symbols, 10 digital numbers, 1 unknown token and 1 end-of-sequence token.
|
||||
- The encoder-block number is 6.
|
||||
- `1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width.
|
||||
|
|
|
@ -4,13 +4,13 @@ Collections:
|
|||
Training Data: OCRDataset
|
||||
Training Techniques:
|
||||
- Adam
|
||||
Epochs: 5
|
||||
Batch Size: 8192
|
||||
Training Resources: 64x GeForce GTX 1080 Ti
|
||||
Epochs: 6
|
||||
Batch Size: 6144
|
||||
Training Resources: 48x GeForce GTX 1080 Ti
|
||||
Architecture:
|
||||
- ResNet31OCR
|
||||
- TFEncoder
|
||||
- TFDecoder
|
||||
- CNN
|
||||
- NRTREncoder
|
||||
- NRTRDecoder
|
||||
Paper:
|
||||
URL: https://arxiv.org/pdf/1806.00926.pdf
|
||||
Title: 'NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition'
|
||||
|
@ -28,28 +28,28 @@ Models:
|
|||
- Task: Text Recognition
|
||||
Dataset: IIIT5K
|
||||
Metrics:
|
||||
word_acc: 93.9
|
||||
word_acc: 94.7
|
||||
- Task: Text Recognition
|
||||
Dataset: SVT
|
||||
Metrics:
|
||||
word_acc: 80.0
|
||||
word_acc: 87.3
|
||||
- Task: Text Recognition
|
||||
Dataset: ICDAR2013
|
||||
Metrics:
|
||||
word_acc: 93.5
|
||||
word_acc: 94.3
|
||||
- Task: Text Recognition
|
||||
Dataset: ICDAR2015
|
||||
Metrics:
|
||||
word_acc: 74.5
|
||||
word_acc: 73.5
|
||||
- Task: Text Recognition
|
||||
Dataset: SVTP
|
||||
Metrics:
|
||||
word_acc: 78.5
|
||||
word_acc: 78.9
|
||||
- Task: Text Recognition
|
||||
Dataset: CT80
|
||||
Metrics:
|
||||
word_acc: 86.5
|
||||
Weights: https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_academic_20210406-954db95e.pth
|
||||
word_acc: 85.1
|
||||
Weights: https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth
|
||||
|
||||
- Name: nrtr_r31_1by8_1by4_academic
|
||||
In Collection: NRTR
|
||||
|
@ -62,25 +62,25 @@ Models:
|
|||
- Task: Text Recognition
|
||||
Dataset: IIIT5K
|
||||
Metrics:
|
||||
word_acc: 94.7
|
||||
word_acc: 95.2
|
||||
- Task: Text Recognition
|
||||
Dataset: SVT
|
||||
Metrics:
|
||||
word_acc: 87.5
|
||||
word_acc: 90.0
|
||||
- Task: Text Recognition
|
||||
Dataset: ICDAR2013
|
||||
Metrics:
|
||||
word_acc: 93.3
|
||||
word_acc: 94.0
|
||||
- Task: Text Recognition
|
||||
Dataset: ICDAR2015
|
||||
Metrics:
|
||||
word_acc: 75.1
|
||||
word_acc: 74.1
|
||||
- Task: Text Recognition
|
||||
Dataset: SVTP
|
||||
Metrics:
|
||||
word_acc: 78.9
|
||||
word_acc: 79.4
|
||||
- Task: Text Recognition
|
||||
Dataset: CT80
|
||||
Metrics:
|
||||
word_acc: 87.9
|
||||
Weights: https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth
|
||||
word_acc: 88.2
|
||||
Weights: https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
_base_ = [
|
||||
'../../_base_/default_runtime.py',
|
||||
'../../_base_/recog_models/nrtr_modality_transform.py',
|
||||
'../../_base_/schedules/schedule_adam_step_6e.py',
|
||||
'../../_base_/recog_datasets/ST_MJ_train.py',
|
||||
'../../_base_/recog_datasets/academic_test.py',
|
||||
'../../_base_/recog_pipelines/nrtr_pipeline.py'
|
||||
]
|
||||
|
||||
train_list = {{_base_.train_list}}
|
||||
test_list = {{_base_.test_list}}
|
||||
|
||||
train_pipeline = {{_base_.train_pipeline}}
|
||||
test_pipeline = {{_base_.test_pipeline}}
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=128,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type='UniformConcatDataset',
|
||||
datasets=train_list,
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type='UniformConcatDataset',
|
||||
datasets=test_list,
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type='UniformConcatDataset',
|
||||
datasets=test_list,
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=1, metric='acc')
|
|
@ -0,0 +1,31 @@
|
|||
_base_ = [
|
||||
'../../_base_/default_runtime.py',
|
||||
'../../_base_/recog_models/nrtr_modality_transform.py',
|
||||
'../../_base_/schedules/schedule_adam_step_6e.py',
|
||||
'../../_base_/recog_datasets/toy_data.py',
|
||||
'../../_base_/recog_pipelines/nrtr_pipeline.py'
|
||||
]
|
||||
|
||||
train_list = {{_base_.train_list}}
|
||||
test_list = {{_base_.test_list}}
|
||||
|
||||
train_pipeline = {{_base_.train_pipeline}}
|
||||
test_pipeline = {{_base_.test_pipeline}}
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=16,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type='UniformConcatDataset',
|
||||
datasets=train_list,
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type='UniformConcatDataset',
|
||||
datasets=test_list,
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type='UniformConcatDataset',
|
||||
datasets=test_list,
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=1, metric='acc')
|
|
@ -23,8 +23,8 @@ model = dict(
|
|||
channels=[32, 64, 128, 256, 512, 512],
|
||||
stage4_pool_cfg=dict(kernel_size=(2, 1), stride=(2, 1)),
|
||||
last_stage_pool=True),
|
||||
encoder=dict(type='TFEncoder'),
|
||||
decoder=dict(type='TFDecoder'),
|
||||
encoder=dict(type='NRTREncoder'),
|
||||
decoder=dict(type='NRTRDecoder'),
|
||||
loss=dict(type='TFLoss'),
|
||||
label_convertor=label_convertor,
|
||||
max_seq_len=40)
|
||||
|
@ -32,8 +32,6 @@ model = dict(
|
|||
data = dict(
|
||||
samples_per_gpu=128,
|
||||
workers_per_gpu=4,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type='UniformConcatDataset',
|
||||
datasets=train_list,
|
||||
|
|
|
@ -23,8 +23,8 @@ model = dict(
|
|||
channels=[32, 64, 128, 256, 512, 512],
|
||||
stage4_pool_cfg=dict(kernel_size=(2, 1), stride=(2, 1)),
|
||||
last_stage_pool=False),
|
||||
encoder=dict(type='TFEncoder'),
|
||||
decoder=dict(type='TFDecoder'),
|
||||
encoder=dict(type='NRTREncoder'),
|
||||
decoder=dict(type='NRTRDecoder'),
|
||||
loss=dict(type='TFLoss'),
|
||||
label_convertor=label_convertor,
|
||||
max_seq_len=40)
|
||||
|
@ -32,8 +32,6 @@ model = dict(
|
|||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
workers_per_gpu=4,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type='UniformConcatDataset',
|
||||
datasets=train_list,
|
||||
|
|
|
@ -28,7 +28,7 @@ model = dict(
|
|||
d_inner=512 * 4,
|
||||
dropout=0.1),
|
||||
decoder=dict(
|
||||
type='TFDecoder',
|
||||
type='NRTRDecoder',
|
||||
n_layers=6,
|
||||
d_embedding=512,
|
||||
n_head=8,
|
||||
|
|
|
@ -28,7 +28,7 @@ model = dict(
|
|||
d_inner=256 * 4,
|
||||
dropout=0.1),
|
||||
decoder=dict(
|
||||
type='TFDecoder',
|
||||
type='NRTRDecoder',
|
||||
n_layers=6,
|
||||
d_embedding=256,
|
||||
n_head=8,
|
||||
|
|
|
@ -68,12 +68,36 @@ def disable_text_recog_aug_test(cfg, set_types=None):
|
|||
assert set_types is None or isinstance(set_types, list)
|
||||
if set_types is None:
|
||||
set_types = ['val', 'test']
|
||||
warnings.simplefilter('once')
|
||||
warning_msg = 'Remove "MultiRotateAugOCR" to support batch ' + \
|
||||
'inference since samples_per_gpu > 1.'
|
||||
for set_type in set_types:
|
||||
if cfg.data[set_type].pipeline[1].type == 'MultiRotateAugOCR':
|
||||
cfg.data[set_type].pipeline = [
|
||||
cfg.data[set_type].pipeline[0],
|
||||
*cfg.data[set_type].pipeline[1].transforms
|
||||
]
|
||||
dataset_type = cfg.data[set_type].type
|
||||
if dataset_type in ['OCRDataset', 'OCRSegDataset']:
|
||||
if cfg.data[set_type].pipeline[1].type == 'MultiRotateAugOCR':
|
||||
warnings.warn(warning_msg)
|
||||
cfg.data[set_type].pipeline = [
|
||||
cfg.data[set_type].pipeline[0],
|
||||
*cfg.data[set_type].pipeline[1].transforms
|
||||
]
|
||||
elif dataset_type in ['ConcatDataset', 'UniformConcatDataset']:
|
||||
if dataset_type == 'UniformConcatDataset':
|
||||
uniform_pipeline = cfg.data[set_type].pipeline
|
||||
if uniform_pipeline is not None:
|
||||
if uniform_pipeline[1].type == 'MultiRotateAugOCR':
|
||||
warnings.warn(warning_msg)
|
||||
cfg.data[set_type].pipeline = [
|
||||
uniform_pipeline[0],
|
||||
*uniform_pipeline[1].transforms
|
||||
]
|
||||
for dataset in cfg.data[set_type].datasets:
|
||||
if dataset.pipeline is not None:
|
||||
if dataset.pipeline[1].type == 'MultiRotateAugOCR':
|
||||
warnings.warn(warning_msg)
|
||||
dataset.pipeline = [
|
||||
dataset.pipeline[0],
|
||||
*dataset.pipeline[1].transforms
|
||||
]
|
||||
|
||||
return cfg
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import backbones, losses
|
||||
from . import backbones, layers, losses, modules
|
||||
|
||||
from .backbones import * # NOQA
|
||||
from .losses import * # NOQA
|
||||
from .layers import * # NOQA
|
||||
from .modules import * # NOQA
|
||||
|
||||
__all__ = backbones.__all__ + losses.__all__
|
||||
__all__ = backbones.__all__ + losses.__all__ + layers.__all__ + modules.__all__
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from .transformer_layers import TFDecoderLayer, TFEncoderLayer
|
||||
|
||||
__all__ = ['TFEncoderLayer', 'TFDecoderLayer']
|
|
@ -0,0 +1,167 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.models.common.modules import (MultiHeadAttention,
|
||||
PositionwiseFeedForward)
|
||||
|
||||
|
||||
class TFEncoderLayer(BaseModule):
|
||||
"""Transformer Encoder Layer.
|
||||
|
||||
Args:
|
||||
d_model (int): The number of expected features
|
||||
in the decoder inputs (default=512).
|
||||
d_inner (int): The dimension of the feedforward
|
||||
network model (default=256).
|
||||
n_head (int): The number of heads in the
|
||||
multiheadattention models (default=8).
|
||||
d_k (int): Total number of features in key.
|
||||
d_v (int): Total number of features in value.
|
||||
dropout (float): Dropout layer on attn_output_weights.
|
||||
qkv_bias (bool): Add bias in projection layer. Default: False.
|
||||
act_cfg (dict): Activation cfg for feedforward module.
|
||||
operation_order (tuple[str]): The execution order of operation
|
||||
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm')
|
||||
or ('norm', 'self_attn', 'norm', 'ffn').
|
||||
Default:None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_model=512,
|
||||
d_inner=256,
|
||||
n_head=8,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
dropout=0.1,
|
||||
qkv_bias=False,
|
||||
act_cfg=dict(type='mmcv.GELU'),
|
||||
operation_order=None):
|
||||
super().__init__()
|
||||
self.attn = MultiHeadAttention(
|
||||
n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.mlp = PositionwiseFeedForward(
|
||||
d_model, d_inner, dropout=dropout, act_cfg=act_cfg)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
|
||||
self.operation_order = operation_order
|
||||
if self.operation_order is None:
|
||||
self.operation_order = ('norm', 'self_attn', 'norm', 'ffn')
|
||||
|
||||
assert self.operation_order in [('norm', 'self_attn', 'norm', 'ffn'),
|
||||
('self_attn', 'norm', 'ffn', 'norm')]
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
if self.operation_order == ('self_attn', 'norm', 'ffn', 'norm'):
|
||||
residual = x
|
||||
x = residual + self.attn(x, x, x, mask)
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
x = residual + self.mlp(x)
|
||||
x = self.norm2(x)
|
||||
elif self.operation_order == ('norm', 'self_attn', 'norm', 'ffn'):
|
||||
residual = x
|
||||
x = self.norm1(x)
|
||||
x = residual + self.attn(x, x, x, mask)
|
||||
|
||||
residual = x
|
||||
x = self.norm2(x)
|
||||
x = residual + self.mlp(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TFDecoderLayer(nn.Module):
|
||||
"""Transformer Decoder Layer.
|
||||
|
||||
Args:
|
||||
d_model (int): The number of expected features
|
||||
in the decoder inputs (default=512).
|
||||
d_inner (int): The dimension of the feedforward
|
||||
network model (default=256).
|
||||
n_head (int): The number of heads in the
|
||||
multiheadattention models (default=8).
|
||||
d_k (int): Total number of features in key.
|
||||
d_v (int): Total number of features in value.
|
||||
dropout (float): Dropout layer on attn_output_weights.
|
||||
qkv_bias (bool): Add bias in projection layer. Default: False.
|
||||
act_cfg (dict): Activation cfg for feedforward module.
|
||||
operation_order (tuple[str]): The execution order of operation
|
||||
in transformer. Such as ('self_attn', 'norm', 'enc_dec_attn',
|
||||
'norm', 'ffn', 'norm') or ('norm', 'self_attn', 'norm',
|
||||
'enc_dec_attn', 'norm', 'ffn').
|
||||
Default:None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_model=512,
|
||||
d_inner=256,
|
||||
n_head=8,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
dropout=0.1,
|
||||
qkv_bias=False,
|
||||
act_cfg=dict(type='mmcv.GELU'),
|
||||
operation_order=None):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
|
||||
self.self_attn = MultiHeadAttention(
|
||||
n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias)
|
||||
|
||||
self.enc_attn = MultiHeadAttention(
|
||||
n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias)
|
||||
|
||||
self.mlp = PositionwiseFeedForward(
|
||||
d_model, d_inner, dropout=dropout, act_cfg=act_cfg)
|
||||
|
||||
self.operation_order = operation_order
|
||||
if self.operation_order is None:
|
||||
self.operation_order = ('norm', 'self_attn', 'norm',
|
||||
'enc_dec_attn', 'norm', 'ffn')
|
||||
assert self.operation_order in [
|
||||
('norm', 'self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn'),
|
||||
('self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn', 'norm')
|
||||
]
|
||||
|
||||
def forward(self,
|
||||
dec_input,
|
||||
enc_output,
|
||||
self_attn_mask=None,
|
||||
dec_enc_attn_mask=None):
|
||||
if self.operation_order == ('self_attn', 'norm', 'enc_dec_attn',
|
||||
'norm', 'ffn', 'norm'):
|
||||
dec_attn_out = self.self_attn(dec_input, dec_input, dec_input,
|
||||
self_attn_mask)
|
||||
dec_attn_out += dec_input
|
||||
dec_attn_out = self.norm1(dec_attn_out)
|
||||
|
||||
enc_dec_attn_out = self.enc_attn(dec_attn_out, enc_output,
|
||||
enc_output, dec_enc_attn_mask)
|
||||
enc_dec_attn_out += dec_attn_out
|
||||
enc_dec_attn_out = self.norm2(enc_dec_attn_out)
|
||||
|
||||
mlp_out = self.mlp(enc_dec_attn_out)
|
||||
mlp_out += enc_dec_attn_out
|
||||
mlp_out = self.norm3(mlp_out)
|
||||
elif self.operation_order == ('norm', 'self_attn', 'norm',
|
||||
'enc_dec_attn', 'norm', 'ffn'):
|
||||
dec_input_norm = self.norm1(dec_input)
|
||||
dec_attn_out = self.self_attn(dec_input_norm, dec_input_norm,
|
||||
dec_input_norm, self_attn_mask)
|
||||
dec_attn_out += dec_input
|
||||
|
||||
enc_dec_attn_in = self.norm2(dec_attn_out)
|
||||
enc_dec_attn_out = self.enc_attn(enc_dec_attn_in, enc_output,
|
||||
enc_output, dec_enc_attn_mask)
|
||||
enc_dec_attn_out += dec_attn_out
|
||||
|
||||
mlp_out = self.mlp(self.norm3(enc_dec_attn_out))
|
||||
mlp_out += enc_dec_attn_out
|
||||
|
||||
return mlp_out
|
|
@ -0,0 +1,8 @@
|
|||
from .transformer_module import (MultiHeadAttention, PositionalEncoding,
|
||||
PositionwiseFeedForward,
|
||||
ScaledDotProductAttention)
|
||||
|
||||
__all__ = [
|
||||
'ScaledDotProductAttention', 'MultiHeadAttention',
|
||||
'PositionwiseFeedForward', 'PositionalEncoding'
|
||||
]
|
|
@ -0,0 +1,155 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmocr.models.builder import build_activation_layer
|
||||
|
||||
|
||||
class ScaledDotProductAttention(nn.Module):
|
||||
"""Scaled Dot-Product Attention Module. This code is adopted from
|
||||
https://github.com/jadore801120/attention-is-all-you-need-pytorch.
|
||||
|
||||
Args:
|
||||
temperature (float): The scale factor for softmax input.
|
||||
attn_dropout (float): Dropout layer on attn_output_weights.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature, attn_dropout=0.1):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.dropout = nn.Dropout(attn_dropout)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
|
||||
attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
|
||||
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill(mask == 0, float('-inf'))
|
||||
|
||||
attn = self.dropout(F.softmax(attn, dim=-1))
|
||||
output = torch.matmul(attn, v)
|
||||
|
||||
return output, attn
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-Head Attention module.
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads in the
|
||||
multiheadattention models (default=8).
|
||||
d_model (int): The number of expected features
|
||||
in the decoder inputs (default=512).
|
||||
d_k (int): Total number of features in key.
|
||||
d_v (int): Total number of features in value.
|
||||
dropout (float): Dropout layer on attn_output_weights.
|
||||
qkv_bias (bool): Add bias in projection layer. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_head=8,
|
||||
d_model=512,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
dropout=0.1,
|
||||
qkv_bias=False):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.d_k = d_k
|
||||
self.d_v = d_v
|
||||
|
||||
self.dim_k = n_head * d_k
|
||||
self.dim_v = n_head * d_v
|
||||
|
||||
self.linear_q = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias)
|
||||
self.linear_k = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias)
|
||||
self.linear_v = nn.Linear(self.dim_v, self.dim_v, bias=qkv_bias)
|
||||
|
||||
self.attention = ScaledDotProductAttention(d_k**0.5, dropout)
|
||||
|
||||
self.fc = nn.Linear(self.dim_v, d_model, bias=qkv_bias)
|
||||
self.proj_drop = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
batch_size, len_q, _ = q.size()
|
||||
_, len_k, _ = k.size()
|
||||
|
||||
q = self.linear_q(q).view(batch_size, len_q, self.n_head, self.d_k)
|
||||
k = self.linear_k(k).view(batch_size, len_k, self.n_head, self.d_k)
|
||||
v = self.linear_v(v).view(batch_size, len_k, self.n_head, self.d_v)
|
||||
|
||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
||||
|
||||
if mask is not None:
|
||||
if mask.dim() == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
elif mask.dim() == 2:
|
||||
mask = mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
attn_out, _ = self.attention(q, k, v, mask=mask)
|
||||
|
||||
attn_out = attn_out.transpose(1, 2).contiguous().view(
|
||||
batch_size, len_q, self.dim_v)
|
||||
|
||||
attn_out = self.fc(attn_out)
|
||||
attn_out = self.proj_drop(attn_out)
|
||||
|
||||
return attn_out
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
"""Two-layer feed-forward module.
|
||||
|
||||
Args:
|
||||
d_in (int): The dimension of the input for feedforward
|
||||
network model.
|
||||
d_hid (int): The dimension of the feedforward
|
||||
network model.
|
||||
dropout (float): Dropout layer on feedforward output.
|
||||
act_cfg (dict): Activation cfg for feedforward module.
|
||||
"""
|
||||
|
||||
def __init__(self, d_in, d_hid, dropout=0.1, act_cfg=dict(type='Relu')):
|
||||
super().__init__()
|
||||
self.w_1 = nn.Linear(d_in, d_hid)
|
||||
self.w_2 = nn.Linear(d_hid, d_in)
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.w_1(x)
|
||||
x = self.act(x)
|
||||
x = self.w_2(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""Fixed positional encoding with sine and cosine functions."""
|
||||
|
||||
def __init__(self, d_hid=512, n_position=200):
|
||||
super().__init__()
|
||||
# Not a parameter
|
||||
self.register_buffer(
|
||||
'position_table',
|
||||
self._get_sinusoid_encoding_table(n_position, d_hid))
|
||||
|
||||
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
||||
"""Sinusoid position encoding table."""
|
||||
denominator = torch.Tensor([
|
||||
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||
for hid_j in range(d_hid)
|
||||
])
|
||||
denominator = denominator.view(1, -1)
|
||||
pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
|
||||
sinusoid_table = pos_tensor * denominator
|
||||
sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
|
||||
sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
|
||||
|
||||
return sinusoid_table.unsqueeze(0)
|
||||
|
||||
def forward(self, x):
|
||||
self.device = x.device
|
||||
return x + self.position_table[:, :x.size(1)].clone().detach()
|
|
@ -10,7 +10,6 @@ class NRTRModalityTransform(BaseModule):
|
|||
|
||||
def __init__(self,
|
||||
input_channels=3,
|
||||
input_height=32,
|
||||
init_cfg=[
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(type='Uniform', layer='BatchNorm2d')
|
||||
|
@ -35,9 +34,7 @@ class NRTRModalityTransform(BaseModule):
|
|||
self.relu_2 = nn.ReLU(True)
|
||||
self.bn_2 = nn.BatchNorm2d(64)
|
||||
|
||||
feat_height = input_height // 4
|
||||
|
||||
self.linear = nn.Linear(64 * feat_height, 512)
|
||||
self.linear = nn.Linear(512, 512)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_1(x)
|
||||
|
@ -49,7 +46,11 @@ class NRTRModalityTransform(BaseModule):
|
|||
x = self.bn_2(x)
|
||||
|
||||
n, c, h, w = x.size()
|
||||
|
||||
x = x.permute(0, 3, 2, 1).contiguous().view(n, w, h * c)
|
||||
|
||||
x = self.linear(x)
|
||||
|
||||
x = x.permute(0, 2, 1).contiguous().view(n, -1, 1, w)
|
||||
|
||||
return x
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_decoder import BaseDecoder
|
||||
from .crnn_decoder import CRNNDecoder
|
||||
from .nrtr_decoder import NRTRDecoder
|
||||
from .position_attention_decoder import PositionAttentionDecoder
|
||||
from .robust_scanner_decoder import RobustScannerDecoder
|
||||
from .sar_decoder import ParallelSARDecoder, SequentialSARDecoder
|
||||
from .sar_decoder_with_bs import ParallelSARDecoderWithBS
|
||||
from .sequence_attention_decoder import SequenceAttentionDecoder
|
||||
from .transformer_decoder import TFDecoder
|
||||
|
||||
__all__ = [
|
||||
'CRNNDecoder', 'ParallelSARDecoder', 'SequentialSARDecoder',
|
||||
'ParallelSARDecoderWithBS', 'TFDecoder', 'BaseDecoder',
|
||||
'ParallelSARDecoderWithBS', 'NRTRDecoder', 'BaseDecoder',
|
||||
'SequenceAttentionDecoder', 'PositionAttentionDecoder',
|
||||
'RobustScannerDecoder'
|
||||
]
|
||||
|
|
|
@ -7,14 +7,12 @@ import torch.nn.functional as F
|
|||
from mmcv.runner import ModuleList
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.models.textrecog.layers.transformer_layer import (
|
||||
PositionalEncoding, TransformerDecoderLayer, get_pad_mask,
|
||||
get_subsequent_mask)
|
||||
from mmocr.models.common import PositionalEncoding, TFDecoderLayer
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
class TFDecoder(BaseDecoder):
|
||||
class NRTRDecoder(BaseDecoder):
|
||||
"""Transformer Decoder block with self attention mechanism.
|
||||
|
||||
Args:
|
||||
|
@ -71,8 +69,8 @@ class TFDecoder(BaseDecoder):
|
|||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
self.layer_stack = ModuleList([
|
||||
TransformerDecoderLayer(
|
||||
d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
||||
TFDecoderLayer(
|
||||
d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||
|
@ -80,13 +78,29 @@ class TFDecoder(BaseDecoder):
|
|||
pred_num_class = num_classes - 1 # ignore padding_idx
|
||||
self.classifier = nn.Linear(d_model, pred_num_class)
|
||||
|
||||
@staticmethod
|
||||
def get_pad_mask(seq, pad_idx):
|
||||
|
||||
return (seq != pad_idx).unsqueeze(-2)
|
||||
|
||||
@staticmethod
|
||||
def get_subsequent_mask(seq):
|
||||
"""For masking out the subsequent info."""
|
||||
len_s = seq.size(1)
|
||||
subsequent_mask = 1 - torch.triu(
|
||||
torch.ones((len_s, len_s), device=seq.device), diagonal=1)
|
||||
subsequent_mask = subsequent_mask.unsqueeze(0).bool()
|
||||
|
||||
return subsequent_mask
|
||||
|
||||
def _attention(self, trg_seq, src, src_mask=None):
|
||||
trg_embedding = self.trg_word_emb(trg_seq)
|
||||
trg_pos_encoded = self.position_enc(trg_embedding)
|
||||
tgt = self.dropout(trg_pos_encoded)
|
||||
|
||||
trg_mask = get_pad_mask(
|
||||
trg_seq, pad_idx=self.padding_idx) & get_subsequent_mask(trg_seq)
|
||||
trg_mask = self.get_pad_mask(
|
||||
trg_seq,
|
||||
pad_idx=self.padding_idx) & self.get_subsequent_mask(trg_seq)
|
||||
output = tgt
|
||||
for dec_layer in self.layer_stack:
|
||||
output = dec_layer(
|
||||
|
@ -98,11 +112,27 @@ class TFDecoder(BaseDecoder):
|
|||
|
||||
return output
|
||||
|
||||
def _get_mask(self, logit, img_metas):
|
||||
valid_ratios = None
|
||||
if img_metas is not None:
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
]
|
||||
N, T, _ = logit.size()
|
||||
mask = None
|
||||
if valid_ratios is not None:
|
||||
mask = logit.new_zeros((N, T))
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(T, math.ceil(T * valid_ratio))
|
||||
mask[i, :valid_width] = 1
|
||||
|
||||
return mask
|
||||
|
||||
def forward_train(self, feat, out_enc, targets_dict, img_metas):
|
||||
r"""
|
||||
Args:
|
||||
feat (None): Unused.
|
||||
out_enc (Tensor): Encoder output of shape :math:`(N, D_m, H, W)`
|
||||
out_enc (Tensor): Encoder output of shape :math:`(N, T, D_m)`
|
||||
where :math:`D_m` is ``d_model``.
|
||||
targets_dict (dict): A dict with the key ``padded_targets``, a
|
||||
tensor of shape :math:`(N, T)`. Each element is the index of a
|
||||
|
@ -113,44 +143,17 @@ class TFDecoder(BaseDecoder):
|
|||
Returns:
|
||||
Tensor: The raw logit tensor. Shape :math:`(N, T, C)`.
|
||||
"""
|
||||
valid_ratios = None
|
||||
if img_metas is not None:
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
]
|
||||
n, c, h, w = out_enc.size()
|
||||
src_mask = None
|
||||
if valid_ratios is not None:
|
||||
src_mask = out_enc.new_zeros((n, h, w))
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(w, math.ceil(w * valid_ratio))
|
||||
src_mask[i, :, :valid_width] = 1
|
||||
src_mask = src_mask.view(n, h * w)
|
||||
out_enc = out_enc.view(n, c, h * w).permute(0, 2, 1)
|
||||
out_enc = out_enc.contiguous()
|
||||
src_mask = self._get_mask(out_enc, img_metas)
|
||||
targets = targets_dict['padded_targets'].to(out_enc.device)
|
||||
attn_output = self._attention(targets, out_enc, src_mask=src_mask)
|
||||
outputs = self.classifier(attn_output)
|
||||
|
||||
return outputs
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
valid_ratios = None
|
||||
if img_metas is not None:
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
]
|
||||
n, c, h, w = out_enc.size()
|
||||
src_mask = None
|
||||
if valid_ratios is not None:
|
||||
src_mask = out_enc.new_zeros((n, h, w))
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(w, math.ceil(w * valid_ratio))
|
||||
src_mask[i, :, :valid_width] = 1
|
||||
src_mask = src_mask.view(n, h * w)
|
||||
out_enc = out_enc.view(n, c, h * w).permute(0, 2, 1)
|
||||
out_enc = out_enc.contiguous()
|
||||
|
||||
init_target_seq = torch.full((n, self.max_seq_len + 1),
|
||||
src_mask = self._get_mask(out_enc, img_metas)
|
||||
N = out_enc.size(0)
|
||||
init_target_seq = torch.full((N, self.max_seq_len + 1),
|
||||
self.padding_idx,
|
||||
device=out_enc.device,
|
||||
dtype=torch.long)
|
||||
|
@ -161,7 +164,7 @@ class TFDecoder(BaseDecoder):
|
|||
for step in range(0, self.max_seq_len):
|
||||
decoder_output = self._attention(
|
||||
init_target_seq, out_enc, src_mask=src_mask)
|
||||
# bsz * seq_len * 512
|
||||
# bsz * seq_len * C
|
||||
step_result = F.softmax(
|
||||
self.classifier(decoder_output[:, step, :]), dim=-1)
|
||||
# bsz * num_classes
|
|
@ -1,11 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_encoder import BaseEncoder
|
||||
from .channel_reduction_encoder import ChannelReductionEncoder
|
||||
from .nrtr_encoder import NRTREncoder
|
||||
from .sar_encoder import SAREncoder
|
||||
from .satrn_encoder import SatrnEncoder
|
||||
from .transformer_encoder import TFEncoder
|
||||
|
||||
__all__ = [
|
||||
'SAREncoder', 'TFEncoder', 'BaseEncoder', 'ChannelReductionEncoder',
|
||||
'SAREncoder', 'NRTREncoder', 'BaseEncoder', 'ChannelReductionEncoder',
|
||||
'SatrnEncoder'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import ModuleList
|
||||
|
||||
from mmocr.models.builder import ENCODERS
|
||||
from mmocr.models.common import TFEncoderLayer
|
||||
from .base_encoder import BaseEncoder
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
class NRTREncoder(BaseEncoder):
|
||||
"""Transformer Encoder block with self attention mechanism.
|
||||
|
||||
Args:
|
||||
n_layers (int): The number of sub-encoder-layers
|
||||
in the encoder (default=6).
|
||||
n_head (int): The number of heads in the
|
||||
multiheadattention models (default=8).
|
||||
d_k (int): Total number of features in key.
|
||||
d_v (int): Total number of features in value.
|
||||
d_model (int): The number of expected features
|
||||
in the decoder inputs (default=512).
|
||||
d_inner (int): The dimension of the feedforward
|
||||
network model (default=256).
|
||||
dropout (float): Dropout layer on attn_output_weights.
|
||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_layers=6,
|
||||
n_head=8,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
d_model=512,
|
||||
d_inner=256,
|
||||
dropout=0.1,
|
||||
init_cfg=None,
|
||||
**kwargs):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.d_model = d_model
|
||||
self.layer_stack = ModuleList([
|
||||
TFEncoderLayer(
|
||||
d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
|
||||
def _get_mask(self, logit, img_metas):
|
||||
valid_ratios = None
|
||||
if img_metas is not None:
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
]
|
||||
N, T, _ = logit.size()
|
||||
mask = None
|
||||
if valid_ratios is not None:
|
||||
mask = logit.new_zeros((N, T))
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(T, math.ceil(T * valid_ratio))
|
||||
mask[i, :valid_width] = 1
|
||||
|
||||
return mask
|
||||
|
||||
def forward(self, feat, img_metas=None):
|
||||
r"""
|
||||
Args:
|
||||
feat (Tensor): Backbone output of shape :math:`(N, C, H, W)`.
|
||||
img_metas (dict): A dict that contains meta information of input
|
||||
images. Preferably with the key ``valid_ratio``.
|
||||
|
||||
Returns:
|
||||
Tensor: The encoder output tensor. Shape :math:`(N, T, C)`.
|
||||
"""
|
||||
n, c, h, w = feat.size()
|
||||
|
||||
feat = feat.view(n, c, h * w).permute(0, 2, 1).contiguous()
|
||||
|
||||
mask = self._get_mask(feat, img_metas)
|
||||
|
||||
output = feat
|
||||
for enc_layer in self.layer_stack:
|
||||
output = enc_layer(output, mask)
|
||||
output = self.layer_norm(output)
|
||||
|
||||
return output
|
|
@ -61,7 +61,7 @@ class SatrnEncoder(BaseEncoder):
|
|||
images. Preferably with the key ``valid_ratio``.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of shape :math:`(N, D_m, H, W)`.
|
||||
Tensor: A tensor of shape :math:`(N, T, D_m)`.
|
||||
"""
|
||||
valid_ratios = [1.0 for _ in range(feat.size(0))]
|
||||
if img_metas is not None:
|
||||
|
@ -82,7 +82,4 @@ class SatrnEncoder(BaseEncoder):
|
|||
output = enc_layer(output, h, w, mask)
|
||||
output = self.layer_norm(output)
|
||||
|
||||
output = output.permute(0, 2, 1).contiguous()
|
||||
output = output.view(n, self.d_model, h, w)
|
||||
|
||||
return output
|
||||
|
|
|
@ -1,57 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import ModuleList
|
||||
|
||||
from mmocr.models.builder import ENCODERS
|
||||
from mmocr.models.textrecog.layers import TransformerEncoderLayer
|
||||
from .base_encoder import BaseEncoder
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
class TFEncoder(BaseEncoder):
|
||||
"""Encode 2d feature map to 1d sequence."""
|
||||
|
||||
def __init__(self,
|
||||
n_layers=6,
|
||||
n_head=8,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
d_model=512,
|
||||
d_inner=256,
|
||||
dropout=0.1,
|
||||
init_cfg=None,
|
||||
**kwargs):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.d_model = d_model
|
||||
self.layer_stack = ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, feat, img_metas=None):
|
||||
valid_ratios = [1.0 for _ in range(feat.size(0))]
|
||||
if img_metas is not None:
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
]
|
||||
n, c, h, w = feat.size()
|
||||
mask = feat.new_zeros((n, h, w))
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(w, math.ceil(w * valid_ratio))
|
||||
mask[i, :, :valid_width] = 1
|
||||
mask = mask.view(n, h * w)
|
||||
feat = feat.view(n, c, h * w)
|
||||
|
||||
output = feat.permute(0, 2, 1).contiguous()
|
||||
for enc_layer in self.layer_stack:
|
||||
output = enc_layer(output, mask)
|
||||
output = self.layer_norm(output)
|
||||
|
||||
output = output.permute(0, 2, 1).contiguous()
|
||||
output = output.view(n, self.d_model, h, w)
|
||||
|
||||
return output
|
|
@ -4,17 +4,10 @@ from .dot_product_attention_layer import DotProductAttentionLayer
|
|||
from .lstm_layer import BidirectionalLSTM
|
||||
from .position_aware_layer import PositionAwareLayer
|
||||
from .robust_scanner_fusion_layer import RobustScannerFusionLayer
|
||||
from .transformer_layer import (Adaptive2DPositionalEncoding,
|
||||
MultiHeadAttention, PositionalEncoding,
|
||||
PositionwiseFeedForward, SatrnEncoderLayer,
|
||||
TransformerDecoderLayer,
|
||||
TransformerEncoderLayer, get_pad_mask,
|
||||
get_subsequent_mask)
|
||||
from .satrn_layers import Adaptive2DPositionalEncoding, SatrnEncoderLayer
|
||||
|
||||
__all__ = [
|
||||
'BidirectionalLSTM', 'MultiHeadAttention', 'PositionalEncoding',
|
||||
'Adaptive2DPositionalEncoding', 'PositionwiseFeedForward', 'BasicBlock',
|
||||
'BidirectionalLSTM', 'Adaptive2DPositionalEncoding', 'BasicBlock',
|
||||
'Bottleneck', 'RobustScannerFusionLayer', 'DotProductAttentionLayer',
|
||||
'PositionAwareLayer', 'get_pad_mask', 'get_subsequent_mask',
|
||||
'TransformerDecoderLayer', 'TransformerEncoderLayer', 'SatrnEncoderLayer'
|
||||
'PositionAwareLayer', 'SatrnEncoderLayer'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,167 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.models.common import MultiHeadAttention
|
||||
|
||||
|
||||
class SatrnEncoderLayer(BaseModule):
|
||||
""""""
|
||||
|
||||
def __init__(self,
|
||||
d_model=512,
|
||||
d_inner=512,
|
||||
n_head=8,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
dropout=0.1,
|
||||
qkv_bias=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.attn = MultiHeadAttention(
|
||||
n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.feed_forward = LocalityAwareFeedforward(
|
||||
d_model, d_inner, dropout=dropout)
|
||||
|
||||
def forward(self, x, h, w, mask=None):
|
||||
n, hw, c = x.size()
|
||||
residual = x
|
||||
x = self.norm1(x)
|
||||
x = residual + self.attn(x, x, x, mask)
|
||||
residual = x
|
||||
x = self.norm2(x)
|
||||
x = x.transpose(1, 2).contiguous().view(n, c, h, w)
|
||||
x = self.feed_forward(x)
|
||||
x = x.view(n, c, hw).transpose(1, 2)
|
||||
x = residual + x
|
||||
return x
|
||||
|
||||
|
||||
class LocalityAwareFeedforward(BaseModule):
|
||||
"""Locality-aware feedforward layer in SATRN, see `SATRN.
|
||||
|
||||
<https://arxiv.org/abs/1910.04396>`_
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_in,
|
||||
d_hid,
|
||||
dropout=0.1,
|
||||
init_cfg=[
|
||||
dict(type='Xavier', layer='Conv2d'),
|
||||
dict(type='Constant', layer='BatchNorm2d', val=1, bias=0)
|
||||
]):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv1 = ConvModule(
|
||||
d_in,
|
||||
d_hid,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'))
|
||||
|
||||
self.depthwise_conv = ConvModule(
|
||||
d_hid,
|
||||
d_hid,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
groups=d_hid,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'))
|
||||
|
||||
self.conv2 = ConvModule(
|
||||
d_hid,
|
||||
d_in,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Adaptive2DPositionalEncoding(BaseModule):
|
||||
"""Implement Adaptive 2D positional encoder for SATRN, see
|
||||
`SATRN <https://arxiv.org/abs/1910.04396>`_
|
||||
Modified from https://github.com/Media-Smart/vedastr
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
Args:
|
||||
d_hid (int): Dimensions of hidden layer.
|
||||
n_height (int): Max height of the 2D feature output.
|
||||
n_width (int): Max width of the 2D feature output.
|
||||
dropout (int): Size of hidden layers of the model.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_hid=512,
|
||||
n_height=100,
|
||||
n_width=100,
|
||||
dropout=0.1,
|
||||
init_cfg=[dict(type='Xavier', layer='Conv2d')]):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid)
|
||||
h_position_encoder = h_position_encoder.transpose(0, 1)
|
||||
h_position_encoder = h_position_encoder.view(1, d_hid, n_height, 1)
|
||||
|
||||
w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid)
|
||||
w_position_encoder = w_position_encoder.transpose(0, 1)
|
||||
w_position_encoder = w_position_encoder.view(1, d_hid, 1, n_width)
|
||||
|
||||
self.register_buffer('h_position_encoder', h_position_encoder)
|
||||
self.register_buffer('w_position_encoder', w_position_encoder)
|
||||
|
||||
self.h_scale = self.scale_factor_generate(d_hid)
|
||||
self.w_scale = self.scale_factor_generate(d_hid)
|
||||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
||||
"""Sinusoid position encoding table."""
|
||||
denominator = torch.Tensor([
|
||||
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||
for hid_j in range(d_hid)
|
||||
])
|
||||
denominator = denominator.view(1, -1)
|
||||
pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
|
||||
sinusoid_table = pos_tensor * denominator
|
||||
sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
|
||||
sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
|
||||
|
||||
return sinusoid_table
|
||||
|
||||
def scale_factor_generate(self, d_hid):
|
||||
scale_factor = nn.Sequential(
|
||||
nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.Sigmoid())
|
||||
|
||||
return scale_factor
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.size()
|
||||
|
||||
avg_pool = self.pool(x)
|
||||
|
||||
h_pos_encoding = \
|
||||
self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :]
|
||||
w_pos_encoding = \
|
||||
self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w]
|
||||
|
||||
out = x + h_pos_encoding + w_pos_encoding
|
||||
|
||||
out = self.dropout(out)
|
||||
|
||||
return out
|
|
@ -1,399 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""This code is from https://github.com/jadore801120/attention-is-all-you-need-
|
||||
pytorch."""
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
""""""
|
||||
|
||||
def __init__(self,
|
||||
d_model=512,
|
||||
d_inner=256,
|
||||
n_head=8,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
dropout=0.1,
|
||||
qkv_bias=False,
|
||||
mask_value=0,
|
||||
act_layer=nn.GELU):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.attn = MultiHeadAttention(
|
||||
n_head,
|
||||
d_model,
|
||||
d_k,
|
||||
d_v,
|
||||
qkv_bias=qkv_bias,
|
||||
dropout=dropout,
|
||||
mask_value=mask_value)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.mlp = PositionwiseFeedForward(
|
||||
d_model, d_inner, dropout=dropout, act_layer=act_layer)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
residual = x
|
||||
x = self.norm1(x)
|
||||
x = residual + self.attn(x, x, x, mask)
|
||||
residual = x
|
||||
x = self.norm2(x)
|
||||
x = residual + self.mlp(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SatrnEncoderLayer(BaseModule):
|
||||
""""""
|
||||
|
||||
def __init__(self,
|
||||
d_model=512,
|
||||
d_inner=512,
|
||||
n_head=8,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
dropout=0.1,
|
||||
qkv_bias=False,
|
||||
mask_value=0,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.attn = MultiHeadAttention(
|
||||
n_head,
|
||||
d_model,
|
||||
d_k,
|
||||
d_v,
|
||||
qkv_bias=qkv_bias,
|
||||
dropout=dropout,
|
||||
mask_value=mask_value)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.feed_forward = LocalityAwareFeedforward(
|
||||
d_model, d_inner, dropout=dropout)
|
||||
|
||||
def forward(self, x, h, w, mask=None):
|
||||
n, hw, c = x.size()
|
||||
residual = x
|
||||
x = self.norm1(x)
|
||||
x = residual + self.attn(x, x, x, mask)
|
||||
residual = x
|
||||
x = self.norm2(x)
|
||||
x = x.transpose(1, 2).contiguous().view(n, c, h, w)
|
||||
x = self.feed_forward(x)
|
||||
x = x.view(n, c, hw).transpose(1, 2)
|
||||
x = residual + x
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
d_model=512,
|
||||
d_inner=256,
|
||||
n_head=8,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
dropout=0.1,
|
||||
qkv_bias=False,
|
||||
mask_value=0,
|
||||
act_layer=nn.GELU):
|
||||
super().__init__()
|
||||
self.self_attn = MultiHeadAttention()
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
|
||||
self.self_attn = MultiHeadAttention(
|
||||
n_head,
|
||||
d_model,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout=dropout,
|
||||
qkv_bias=qkv_bias,
|
||||
mask_value=mask_value)
|
||||
self.enc_attn = MultiHeadAttention(
|
||||
n_head,
|
||||
d_model,
|
||||
d_k,
|
||||
d_v,
|
||||
dropout=dropout,
|
||||
qkv_bias=qkv_bias,
|
||||
mask_value=mask_value)
|
||||
self.mlp = PositionwiseFeedForward(
|
||||
d_model, d_inner, dropout=dropout, act_layer=act_layer)
|
||||
|
||||
def forward(self,
|
||||
dec_input,
|
||||
enc_output,
|
||||
self_attn_mask=None,
|
||||
dec_enc_attn_mask=None):
|
||||
self_attn_in = self.norm1(dec_input)
|
||||
self_attn_out = self.self_attn(self_attn_in, self_attn_in,
|
||||
self_attn_in, self_attn_mask)
|
||||
enc_attn_in = dec_input + self_attn_out
|
||||
|
||||
enc_attn_q = self.norm2(enc_attn_in)
|
||||
enc_attn_out = self.enc_attn(enc_attn_q, enc_output, enc_output,
|
||||
dec_enc_attn_mask)
|
||||
|
||||
mlp_in = enc_attn_in + enc_attn_out
|
||||
mlp_out = self.mlp(self.norm3(mlp_in))
|
||||
out = mlp_in + mlp_out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-Head Attention module."""
|
||||
|
||||
def __init__(self,
|
||||
n_head=8,
|
||||
d_model=512,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
dropout=0.1,
|
||||
qkv_bias=False,
|
||||
mask_value=0):
|
||||
super().__init__()
|
||||
|
||||
self.mask_value = mask_value
|
||||
|
||||
self.n_head = n_head
|
||||
self.d_k = d_k
|
||||
self.d_v = d_v
|
||||
|
||||
self.scale = d_k**-0.5
|
||||
|
||||
self.dim_k = n_head * d_k
|
||||
self.dim_v = n_head * d_v
|
||||
|
||||
self.linear_q = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias)
|
||||
|
||||
self.linear_k = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias)
|
||||
|
||||
self.linear_v = nn.Linear(self.dim_v, self.dim_v, bias=qkv_bias)
|
||||
|
||||
self.fc = nn.Linear(self.dim_v, d_model, bias=qkv_bias)
|
||||
|
||||
self.attn_drop = nn.Dropout(dropout)
|
||||
self.proj_drop = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
batch_size, len_q, _ = q.size()
|
||||
_, len_k, _ = k.size()
|
||||
|
||||
q = self.linear_q(q).view(batch_size, len_q, self.n_head, self.d_k)
|
||||
k = self.linear_k(k).view(batch_size, len_k, self.n_head, self.d_k)
|
||||
v = self.linear_v(v).view(batch_size, len_k, self.n_head, self.d_v)
|
||||
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 3, 1)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
|
||||
logits = torch.matmul(q, k) * self.scale
|
||||
|
||||
if mask is not None:
|
||||
if mask.dim() == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
elif mask.dim() == 2:
|
||||
mask = mask.unsqueeze(1).unsqueeze(1)
|
||||
logits = logits.masked_fill(mask == self.mask_value, float('-inf'))
|
||||
weights = logits.softmax(dim=-1)
|
||||
weights = self.attn_drop(weights)
|
||||
|
||||
attn_out = torch.matmul(weights, v).transpose(1, 2)
|
||||
attn_out = attn_out.reshape(batch_size, len_q, self.dim_v)
|
||||
attn_out = self.fc(attn_out)
|
||||
attn_out = self.proj_drop(attn_out)
|
||||
|
||||
return attn_out
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
"""A two-feed-forward-layer module."""
|
||||
|
||||
def __init__(self, d_in, d_hid, dropout=0.1, act_layer=nn.GELU):
|
||||
super().__init__()
|
||||
self.w_1 = nn.Linear(d_in, d_hid)
|
||||
self.w_2 = nn.Linear(d_hid, d_in)
|
||||
self.act = act_layer()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.w_1(x)
|
||||
x = self.act(x)
|
||||
x = self.dropout(x)
|
||||
x = self.w_2(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LocalityAwareFeedforward(BaseModule):
|
||||
"""Locality-aware feedforward layer in SATRN, see `SATRN.
|
||||
|
||||
<https://arxiv.org/abs/1910.04396>`_
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_in,
|
||||
d_hid,
|
||||
dropout=0.1,
|
||||
init_cfg=[
|
||||
dict(type='Xavier', layer='Conv2d'),
|
||||
dict(type='Constant', layer='BatchNorm2d', val=1, bias=0)
|
||||
]):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv1 = ConvModule(
|
||||
d_in,
|
||||
d_hid,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'))
|
||||
|
||||
self.depthwise_conv = ConvModule(
|
||||
d_hid,
|
||||
d_hid,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
groups=d_hid,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'))
|
||||
|
||||
self.conv2 = ConvModule(
|
||||
d_hid,
|
||||
d_in,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
|
||||
def __init__(self, d_hid=512, n_position=200):
|
||||
super().__init__()
|
||||
|
||||
# Not a parameter
|
||||
self.register_buffer(
|
||||
'position_table',
|
||||
self._get_sinusoid_encoding_table(n_position, d_hid))
|
||||
|
||||
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
||||
"""Sinusoid position encoding table."""
|
||||
denominator = torch.Tensor([
|
||||
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||
for hid_j in range(d_hid)
|
||||
])
|
||||
denominator = denominator.view(1, -1)
|
||||
pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
|
||||
sinusoid_table = pos_tensor * denominator
|
||||
sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
|
||||
sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
|
||||
|
||||
return sinusoid_table.unsqueeze(0)
|
||||
|
||||
def forward(self, x):
|
||||
self.device = x.device
|
||||
return x + self.position_table[:, :x.size(1)].clone().detach()
|
||||
|
||||
|
||||
class Adaptive2DPositionalEncoding(BaseModule):
|
||||
"""Implement Adaptive 2D positional encoder for SATRN, see
|
||||
`SATRN <https://arxiv.org/abs/1910.04396>`_
|
||||
Modified from https://github.com/Media-Smart/vedastr
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
Args:
|
||||
d_hid (int): Dimensions of hidden layer.
|
||||
n_height (int): Max height of the 2D feature output.
|
||||
n_width (int): Max width of the 2D feature output.
|
||||
dropout (int): Size of hidden layers of the model.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_hid=512,
|
||||
n_height=100,
|
||||
n_width=100,
|
||||
dropout=0.1,
|
||||
init_cfg=[dict(type='Xavier', layer='Conv2d')]):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid)
|
||||
h_position_encoder = h_position_encoder.transpose(0, 1)
|
||||
h_position_encoder = h_position_encoder.view(1, d_hid, n_height, 1)
|
||||
|
||||
w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid)
|
||||
w_position_encoder = w_position_encoder.transpose(0, 1)
|
||||
w_position_encoder = w_position_encoder.view(1, d_hid, 1, n_width)
|
||||
|
||||
self.register_buffer('h_position_encoder', h_position_encoder)
|
||||
self.register_buffer('w_position_encoder', w_position_encoder)
|
||||
|
||||
self.h_scale = self.scale_factor_generate(d_hid)
|
||||
self.w_scale = self.scale_factor_generate(d_hid)
|
||||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
||||
"""Sinusoid position encoding table."""
|
||||
denominator = torch.Tensor([
|
||||
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||
for hid_j in range(d_hid)
|
||||
])
|
||||
denominator = denominator.view(1, -1)
|
||||
pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
|
||||
sinusoid_table = pos_tensor * denominator
|
||||
sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
|
||||
sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
|
||||
|
||||
return sinusoid_table
|
||||
|
||||
def scale_factor_generate(self, d_hid):
|
||||
scale_factor = nn.Sequential(
|
||||
nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.Sigmoid())
|
||||
|
||||
return scale_factor
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.size()
|
||||
|
||||
avg_pool = self.pool(x)
|
||||
|
||||
h_pos_encoding = \
|
||||
self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :]
|
||||
w_pos_encoding = \
|
||||
self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w]
|
||||
|
||||
out = x + h_pos_encoding + w_pos_encoding
|
||||
|
||||
out = self.dropout(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def get_pad_mask(seq, pad_idx):
|
||||
return (seq != pad_idx).unsqueeze(-2)
|
||||
|
||||
|
||||
def get_subsequent_mask(seq):
|
||||
"""For masking out the subsequent info."""
|
||||
len_s = seq.size(1)
|
||||
subsequent_mask = 1 - torch.triu(
|
||||
torch.ones((len_s, len_s), device=seq.device), diagonal=1)
|
||||
subsequent_mask = subsequent_mask.unsqueeze(0).bool()
|
||||
return subsequent_mask
|
|
@ -282,12 +282,13 @@ class MMOCR:
|
|||
},
|
||||
'NRTR_1/16-1/8': {
|
||||
'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py',
|
||||
'ckpt': 'nrtr/nrtr_r31_academic_20210406-954db95e.pth'
|
||||
'ckpt':
|
||||
'nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth'
|
||||
},
|
||||
'NRTR_1/8-1/4': {
|
||||
'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py',
|
||||
'ckpt':
|
||||
'nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth'
|
||||
'nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth'
|
||||
},
|
||||
'RobustScanner': {
|
||||
'config': 'robust_scanner/robustscanner_r31_academic.py',
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from mmcv import Config
|
||||
from mmcv.image import imread
|
||||
|
||||
from mmocr.apis.inference import init_detector, model_inference
|
||||
from mmocr.apis.inference import (disable_text_recog_aug_test, init_detector,
|
||||
model_inference)
|
||||
from mmocr.datasets import build_dataset # noqa: F401
|
||||
from mmocr.models import build_detector # noqa: F401
|
||||
from mmocr.utils import revert_sync_batchnorm
|
||||
|
@ -121,3 +124,26 @@ def test_model_batch_inference_empty_detection(cfg_file):
|
|||
match='empty imgs provided, please check and try again'):
|
||||
|
||||
model_inference(model, empty_detection, batch_mode=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg_file', [
|
||||
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
|
||||
])
|
||||
def test_disable_text_recog_aug_test(cfg_file):
|
||||
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
config_file = os.path.join(tmp_dir, cfg_file)
|
||||
|
||||
cfg = Config.fromfile(config_file)
|
||||
cfg1 = copy.deepcopy(cfg)
|
||||
test = cfg1.data.test.datasets[0]
|
||||
test.pipeline = cfg1.data.test.pipeline
|
||||
cfg1.data.test = test
|
||||
disable_text_recog_aug_test(cfg1, set_types=['test'])
|
||||
|
||||
cfg2 = copy.deepcopy(cfg)
|
||||
cfg2.data.test.pipeline = None
|
||||
disable_text_recog_aug_test(cfg2, set_types=['test'])
|
||||
|
||||
cfg2 = copy.deepcopy(cfg)
|
||||
cfg2.data.test = Config(dict(type='ConcatDataset', datasets=[test]))
|
||||
disable_text_recog_aug_test(cfg2, set_types=['test'])
|
||||
|
|
|
@ -4,9 +4,10 @@ import math
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.decoders import (BaseDecoder, ParallelSARDecoder,
|
||||
from mmocr.models.textrecog.decoders import (BaseDecoder, NRTRDecoder,
|
||||
ParallelSARDecoder,
|
||||
ParallelSARDecoderWithBS,
|
||||
SequentialSARDecoder, TFDecoder)
|
||||
SequentialSARDecoder)
|
||||
from mmocr.models.textrecog.decoders.sar_decoder_with_bs import DecodeNode
|
||||
|
||||
|
||||
|
@ -97,11 +98,11 @@ def test_parallel_sar_decoder_with_beam_search():
|
|||
|
||||
|
||||
def test_transformer_decoder():
|
||||
decoder = TFDecoder(num_classes=37, padding_idx=36, max_seq_len=5)
|
||||
decoder = NRTRDecoder(num_classes=37, padding_idx=36, max_seq_len=5)
|
||||
decoder.init_weights()
|
||||
decoder.train()
|
||||
|
||||
out_enc = torch.rand(1, 512, 1, 25)
|
||||
out_enc = torch.rand(1, 25, 512)
|
||||
tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])}
|
||||
img_metas = [{'valid_ratio': 1.0}]
|
||||
tgt_dict['padded_targets'] = tgt_dict['padded_targets']
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.encoders import (BaseEncoder, SAREncoder,
|
||||
SatrnEncoder, TFEncoder)
|
||||
from mmocr.models.textrecog.encoders import (BaseEncoder, NRTREncoder,
|
||||
SAREncoder, SatrnEncoder)
|
||||
|
||||
|
||||
def test_sar_encoder():
|
||||
|
@ -34,14 +34,14 @@ def test_sar_encoder():
|
|||
|
||||
|
||||
def test_transformer_encoder():
|
||||
tf_encoder = TFEncoder()
|
||||
tf_encoder = NRTREncoder()
|
||||
tf_encoder.init_weights()
|
||||
tf_encoder.train()
|
||||
|
||||
feat = torch.randn(1, 512, 1, 25)
|
||||
out_enc = tf_encoder(feat)
|
||||
print('hello', out_enc.size())
|
||||
assert out_enc.shape == torch.Size([1, 512, 1, 25])
|
||||
assert out_enc.shape == torch.Size([1, 25, 512])
|
||||
|
||||
|
||||
def test_satrn_encoder():
|
||||
|
@ -51,7 +51,7 @@ def test_satrn_encoder():
|
|||
|
||||
feat = torch.randn(1, 512, 8, 25)
|
||||
out_enc = satrn_encoder(feat)
|
||||
assert out_enc.shape == torch.Size([1, 512, 8, 25])
|
||||
assert out_enc.shape == torch.Size([1, 200, 512])
|
||||
|
||||
|
||||
def test_base_encoder():
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.layers import (BasicBlock, Bottleneck,
|
||||
PositionalEncoding,
|
||||
TransformerDecoderLayer,
|
||||
get_pad_mask, get_subsequent_mask)
|
||||
from mmocr.models.common import (PositionalEncoding, TFDecoderLayer,
|
||||
TFEncoderLayer)
|
||||
from mmocr.models.textrecog.layers import BasicBlock, Bottleneck
|
||||
from mmocr.models.textrecog.layers.conv_layer import conv3x3
|
||||
|
||||
|
||||
|
@ -34,24 +33,31 @@ def test_conv_layer():
|
|||
|
||||
def test_transformer_layer():
|
||||
# test decoder_layer
|
||||
decoder_layer = TransformerDecoderLayer()
|
||||
decoder_layer = TFDecoderLayer()
|
||||
in_dec = torch.rand(1, 30, 512)
|
||||
out_enc = torch.rand(1, 128, 512)
|
||||
out_dec = decoder_layer(in_dec, out_enc)
|
||||
assert out_dec.shape == torch.Size([1, 30, 512])
|
||||
|
||||
decoder_layer = TFDecoderLayer(
|
||||
operation_order=('self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn',
|
||||
'norm'))
|
||||
out_dec = decoder_layer(in_dec, out_enc)
|
||||
assert out_dec.shape == torch.Size([1, 30, 512])
|
||||
|
||||
# test positional_encoding
|
||||
pos_encoder = PositionalEncoding()
|
||||
x = torch.rand(1, 30, 512)
|
||||
out = pos_encoder(x)
|
||||
assert out.size() == x.size()
|
||||
|
||||
# test get pad mask
|
||||
seq = torch.rand(1, 30)
|
||||
pad_idx = 0
|
||||
out = get_pad_mask(seq, pad_idx)
|
||||
assert out.shape == torch.Size([1, 1, 30])
|
||||
# test encoder_layer
|
||||
encoder_layer = TFEncoderLayer()
|
||||
in_enc = torch.rand(1, 20, 512)
|
||||
out_enc = encoder_layer(in_enc)
|
||||
assert out_dec.shape == torch.Size([1, 30, 512])
|
||||
|
||||
# test get_subsequent_mask
|
||||
out_mask = get_subsequent_mask(seq)
|
||||
assert out_mask.shape == torch.Size([1, 30, 30])
|
||||
encoder_layer = TFEncoderLayer(
|
||||
operation_order=('self_attn', 'norm', 'ffn', 'norm'))
|
||||
out_enc = encoder_layer(in_enc)
|
||||
assert out_dec.shape == torch.Size([1, 30, 512])
|
||||
|
|
|
@ -102,11 +102,19 @@ def _get_detector_cfg(fname):
|
|||
|
||||
@pytest.mark.parametrize('cfg_file', [
|
||||
'textrecog/sar/sar_r31_parallel_decoder_academic.py',
|
||||
'textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py',
|
||||
'textrecog/sar/sar_r31_sequential_decoder_academic.py',
|
||||
'textrecog/crnn/crnn_toy_dataset.py',
|
||||
'textrecog/crnn/crnn_academic_dataset.py',
|
||||
'textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py',
|
||||
'textrecog/nrtr/nrtr_modality_transform_academic.py',
|
||||
'textrecog/nrtr/nrtr_modality_transform_toy_dataset.py',
|
||||
'textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py',
|
||||
'textrecog/robust_scanner/robustscanner_r31_academic.py',
|
||||
'textrecog/seg/seg_r31_1by16_fpnocr_academic.py',
|
||||
'textrecog/satrn/satrn_academic.py'
|
||||
'textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py',
|
||||
'textrecog/satrn/satrn_academic.py', 'textrecog/satrn/satrn_small.py',
|
||||
'textrecog/tps/crnn_tps_academic_dataset.py'
|
||||
])
|
||||
def test_recognizer_pipeline(cfg_file):
|
||||
model = _get_detector_cfg(cfg_file)
|
||||
|
|
Loading…
Reference in New Issue