[Refactor] refactor transformer modules (#618)

* base refactor

* update config

* modify implementation of nrtr

* add config file

* add mask

* add operation order

* fix contiguous bug

* update config

* fix pytest

* fix pytest

* update readme

* update readme and metafile

* update docstring

* fix norm cfg and dict size

* rm useless

* use mmocr builder instead

* update pytest

* update

* remove useless

* fix ckpt name

* fix path

* include all config file into pytest

* update inference

* Update test_recog_config.py
pull/521/head^2
Hongbin Sun 2021-12-04 17:12:31 +08:00 committed by GitHub
parent 5a8859fe66
commit 0a1787d6bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 859 additions and 608 deletions

View File

@ -4,8 +4,8 @@ label_convertor = dict(
model = dict(
type='NRTR',
backbone=dict(type='NRTRModalityTransform'),
encoder=dict(type='TFEncoder'),
decoder=dict(type='TFDecoder'),
encoder=dict(type='NRTREncoder', n_layers=12),
decoder=dict(type='NRTRDecoder'),
loss=dict(type='TFLoss'),
label_convertor=label_convertor,
max_seq_len=40)

View File

@ -17,26 +17,19 @@ train_pipeline = [
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiRotateAugOCR',
rotate_degrees=[0, 90, 270],
transforms=[
dict(
type='ResizeOCR',
height=32,
min_width=32,
max_width=160,
keep_aspect_ratio=True,
width_downsample_ratio=0.25),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
]),
])
type='ResizeOCR',
height=32,
min_width=32,
max_width=160,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=['filename', 'ori_shape', 'resize_shape', 'valid_ratio'])
]

View File

@ -66,10 +66,16 @@ Backbone
| Methods | Backbone | | Regular Text | | | | Irregular Text | | download |
| :-------------------------------------------------------------: | :----------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| | | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
| [NRTR](/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py) | R31-1/16-1/8 | 93.9 | 90.0 | 93.5 | | 74.5 | 78.5 | 86.5 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_academic_20210406-954db95e.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20210406_010150.log.json) |
| [NRTR](/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py) | R31-1/8-1/4 | 94.7 | 87.5 | 93.3 | | 75.1 | 78.9 | 87.9 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20210406_160845.log.json) |
| [NRTR](/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py) | R31-1/16-1/8 | 94.7 | 87.3 | 94.3 | | 73.5 | 78.9 | 85.1 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20211124_002420.log.json) |
| [NRTR](/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py) | R31-1/8-1/4 | 95.2 | 90.0 | 94.0 | | 74.1 | 79.4 | 88.2 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20211123_232151.log.json) |
**Notes:**
- `R31-1/16-1/8` means the height of feature from backbone is 1/16 of input image, where 1/8 for width.
- `R31-1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width.
- For backbone `R31-1/16-1/8`:
- The output consists of 92 classes, including 26 lowercase letters, 26 uppercase letters, 28 symbols, 10 digital numbers, 1 unknown token and 1 end-of-sequence token.
- The encoder-block number is 6.
- `1/16-1/8` means the height of feature from backbone is 1/16 of input image, where 1/8 for width.
- For backbone `R31-1/8-1/4`:
- The output consists of 92 classes, including 26 lowercase letters, 26 uppercase letters, 28 symbols, 10 digital numbers, 1 unknown token and 1 end-of-sequence token.
- The encoder-block number is 6.
- `1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width.

View File

@ -4,13 +4,13 @@ Collections:
Training Data: OCRDataset
Training Techniques:
- Adam
Epochs: 5
Batch Size: 8192
Training Resources: 64x GeForce GTX 1080 Ti
Epochs: 6
Batch Size: 6144
Training Resources: 48x GeForce GTX 1080 Ti
Architecture:
- ResNet31OCR
- TFEncoder
- TFDecoder
- CNN
- NRTREncoder
- NRTRDecoder
Paper:
URL: https://arxiv.org/pdf/1806.00926.pdf
Title: 'NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition'
@ -28,28 +28,28 @@ Models:
- Task: Text Recognition
Dataset: IIIT5K
Metrics:
word_acc: 93.9
word_acc: 94.7
- Task: Text Recognition
Dataset: SVT
Metrics:
word_acc: 80.0
word_acc: 87.3
- Task: Text Recognition
Dataset: ICDAR2013
Metrics:
word_acc: 93.5
word_acc: 94.3
- Task: Text Recognition
Dataset: ICDAR2015
Metrics:
word_acc: 74.5
word_acc: 73.5
- Task: Text Recognition
Dataset: SVTP
Metrics:
word_acc: 78.5
word_acc: 78.9
- Task: Text Recognition
Dataset: CT80
Metrics:
word_acc: 86.5
Weights: https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_academic_20210406-954db95e.pth
word_acc: 85.1
Weights: https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth
- Name: nrtr_r31_1by8_1by4_academic
In Collection: NRTR
@ -62,25 +62,25 @@ Models:
- Task: Text Recognition
Dataset: IIIT5K
Metrics:
word_acc: 94.7
word_acc: 95.2
- Task: Text Recognition
Dataset: SVT
Metrics:
word_acc: 87.5
word_acc: 90.0
- Task: Text Recognition
Dataset: ICDAR2013
Metrics:
word_acc: 93.3
word_acc: 94.0
- Task: Text Recognition
Dataset: ICDAR2015
Metrics:
word_acc: 75.1
word_acc: 74.1
- Task: Text Recognition
Dataset: SVTP
Metrics:
word_acc: 78.9
word_acc: 79.4
- Task: Text Recognition
Dataset: CT80
Metrics:
word_acc: 87.9
Weights: https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth
word_acc: 88.2
Weights: https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth

View File

@ -0,0 +1,32 @@
_base_ = [
'../../_base_/default_runtime.py',
'../../_base_/recog_models/nrtr_modality_transform.py',
'../../_base_/schedules/schedule_adam_step_6e.py',
'../../_base_/recog_datasets/ST_MJ_train.py',
'../../_base_/recog_datasets/academic_test.py',
'../../_base_/recog_pipelines/nrtr_pipeline.py'
]
train_list = {{_base_.train_list}}
test_list = {{_base_.test_list}}
train_pipeline = {{_base_.train_pipeline}}
test_pipeline = {{_base_.test_pipeline}}
data = dict(
samples_per_gpu=128,
workers_per_gpu=4,
train=dict(
type='UniformConcatDataset',
datasets=train_list,
pipeline=train_pipeline),
val=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline),
test=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='acc')

View File

@ -0,0 +1,31 @@
_base_ = [
'../../_base_/default_runtime.py',
'../../_base_/recog_models/nrtr_modality_transform.py',
'../../_base_/schedules/schedule_adam_step_6e.py',
'../../_base_/recog_datasets/toy_data.py',
'../../_base_/recog_pipelines/nrtr_pipeline.py'
]
train_list = {{_base_.train_list}}
test_list = {{_base_.test_list}}
train_pipeline = {{_base_.train_pipeline}}
test_pipeline = {{_base_.test_pipeline}}
data = dict(
samples_per_gpu=16,
workers_per_gpu=2,
train=dict(
type='UniformConcatDataset',
datasets=train_list,
pipeline=train_pipeline),
val=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline),
test=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='acc')

View File

@ -23,8 +23,8 @@ model = dict(
channels=[32, 64, 128, 256, 512, 512],
stage4_pool_cfg=dict(kernel_size=(2, 1), stride=(2, 1)),
last_stage_pool=True),
encoder=dict(type='TFEncoder'),
decoder=dict(type='TFDecoder'),
encoder=dict(type='NRTREncoder'),
decoder=dict(type='NRTRDecoder'),
loss=dict(type='TFLoss'),
label_convertor=label_convertor,
max_seq_len=40)
@ -32,8 +32,6 @@ model = dict(
data = dict(
samples_per_gpu=128,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='UniformConcatDataset',
datasets=train_list,

View File

@ -23,8 +23,8 @@ model = dict(
channels=[32, 64, 128, 256, 512, 512],
stage4_pool_cfg=dict(kernel_size=(2, 1), stride=(2, 1)),
last_stage_pool=False),
encoder=dict(type='TFEncoder'),
decoder=dict(type='TFDecoder'),
encoder=dict(type='NRTREncoder'),
decoder=dict(type='NRTRDecoder'),
loss=dict(type='TFLoss'),
label_convertor=label_convertor,
max_seq_len=40)
@ -32,8 +32,6 @@ model = dict(
data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='UniformConcatDataset',
datasets=train_list,

View File

@ -28,7 +28,7 @@ model = dict(
d_inner=512 * 4,
dropout=0.1),
decoder=dict(
type='TFDecoder',
type='NRTRDecoder',
n_layers=6,
d_embedding=512,
n_head=8,

View File

@ -28,7 +28,7 @@ model = dict(
d_inner=256 * 4,
dropout=0.1),
decoder=dict(
type='TFDecoder',
type='NRTRDecoder',
n_layers=6,
d_embedding=256,
n_head=8,

View File

@ -68,12 +68,36 @@ def disable_text_recog_aug_test(cfg, set_types=None):
assert set_types is None or isinstance(set_types, list)
if set_types is None:
set_types = ['val', 'test']
warnings.simplefilter('once')
warning_msg = 'Remove "MultiRotateAugOCR" to support batch ' + \
'inference since samples_per_gpu > 1.'
for set_type in set_types:
if cfg.data[set_type].pipeline[1].type == 'MultiRotateAugOCR':
cfg.data[set_type].pipeline = [
cfg.data[set_type].pipeline[0],
*cfg.data[set_type].pipeline[1].transforms
]
dataset_type = cfg.data[set_type].type
if dataset_type in ['OCRDataset', 'OCRSegDataset']:
if cfg.data[set_type].pipeline[1].type == 'MultiRotateAugOCR':
warnings.warn(warning_msg)
cfg.data[set_type].pipeline = [
cfg.data[set_type].pipeline[0],
*cfg.data[set_type].pipeline[1].transforms
]
elif dataset_type in ['ConcatDataset', 'UniformConcatDataset']:
if dataset_type == 'UniformConcatDataset':
uniform_pipeline = cfg.data[set_type].pipeline
if uniform_pipeline is not None:
if uniform_pipeline[1].type == 'MultiRotateAugOCR':
warnings.warn(warning_msg)
cfg.data[set_type].pipeline = [
uniform_pipeline[0],
*uniform_pipeline[1].transforms
]
for dataset in cfg.data[set_type].datasets:
if dataset.pipeline is not None:
if dataset.pipeline[1].type == 'MultiRotateAugOCR':
warnings.warn(warning_msg)
dataset.pipeline = [
dataset.pipeline[0],
*dataset.pipeline[1].transforms
]
return cfg

View File

@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import backbones, losses
from . import backbones, layers, losses, modules
from .backbones import * # NOQA
from .losses import * # NOQA
from .layers import * # NOQA
from .modules import * # NOQA
__all__ = backbones.__all__ + losses.__all__
__all__ = backbones.__all__ + losses.__all__ + layers.__all__ + modules.__all__

View File

@ -0,0 +1,3 @@
from .transformer_layers import TFDecoderLayer, TFEncoderLayer
__all__ = ['TFEncoderLayer', 'TFDecoderLayer']

View File

@ -0,0 +1,167 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.runner import BaseModule
from mmocr.models.common.modules import (MultiHeadAttention,
PositionwiseFeedForward)
class TFEncoderLayer(BaseModule):
"""Transformer Encoder Layer.
Args:
d_model (int): The number of expected features
in the decoder inputs (default=512).
d_inner (int): The dimension of the feedforward
network model (default=256).
n_head (int): The number of heads in the
multiheadattention models (default=8).
d_k (int): Total number of features in key.
d_v (int): Total number of features in value.
dropout (float): Dropout layer on attn_output_weights.
qkv_bias (bool): Add bias in projection layer. Default: False.
act_cfg (dict): Activation cfg for feedforward module.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm')
or ('norm', 'self_attn', 'norm', 'ffn').
DefaultNone.
"""
def __init__(self,
d_model=512,
d_inner=256,
n_head=8,
d_k=64,
d_v=64,
dropout=0.1,
qkv_bias=False,
act_cfg=dict(type='mmcv.GELU'),
operation_order=None):
super().__init__()
self.attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout)
self.norm1 = nn.LayerNorm(d_model)
self.mlp = PositionwiseFeedForward(
d_model, d_inner, dropout=dropout, act_cfg=act_cfg)
self.norm2 = nn.LayerNorm(d_model)
self.operation_order = operation_order
if self.operation_order is None:
self.operation_order = ('norm', 'self_attn', 'norm', 'ffn')
assert self.operation_order in [('norm', 'self_attn', 'norm', 'ffn'),
('self_attn', 'norm', 'ffn', 'norm')]
def forward(self, x, mask=None):
if self.operation_order == ('self_attn', 'norm', 'ffn', 'norm'):
residual = x
x = residual + self.attn(x, x, x, mask)
x = self.norm1(x)
residual = x
x = residual + self.mlp(x)
x = self.norm2(x)
elif self.operation_order == ('norm', 'self_attn', 'norm', 'ffn'):
residual = x
x = self.norm1(x)
x = residual + self.attn(x, x, x, mask)
residual = x
x = self.norm2(x)
x = residual + self.mlp(x)
return x
class TFDecoderLayer(nn.Module):
"""Transformer Decoder Layer.
Args:
d_model (int): The number of expected features
in the decoder inputs (default=512).
d_inner (int): The dimension of the feedforward
network model (default=256).
n_head (int): The number of heads in the
multiheadattention models (default=8).
d_k (int): Total number of features in key.
d_v (int): Total number of features in value.
dropout (float): Dropout layer on attn_output_weights.
qkv_bias (bool): Add bias in projection layer. Default: False.
act_cfg (dict): Activation cfg for feedforward module.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'enc_dec_attn',
'norm', 'ffn', 'norm') or ('norm', 'self_attn', 'norm',
'enc_dec_attn', 'norm', 'ffn').
DefaultNone.
"""
def __init__(self,
d_model=512,
d_inner=256,
n_head=8,
d_k=64,
d_v=64,
dropout=0.1,
qkv_bias=False,
act_cfg=dict(type='mmcv.GELU'),
operation_order=None):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.self_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias)
self.enc_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias)
self.mlp = PositionwiseFeedForward(
d_model, d_inner, dropout=dropout, act_cfg=act_cfg)
self.operation_order = operation_order
if self.operation_order is None:
self.operation_order = ('norm', 'self_attn', 'norm',
'enc_dec_attn', 'norm', 'ffn')
assert self.operation_order in [
('norm', 'self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn'),
('self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn', 'norm')
]
def forward(self,
dec_input,
enc_output,
self_attn_mask=None,
dec_enc_attn_mask=None):
if self.operation_order == ('self_attn', 'norm', 'enc_dec_attn',
'norm', 'ffn', 'norm'):
dec_attn_out = self.self_attn(dec_input, dec_input, dec_input,
self_attn_mask)
dec_attn_out += dec_input
dec_attn_out = self.norm1(dec_attn_out)
enc_dec_attn_out = self.enc_attn(dec_attn_out, enc_output,
enc_output, dec_enc_attn_mask)
enc_dec_attn_out += dec_attn_out
enc_dec_attn_out = self.norm2(enc_dec_attn_out)
mlp_out = self.mlp(enc_dec_attn_out)
mlp_out += enc_dec_attn_out
mlp_out = self.norm3(mlp_out)
elif self.operation_order == ('norm', 'self_attn', 'norm',
'enc_dec_attn', 'norm', 'ffn'):
dec_input_norm = self.norm1(dec_input)
dec_attn_out = self.self_attn(dec_input_norm, dec_input_norm,
dec_input_norm, self_attn_mask)
dec_attn_out += dec_input
enc_dec_attn_in = self.norm2(dec_attn_out)
enc_dec_attn_out = self.enc_attn(enc_dec_attn_in, enc_output,
enc_output, dec_enc_attn_mask)
enc_dec_attn_out += dec_attn_out
mlp_out = self.mlp(self.norm3(enc_dec_attn_out))
mlp_out += enc_dec_attn_out
return mlp_out

View File

@ -0,0 +1,8 @@
from .transformer_module import (MultiHeadAttention, PositionalEncoding,
PositionwiseFeedForward,
ScaledDotProductAttention)
__all__ = [
'ScaledDotProductAttention', 'MultiHeadAttention',
'PositionwiseFeedForward', 'PositionalEncoding'
]

View File

@ -0,0 +1,155 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmocr.models.builder import build_activation_layer
class ScaledDotProductAttention(nn.Module):
"""Scaled Dot-Product Attention Module. This code is adopted from
https://github.com/jadore801120/attention-is-all-you-need-pytorch.
Args:
temperature (float): The scale factor for softmax input.
attn_dropout (float): Dropout layer on attn_output_weights.
"""
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
def forward(self, q, k, v, mask=None):
attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
if mask is not None:
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = self.dropout(F.softmax(attn, dim=-1))
output = torch.matmul(attn, v)
return output, attn
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention module.
Args:
n_head (int): The number of heads in the
multiheadattention models (default=8).
d_model (int): The number of expected features
in the decoder inputs (default=512).
d_k (int): Total number of features in key.
d_v (int): Total number of features in value.
dropout (float): Dropout layer on attn_output_weights.
qkv_bias (bool): Add bias in projection layer. Default: False.
"""
def __init__(self,
n_head=8,
d_model=512,
d_k=64,
d_v=64,
dropout=0.1,
qkv_bias=False):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.dim_k = n_head * d_k
self.dim_v = n_head * d_v
self.linear_q = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias)
self.linear_k = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias)
self.linear_v = nn.Linear(self.dim_v, self.dim_v, bias=qkv_bias)
self.attention = ScaledDotProductAttention(d_k**0.5, dropout)
self.fc = nn.Linear(self.dim_v, d_model, bias=qkv_bias)
self.proj_drop = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
batch_size, len_q, _ = q.size()
_, len_k, _ = k.size()
q = self.linear_q(q).view(batch_size, len_q, self.n_head, self.d_k)
k = self.linear_k(k).view(batch_size, len_k, self.n_head, self.d_k)
v = self.linear_v(v).view(batch_size, len_k, self.n_head, self.d_v)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
if mask is not None:
if mask.dim() == 3:
mask = mask.unsqueeze(1)
elif mask.dim() == 2:
mask = mask.unsqueeze(1).unsqueeze(1)
attn_out, _ = self.attention(q, k, v, mask=mask)
attn_out = attn_out.transpose(1, 2).contiguous().view(
batch_size, len_q, self.dim_v)
attn_out = self.fc(attn_out)
attn_out = self.proj_drop(attn_out)
return attn_out
class PositionwiseFeedForward(nn.Module):
"""Two-layer feed-forward module.
Args:
d_in (int): The dimension of the input for feedforward
network model.
d_hid (int): The dimension of the feedforward
network model.
dropout (float): Dropout layer on feedforward output.
act_cfg (dict): Activation cfg for feedforward module.
"""
def __init__(self, d_in, d_hid, dropout=0.1, act_cfg=dict(type='Relu')):
super().__init__()
self.w_1 = nn.Linear(d_in, d_hid)
self.w_2 = nn.Linear(d_hid, d_in)
self.act = build_activation_layer(act_cfg)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.w_1(x)
x = self.act(x)
x = self.w_2(x)
x = self.dropout(x)
return x
class PositionalEncoding(nn.Module):
"""Fixed positional encoding with sine and cosine functions."""
def __init__(self, d_hid=512, n_position=200):
super().__init__()
# Not a parameter
self.register_buffer(
'position_table',
self._get_sinusoid_encoding_table(n_position, d_hid))
def _get_sinusoid_encoding_table(self, n_position, d_hid):
"""Sinusoid position encoding table."""
denominator = torch.Tensor([
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
])
denominator = denominator.view(1, -1)
pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
sinusoid_table = pos_tensor * denominator
sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
return sinusoid_table.unsqueeze(0)
def forward(self, x):
self.device = x.device
return x + self.position_table[:, :x.size(1)].clone().detach()

View File

@ -10,7 +10,6 @@ class NRTRModalityTransform(BaseModule):
def __init__(self,
input_channels=3,
input_height=32,
init_cfg=[
dict(type='Kaiming', layer='Conv2d'),
dict(type='Uniform', layer='BatchNorm2d')
@ -35,9 +34,7 @@ class NRTRModalityTransform(BaseModule):
self.relu_2 = nn.ReLU(True)
self.bn_2 = nn.BatchNorm2d(64)
feat_height = input_height // 4
self.linear = nn.Linear(64 * feat_height, 512)
self.linear = nn.Linear(512, 512)
def forward(self, x):
x = self.conv_1(x)
@ -49,7 +46,11 @@ class NRTRModalityTransform(BaseModule):
x = self.bn_2(x)
n, c, h, w = x.size()
x = x.permute(0, 3, 2, 1).contiguous().view(n, w, h * c)
x = self.linear(x)
x = x.permute(0, 2, 1).contiguous().view(n, -1, 1, w)
return x

View File

@ -1,16 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_decoder import BaseDecoder
from .crnn_decoder import CRNNDecoder
from .nrtr_decoder import NRTRDecoder
from .position_attention_decoder import PositionAttentionDecoder
from .robust_scanner_decoder import RobustScannerDecoder
from .sar_decoder import ParallelSARDecoder, SequentialSARDecoder
from .sar_decoder_with_bs import ParallelSARDecoderWithBS
from .sequence_attention_decoder import SequenceAttentionDecoder
from .transformer_decoder import TFDecoder
__all__ = [
'CRNNDecoder', 'ParallelSARDecoder', 'SequentialSARDecoder',
'ParallelSARDecoderWithBS', 'TFDecoder', 'BaseDecoder',
'ParallelSARDecoderWithBS', 'NRTRDecoder', 'BaseDecoder',
'SequenceAttentionDecoder', 'PositionAttentionDecoder',
'RobustScannerDecoder'
]

View File

@ -7,14 +7,12 @@ import torch.nn.functional as F
from mmcv.runner import ModuleList
from mmocr.models.builder import DECODERS
from mmocr.models.textrecog.layers.transformer_layer import (
PositionalEncoding, TransformerDecoderLayer, get_pad_mask,
get_subsequent_mask)
from mmocr.models.common import PositionalEncoding, TFDecoderLayer
from .base_decoder import BaseDecoder
@DECODERS.register_module()
class TFDecoder(BaseDecoder):
class NRTRDecoder(BaseDecoder):
"""Transformer Decoder block with self attention mechanism.
Args:
@ -71,8 +69,8 @@ class TFDecoder(BaseDecoder):
self.dropout = nn.Dropout(p=dropout)
self.layer_stack = ModuleList([
TransformerDecoderLayer(
d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
TFDecoderLayer(
d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs)
for _ in range(n_layers)
])
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
@ -80,13 +78,29 @@ class TFDecoder(BaseDecoder):
pred_num_class = num_classes - 1 # ignore padding_idx
self.classifier = nn.Linear(d_model, pred_num_class)
@staticmethod
def get_pad_mask(seq, pad_idx):
return (seq != pad_idx).unsqueeze(-2)
@staticmethod
def get_subsequent_mask(seq):
"""For masking out the subsequent info."""
len_s = seq.size(1)
subsequent_mask = 1 - torch.triu(
torch.ones((len_s, len_s), device=seq.device), diagonal=1)
subsequent_mask = subsequent_mask.unsqueeze(0).bool()
return subsequent_mask
def _attention(self, trg_seq, src, src_mask=None):
trg_embedding = self.trg_word_emb(trg_seq)
trg_pos_encoded = self.position_enc(trg_embedding)
tgt = self.dropout(trg_pos_encoded)
trg_mask = get_pad_mask(
trg_seq, pad_idx=self.padding_idx) & get_subsequent_mask(trg_seq)
trg_mask = self.get_pad_mask(
trg_seq,
pad_idx=self.padding_idx) & self.get_subsequent_mask(trg_seq)
output = tgt
for dec_layer in self.layer_stack:
output = dec_layer(
@ -98,11 +112,27 @@ class TFDecoder(BaseDecoder):
return output
def _get_mask(self, logit, img_metas):
valid_ratios = None
if img_metas is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
]
N, T, _ = logit.size()
mask = None
if valid_ratios is not None:
mask = logit.new_zeros((N, T))
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(T, math.ceil(T * valid_ratio))
mask[i, :valid_width] = 1
return mask
def forward_train(self, feat, out_enc, targets_dict, img_metas):
r"""
Args:
feat (None): Unused.
out_enc (Tensor): Encoder output of shape :math:`(N, D_m, H, W)`
out_enc (Tensor): Encoder output of shape :math:`(N, T, D_m)`
where :math:`D_m` is ``d_model``.
targets_dict (dict): A dict with the key ``padded_targets``, a
tensor of shape :math:`(N, T)`. Each element is the index of a
@ -113,44 +143,17 @@ class TFDecoder(BaseDecoder):
Returns:
Tensor: The raw logit tensor. Shape :math:`(N, T, C)`.
"""
valid_ratios = None
if img_metas is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
]
n, c, h, w = out_enc.size()
src_mask = None
if valid_ratios is not None:
src_mask = out_enc.new_zeros((n, h, w))
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, math.ceil(w * valid_ratio))
src_mask[i, :, :valid_width] = 1
src_mask = src_mask.view(n, h * w)
out_enc = out_enc.view(n, c, h * w).permute(0, 2, 1)
out_enc = out_enc.contiguous()
src_mask = self._get_mask(out_enc, img_metas)
targets = targets_dict['padded_targets'].to(out_enc.device)
attn_output = self._attention(targets, out_enc, src_mask=src_mask)
outputs = self.classifier(attn_output)
return outputs
def forward_test(self, feat, out_enc, img_metas):
valid_ratios = None
if img_metas is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
]
n, c, h, w = out_enc.size()
src_mask = None
if valid_ratios is not None:
src_mask = out_enc.new_zeros((n, h, w))
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, math.ceil(w * valid_ratio))
src_mask[i, :, :valid_width] = 1
src_mask = src_mask.view(n, h * w)
out_enc = out_enc.view(n, c, h * w).permute(0, 2, 1)
out_enc = out_enc.contiguous()
init_target_seq = torch.full((n, self.max_seq_len + 1),
src_mask = self._get_mask(out_enc, img_metas)
N = out_enc.size(0)
init_target_seq = torch.full((N, self.max_seq_len + 1),
self.padding_idx,
device=out_enc.device,
dtype=torch.long)
@ -161,7 +164,7 @@ class TFDecoder(BaseDecoder):
for step in range(0, self.max_seq_len):
decoder_output = self._attention(
init_target_seq, out_enc, src_mask=src_mask)
# bsz * seq_len * 512
# bsz * seq_len * C
step_result = F.softmax(
self.classifier(decoder_output[:, step, :]), dim=-1)
# bsz * num_classes

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_encoder import BaseEncoder
from .channel_reduction_encoder import ChannelReductionEncoder
from .nrtr_encoder import NRTREncoder
from .sar_encoder import SAREncoder
from .satrn_encoder import SatrnEncoder
from .transformer_encoder import TFEncoder
__all__ = [
'SAREncoder', 'TFEncoder', 'BaseEncoder', 'ChannelReductionEncoder',
'SAREncoder', 'NRTREncoder', 'BaseEncoder', 'ChannelReductionEncoder',
'SatrnEncoder'
]

View File

@ -0,0 +1,87 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch.nn as nn
from mmcv.runner import ModuleList
from mmocr.models.builder import ENCODERS
from mmocr.models.common import TFEncoderLayer
from .base_encoder import BaseEncoder
@ENCODERS.register_module()
class NRTREncoder(BaseEncoder):
"""Transformer Encoder block with self attention mechanism.
Args:
n_layers (int): The number of sub-encoder-layers
in the encoder (default=6).
n_head (int): The number of heads in the
multiheadattention models (default=8).
d_k (int): Total number of features in key.
d_v (int): Total number of features in value.
d_model (int): The number of expected features
in the decoder inputs (default=512).
d_inner (int): The dimension of the feedforward
network model (default=256).
dropout (float): Dropout layer on attn_output_weights.
init_cfg (dict or list[dict], optional): Initialization configs.
"""
def __init__(self,
n_layers=6,
n_head=8,
d_k=64,
d_v=64,
d_model=512,
d_inner=256,
dropout=0.1,
init_cfg=None,
**kwargs):
super().__init__(init_cfg=init_cfg)
self.d_model = d_model
self.layer_stack = ModuleList([
TFEncoderLayer(
d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs)
for _ in range(n_layers)
])
self.layer_norm = nn.LayerNorm(d_model)
def _get_mask(self, logit, img_metas):
valid_ratios = None
if img_metas is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
]
N, T, _ = logit.size()
mask = None
if valid_ratios is not None:
mask = logit.new_zeros((N, T))
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(T, math.ceil(T * valid_ratio))
mask[i, :valid_width] = 1
return mask
def forward(self, feat, img_metas=None):
r"""
Args:
feat (Tensor): Backbone output of shape :math:`(N, C, H, W)`.
img_metas (dict): A dict that contains meta information of input
images. Preferably with the key ``valid_ratio``.
Returns:
Tensor: The encoder output tensor. Shape :math:`(N, T, C)`.
"""
n, c, h, w = feat.size()
feat = feat.view(n, c, h * w).permute(0, 2, 1).contiguous()
mask = self._get_mask(feat, img_metas)
output = feat
for enc_layer in self.layer_stack:
output = enc_layer(output, mask)
output = self.layer_norm(output)
return output

View File

@ -61,7 +61,7 @@ class SatrnEncoder(BaseEncoder):
images. Preferably with the key ``valid_ratio``.
Returns:
Tensor: A tensor of shape :math:`(N, D_m, H, W)`.
Tensor: A tensor of shape :math:`(N, T, D_m)`.
"""
valid_ratios = [1.0 for _ in range(feat.size(0))]
if img_metas is not None:
@ -82,7 +82,4 @@ class SatrnEncoder(BaseEncoder):
output = enc_layer(output, h, w, mask)
output = self.layer_norm(output)
output = output.permute(0, 2, 1).contiguous()
output = output.view(n, self.d_model, h, w)
return output

View File

@ -1,57 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch.nn as nn
from mmcv.runner import ModuleList
from mmocr.models.builder import ENCODERS
from mmocr.models.textrecog.layers import TransformerEncoderLayer
from .base_encoder import BaseEncoder
@ENCODERS.register_module()
class TFEncoder(BaseEncoder):
"""Encode 2d feature map to 1d sequence."""
def __init__(self,
n_layers=6,
n_head=8,
d_k=64,
d_v=64,
d_model=512,
d_inner=256,
dropout=0.1,
init_cfg=None,
**kwargs):
super().__init__(init_cfg=init_cfg)
self.d_model = d_model
self.layer_stack = ModuleList([
TransformerEncoderLayer(
d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)
])
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, feat, img_metas=None):
valid_ratios = [1.0 for _ in range(feat.size(0))]
if img_metas is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
]
n, c, h, w = feat.size()
mask = feat.new_zeros((n, h, w))
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, math.ceil(w * valid_ratio))
mask[i, :, :valid_width] = 1
mask = mask.view(n, h * w)
feat = feat.view(n, c, h * w)
output = feat.permute(0, 2, 1).contiguous()
for enc_layer in self.layer_stack:
output = enc_layer(output, mask)
output = self.layer_norm(output)
output = output.permute(0, 2, 1).contiguous()
output = output.view(n, self.d_model, h, w)
return output

View File

@ -4,17 +4,10 @@ from .dot_product_attention_layer import DotProductAttentionLayer
from .lstm_layer import BidirectionalLSTM
from .position_aware_layer import PositionAwareLayer
from .robust_scanner_fusion_layer import RobustScannerFusionLayer
from .transformer_layer import (Adaptive2DPositionalEncoding,
MultiHeadAttention, PositionalEncoding,
PositionwiseFeedForward, SatrnEncoderLayer,
TransformerDecoderLayer,
TransformerEncoderLayer, get_pad_mask,
get_subsequent_mask)
from .satrn_layers import Adaptive2DPositionalEncoding, SatrnEncoderLayer
__all__ = [
'BidirectionalLSTM', 'MultiHeadAttention', 'PositionalEncoding',
'Adaptive2DPositionalEncoding', 'PositionwiseFeedForward', 'BasicBlock',
'BidirectionalLSTM', 'Adaptive2DPositionalEncoding', 'BasicBlock',
'Bottleneck', 'RobustScannerFusionLayer', 'DotProductAttentionLayer',
'PositionAwareLayer', 'get_pad_mask', 'get_subsequent_mask',
'TransformerDecoderLayer', 'TransformerEncoderLayer', 'SatrnEncoderLayer'
'PositionAwareLayer', 'SatrnEncoderLayer'
]

View File

@ -0,0 +1,167 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from mmocr.models.common import MultiHeadAttention
class SatrnEncoderLayer(BaseModule):
""""""
def __init__(self,
d_model=512,
d_inner=512,
n_head=8,
d_k=64,
d_v=64,
dropout=0.1,
qkv_bias=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.norm1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout)
self.norm2 = nn.LayerNorm(d_model)
self.feed_forward = LocalityAwareFeedforward(
d_model, d_inner, dropout=dropout)
def forward(self, x, h, w, mask=None):
n, hw, c = x.size()
residual = x
x = self.norm1(x)
x = residual + self.attn(x, x, x, mask)
residual = x
x = self.norm2(x)
x = x.transpose(1, 2).contiguous().view(n, c, h, w)
x = self.feed_forward(x)
x = x.view(n, c, hw).transpose(1, 2)
x = residual + x
return x
class LocalityAwareFeedforward(BaseModule):
"""Locality-aware feedforward layer in SATRN, see `SATRN.
<https://arxiv.org/abs/1910.04396>`_
"""
def __init__(self,
d_in,
d_hid,
dropout=0.1,
init_cfg=[
dict(type='Xavier', layer='Conv2d'),
dict(type='Constant', layer='BatchNorm2d', val=1, bias=0)
]):
super().__init__(init_cfg=init_cfg)
self.conv1 = ConvModule(
d_in,
d_hid,
kernel_size=1,
padding=0,
bias=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'))
self.depthwise_conv = ConvModule(
d_hid,
d_hid,
kernel_size=3,
padding=1,
bias=False,
groups=d_hid,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'))
self.conv2 = ConvModule(
d_hid,
d_in,
kernel_size=1,
padding=0,
bias=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'))
def forward(self, x):
x = self.conv1(x)
x = self.depthwise_conv(x)
x = self.conv2(x)
return x
class Adaptive2DPositionalEncoding(BaseModule):
"""Implement Adaptive 2D positional encoder for SATRN, see
`SATRN <https://arxiv.org/abs/1910.04396>`_
Modified from https://github.com/Media-Smart/vedastr
Licensed under the Apache License, Version 2.0 (the "License");
Args:
d_hid (int): Dimensions of hidden layer.
n_height (int): Max height of the 2D feature output.
n_width (int): Max width of the 2D feature output.
dropout (int): Size of hidden layers of the model.
"""
def __init__(self,
d_hid=512,
n_height=100,
n_width=100,
dropout=0.1,
init_cfg=[dict(type='Xavier', layer='Conv2d')]):
super().__init__(init_cfg=init_cfg)
h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid)
h_position_encoder = h_position_encoder.transpose(0, 1)
h_position_encoder = h_position_encoder.view(1, d_hid, n_height, 1)
w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid)
w_position_encoder = w_position_encoder.transpose(0, 1)
w_position_encoder = w_position_encoder.view(1, d_hid, 1, n_width)
self.register_buffer('h_position_encoder', h_position_encoder)
self.register_buffer('w_position_encoder', w_position_encoder)
self.h_scale = self.scale_factor_generate(d_hid)
self.w_scale = self.scale_factor_generate(d_hid)
self.pool = nn.AdaptiveAvgPool2d(1)
self.dropout = nn.Dropout(p=dropout)
def _get_sinusoid_encoding_table(self, n_position, d_hid):
"""Sinusoid position encoding table."""
denominator = torch.Tensor([
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
])
denominator = denominator.view(1, -1)
pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
sinusoid_table = pos_tensor * denominator
sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
return sinusoid_table
def scale_factor_generate(self, d_hid):
scale_factor = nn.Sequential(
nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.ReLU(inplace=True),
nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.Sigmoid())
return scale_factor
def forward(self, x):
b, c, h, w = x.size()
avg_pool = self.pool(x)
h_pos_encoding = \
self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :]
w_pos_encoding = \
self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w]
out = x + h_pos_encoding + w_pos_encoding
out = self.dropout(out)
return out

View File

@ -1,399 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""This code is from https://github.com/jadore801120/attention-is-all-you-need-
pytorch."""
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
class TransformerEncoderLayer(nn.Module):
""""""
def __init__(self,
d_model=512,
d_inner=256,
n_head=8,
d_k=64,
d_v=64,
dropout=0.1,
qkv_bias=False,
mask_value=0,
act_layer=nn.GELU):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(
n_head,
d_model,
d_k,
d_v,
qkv_bias=qkv_bias,
dropout=dropout,
mask_value=mask_value)
self.norm2 = nn.LayerNorm(d_model)
self.mlp = PositionwiseFeedForward(
d_model, d_inner, dropout=dropout, act_layer=act_layer)
def forward(self, x, mask=None):
residual = x
x = self.norm1(x)
x = residual + self.attn(x, x, x, mask)
residual = x
x = self.norm2(x)
x = residual + self.mlp(x)
return x
class SatrnEncoderLayer(BaseModule):
""""""
def __init__(self,
d_model=512,
d_inner=512,
n_head=8,
d_k=64,
d_v=64,
dropout=0.1,
qkv_bias=False,
mask_value=0,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.norm1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(
n_head,
d_model,
d_k,
d_v,
qkv_bias=qkv_bias,
dropout=dropout,
mask_value=mask_value)
self.norm2 = nn.LayerNorm(d_model)
self.feed_forward = LocalityAwareFeedforward(
d_model, d_inner, dropout=dropout)
def forward(self, x, h, w, mask=None):
n, hw, c = x.size()
residual = x
x = self.norm1(x)
x = residual + self.attn(x, x, x, mask)
residual = x
x = self.norm2(x)
x = x.transpose(1, 2).contiguous().view(n, c, h, w)
x = self.feed_forward(x)
x = x.view(n, c, hw).transpose(1, 2)
x = residual + x
return x
class TransformerDecoderLayer(nn.Module):
def __init__(self,
d_model=512,
d_inner=256,
n_head=8,
d_k=64,
d_v=64,
dropout=0.1,
qkv_bias=False,
mask_value=0,
act_layer=nn.GELU):
super().__init__()
self.self_attn = MultiHeadAttention()
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.self_attn = MultiHeadAttention(
n_head,
d_model,
d_k,
d_v,
dropout=dropout,
qkv_bias=qkv_bias,
mask_value=mask_value)
self.enc_attn = MultiHeadAttention(
n_head,
d_model,
d_k,
d_v,
dropout=dropout,
qkv_bias=qkv_bias,
mask_value=mask_value)
self.mlp = PositionwiseFeedForward(
d_model, d_inner, dropout=dropout, act_layer=act_layer)
def forward(self,
dec_input,
enc_output,
self_attn_mask=None,
dec_enc_attn_mask=None):
self_attn_in = self.norm1(dec_input)
self_attn_out = self.self_attn(self_attn_in, self_attn_in,
self_attn_in, self_attn_mask)
enc_attn_in = dec_input + self_attn_out
enc_attn_q = self.norm2(enc_attn_in)
enc_attn_out = self.enc_attn(enc_attn_q, enc_output, enc_output,
dec_enc_attn_mask)
mlp_in = enc_attn_in + enc_attn_out
mlp_out = self.mlp(self.norm3(mlp_in))
out = mlp_in + mlp_out
return out
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention module."""
def __init__(self,
n_head=8,
d_model=512,
d_k=64,
d_v=64,
dropout=0.1,
qkv_bias=False,
mask_value=0):
super().__init__()
self.mask_value = mask_value
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.scale = d_k**-0.5
self.dim_k = n_head * d_k
self.dim_v = n_head * d_v
self.linear_q = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias)
self.linear_k = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias)
self.linear_v = nn.Linear(self.dim_v, self.dim_v, bias=qkv_bias)
self.fc = nn.Linear(self.dim_v, d_model, bias=qkv_bias)
self.attn_drop = nn.Dropout(dropout)
self.proj_drop = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
batch_size, len_q, _ = q.size()
_, len_k, _ = k.size()
q = self.linear_q(q).view(batch_size, len_q, self.n_head, self.d_k)
k = self.linear_k(k).view(batch_size, len_k, self.n_head, self.d_k)
v = self.linear_v(v).view(batch_size, len_k, self.n_head, self.d_v)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 3, 1)
v = v.permute(0, 2, 1, 3)
logits = torch.matmul(q, k) * self.scale
if mask is not None:
if mask.dim() == 3:
mask = mask.unsqueeze(1)
elif mask.dim() == 2:
mask = mask.unsqueeze(1).unsqueeze(1)
logits = logits.masked_fill(mask == self.mask_value, float('-inf'))
weights = logits.softmax(dim=-1)
weights = self.attn_drop(weights)
attn_out = torch.matmul(weights, v).transpose(1, 2)
attn_out = attn_out.reshape(batch_size, len_q, self.dim_v)
attn_out = self.fc(attn_out)
attn_out = self.proj_drop(attn_out)
return attn_out
class PositionwiseFeedForward(nn.Module):
"""A two-feed-forward-layer module."""
def __init__(self, d_in, d_hid, dropout=0.1, act_layer=nn.GELU):
super().__init__()
self.w_1 = nn.Linear(d_in, d_hid)
self.w_2 = nn.Linear(d_hid, d_in)
self.act = act_layer()
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.w_1(x)
x = self.act(x)
x = self.dropout(x)
x = self.w_2(x)
x = self.dropout(x)
return x
class LocalityAwareFeedforward(BaseModule):
"""Locality-aware feedforward layer in SATRN, see `SATRN.
<https://arxiv.org/abs/1910.04396>`_
"""
def __init__(self,
d_in,
d_hid,
dropout=0.1,
init_cfg=[
dict(type='Xavier', layer='Conv2d'),
dict(type='Constant', layer='BatchNorm2d', val=1, bias=0)
]):
super().__init__(init_cfg=init_cfg)
self.conv1 = ConvModule(
d_in,
d_hid,
kernel_size=1,
padding=0,
bias=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'))
self.depthwise_conv = ConvModule(
d_hid,
d_hid,
kernel_size=3,
padding=1,
bias=False,
groups=d_hid,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'))
self.conv2 = ConvModule(
d_hid,
d_in,
kernel_size=1,
padding=0,
bias=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'))
def forward(self, x):
x = self.conv1(x)
x = self.depthwise_conv(x)
x = self.conv2(x)
return x
class PositionalEncoding(nn.Module):
def __init__(self, d_hid=512, n_position=200):
super().__init__()
# Not a parameter
self.register_buffer(
'position_table',
self._get_sinusoid_encoding_table(n_position, d_hid))
def _get_sinusoid_encoding_table(self, n_position, d_hid):
"""Sinusoid position encoding table."""
denominator = torch.Tensor([
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
])
denominator = denominator.view(1, -1)
pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
sinusoid_table = pos_tensor * denominator
sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
return sinusoid_table.unsqueeze(0)
def forward(self, x):
self.device = x.device
return x + self.position_table[:, :x.size(1)].clone().detach()
class Adaptive2DPositionalEncoding(BaseModule):
"""Implement Adaptive 2D positional encoder for SATRN, see
`SATRN <https://arxiv.org/abs/1910.04396>`_
Modified from https://github.com/Media-Smart/vedastr
Licensed under the Apache License, Version 2.0 (the "License");
Args:
d_hid (int): Dimensions of hidden layer.
n_height (int): Max height of the 2D feature output.
n_width (int): Max width of the 2D feature output.
dropout (int): Size of hidden layers of the model.
"""
def __init__(self,
d_hid=512,
n_height=100,
n_width=100,
dropout=0.1,
init_cfg=[dict(type='Xavier', layer='Conv2d')]):
super().__init__(init_cfg=init_cfg)
h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid)
h_position_encoder = h_position_encoder.transpose(0, 1)
h_position_encoder = h_position_encoder.view(1, d_hid, n_height, 1)
w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid)
w_position_encoder = w_position_encoder.transpose(0, 1)
w_position_encoder = w_position_encoder.view(1, d_hid, 1, n_width)
self.register_buffer('h_position_encoder', h_position_encoder)
self.register_buffer('w_position_encoder', w_position_encoder)
self.h_scale = self.scale_factor_generate(d_hid)
self.w_scale = self.scale_factor_generate(d_hid)
self.pool = nn.AdaptiveAvgPool2d(1)
self.dropout = nn.Dropout(p=dropout)
def _get_sinusoid_encoding_table(self, n_position, d_hid):
"""Sinusoid position encoding table."""
denominator = torch.Tensor([
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
])
denominator = denominator.view(1, -1)
pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
sinusoid_table = pos_tensor * denominator
sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
return sinusoid_table
def scale_factor_generate(self, d_hid):
scale_factor = nn.Sequential(
nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.ReLU(inplace=True),
nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.Sigmoid())
return scale_factor
def forward(self, x):
b, c, h, w = x.size()
avg_pool = self.pool(x)
h_pos_encoding = \
self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :]
w_pos_encoding = \
self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w]
out = x + h_pos_encoding + w_pos_encoding
out = self.dropout(out)
return out
def get_pad_mask(seq, pad_idx):
return (seq != pad_idx).unsqueeze(-2)
def get_subsequent_mask(seq):
"""For masking out the subsequent info."""
len_s = seq.size(1)
subsequent_mask = 1 - torch.triu(
torch.ones((len_s, len_s), device=seq.device), diagonal=1)
subsequent_mask = subsequent_mask.unsqueeze(0).bool()
return subsequent_mask

View File

@ -282,12 +282,13 @@ class MMOCR:
},
'NRTR_1/16-1/8': {
'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py',
'ckpt': 'nrtr/nrtr_r31_academic_20210406-954db95e.pth'
'ckpt':
'nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth'
},
'NRTR_1/8-1/4': {
'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py',
'ckpt':
'nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth'
'nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth'
},
'RobustScanner': {
'config': 'robust_scanner/robustscanner_r31_academic.py',

View File

@ -1,10 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import pytest
from mmcv import Config
from mmcv.image import imread
from mmocr.apis.inference import init_detector, model_inference
from mmocr.apis.inference import (disable_text_recog_aug_test, init_detector,
model_inference)
from mmocr.datasets import build_dataset # noqa: F401
from mmocr.models import build_detector # noqa: F401
from mmocr.utils import revert_sync_batchnorm
@ -121,3 +124,26 @@ def test_model_batch_inference_empty_detection(cfg_file):
match='empty imgs provided, please check and try again'):
model_inference(model, empty_detection, batch_mode=True)
@pytest.mark.parametrize('cfg_file', [
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
])
def test_disable_text_recog_aug_test(cfg_file):
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
config_file = os.path.join(tmp_dir, cfg_file)
cfg = Config.fromfile(config_file)
cfg1 = copy.deepcopy(cfg)
test = cfg1.data.test.datasets[0]
test.pipeline = cfg1.data.test.pipeline
cfg1.data.test = test
disable_text_recog_aug_test(cfg1, set_types=['test'])
cfg2 = copy.deepcopy(cfg)
cfg2.data.test.pipeline = None
disable_text_recog_aug_test(cfg2, set_types=['test'])
cfg2 = copy.deepcopy(cfg)
cfg2.data.test = Config(dict(type='ConcatDataset', datasets=[test]))
disable_text_recog_aug_test(cfg2, set_types=['test'])

View File

@ -4,9 +4,10 @@ import math
import pytest
import torch
from mmocr.models.textrecog.decoders import (BaseDecoder, ParallelSARDecoder,
from mmocr.models.textrecog.decoders import (BaseDecoder, NRTRDecoder,
ParallelSARDecoder,
ParallelSARDecoderWithBS,
SequentialSARDecoder, TFDecoder)
SequentialSARDecoder)
from mmocr.models.textrecog.decoders.sar_decoder_with_bs import DecodeNode
@ -97,11 +98,11 @@ def test_parallel_sar_decoder_with_beam_search():
def test_transformer_decoder():
decoder = TFDecoder(num_classes=37, padding_idx=36, max_seq_len=5)
decoder = NRTRDecoder(num_classes=37, padding_idx=36, max_seq_len=5)
decoder.init_weights()
decoder.train()
out_enc = torch.rand(1, 512, 1, 25)
out_enc = torch.rand(1, 25, 512)
tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])}
img_metas = [{'valid_ratio': 1.0}]
tgt_dict['padded_targets'] = tgt_dict['padded_targets']

View File

@ -2,8 +2,8 @@
import pytest
import torch
from mmocr.models.textrecog.encoders import (BaseEncoder, SAREncoder,
SatrnEncoder, TFEncoder)
from mmocr.models.textrecog.encoders import (BaseEncoder, NRTREncoder,
SAREncoder, SatrnEncoder)
def test_sar_encoder():
@ -34,14 +34,14 @@ def test_sar_encoder():
def test_transformer_encoder():
tf_encoder = TFEncoder()
tf_encoder = NRTREncoder()
tf_encoder.init_weights()
tf_encoder.train()
feat = torch.randn(1, 512, 1, 25)
out_enc = tf_encoder(feat)
print('hello', out_enc.size())
assert out_enc.shape == torch.Size([1, 512, 1, 25])
assert out_enc.shape == torch.Size([1, 25, 512])
def test_satrn_encoder():
@ -51,7 +51,7 @@ def test_satrn_encoder():
feat = torch.randn(1, 512, 8, 25)
out_enc = satrn_encoder(feat)
assert out_enc.shape == torch.Size([1, 512, 8, 25])
assert out_enc.shape == torch.Size([1, 200, 512])
def test_base_encoder():

View File

@ -1,10 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmocr.models.textrecog.layers import (BasicBlock, Bottleneck,
PositionalEncoding,
TransformerDecoderLayer,
get_pad_mask, get_subsequent_mask)
from mmocr.models.common import (PositionalEncoding, TFDecoderLayer,
TFEncoderLayer)
from mmocr.models.textrecog.layers import BasicBlock, Bottleneck
from mmocr.models.textrecog.layers.conv_layer import conv3x3
@ -34,24 +33,31 @@ def test_conv_layer():
def test_transformer_layer():
# test decoder_layer
decoder_layer = TransformerDecoderLayer()
decoder_layer = TFDecoderLayer()
in_dec = torch.rand(1, 30, 512)
out_enc = torch.rand(1, 128, 512)
out_dec = decoder_layer(in_dec, out_enc)
assert out_dec.shape == torch.Size([1, 30, 512])
decoder_layer = TFDecoderLayer(
operation_order=('self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn',
'norm'))
out_dec = decoder_layer(in_dec, out_enc)
assert out_dec.shape == torch.Size([1, 30, 512])
# test positional_encoding
pos_encoder = PositionalEncoding()
x = torch.rand(1, 30, 512)
out = pos_encoder(x)
assert out.size() == x.size()
# test get pad mask
seq = torch.rand(1, 30)
pad_idx = 0
out = get_pad_mask(seq, pad_idx)
assert out.shape == torch.Size([1, 1, 30])
# test encoder_layer
encoder_layer = TFEncoderLayer()
in_enc = torch.rand(1, 20, 512)
out_enc = encoder_layer(in_enc)
assert out_dec.shape == torch.Size([1, 30, 512])
# test get_subsequent_mask
out_mask = get_subsequent_mask(seq)
assert out_mask.shape == torch.Size([1, 30, 30])
encoder_layer = TFEncoderLayer(
operation_order=('self_attn', 'norm', 'ffn', 'norm'))
out_enc = encoder_layer(in_enc)
assert out_dec.shape == torch.Size([1, 30, 512])

View File

@ -102,11 +102,19 @@ def _get_detector_cfg(fname):
@pytest.mark.parametrize('cfg_file', [
'textrecog/sar/sar_r31_parallel_decoder_academic.py',
'textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py',
'textrecog/sar/sar_r31_sequential_decoder_academic.py',
'textrecog/crnn/crnn_toy_dataset.py',
'textrecog/crnn/crnn_academic_dataset.py',
'textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py',
'textrecog/nrtr/nrtr_modality_transform_academic.py',
'textrecog/nrtr/nrtr_modality_transform_toy_dataset.py',
'textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py',
'textrecog/robust_scanner/robustscanner_r31_academic.py',
'textrecog/seg/seg_r31_1by16_fpnocr_academic.py',
'textrecog/satrn/satrn_academic.py'
'textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py',
'textrecog/satrn/satrn_academic.py', 'textrecog/satrn/satrn_small.py',
'textrecog/tps/crnn_tps_academic_dataset.py'
])
def test_recognizer_pipeline(cfg_file):
model = _get_detector_cfg(cfg_file)