mirror of https://github.com/open-mmlab/mmocr.git
[Model] Full ABINet Framework (#651)
Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com>pull/672/head
parent
a4237ad568
commit
9104667112
|
@ -62,6 +62,7 @@ Supported algorithms:
|
|||
<details open>
|
||||
<summary>Text Recognition</summary>
|
||||
|
||||
- [x] [ABINet](configs/textrecog/abinet/README.md) (CVPR'2021)
|
||||
- [x] [CRNN](configs/textrecog/crnn/README.md) (TPAMI'2016)
|
||||
- [x] [NRTR](configs/textrecog/nrtr/README.md) (ICDAR'2019)
|
||||
- [x] [RobustScanner](configs/textrecog/robust_scanner/README.md) (ECCV'2020)
|
||||
|
|
|
@ -62,6 +62,7 @@ MMOCR 是基于 PyTorch 和 mmdetection 的开源工具箱,专注于文本检
|
|||
<details open>
|
||||
<summary>文字识别</summary>
|
||||
|
||||
- [x] [ABINet](configs/textrecog/abinet/README.md) (CVPR'2021)
|
||||
- [x] [CRNN](configs/textrecog/crnn/README.md) (TPAMI'2016)
|
||||
- [x] [NRTR](configs/textrecog/nrtr/README.md) (ICDAR'2019)
|
||||
- [x] [RobustScanner](configs/textrecog/robust_scanner/README.md) (ECCV'2020)
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# Text Recognition Training set, including:
|
||||
# Synthetic Datasets: SynthText, Syn90k
|
||||
# Both annotations are filtered so that
|
||||
# only alphanumeric terms are left
|
||||
|
||||
train_root = 'data/mixture'
|
||||
|
||||
train_img_prefix1 = f'{train_root}/Syn90k/mnt/ramdisk/max/90kDICT32px'
|
||||
train_ann_file1 = f'{train_root}/Syn90k/label.lmdb'
|
||||
|
||||
train1 = dict(
|
||||
type='OCRDataset',
|
||||
img_prefix=train_img_prefix1,
|
||||
ann_file=train_ann_file1,
|
||||
loader=dict(
|
||||
type='LmdbLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=None,
|
||||
test_mode=False)
|
||||
|
||||
train_img_prefix2 = f'{train_root}/SynthText/' + \
|
||||
'synthtext/SynthText_patch_horizontal'
|
||||
train_ann_file2 = f'{train_root}/SynthText/alphanumeric_label.lmdb'
|
||||
|
||||
train2 = {key: value for key, value in train1.items()}
|
||||
train2['img_prefix'] = train_img_prefix2
|
||||
train2['ann_file'] = train_ann_file2
|
||||
|
||||
train_list = [train1, train2]
|
|
@ -0,0 +1,62 @@
|
|||
num_chars = 37
|
||||
max_seq_len = 26
|
||||
|
||||
label_convertor = dict(
|
||||
type='ABIConvertor',
|
||||
dict_type='DICT36',
|
||||
with_unknown=False,
|
||||
with_padding=False,
|
||||
lower=True,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type='ABINet',
|
||||
backbone=dict(type='ResNetABI'),
|
||||
encoder=dict(
|
||||
type='ABIVisionModel',
|
||||
encoder=dict(
|
||||
type='TransformerEncoder',
|
||||
n_layers=3,
|
||||
n_head=8,
|
||||
d_model=512,
|
||||
d_inner=2048,
|
||||
dropout=0.1,
|
||||
max_len=8 * 32,
|
||||
),
|
||||
decoder=dict(
|
||||
type='ABIVisionDecoder',
|
||||
in_channels=512,
|
||||
num_channels=64,
|
||||
attn_height=8,
|
||||
attn_width=32,
|
||||
attn_mode='nearest',
|
||||
use_result='feature',
|
||||
num_chars=num_chars,
|
||||
max_seq_len=max_seq_len,
|
||||
init_cfg=dict(type='Xavier', layer='Conv2d')),
|
||||
),
|
||||
decoder=dict(
|
||||
type='ABILanguageDecoder',
|
||||
d_model=512,
|
||||
n_head=8,
|
||||
d_inner=2048,
|
||||
n_layers=4,
|
||||
dropout=0.1,
|
||||
detach_tokens=True,
|
||||
use_self_attn=False,
|
||||
pad_idx=num_chars - 1,
|
||||
num_chars=num_chars,
|
||||
max_seq_len=max_seq_len,
|
||||
init_cfg=None),
|
||||
fuser=dict(
|
||||
type='ABIFuser',
|
||||
d_model=512,
|
||||
num_chars=num_chars,
|
||||
init_cfg=None,
|
||||
max_seq_len=max_seq_len,
|
||||
),
|
||||
loss=dict(
|
||||
type='ABILoss', enc_weight=1.0, dec_weight=1.0, fusion_weight=1.0),
|
||||
label_convertor=label_convertor,
|
||||
max_seq_len=max_seq_len,
|
||||
iter_size=3)
|
|
@ -0,0 +1,96 @@
|
|||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=32,
|
||||
min_width=128,
|
||||
max_width=128,
|
||||
keep_aspect_ratio=False,
|
||||
width_downsample_ratio=0.25),
|
||||
dict(
|
||||
type='RandomWrapper',
|
||||
p=0.5,
|
||||
transforms=[
|
||||
dict(
|
||||
type='OneOfWrapper',
|
||||
transforms=[
|
||||
dict(
|
||||
type='RandomRotateTextDet',
|
||||
max_angle=15,
|
||||
),
|
||||
dict(
|
||||
type='TorchVisionWrapper',
|
||||
op='RandomAffine',
|
||||
degrees=15,
|
||||
translate=(0.3, 0.3),
|
||||
scale=(0.5, 2.),
|
||||
shear=(-45, 45),
|
||||
),
|
||||
dict(
|
||||
type='TorchVisionWrapper',
|
||||
op='RandomPerspective',
|
||||
distortion_scale=0.5,
|
||||
p=1,
|
||||
),
|
||||
])
|
||||
],
|
||||
),
|
||||
dict(
|
||||
type='RandomWrapper',
|
||||
p=0.25,
|
||||
transforms=[
|
||||
dict(type='PyramidRescale'),
|
||||
dict(
|
||||
type='Albu',
|
||||
transforms=[
|
||||
dict(type='GaussNoise', var_limit=(20, 20), p=0.5),
|
||||
dict(type='MotionBlur', blur_limit=6, p=0.5),
|
||||
]),
|
||||
]),
|
||||
dict(
|
||||
type='RandomWrapper',
|
||||
p=0.25,
|
||||
transforms=[
|
||||
dict(
|
||||
type='TorchVisionWrapper',
|
||||
op='ColorJitter',
|
||||
brightness=0.5,
|
||||
saturation=0.5,
|
||||
contrast=0.5,
|
||||
hue=0.1),
|
||||
]),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio',
|
||||
'resize_shape'
|
||||
]),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiRotateAugOCR',
|
||||
rotate_degrees=[0, 90, 270],
|
||||
transforms=[
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=32,
|
||||
min_width=128,
|
||||
max_width=128,
|
||||
keep_aspect_ratio=False,
|
||||
width_downsample_ratio=0.25),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'valid_ratio',
|
||||
'resize_shape'
|
||||
]),
|
||||
])
|
||||
]
|
|
@ -0,0 +1,10 @@
|
|||
optimizer = dict(type='Adam', lr=1e-4)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
lr_config = dict(
|
||||
policy='step',
|
||||
step=[16, 18],
|
||||
warmup='linear',
|
||||
warmup_iters=1,
|
||||
warmup_ratio=0.001,
|
||||
warmup_by_epoch=True)
|
||||
total_epochs = 20
|
|
@ -0,0 +1,59 @@
|
|||
# Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition
|
||||
|
||||
## Abstract
|
||||
|
||||
<!-- [ABSTRACT] -->
|
||||
Linguistic knowledge is of great benefit to scene text recognition. However, how to effectively model linguistic rules in end-to-end deep networks remains a research challenge. In this paper, we argue that the limited capacity of language models comes from: 1) implicitly language modeling; 2) unidirectional feature representation; and 3) language model with noise input. Correspondingly, we propose an autonomous, bidirectional and iterative ABINet for scene text recognition. Firstly, the autonomous suggests to block gradient flow between vision and language models to enforce explicitly language modeling. Secondly, a novel bidirectional cloze network (BCN) as the language model is proposed based on bidirectional feature representation. Thirdly, we propose an execution manner of iterative correction for language model which can effectively alleviate the impact of noise input. Additionally, based on the ensemble of iterative predictions, we propose a self-training method which can learn from unlabeled images effectively. Extensive experiments indicate that ABINet has superiority on low-quality images and achieves state-of-the-art results on several mainstream benchmarks. Besides, the ABINet trained with ensemble self-training shows promising improvement in realizing human-level recognition.
|
||||
|
||||
<!-- [IMAGE] -->
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/22607038/145804331-9ae955dc-0d3b-41eb-a6b2-dc7c9f7c1bef.png"/>
|
||||
</div>
|
||||
|
||||
## Citation
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
```bibtex
|
||||
@article{fang2021read,
|
||||
title={Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition},
|
||||
author={Fang, Shancheng and Xie, Hongtao and Wang, Yuxin and Mao, Zhendong and Zhang, Yongdong},
|
||||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
## Dataset
|
||||
|
||||
### Train Dataset
|
||||
|
||||
| trainset | instance_num | repeat_num | note |
|
||||
| :-------: | :----------: | :--------: | :----------: |
|
||||
| Syn90k | 8919273 | 1 | synth |
|
||||
| SynthText | 7239272 | 1 | alphanumeric |
|
||||
|
||||
### Test Dataset
|
||||
|
||||
| testset | instance_num | note |
|
||||
| :-----: | :----------: | :-------: |
|
||||
| IIIT5K | 3000 | regular |
|
||||
| SVT | 647 | regular |
|
||||
| IC13 | 1015 | regular |
|
||||
| IC15 | 2077 | irregular |
|
||||
| SVTP | 645 | irregular |
|
||||
| CT80 | 288 | irregular |
|
||||
|
||||
## Results and models
|
||||
|
||||
| methods | pretrained | | Regular Text | | | Irregular Text | | download |
|
||||
| :----------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------: | :----: | :----------: | :--: | :--: | :------------: | :--: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| | | IIIT5K | SVT | IC13 | IC15 | SVTP | CT80 | |
|
||||
| [ABINet-Vision](https://github.com/open-mmlab/mmocr/tree/master/configs/textrecog/abinet/abinet_vision_only_academic.py) | - | 94.7 | 91.7 | 93.6 | 83.0 | 85.1 | 86.5 | [model](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_vision_only_academic-e6b9ea89.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/abinet/20211201_195512.log) |
|
||||
| [ABINet](https://github.com/open-mmlab/mmocr/tree/master/configs/textrecog/abinet/abinet_academic.py) | [Pretrained](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_pretrain-1bed979b.pth) | 95.7 | 94.6 | 95.7 | 85.1 | 90.4 | 90.3 | [model](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_academic-f718abf6.pth) \| [log1](https://download.openmmlab.com/mmocr/textrecog/abinet/20211210_095832.log) \| [log2](https://download.openmmlab.com/mmocr/textrecog/abinet/20211213_131724.log) |
|
||||
|
||||
:::{note}
|
||||
1. ABINet allows its encoder to run and be trained without decoder and fuser. Its encoder is designed to recognize texts as a stand-alone model and therefore can work as an independent text recognizer. We release it as ABINet-Vision.
|
||||
2. Facts about the pretrained model: MMOCR does not have a systematic pipeline to pretrain the language model (LM) yet, thus the weights of LM are converted from [the official pretrained model](https://github.com/FangShancheng/ABINet). The weights of ABINet-Vision are directly used as the vision model of ABINet.
|
||||
3. Due to some technical issues, the training process of ABINet was interrupted at the 13th epoch and we resumed it later. Both logs are released for full reference.
|
||||
4. The model architecture in the logs looks slightly different from the final released version, since it was refactored afterward. However, both architectures are essentially equivalent.
|
||||
:::
|
|
@ -0,0 +1,34 @@
|
|||
_base_ = [
|
||||
'../../_base_/default_runtime.py',
|
||||
'../../_base_/schedules/schedule_adam_step_20e.py',
|
||||
'../../_base_/recog_pipelines/abinet_pipeline.py',
|
||||
'../../_base_/recog_models/abinet.py',
|
||||
'../../_base_/recog_datasets/ST_MJ_alphanumeric_train.py',
|
||||
'../../_base_/recog_datasets/academic_test.py'
|
||||
]
|
||||
|
||||
train_list = {{_base_.train_list}}
|
||||
test_list = {{_base_.test_list}}
|
||||
|
||||
train_pipeline = {{_base_.train_pipeline}}
|
||||
test_pipeline = {{_base_.test_pipeline}}
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=192,
|
||||
workers_per_gpu=8,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type='UniformConcatDataset',
|
||||
datasets=train_list,
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type='UniformConcatDataset',
|
||||
datasets=test_list,
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type='UniformConcatDataset',
|
||||
datasets=test_list,
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=1, metric='acc')
|
|
@ -0,0 +1,76 @@
|
|||
_base_ = [
|
||||
'../../_base_/default_runtime.py',
|
||||
'../../_base_/schedules/schedule_adam_step_20e.py',
|
||||
'../../_base_/recog_pipelines/abinet_pipeline.py',
|
||||
'../../_base_/recog_datasets/ST_MJ_alphanumeric_train.py',
|
||||
'../../_base_/recog_datasets/academic_test.py'
|
||||
]
|
||||
|
||||
train_list = {{_base_.train_list}}
|
||||
test_list = {{_base_.test_list}}
|
||||
|
||||
train_pipeline = {{_base_.train_pipeline}}
|
||||
test_pipeline = {{_base_.test_pipeline}}
|
||||
|
||||
# Model
|
||||
num_chars = 37
|
||||
max_seq_len = 26
|
||||
label_convertor = dict(
|
||||
type='ABIConvertor',
|
||||
dict_type='DICT36',
|
||||
with_unknown=False,
|
||||
with_padding=False,
|
||||
lower=True,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type='ABINet',
|
||||
backbone=dict(type='ResNetABI'),
|
||||
encoder=dict(
|
||||
type='ABIVisionModel',
|
||||
encoder=dict(
|
||||
type='TransformerEncoder',
|
||||
n_layers=3,
|
||||
n_head=8,
|
||||
d_model=512,
|
||||
d_inner=2048,
|
||||
dropout=0.1,
|
||||
max_len=8 * 32,
|
||||
),
|
||||
decoder=dict(
|
||||
type='ABIVisionDecoder',
|
||||
in_channels=512,
|
||||
num_channels=64,
|
||||
attn_height=8,
|
||||
attn_width=32,
|
||||
attn_mode='nearest',
|
||||
use_result='feature',
|
||||
num_chars=num_chars,
|
||||
max_seq_len=max_seq_len,
|
||||
init_cfg=dict(type='Xavier', layer='Conv2d')),
|
||||
),
|
||||
loss=dict(
|
||||
type='ABILoss', enc_weight=1.0, dec_weight=1.0, fusion_weight=1.0),
|
||||
label_convertor=label_convertor,
|
||||
max_seq_len=max_seq_len,
|
||||
iter_size=1)
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=192,
|
||||
workers_per_gpu=8,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type='UniformConcatDataset',
|
||||
datasets=train_list,
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type='UniformConcatDataset',
|
||||
datasets=test_list,
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type='UniformConcatDataset',
|
||||
datasets=test_list,
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=1, metric='acc')
|
|
@ -0,0 +1,87 @@
|
|||
Collections:
|
||||
- Name: ABINet
|
||||
Metadata:
|
||||
Training Data: OCRDataset
|
||||
Training Techniques:
|
||||
- Adam
|
||||
Epochs: 20
|
||||
Batch Size: 1536
|
||||
Training Resources: 8x Tesla V100
|
||||
Architecture:
|
||||
- ResNetABI
|
||||
- ABIVisionModel
|
||||
- ABILanguageDecoder
|
||||
- ABIFuser
|
||||
Paper:
|
||||
URL: https://arxiv.org/pdf/2103.06495.pdf
|
||||
Title: 'Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition'
|
||||
README: configs/textrecog/abinet/README.md
|
||||
|
||||
Models:
|
||||
- Name: abinet_vision_only_academic
|
||||
In Collection: ABINet
|
||||
Config: configs/textrecog/abinet/abinet_vision_only_academic.py
|
||||
Metadata:
|
||||
Training Data:
|
||||
- SynthText
|
||||
- Syn90k
|
||||
Results:
|
||||
- Task: Text Recognition
|
||||
Dataset: IIIT5K
|
||||
Metrics:
|
||||
word_acc: 94.7
|
||||
- Task: Text Recognition
|
||||
Dataset: SVT
|
||||
Metrics:
|
||||
word_acc: 91.7
|
||||
- Task: Text Recognition
|
||||
Dataset: ICDAR2013
|
||||
Metrics:
|
||||
word_acc: 93.6
|
||||
- Task: Text Recognition
|
||||
Dataset: ICDAR2015
|
||||
Metrics:
|
||||
word_acc: 83.0
|
||||
- Task: Text Recognition
|
||||
Dataset: SVTP
|
||||
Metrics:
|
||||
word_acc: 85.1
|
||||
- Task: Text Recognition
|
||||
Dataset: CT80
|
||||
Metrics:
|
||||
word_acc: 86.5
|
||||
Weights: https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_vision_only_academic-e6b9ea89.pth
|
||||
|
||||
- Name: abinet_academic
|
||||
In Collection: ABINet
|
||||
Config: configs/textrecog/abinet/abinet_academic.py
|
||||
Metadata:
|
||||
Training Data:
|
||||
- SynthText
|
||||
- Syn90k
|
||||
Results:
|
||||
- Task: Text Recognition
|
||||
Dataset: IIIT5K
|
||||
Metrics:
|
||||
word_acc: 95.7
|
||||
- Task: Text Recognition
|
||||
Dataset: SVT
|
||||
Metrics:
|
||||
word_acc: 94.6
|
||||
- Task: Text Recognition
|
||||
Dataset: ICDAR2013
|
||||
Metrics:
|
||||
word_acc: 95.7
|
||||
- Task: Text Recognition
|
||||
Dataset: ICDAR2015
|
||||
Metrics:
|
||||
word_acc: 85.1
|
||||
- Task: Text Recognition
|
||||
Dataset: SVTP
|
||||
Metrics:
|
||||
word_acc: 90.4
|
||||
- Task: Text Recognition
|
||||
Dataset: CT80
|
||||
Metrics:
|
||||
word_acc: 90.3
|
||||
Weights: https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_academic-f718abf6.pth
|
|
@ -204,6 +204,7 @@ means that `batch_mode` and `print_result` are set to `True`)
|
|||
|
||||
| Name | Reference | `batch_mode` inference support |
|
||||
| ------------- | :--------------------------------------------------------------------------------------------------------------------------------: | :------------------: |
|
||||
| ABINet | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#read-like-humans-autonomous-bidirectional-and-iterative-language-modeling-for-scene-text-recognition) | :heavy_check_mark: |
|
||||
| CRNN | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#an-end-to-end-trainable-neural-network-for-image-based-sequence-recognition-and-its-application-to-scene-text-recognition) | :x: |
|
||||
| SAR | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#show-attend-and-read-a-simple-and-strong-baseline-for-irregular-text-recognition) | :heavy_check_mark: |
|
||||
| SAR_CN * | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#show-attend-and-read-a-simple-and-strong-baseline-for-irregular-text-recognition) | :heavy_check_mark: |
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Text Recognition
|
||||
Text Recognition
|
||||
|
||||
## Overview
|
||||
|
||||
|
@ -43,6 +43,7 @@
|
|||
│ │ ├── label.lmdb
|
||||
│ │ ├── mnt
|
||||
│ ├── SynthText
|
||||
│ │ ├── alphanumeric_labels.txt
|
||||
│ │ ├── shuffle_labels.txt
|
||||
│ │ ├── instances_train.txt
|
||||
│ │ ├── label.txt
|
||||
|
@ -86,7 +87,7 @@
|
|||
| svt |[homepage](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset) | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svt/test_label.txt) | |
|
||||
| svtp | [unofficial homepage\[1\]](https://github.com/Jyouhou/Case-Sensitive-Scene-Text-Recognition-Datasets) | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svtp/test_label.txt) | |
|
||||
| MJSynth (Syn90k) | [homepage](https://www.robots.ox.ac.uk/~vgg/data/text/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/Syn90k/shuffle_labels.txt) \| [label.txt](https://download.openmmlab.com/mmocr/data/mixture/Syn90k/label.txt) | - | |
|
||||
| SynthText (Synth800k) | [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) \| [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt) \| [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/label.txt) | - | |
|
||||
| SynthText (Synth800k) | [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | [alphanumeric_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/alphanumeric_labels.txt) \|[shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) \| [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt) \| [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/label.txt) | - | |
|
||||
| SynthAdd | [SynthText_Add.zip](https://pan.baidu.com/s/1uV0LtoNmcxbO-0YA7Ch4dg) (code:627x) | [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/label.txt) | - | |
|
||||
| TextOCR | [homepage](https://textvqa.org/textocr/dataset) | - | - | |
|
||||
| Totaltext | [homepage](https://github.com/cs-chan/Total-Text-Dataset) | - | - | |
|
||||
|
@ -150,9 +151,14 @@
|
|||
### SynthText (Synth800k)
|
||||
- Step1: Download `SynthText.zip` from [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/)
|
||||
|
||||
- Step2: Download [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/label.txt) (7,266,686 annotations) and [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) (2,400,000 randomly sampled annotations). **Please make sure you're using the right annotation to train the model by checking its dataset specs in Model Zoo.**
|
||||
- Step2: According to your actual needs, download the most appropriate one from the following options: [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/label.txt) (7,266,686 annotations), [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) (2,400,000 randomly sampled annotations), [alphanumeric_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/alphanumeric_labels.txt) (7,239,272 annotations with alphanumeric characters only) and [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt) (7,266,686 character-level annotations).
|
||||
|
||||
:::{warning}
|
||||
Please make sure you're using the right annotation to train the model by checking its dataset specs in Model Zoo.
|
||||
:::
|
||||
|
||||
- Step3:
|
||||
|
||||
```bash
|
||||
mkdir SynthText && cd SynthText
|
||||
mv /path/to/SynthText.zip .
|
||||
|
@ -161,11 +167,14 @@ mv SynthText synthtext
|
|||
|
||||
mv /path/to/shuffle_labels.txt .
|
||||
mv /path/to/label.txt .
|
||||
mv /path/to/alphanumeric_labels.txt .
|
||||
mv /path/to/instances_train.txt .
|
||||
|
||||
# create soft link
|
||||
cd /path/to/mmocr/data/mixture
|
||||
ln -s /path/to/SynthText SynthText
|
||||
```
|
||||
|
||||
- Step4:
|
||||
Generate cropped images and labels:
|
||||
|
||||
|
|
|
@ -43,6 +43,7 @@
|
|||
│ │ ├── label.lmdb
|
||||
│ │ ├── mnt
|
||||
│ ├── SynthText
|
||||
│ │ ├── alphanumeric_labels.txt
|
||||
│ │ ├── shuffle_labels.txt
|
||||
│ │ ├── instances_train.txt
|
||||
│ │ ├── label.txt
|
||||
|
@ -86,7 +87,7 @@
|
|||
| svt |[下载地址](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset) | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svt/test_label.txt) | |
|
||||
| svtp | [非官方下载地址*](https://github.com/Jyouhou/Case-Sensitive-Scene-Text-Recognition-Datasets) | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svtp/test_label.txt) | |
|
||||
| MJSynth (Syn90k) | [下载地址](https://www.robots.ox.ac.uk/~vgg/data/text/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/Syn90k/shuffle_labels.txt) \| [label.txt](https://download.openmmlab.com/mmocr/data/mixture/Syn90k/label.txt) | - | |
|
||||
| SynthText (Synth800k) | [下载地址](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) \| [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt) \| [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/label.txt) | - | |
|
||||
| SynthText (Synth800k) | [下载地址](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) |[alphanumeric_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/alphanumeric_labels.txt) \| [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) \| [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt) \| [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/label.txt) | - | |
|
||||
| SynthAdd | [SynthText_Add.zip](https://pan.baidu.com/s/1uV0LtoNmcxbO-0YA7Ch4dg) (code:627x) | [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/label.txt) | - | |
|
||||
| TextOCR | [下载地址](https://textvqa.org/textocr/dataset) | - | - | |
|
||||
| Totaltext | [下载地址](https://github.com/cs-chan/Total-Text-Dataset) | - | - | |
|
||||
|
@ -148,8 +149,11 @@ python tools/data/textrecog/svt_converter.py <download_svt_dir_path>
|
|||
```
|
||||
|
||||
### SynthText (Synth800k)
|
||||
- 第一步: 从 [下载地址](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) 下载 `SynthText.zip`
|
||||
- 第二步:
|
||||
- 第一步:下载 `SynthText.zip`: [下载地址](https://www.robots.ox.ac.uk/~vgg/data/scenetext/)
|
||||
|
||||
- 第二步:请根据你的实际需要,从下列标注中选择最适合的下载:[label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/label.txt) (7,266,686个标注); [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) (2,400,000个随机采样的标注);[alphanumeric_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/alphanumeric_labels.txt) (7,239,272个仅包含数字和字母的标注);[instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt) (7,266,686个字符级别的标注)。
|
||||
|
||||
- 第三步:
|
||||
|
||||
```bash
|
||||
mkdir SynthText && cd SynthText
|
||||
|
@ -158,14 +162,16 @@ python tools/data/textrecog/svt_converter.py <download_svt_dir_path>
|
|||
mv SynthText synthtext
|
||||
|
||||
mv /path/to/shuffle_labels.txt .
|
||||
mv /path/to/label.txt .
|
||||
mv /path/to/label.txt .
|
||||
mv /path/to/alphanumeric_labels.txt .
|
||||
mv /path/to/instances_train.txt .
|
||||
|
||||
# 创建软链接
|
||||
cd /path/to/mmocr/data/mixture
|
||||
ln -s /path/to/SynthText SynthText
|
||||
```
|
||||
- 第三步:
|
||||
生成裁剪后的图像和标注:
|
||||
|
||||
- 第四步:生成裁剪后的图像和标注:
|
||||
|
||||
```bash
|
||||
cd /path/to/mmocr
|
||||
|
|
|
@ -12,10 +12,11 @@ from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
|
|||
from .test_time_aug import MultiRotateAugOCR
|
||||
from .textdet_targets import (DBNetTargets, FCENetTargets, PANetTargets,
|
||||
TextSnakeTargets)
|
||||
from .transforms import (ColorJitter, RandomCropFlip, RandomCropInstances,
|
||||
RandomCropPolyInstances, RandomRotatePolyInstances,
|
||||
RandomRotateTextDet, RandomScaling, ScaleAspectJitter,
|
||||
SquareResizePad)
|
||||
from .transform_wrappers import OneOfWrapper, RandomWrapper, TorchVisionWrapper
|
||||
from .transforms import (ColorJitter, PyramidRescale, RandomCropFlip,
|
||||
RandomCropInstances, RandomCropPolyInstances,
|
||||
RandomRotatePolyInstances, RandomRotateTextDet,
|
||||
RandomScaling, ScaleAspectJitter, SquareResizePad)
|
||||
|
||||
__all__ = [
|
||||
'LoadTextAnnotations', 'NormalizeOCR', 'OnlineCropOCR', 'ResizeOCR',
|
||||
|
@ -27,5 +28,6 @@ __all__ = [
|
|||
'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets',
|
||||
'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8', 'FCENetTargets',
|
||||
'RandomScaling', 'RandomCropFlip', 'NerTransform', 'ToTensorNER',
|
||||
'ResizeNoImg'
|
||||
'ResizeNoImg', 'PyramidRescale', 'OneOfWrapper', 'RandomWrapper',
|
||||
'TorchVisionWrapper'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import inspect
|
||||
import random
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torchvision.transforms as torchvision_transforms
|
||||
from mmcv.utils import build_from_cfg
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from mmdet.datasets.pipelines import Compose
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class OneOfWrapper:
|
||||
"""Randomly select and apply one of the transforms, each with the equal
|
||||
chance.
|
||||
|
||||
Warning:
|
||||
Different from albumentations, this wrapper only runs the selected
|
||||
transform, but doesn't guarantee the transform can always be applied to
|
||||
the input if the transform comes with a probability to run.
|
||||
|
||||
Args:
|
||||
transforms (list[dict|callable]): Candidate transforms to be applied.
|
||||
"""
|
||||
|
||||
def __init__(self, transforms):
|
||||
assert isinstance(transforms, list) or isinstance(transforms, tuple)
|
||||
assert len(transforms) > 0, 'Need at least one transform.'
|
||||
self.transforms = []
|
||||
for t in transforms:
|
||||
if isinstance(t, dict):
|
||||
self.transforms.append(build_from_cfg(t, PIPELINES))
|
||||
elif callable(t):
|
||||
self.transforms.append(t)
|
||||
else:
|
||||
raise TypeError('transform must be callable or a dict')
|
||||
|
||||
def __call__(self, results):
|
||||
return random.choice(self.transforms)(results)
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(transforms={self.transforms})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RandomWrapper:
|
||||
"""Run a transform or a sequence of transforms with probability p.
|
||||
|
||||
Args:
|
||||
transforms (list[dict|callable]): Transform(s) to be applied.
|
||||
p (int|float): Probability of running transform(s).
|
||||
"""
|
||||
|
||||
def __init__(self, transforms, p):
|
||||
assert 0 <= p <= 1
|
||||
self.transforms = Compose(transforms)
|
||||
self.p = p
|
||||
|
||||
def __call__(self, results):
|
||||
return results if np.random.uniform() > self.p else self.transforms(
|
||||
results)
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(transforms={self.transforms}, '
|
||||
repr_str += f'p={self.p})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class TorchVisionWrapper:
|
||||
"""A wrapper of torchvision trasnforms. It applies specific transform to
|
||||
``img`` and updates ``img_shape`` accordingly.
|
||||
|
||||
Warning:
|
||||
This transform only affects the image but not its associated
|
||||
annotations, such as word bounding boxes and polygon masks. Therefore,
|
||||
it may only be applicable to text recognition tasks.
|
||||
|
||||
Args:
|
||||
op (str): The name of any transform class in
|
||||
:func:`torchvision.transforms`.
|
||||
**kwargs: Arguments that will be passed to initializer of torchvision
|
||||
transform.
|
||||
|
||||
:Required Keys:
|
||||
- | ``img`` (ndarray): The input image.
|
||||
|
||||
:Affected Keys:
|
||||
:Modified:
|
||||
- | ``img`` (ndarray): The modified image.
|
||||
:Added:
|
||||
- | ``img_shape`` (tuple(int)): Size of the modified image.
|
||||
"""
|
||||
|
||||
def __init__(self, op, **kwargs):
|
||||
assert type(op) is str
|
||||
|
||||
if mmcv.is_str(op):
|
||||
obj_cls = getattr(torchvision_transforms, op)
|
||||
elif inspect.isclass(op):
|
||||
obj_cls = op
|
||||
else:
|
||||
raise TypeError(
|
||||
f'type must be a str or valid type, but got {type(type)}')
|
||||
self.transform = obj_cls(**kwargs)
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self, results):
|
||||
assert 'img' in results
|
||||
# BGR -> RGB
|
||||
img = results['img'][..., ::-1]
|
||||
img = Image.fromarray(img)
|
||||
img = self.transform(img)
|
||||
img = np.asarray(img)
|
||||
img = img[..., ::-1]
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(transform={self.transform})'
|
||||
return repr_str
|
|
@ -967,3 +967,54 @@ class RandomCropFlip:
|
|||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
return h_axis, w_axis
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class PyramidRescale:
|
||||
"""Resize the image to the base shape, downsample it with gaussian pyramid,
|
||||
and rescale it back to original size.
|
||||
|
||||
Adapted from https://github.com/FangShancheng/ABINet.
|
||||
|
||||
Args:
|
||||
factor (int): The decay factor from base size, or the number of
|
||||
downsampling operations from the base layer.
|
||||
base_shape (tuple(int)): The shape of the base layer of the pyramid.
|
||||
randomize_factor (bool): If True, the final factor would be a random
|
||||
integer in [0, factor].
|
||||
|
||||
:Required Keys:
|
||||
- | ``img`` (ndarray): The input image.
|
||||
|
||||
:Affected Keys:
|
||||
:Modified:
|
||||
- | ``img`` (ndarray): The modified image.
|
||||
"""
|
||||
|
||||
def __init__(self, factor=4, base_shape=(128, 512), randomize_factor=True):
|
||||
assert isinstance(factor, int)
|
||||
assert isinstance(base_shape, list) or isinstance(base_shape, tuple)
|
||||
assert len(base_shape) == 2
|
||||
assert isinstance(randomize_factor, bool)
|
||||
self.factor = factor if not randomize_factor else np.random.randint(
|
||||
0, factor + 1)
|
||||
self.base_w, self.base_h = base_shape
|
||||
|
||||
def __call__(self, results):
|
||||
assert 'img' in results
|
||||
if self.factor == 0:
|
||||
return results
|
||||
img = results['img']
|
||||
src_h, src_w = img.shape[:2]
|
||||
scale_img = mmcv.imresize(img, (self.base_w, self.base_h))
|
||||
for _ in range(self.factor):
|
||||
scale_img = cv2.pyrDown(scale_img)
|
||||
scale_img = mmcv.imresize(scale_img, (src_w, src_h))
|
||||
results['img'] = scale_img
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(factor={self.factor}, '
|
||||
repr_str += f'basew={self.basew}, baseh={self.baseh})'
|
||||
return repr_str
|
||||
|
|
|
@ -21,6 +21,7 @@ DETECTORS = BACKBONES
|
|||
ROI_EXTRACTORS = BACKBONES
|
||||
HEADS = BACKBONES
|
||||
NECKS = BACKBONES
|
||||
FUSERS = BACKBONES
|
||||
|
||||
ACTIVATION_LAYERS = Registry('activation layer', parent=MMCV_ACTIVATION_LAYERS)
|
||||
|
||||
|
@ -81,6 +82,11 @@ def build_neck(cfg):
|
|||
return NECKS.build(cfg)
|
||||
|
||||
|
||||
def build_fuser(cfg):
|
||||
"""Build fuser."""
|
||||
return FUSERS.build(cfg)
|
||||
|
||||
|
||||
def build_upsample_layer(cfg, *args, **kwargs):
|
||||
"""Build upsample layer.
|
||||
|
||||
|
|
|
@ -129,9 +129,12 @@ class PositionwiseFeedForward(nn.Module):
|
|||
class PositionalEncoding(nn.Module):
|
||||
"""Fixed positional encoding with sine and cosine functions."""
|
||||
|
||||
def __init__(self, d_hid=512, n_position=200):
|
||||
def __init__(self, d_hid=512, n_position=200, dropout=0):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Not a parameter
|
||||
# Position table of shape (1, n_position, d_hid)
|
||||
self.register_buffer(
|
||||
'position_table',
|
||||
self._get_sinusoid_encoding_table(n_position, d_hid))
|
||||
|
@ -151,5 +154,10 @@ class PositionalEncoding(nn.Module):
|
|||
return sinusoid_table.unsqueeze(0)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): Tensor of shape (batch_size, pos_len, d_hid, ...)
|
||||
"""
|
||||
self.device = x.device
|
||||
return x + self.position_table[:, :x.size(1)].clone().detach()
|
||||
x = x + self.position_table[:, :x.size(1)].clone().detach()
|
||||
return self.dropout(x)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import (backbones, convertors, decoders, encoders, heads, losses, necks,
|
||||
preprocessor, recognizer)
|
||||
from . import (backbones, convertors, decoders, encoders, fusers, heads,
|
||||
losses, necks, preprocessor, recognizer)
|
||||
|
||||
from .backbones import * # NOQA
|
||||
from .convertors import * # NOQA
|
||||
|
@ -11,8 +11,9 @@ from .losses import * # NOQA
|
|||
from .necks import * # NOQA
|
||||
from .preprocessor import * # NOQA
|
||||
from .recognizer import * # NOQA
|
||||
from .fusers import * # NOQA
|
||||
|
||||
__all__ = (
|
||||
backbones.__all__ + convertors.__all__ + decoders.__all__ +
|
||||
encoders.__all__ + heads.__all__ + losses.__all__ + necks.__all__ +
|
||||
preprocessor.__all__ + recognizer.__all__)
|
||||
preprocessor.__all__ + recognizer.__all__ + fusers.__all__)
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .nrtr_modality_transformer import NRTRModalityTransform
|
||||
from .resnet31_ocr import ResNet31OCR
|
||||
from .resnet_abi import ResNetABI
|
||||
from .shallow_cnn import ShallowCNN
|
||||
from .very_deep_vgg import VeryDeepVgg
|
||||
|
||||
__all__ = ['ResNet31OCR', 'VeryDeepVgg', 'NRTRModalityTransform', 'ShallowCNN']
|
||||
__all__ = [
|
||||
'ResNet31OCR', 'VeryDeepVgg', 'NRTRModalityTransform', 'ShallowCNN',
|
||||
'ResNetABI'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule, Sequential
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import BACKBONES
|
||||
from mmocr.models.textrecog.layers import BasicBlock
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNetABI(BaseModule):
|
||||
"""Implement ResNet backbone for text recognition, modified from `ResNet.
|
||||
|
||||
<https://arxiv.org/pdf/1512.03385.pdf>`_ and
|
||||
`<https://github.com/FangShancheng/ABINet>`_
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels of input image tensor.
|
||||
stem_channels (int): Number of stem channels.
|
||||
base_channels (int): Number of base channels.
|
||||
arch_settings (list[int]): List of BasicBlock number for each stage.
|
||||
strides (Sequence[int]): Strides of the first block of each stage.
|
||||
out_indices (None | Sequence[int]): Indices of output stages. If not
|
||||
specified, only the last stage will be returned.
|
||||
last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
stem_channels=32,
|
||||
base_channels=32,
|
||||
arch_settings=[3, 4, 6, 6, 3],
|
||||
strides=[2, 1, 2, 1, 1],
|
||||
out_indices=None,
|
||||
last_stage_pool=False,
|
||||
init_cfg=[
|
||||
dict(type='Xavier', layer='Conv2d'),
|
||||
dict(type='Constant', val=1, layer='BatchNorm2d')
|
||||
]):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(stem_channels, int)
|
||||
assert utils.is_type_list(arch_settings, int)
|
||||
assert utils.is_type_list(strides, int)
|
||||
assert len(arch_settings) == len(strides)
|
||||
assert out_indices is None or isinstance(out_indices, (list, tuple))
|
||||
assert isinstance(last_stage_pool, bool)
|
||||
|
||||
self.out_indices = out_indices
|
||||
self.last_stage_pool = last_stage_pool
|
||||
self.block = BasicBlock
|
||||
self.inplanes = stem_channels
|
||||
|
||||
self._make_stem_layer(in_channels, stem_channels)
|
||||
|
||||
self.res_layers = []
|
||||
planes = base_channels
|
||||
for i, num_blocks in enumerate(arch_settings):
|
||||
stride = strides[i]
|
||||
res_layer = self._make_layer(
|
||||
block=self.block,
|
||||
inplanes=self.inplanes,
|
||||
planes=planes,
|
||||
blocks=num_blocks,
|
||||
stride=stride)
|
||||
self.inplanes = planes * self.block.expansion
|
||||
planes *= 2
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, res_layer)
|
||||
self.res_layers.append(layer_name)
|
||||
|
||||
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
||||
layers = []
|
||||
downsample = None
|
||||
if stride != 1 or inplanes != planes:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(inplanes, planes, 1, stride, bias=False),
|
||||
nn.BatchNorm2d(planes),
|
||||
)
|
||||
layers.append(
|
||||
block(
|
||||
inplanes,
|
||||
planes,
|
||||
use_conv1x1=True,
|
||||
stride=stride,
|
||||
downsample=downsample))
|
||||
inplanes = planes
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(inplanes, planes, use_conv1x1=True))
|
||||
|
||||
return Sequential(*layers)
|
||||
|
||||
def _make_stem_layer(self, in_channels, stem_channels):
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, stem_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.bn1 = nn.BatchNorm2d(stem_channels)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): Image tensor of shape :math:`(N, 3, H, W)`.
|
||||
|
||||
Returns:
|
||||
Tensor or list[Tensor]: Feature tensor. Its shape depends on
|
||||
ResNetABI's config. It can be a list of feature outputs at specific
|
||||
layers if ``out_indices`` is specified.
|
||||
"""
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.res_layers):
|
||||
res_layer = getattr(self, layer_name)
|
||||
x = res_layer(x)
|
||||
if self.out_indices and i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
return tuple(outs) if self.out_indices else x
|
|
@ -1,7 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .abi import ABIConvertor
|
||||
from .attn import AttnConvertor
|
||||
from .base import BaseConvertor
|
||||
from .ctc import CTCConvertor
|
||||
from .seg import SegConvertor
|
||||
|
||||
__all__ = ['BaseConvertor', 'CTCConvertor', 'AttnConvertor', 'SegConvertor']
|
||||
__all__ = [
|
||||
'BaseConvertor', 'CTCConvertor', 'AttnConvertor', 'SegConvertor',
|
||||
'ABIConvertor'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import CONVERTORS
|
||||
from .attn import AttnConvertor
|
||||
|
||||
|
||||
@CONVERTORS.register_module()
|
||||
class ABIConvertor(AttnConvertor):
|
||||
"""Convert between text, index and tensor for encoder-decoder based
|
||||
pipeline. Modified from AttnConvertor to get closer to ABINet's original
|
||||
implementation.
|
||||
|
||||
Args:
|
||||
dict_type (str): Type of dict, should be one of {'DICT36', 'DICT90'}.
|
||||
dict_file (None|str): Character dict file path. If not none,
|
||||
higher priority than dict_type.
|
||||
dict_list (None|list[str]): Character list. If not none, higher
|
||||
priority than dict_type, but lower than dict_file.
|
||||
with_unknown (bool): If True, add `UKN` token to class.
|
||||
max_seq_len (int): Maximum sequence length of label.
|
||||
lower (bool): If True, convert original string to lower case.
|
||||
start_end_same (bool): Whether use the same index for
|
||||
start and end token or not. Default: True.
|
||||
"""
|
||||
|
||||
def str2tensor(self, strings):
|
||||
"""
|
||||
Convert text-string into tensor. Different from
|
||||
:obj:`mmocr.models.textrecog.convertors.AttnConvertor`, the targets
|
||||
field returns target index no longer than max_seq_len (EOS token
|
||||
included).
|
||||
|
||||
Args:
|
||||
strings (list[str]): For instance, ['hello', 'world']
|
||||
|
||||
Returns:
|
||||
dict: A dict with two tensors.
|
||||
|
||||
- | targets (list[Tensor]): [torch.Tensor([1,2,3,3,4,8]),
|
||||
torch.Tensor([5,4,6,3,7,8])]
|
||||
- | padded_targets (Tensor): Tensor of shape
|
||||
(bsz * max_seq_len)).
|
||||
"""
|
||||
assert utils.is_type_list(strings, str)
|
||||
|
||||
tensors, padded_targets = [], []
|
||||
indexes = self.str2idx(strings)
|
||||
for index in indexes:
|
||||
tensor = torch.LongTensor(index[:self.max_seq_len - 1] +
|
||||
[self.end_idx])
|
||||
tensors.append(tensor)
|
||||
# target tensor for loss
|
||||
src_target = torch.LongTensor(tensor.size(0) + 1).fill_(0)
|
||||
src_target[0] = self.start_idx
|
||||
src_target[1:] = tensor
|
||||
padded_target = (torch.ones(self.max_seq_len) *
|
||||
self.padding_idx).long()
|
||||
char_num = src_target.size(0)
|
||||
if char_num > self.max_seq_len:
|
||||
padded_target = src_target[:self.max_seq_len]
|
||||
else:
|
||||
padded_target[:char_num] = src_target
|
||||
padded_targets.append(padded_target)
|
||||
padded_targets = torch.stack(padded_targets, 0).long()
|
||||
|
||||
return {'targets': tensors, 'padded_targets': padded_targets}
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .abinet_language_decoder import ABILanguageDecoder
|
||||
from .abinet_vision_decoder import ABIVisionDecoder
|
||||
from .base_decoder import BaseDecoder
|
||||
from .crnn_decoder import CRNNDecoder
|
||||
from .nrtr_decoder import NRTRDecoder
|
||||
|
@ -12,5 +14,5 @@ __all__ = [
|
|||
'CRNNDecoder', 'ParallelSARDecoder', 'SequentialSARDecoder',
|
||||
'ParallelSARDecoderWithBS', 'NRTRDecoder', 'BaseDecoder',
|
||||
'SequenceAttentionDecoder', 'PositionAttentionDecoder',
|
||||
'RobustScannerDecoder'
|
||||
'RobustScannerDecoder', 'ABILanguageDecoder', 'ABIVisionDecoder'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,181 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
||||
from mmcv.runner import ModuleList
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.models.common.modules import PositionalEncoding
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
class ABILanguageDecoder(BaseDecoder):
|
||||
r"""Transformer-based language model responsible for spell correction.
|
||||
Implementation of language model of \
|
||||
`ABINet <https://arxiv.org/abs/1910.04396>`_.
|
||||
|
||||
Args:
|
||||
d_model (int): Hidden size of input.
|
||||
n_head (int): Number of multi-attention heads.
|
||||
d_inner (int): Hidden size of feedforward network model.
|
||||
n_layers (int): The number of similar decoding layers.
|
||||
max_seq_len (int): Maximum text sequence length :math:`T`.
|
||||
dropout (float): Dropout rate.
|
||||
detach_tokens (bool): Whether to block the gradient flow at input
|
||||
tokens.
|
||||
num_chars (int): Number of text characters :math:`C`.
|
||||
use_self_attn (bool): If True, use self attention in decoder layers,
|
||||
otherwise cross attention will be used.
|
||||
pad_idx (bool): The index of the token indicating the end of output,
|
||||
which is used to compute the length of output. It is usually the
|
||||
index of `<EOS>` or `<PAD>` token.
|
||||
init_cfg (dict): Specifies the initialization method for model layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_model=512,
|
||||
n_head=8,
|
||||
d_inner=2048,
|
||||
n_layers=4,
|
||||
max_seq_len=40,
|
||||
dropout=0.1,
|
||||
detach_tokens=True,
|
||||
num_chars=90,
|
||||
use_self_attn=False,
|
||||
pad_idx=0,
|
||||
init_cfg=None,
|
||||
**kwargs):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.detach_tokens = detach_tokens
|
||||
|
||||
self.d_model = d_model
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
self.proj = nn.Linear(num_chars, d_model, False)
|
||||
self.token_encoder = PositionalEncoding(
|
||||
d_model, n_position=self.max_seq_len, dropout=0.1)
|
||||
self.pos_encoder = PositionalEncoding(
|
||||
d_model, n_position=self.max_seq_len)
|
||||
self.pad_idx = pad_idx
|
||||
|
||||
if use_self_attn:
|
||||
operation_order = ('self_attn', 'norm', 'cross_attn', 'norm',
|
||||
'ffn', 'norm')
|
||||
else:
|
||||
operation_order = ('cross_attn', 'norm', 'ffn', 'norm')
|
||||
|
||||
decoder_layer = BaseTransformerLayer(
|
||||
operation_order=operation_order,
|
||||
attn_cfgs=dict(
|
||||
type='MultiheadAttention',
|
||||
embed_dims=d_model,
|
||||
num_heads=n_head,
|
||||
attn_drop=dropout,
|
||||
dropout_layer=dict(type='Dropout', drop_prob=dropout),
|
||||
),
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
embed_dims=d_model,
|
||||
feedforward_channels=d_inner,
|
||||
ffn_drop=dropout,
|
||||
),
|
||||
norm_cfg=dict(type='LN'),
|
||||
)
|
||||
self.decoder_layers = ModuleList(
|
||||
[copy.deepcopy(decoder_layer) for _ in range(n_layers)])
|
||||
|
||||
self.cls = nn.Linear(d_model, num_chars)
|
||||
|
||||
def forward_train(self, feat, logits, targets_dict, img_metas):
|
||||
"""
|
||||
Args:
|
||||
logits (Tensor): Raw language logitis. Shape (N, T, C).
|
||||
|
||||
Returns:
|
||||
A dict with keys ``feature`` and ``logits``.
|
||||
feature (Tensor): Shape (N, T, E). Raw textual features for vision
|
||||
language aligner.
|
||||
logits (Tensor): Shape (N, T, C). The raw logits for characters
|
||||
after spell correction.
|
||||
"""
|
||||
lengths = self._get_length(logits)
|
||||
lengths.clamp_(2, self.max_seq_len)
|
||||
tokens = torch.softmax(logits, dim=-1)
|
||||
if self.detach_tokens:
|
||||
tokens = tokens.detach()
|
||||
embed = self.proj(tokens) # (N, T, E)
|
||||
embed = self.token_encoder(embed) # (N, T, E)
|
||||
padding_mask = self._get_padding_mask(lengths, self.max_seq_len)
|
||||
|
||||
zeros = embed.new_zeros(*embed.shape)
|
||||
query = self.pos_encoder(zeros)
|
||||
query = query.permute(1, 0, 2) # (T, N, E)
|
||||
embed = embed.permute(1, 0, 2)
|
||||
location_mask = self._get_location_mask(self.max_seq_len,
|
||||
tokens.device)
|
||||
output = query
|
||||
for m in self.decoder_layers:
|
||||
output = m(
|
||||
query=output,
|
||||
key=embed,
|
||||
value=embed,
|
||||
attn_masks=location_mask,
|
||||
key_padding_mask=padding_mask)
|
||||
output = output.permute(1, 0, 2) # (N, T, E)
|
||||
|
||||
logits = self.cls(output) # (N, T, C)
|
||||
return {'feature': output, 'logits': logits}
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
return self.forward_train(feat, out_enc, None, img_metas)
|
||||
|
||||
def _get_length(self, logit, dim=-1):
|
||||
"""Greedy decoder to obtain length from logit.
|
||||
|
||||
Returns the first location of padding index or the length of the entire
|
||||
tensor otherwise.
|
||||
"""
|
||||
# out as a boolean vector indicating the existence of end token(s)
|
||||
out = (logit.argmax(dim=-1) == self.pad_idx)
|
||||
abn = out.any(dim)
|
||||
# Get the first index of end token
|
||||
out = ((out.cumsum(dim) == 1) & out).max(dim)[1]
|
||||
out = out + 1
|
||||
out = torch.where(abn, out, out.new_tensor(logit.shape[1]))
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _get_location_mask(seq_len, device=None):
|
||||
"""Generate location masks given input sequence length.
|
||||
|
||||
Args:
|
||||
seq_len (int): The length of input sequence to transformer.
|
||||
device (torch.device or str, optional): The device on which the
|
||||
masks will be placed.
|
||||
|
||||
Returns:
|
||||
Tensor: A mask tensor of shape (seq_len, seq_len) with -infs on
|
||||
diagonal and zeros elsewhere.
|
||||
"""
|
||||
mask = torch.eye(seq_len, device=device)
|
||||
mask = mask.float().masked_fill(mask == 1, float('-inf'))
|
||||
return mask
|
||||
|
||||
@staticmethod
|
||||
def _get_padding_mask(length, max_length):
|
||||
"""Generate padding masks.
|
||||
|
||||
Args:
|
||||
length (Tensor): Shape :math:`(N,)`.
|
||||
max_length (int): The maximum sequence length :math:`T`.
|
||||
|
||||
Returns:
|
||||
Tensor: A bool tensor of shape :math:`(N, T)` with Trues on
|
||||
elements located over the length, or Falses elsewhere.
|
||||
"""
|
||||
length = length.unsqueeze(-1)
|
||||
grid = torch.arange(0, max_length, device=length.device).unsqueeze(0)
|
||||
return grid >= length
|
|
@ -0,0 +1,167 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.models.common.modules import PositionalEncoding
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
class ABIVisionDecoder(BaseDecoder):
|
||||
"""Converts visual features into text characters.
|
||||
|
||||
Implementation of VisionEncoder in
|
||||
`ABINet <https://arxiv.org/abs/1910.04396>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels :math:`E` of input vector.
|
||||
num_channels (int): Number of channels of hidden vectors in mini U-Net.
|
||||
h (int): Height :math:`H` of input image features.
|
||||
w (int): Width :math:`W` of input image features.
|
||||
|
||||
in_channels (int): Number of channels of input image features.
|
||||
num_channels (int): Number of channels of hidden vectors in mini U-Net.
|
||||
attn_height (int): Height :math:`H` of input image features.
|
||||
attn_width (int): Width :math:`W` of input image features.
|
||||
attn_mode (str): Upsampling mode for :obj:`torch.nn.Upsample` in mini
|
||||
U-Net.
|
||||
max_seq_len (int): Maximum text sequence length :math:`T`.
|
||||
num_chars (int): Number of text characters :math:`C`.
|
||||
init_cfg (dict): Specifies the initialization method for model layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=512,
|
||||
num_channels=64,
|
||||
attn_height=8,
|
||||
attn_width=32,
|
||||
attn_mode='nearest',
|
||||
max_seq_len=40,
|
||||
num_chars=90,
|
||||
init_cfg=dict(type='Xavier', layer='Conv2d'),
|
||||
**kwargs):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
# For mini-Unet
|
||||
self.k_encoder = nn.Sequential(
|
||||
self._encoder_layer(in_channels, num_channels, stride=(1, 2)),
|
||||
self._encoder_layer(num_channels, num_channels, stride=(2, 2)),
|
||||
self._encoder_layer(num_channels, num_channels, stride=(2, 2)),
|
||||
self._encoder_layer(num_channels, num_channels, stride=(2, 2)))
|
||||
|
||||
self.k_decoder = nn.Sequential(
|
||||
self._decoder_layer(
|
||||
num_channels, num_channels, scale_factor=2, mode=attn_mode),
|
||||
self._decoder_layer(
|
||||
num_channels, num_channels, scale_factor=2, mode=attn_mode),
|
||||
self._decoder_layer(
|
||||
num_channels, num_channels, scale_factor=2, mode=attn_mode),
|
||||
self._decoder_layer(
|
||||
num_channels,
|
||||
in_channels,
|
||||
size=(attn_height, attn_width),
|
||||
mode=attn_mode))
|
||||
|
||||
self.pos_encoder = PositionalEncoding(in_channels, max_seq_len)
|
||||
self.project = nn.Linear(in_channels, in_channels)
|
||||
self.cls = nn.Linear(in_channels, num_chars)
|
||||
|
||||
def forward_train(self,
|
||||
feat,
|
||||
out_enc=None,
|
||||
targets_dict=None,
|
||||
img_metas=None):
|
||||
"""
|
||||
Args:
|
||||
feat (Tensor): Image features of shape (N, E, H, W).
|
||||
|
||||
Returns:
|
||||
dict: A dict with keys ``feature``, ``logits`` and ``attn_scores``.
|
||||
|
||||
- | feature (Tensor): Shape (N, T, E). Raw visual features for
|
||||
language decoder.
|
||||
- | logits (Tensor): Shape (N, T, C). The raw logits for
|
||||
characters.
|
||||
- | attn_scores (Tensor): Shape (N, T, H, W). Intermediate result
|
||||
for vision-language aligner.
|
||||
"""
|
||||
# Position Attention
|
||||
N, E, H, W = feat.size()
|
||||
k, v = feat, feat # (N, E, H, W)
|
||||
|
||||
# Apply mini U-Net on k
|
||||
features = []
|
||||
for i in range(len(self.k_encoder)):
|
||||
k = self.k_encoder[i](k)
|
||||
features.append(k)
|
||||
for i in range(len(self.k_decoder) - 1):
|
||||
k = self.k_decoder[i](k)
|
||||
k = k + features[len(self.k_decoder) - 2 - i]
|
||||
k = self.k_decoder[-1](k)
|
||||
|
||||
# q = positional encoding
|
||||
zeros = feat.new_zeros((N, self.max_seq_len, E)) # (N, T, E)
|
||||
q = self.pos_encoder(zeros) # (N, T, E)
|
||||
q = self.project(q) # (N, T, E)
|
||||
|
||||
# Attention encoding
|
||||
attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W))
|
||||
attn_scores = attn_scores / (E**0.5)
|
||||
attn_scores = torch.softmax(attn_scores, dim=-1)
|
||||
v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E)
|
||||
attn_vecs = torch.bmm(attn_scores, v) # (N, T, E)
|
||||
|
||||
logits = self.cls(attn_vecs)
|
||||
result = {
|
||||
'feature': attn_vecs,
|
||||
'logits': logits,
|
||||
'attn_scores': attn_scores.view(N, -1, H, W)
|
||||
}
|
||||
return result
|
||||
|
||||
def forward_test(self, feat, out_enc=None, img_metas=None):
|
||||
return self.forward_train(feat, out_enc=out_enc, img_metas=img_metas)
|
||||
|
||||
def _encoder_layer(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1):
|
||||
return ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'))
|
||||
|
||||
def _decoder_layer(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
mode='nearest',
|
||||
scale_factor=None,
|
||||
size=None):
|
||||
align_corners = None if mode == 'nearest' else True
|
||||
return nn.Sequential(
|
||||
nn.Upsample(
|
||||
size=size,
|
||||
scale_factor=scale_factor,
|
||||
mode=mode,
|
||||
align_corners=align_corners),
|
||||
ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU')))
|
|
@ -1,11 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .abinet_vision_model import ABIVisionModel
|
||||
from .base_encoder import BaseEncoder
|
||||
from .channel_reduction_encoder import ChannelReductionEncoder
|
||||
from .nrtr_encoder import NRTREncoder
|
||||
from .sar_encoder import SAREncoder
|
||||
from .satrn_encoder import SatrnEncoder
|
||||
from .transformer import TransformerEncoder
|
||||
|
||||
__all__ = [
|
||||
'SAREncoder', 'NRTREncoder', 'BaseEncoder', 'ChannelReductionEncoder',
|
||||
'SatrnEncoder'
|
||||
'SatrnEncoder', 'TransformerEncoder', 'ABIVisionModel'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import ENCODERS, build_decoder, build_encoder
|
||||
from .base_encoder import BaseEncoder
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
class ABIVisionModel(BaseEncoder):
|
||||
"""A wrapper of visual feature encoder and language token decoder that
|
||||
converts visual features into text tokens.
|
||||
|
||||
Implementation of VisionEncoder in
|
||||
`ABINet <https://arxiv.org/abs/1910.04396>`_.
|
||||
|
||||
Args:
|
||||
encoder (dict): Config for image feature encoder.
|
||||
decoder (dict): Config for language token decoder.
|
||||
init_cfg (dict): Specifies the initialization method for model layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
encoder=dict(type='TransformerEncoder'),
|
||||
decoder=dict(type='ABIVisionDecoder'),
|
||||
init_cfg=dict(type='Xavier', layer='Conv2d'),
|
||||
**kwargs):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.encoder = build_encoder(encoder)
|
||||
self.decoder = build_decoder(decoder)
|
||||
|
||||
def forward(self, feat, img_metas=None):
|
||||
"""
|
||||
Args:
|
||||
feat (Tensor): Images of shape (N, E, H, W).
|
||||
|
||||
Returns:
|
||||
dict: A dict with keys ``feature``, ``logits`` and ``attn_scores``.
|
||||
|
||||
- | feature (Tensor): Shape (N, T, E). Raw visual features for
|
||||
language decoder.
|
||||
- | logits (Tensor): Shape (N, T, C). The raw logits for
|
||||
characters. C is the number of characters.
|
||||
- | attn_scores (Tensor): Shape (N, T, H, W). Intermediate result
|
||||
for vision-language aligner.
|
||||
"""
|
||||
feat = self.encoder(feat)
|
||||
return self.decoder(feat=feat, out_enc=None)
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
|
||||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
||||
from mmcv.runner import BaseModule, ModuleList
|
||||
|
||||
from mmocr.models.builder import ENCODERS
|
||||
from mmocr.models.common.modules import PositionalEncoding
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
class TransformerEncoder(BaseModule):
|
||||
"""Implement transformer encoder for text recognition, modified from
|
||||
`<https://github.com/FangShancheng/ABINet>`.
|
||||
|
||||
Args:
|
||||
n_layers (int): Number of attention layers.
|
||||
n_head (int): Number of parallel attention heads.
|
||||
d_model (int): Dimension :math:`D_m` of the input from previous model.
|
||||
d_inner (int): Hidden dimension of feedforward layers.
|
||||
dropout (float): Dropout rate.
|
||||
max_len (int): Maximum output sequence length :math:`T`.
|
||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_layers=2,
|
||||
n_head=8,
|
||||
d_model=512,
|
||||
d_inner=2048,
|
||||
dropout=0.1,
|
||||
max_len=8 * 32,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert d_model % n_head == 0, 'd_model must be divisible by n_head'
|
||||
|
||||
self.pos_encoder = PositionalEncoding(d_model, n_position=max_len)
|
||||
encoder_layer = BaseTransformerLayer(
|
||||
operation_order=('self_attn', 'norm', 'ffn', 'norm'),
|
||||
attn_cfgs=dict(
|
||||
type='MultiheadAttention',
|
||||
embed_dims=d_model,
|
||||
num_heads=n_head,
|
||||
attn_drop=dropout,
|
||||
dropout_layer=dict(type='Dropout', drop_prob=dropout),
|
||||
),
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
embed_dims=d_model,
|
||||
feedforward_channels=d_inner,
|
||||
ffn_drop=dropout,
|
||||
),
|
||||
norm_cfg=dict(type='LN'),
|
||||
)
|
||||
self.transformer = ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for _ in range(n_layers)])
|
||||
|
||||
def forward(self, feature):
|
||||
"""
|
||||
Args:
|
||||
feature (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
|
||||
|
||||
Returns:
|
||||
Tensor: Features of shape :math:`(N, D_m, H, W)`.
|
||||
"""
|
||||
n, c, h, w = feature.shape
|
||||
feature = feature.view(n, c, -1).transpose(1, 2) # (n, h*w, c)
|
||||
feature = self.pos_encoder(feature) # (n, h*w, c)
|
||||
feature = feature.transpose(0, 1) # (h*w, n, c)
|
||||
for m in self.transformer:
|
||||
feature = m(feature)
|
||||
feature = feature.permute(1, 2, 0).view(n, c, h, w)
|
||||
return feature
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .abi_fuser import ABIFuser
|
||||
|
||||
__all__ = ['ABIFuser']
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.models.builder import FUSERS
|
||||
|
||||
|
||||
@FUSERS.register_module()
|
||||
class ABIFuser(BaseModule):
|
||||
"""Mix and align visual feature and linguistic feature Implementation of
|
||||
language model of `ABINet <https://arxiv.org/abs/1910.04396>`_.
|
||||
|
||||
Args:
|
||||
d_model (int): Hidden size of input.
|
||||
max_seq_len (int): Maximum text sequence length :math:`T`.
|
||||
num_chars (int): Number of text characters :math:`C`.
|
||||
init_cfg (dict): Specifies the initialization method for model layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_model=512,
|
||||
max_seq_len=40,
|
||||
num_chars=90,
|
||||
init_cfg=None,
|
||||
**kwargs):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.max_seq_len = max_seq_len + 1 # additional stop token
|
||||
self.w_att = nn.Linear(2 * d_model, d_model)
|
||||
self.cls = nn.Linear(d_model, num_chars)
|
||||
|
||||
def forward(self, l_feature, v_feature):
|
||||
"""
|
||||
Args:
|
||||
l_feature: (N, T, E) where T is length, N is batch size and
|
||||
d is dim of model.
|
||||
v_feature: (N, T, E) shape the same as l_feature.
|
||||
|
||||
Returns:
|
||||
A dict with key ``logits``
|
||||
The logits of shape (N, T, C) where N is batch size, T is length
|
||||
and C is the number of characters.
|
||||
"""
|
||||
f = torch.cat((l_feature, v_feature), dim=2)
|
||||
f_att = torch.sigmoid(self.w_att(f))
|
||||
output = f_att * v_feature + (1 - f_att) * l_feature
|
||||
|
||||
logits = self.cls(output) # (N, T, C)
|
||||
|
||||
return {'logits': logits}
|
|
@ -1,55 +1,36 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn.resnet import BasicBlock as MMCV_BasicBlock
|
||||
from mmcv.cnn.resnet import conv3x3
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False)
|
||||
in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
class BasicBlock(MMCV_BasicBlock):
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=False):
|
||||
super().__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
if downsample:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inplanes, planes * self.expansion, 1, stride, bias=False),
|
||||
nn.BatchNorm2d(planes * self.expansion),
|
||||
)
|
||||
else:
|
||||
self.downsample = nn.Sequential()
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
downsample=None,
|
||||
use_conv1x1=False,
|
||||
style='pytorch',
|
||||
with_cp=False):
|
||||
super().__init__(
|
||||
inplanes,
|
||||
planes,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
downsample=downsample,
|
||||
style=style,
|
||||
with_cp=with_cp)
|
||||
if use_conv1x1:
|
||||
self.conv1 = conv1x1(inplanes, planes)
|
||||
self.conv2 = conv3x3(planes, planes, stride)
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .ce_loss import CELoss, SARLoss, TFLoss
|
||||
from .ctc_loss import CTCLoss
|
||||
from .mix_loss import ABILoss
|
||||
from .seg_loss import SegLoss
|
||||
|
||||
__all__ = ['CELoss', 'SARLoss', 'CTCLoss', 'TFLoss', 'SegLoss']
|
||||
__all__ = ['CELoss', 'SARLoss', 'CTCLoss', 'TFLoss', 'SegLoss', 'ABILoss']
|
||||
|
|
|
@ -13,22 +13,35 @@ class CELoss(nn.Module):
|
|||
ignore_index (int): Specifies a target value that is
|
||||
ignored and does not contribute to the input gradient.
|
||||
reduction (str): Specifies the reduction to apply to the output,
|
||||
should be one of the following: ("none", "mean", "sum").
|
||||
should be one of the following: ('none', 'mean', 'sum').
|
||||
ignore_first_char (bool): Whether to ignore the first token in target (
|
||||
usually the start token). If ``True``, the last token of the output
|
||||
sequence will also be removed to be aligned with the target length.
|
||||
"""
|
||||
|
||||
def __init__(self, ignore_index=-1, reduction='none'):
|
||||
def __init__(self,
|
||||
ignore_index=-1,
|
||||
reduction='none',
|
||||
ignore_first_char=False):
|
||||
super().__init__()
|
||||
assert isinstance(ignore_index, int)
|
||||
assert isinstance(reduction, str)
|
||||
assert reduction in ['none', 'mean', 'sum']
|
||||
assert isinstance(ignore_first_char, bool)
|
||||
|
||||
self.loss_ce = nn.CrossEntropyLoss(
|
||||
ignore_index=ignore_index, reduction=reduction)
|
||||
self.ignore_first_char = ignore_first_char
|
||||
|
||||
def format(self, outputs, targets_dict):
|
||||
targets = targets_dict['padded_targets']
|
||||
if self.ignore_first_char:
|
||||
targets = targets[:, 1:].contiguous()
|
||||
outputs = outputs[:, :-1, :]
|
||||
|
||||
return outputs.permute(0, 2, 1).contiguous(), targets
|
||||
outputs = outputs.permute(0, 2, 1).contiguous()
|
||||
|
||||
return outputs, targets
|
||||
|
||||
def forward(self, outputs, targets_dict, img_metas=None):
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmocr.models.builder import LOSSES
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class ABILoss(nn.Module):
|
||||
"""Implementation of ABINet multiloss that allows mixing different types of
|
||||
losses with weights.
|
||||
|
||||
Args:
|
||||
enc_weight (float): The weight of encoder loss. Defaults to 1.0.
|
||||
dec_weight (float): The weight of decoder loss. Defaults to 1.0.
|
||||
fusion_weight (float): The weight of fuser (aligner) loss.
|
||||
Defaults to 1.0.
|
||||
num_classes (int): Number of unique output language tokens.
|
||||
|
||||
Returns:
|
||||
A dictionary whose key/value pairs are the losses of three modules.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
enc_weight=1.0,
|
||||
dec_weight=1.0,
|
||||
fusion_weight=1.0,
|
||||
num_classes=37,
|
||||
**kwargs):
|
||||
assert isinstance(enc_weight, float) or isinstance(enc_weight, int)
|
||||
assert isinstance(dec_weight, float) or isinstance(dec_weight, int)
|
||||
assert isinstance(fusion_weight, float) or \
|
||||
isinstance(fusion_weight, int)
|
||||
super().__init__()
|
||||
self.enc_weight = enc_weight
|
||||
self.dec_weight = dec_weight
|
||||
self.fusion_weight = fusion_weight
|
||||
self.num_classes = num_classes
|
||||
|
||||
def _flatten(self, logits, target_lens):
|
||||
flatten_logits = torch.cat(
|
||||
[s[:target_lens[i]] for i, s in enumerate((logits))])
|
||||
return flatten_logits
|
||||
|
||||
def _ce_loss(self, logits, targets):
|
||||
targets_one_hot = F.one_hot(targets, self.num_classes)
|
||||
log_prob = F.log_softmax(logits, dim=-1)
|
||||
loss = -(targets_one_hot.to(log_prob.device) * log_prob).sum(dim=-1)
|
||||
return loss.mean()
|
||||
|
||||
def _loss_over_iters(self, outputs, targets):
|
||||
"""
|
||||
Args:
|
||||
outputs (list[Tensor]): Each tensor has shape (N, T, C) where N is
|
||||
the batch size, T is the sequence length and C is the number of
|
||||
classes.
|
||||
targets_dicts (dict): The dictionary with at least `padded_targets`
|
||||
defined.
|
||||
"""
|
||||
iter_num = len(outputs)
|
||||
dec_outputs = torch.cat(outputs, dim=0)
|
||||
flatten_targets_iternum = targets.repeat(iter_num)
|
||||
return self._ce_loss(dec_outputs, flatten_targets_iternum)
|
||||
|
||||
def forward(self, outputs, targets_dict, img_metas=None):
|
||||
"""
|
||||
Args:
|
||||
outputs (dict): The output dictionary with at least one of
|
||||
``out_enc``, ``out_dec`` and ``out_fusers`` specified.
|
||||
targets_dict (dict): The target dictionary containing the key
|
||||
``padded_targets``, which represents target sequences in
|
||||
shape (batch_size, sequence_length).
|
||||
|
||||
Returns:
|
||||
A loss dictionary with ``loss_visual``, ``loss_lang`` and
|
||||
``loss_fusion``. Each should either be the loss tensor or ``0`` if
|
||||
the output of its corresponding module is not given.
|
||||
"""
|
||||
assert 'out_enc' in outputs or \
|
||||
'out_dec' in outputs or 'out_fusers' in outputs
|
||||
losses = {}
|
||||
|
||||
target_lens = [len(t) for t in targets_dict['targets']]
|
||||
flatten_targets = torch.cat([t for t in targets_dict['targets']])
|
||||
|
||||
if outputs.get('out_enc', None):
|
||||
enc_input = self._flatten(outputs['out_enc']['logits'],
|
||||
target_lens)
|
||||
enc_loss = self._ce_loss(enc_input,
|
||||
flatten_targets) * self.enc_weight
|
||||
losses['loss_visual'] = enc_loss
|
||||
if outputs.get('out_decs', None):
|
||||
dec_logits = [
|
||||
self._flatten(o['logits'], target_lens)
|
||||
for o in outputs['out_decs']
|
||||
]
|
||||
dec_loss = self._loss_over_iters(dec_logits,
|
||||
flatten_targets) * self.dec_weight
|
||||
losses['loss_lang'] = dec_loss
|
||||
if outputs.get('out_fusers', None):
|
||||
fusion_logits = [
|
||||
self._flatten(o['logits'], target_lens)
|
||||
for o in outputs['out_fusers']
|
||||
]
|
||||
fusion_loss = self._loss_over_iters(
|
||||
fusion_logits, flatten_targets) * self.fusion_weight
|
||||
losses['loss_fusion'] = fusion_loss
|
||||
return losses
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .abinet import ABINet
|
||||
from .base import BaseRecognizer
|
||||
from .crnn import CRNNNet
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
|
@ -10,5 +11,5 @@ from .seg_recognizer import SegRecognizer
|
|||
|
||||
__all__ = [
|
||||
'BaseRecognizer', 'EncodeDecodeRecognizer', 'CRNNNet', 'SARNet', 'NRTR',
|
||||
'SegRecognizer', 'RobustScanner', 'SATRN'
|
||||
'SegRecognizer', 'RobustScanner', 'SATRN', 'ABINet'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,192 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.models.builder import (DETECTORS, build_backbone, build_convertor,
|
||||
build_decoder, build_encoder, build_fuser,
|
||||
build_loss, build_preprocessor)
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class ABINet(EncodeDecodeRecognizer):
|
||||
"""Implementation of `Read Like Humans: Autonomous, Bidirectional and
|
||||
Iterative LanguageModeling for Scene Text Recognition.
|
||||
|
||||
<https://arxiv.org/pdf/2103.06495.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
preprocessor=None,
|
||||
backbone=None,
|
||||
encoder=None,
|
||||
decoder=None,
|
||||
iter_size=1,
|
||||
fuser=None,
|
||||
loss=None,
|
||||
label_convertor=None,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
max_seq_len=40,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super(EncodeDecodeRecognizer, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
# Label convertor (str2tensor, tensor2str)
|
||||
assert label_convertor is not None
|
||||
label_convertor.update(max_seq_len=max_seq_len)
|
||||
self.label_convertor = build_convertor(label_convertor)
|
||||
|
||||
# Preprocessor module, e.g., TPS
|
||||
self.preprocessor = None
|
||||
if preprocessor is not None:
|
||||
self.preprocessor = build_preprocessor(preprocessor)
|
||||
|
||||
# Backbone
|
||||
assert backbone is not None
|
||||
self.backbone = build_backbone(backbone)
|
||||
|
||||
# Encoder module
|
||||
self.encoder = None
|
||||
if encoder is not None:
|
||||
self.encoder = build_encoder(encoder)
|
||||
|
||||
# Decoder module
|
||||
self.decoder = None
|
||||
if decoder is not None:
|
||||
decoder.update(num_classes=self.label_convertor.num_classes())
|
||||
decoder.update(start_idx=self.label_convertor.start_idx)
|
||||
decoder.update(padding_idx=self.label_convertor.padding_idx)
|
||||
decoder.update(max_seq_len=max_seq_len)
|
||||
self.decoder = build_decoder(decoder)
|
||||
|
||||
# Loss
|
||||
assert loss is not None
|
||||
self.loss = build_loss(loss)
|
||||
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
if pretrained is not None:
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated \
|
||||
key, please consider using init_cfg')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
|
||||
self.iter_size = iter_size
|
||||
|
||||
self.fuser = None
|
||||
if fuser is not None:
|
||||
self.fuser = build_fuser(fuser)
|
||||
|
||||
def forward_train(self, img, img_metas):
|
||||
"""
|
||||
Args:
|
||||
img (tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
img_metas (list[dict]): A list of image info dict where each dict
|
||||
contains: 'img_shape', 'filename', and may also contain
|
||||
'ori_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
:class:`mmdet.datasets.pipelines.Collect`.
|
||||
|
||||
Returns:
|
||||
dict[str, tensor]: A dictionary of loss components.
|
||||
"""
|
||||
for img_meta in img_metas:
|
||||
valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1)
|
||||
img_meta['valid_ratio'] = valid_ratio
|
||||
|
||||
feat = self.extract_feat(img)
|
||||
|
||||
gt_labels = [img_meta['text'] for img_meta in img_metas]
|
||||
|
||||
targets_dict = self.label_convertor.str2tensor(gt_labels)
|
||||
|
||||
text_logits = None
|
||||
out_enc = None
|
||||
if self.encoder is not None:
|
||||
out_enc = self.encoder(feat)
|
||||
text_logits = out_enc['logits']
|
||||
|
||||
out_decs = []
|
||||
out_fusers = []
|
||||
for _ in range(self.iter_size):
|
||||
if self.decoder is not None:
|
||||
out_dec = self.decoder(
|
||||
feat,
|
||||
text_logits,
|
||||
targets_dict,
|
||||
img_metas,
|
||||
train_mode=True)
|
||||
out_decs.append(out_dec)
|
||||
|
||||
if self.fuser is not None:
|
||||
out_fuser = self.fuser(out_enc['feature'], out_dec['feature'])
|
||||
text_logits = out_fuser['logits']
|
||||
out_fusers.append(out_fuser)
|
||||
|
||||
outputs = dict(
|
||||
out_enc=out_enc, out_decs=out_decs, out_fusers=out_fusers)
|
||||
|
||||
losses = self.loss(outputs, targets_dict, img_metas)
|
||||
|
||||
return losses
|
||||
|
||||
def simple_test(self, img, img_metas, **kwargs):
|
||||
"""Test function with test time augmentation.
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): Image input tensor.
|
||||
img_metas (list[dict]): List of image information.
|
||||
|
||||
Returns:
|
||||
list[str]: Text label result of each image.
|
||||
"""
|
||||
for img_meta in img_metas:
|
||||
valid_ratio = 1.0 * img_meta['resize_shape'][1] / img.size(-1)
|
||||
img_meta['valid_ratio'] = valid_ratio
|
||||
|
||||
feat = self.extract_feat(img)
|
||||
|
||||
text_logits = None
|
||||
out_enc = None
|
||||
if self.encoder is not None:
|
||||
out_enc = self.encoder(feat)
|
||||
text_logits = out_enc['logits']
|
||||
|
||||
out_decs = []
|
||||
out_fusers = []
|
||||
for _ in range(self.iter_size):
|
||||
if self.decoder is not None:
|
||||
out_dec = self.decoder(
|
||||
feat, text_logits, img_metas=img_metas, train_mode=False)
|
||||
out_decs.append(out_dec)
|
||||
|
||||
if self.fuser is not None:
|
||||
out_fuser = self.fuser(out_enc['feature'], out_dec['feature'])
|
||||
text_logits = out_fuser['logits']
|
||||
out_fusers.append(out_fuser)
|
||||
|
||||
if len(out_fusers) > 0:
|
||||
ret = out_fusers[-1]
|
||||
elif len(out_decs) > 0:
|
||||
ret = out_decs[-1]
|
||||
else:
|
||||
ret = out_enc
|
||||
|
||||
# early return to avoid post processing
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
return ret['logits']
|
||||
|
||||
label_indexes, label_scores = self.label_convertor.tensor2idx(
|
||||
ret['logits'], img_metas)
|
||||
label_strings = self.label_convertor.idx2str(label_indexes)
|
||||
|
||||
# flatten batch results
|
||||
results = []
|
||||
for string, score in zip(label_strings, label_scores):
|
||||
results.append(dict(text=string, score=score))
|
||||
|
||||
return results
|
|
@ -302,6 +302,10 @@ class MMOCR:
|
|||
'config': 'satrn/satrn_small.py',
|
||||
'ckpt': 'satrn/satrn_small_20211009-2cf13355.pth'
|
||||
},
|
||||
'ABINet': {
|
||||
'config': 'abinet/abinet_academic.py',
|
||||
'ckpt': 'abinet/abinet_academic-f718abf6.pth'
|
||||
},
|
||||
'SEG': {
|
||||
'config': 'seg/seg_r31_1by16_fpnocr_academic.py',
|
||||
'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth'
|
||||
|
|
|
@ -6,6 +6,7 @@ Import:
|
|||
- configs/textdet/panet/metafile.yml
|
||||
- configs/textdet/psenet/metafile.yml
|
||||
- configs/textdet/textsnake/metafile.yml
|
||||
- configs/textrecog/abinet/metafile.yml
|
||||
- configs/textrecog/crnn/metafile.yml
|
||||
- configs/textrecog/nrtr/metafile.yml
|
||||
- configs/textrecog/robust_scanner/metafile.yml
|
||||
|
|
|
@ -24,6 +24,7 @@ def build_model(config_file):
|
|||
|
||||
@pytest.mark.parametrize('cfg_file', [
|
||||
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
|
||||
'../configs/textrecog/abinet/abinet_academic.py',
|
||||
'../configs/textrecog/crnn/crnn_academic_dataset.py',
|
||||
'../configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py',
|
||||
'../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py'
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import unittest.mock as mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mmocr.datasets.pipelines import (OneOfWrapper, RandomWrapper,
|
||||
TorchVisionWrapper)
|
||||
from mmocr.datasets.pipelines.transforms import ColorJitter
|
||||
|
||||
|
||||
def test_torchvision_wrapper():
|
||||
x = {'img': np.ones((128, 100, 3), dtype=np.uint8)}
|
||||
# object not found error
|
||||
with pytest.raises(Exception):
|
||||
TorchVisionWrapper(op='NonExist')
|
||||
with pytest.raises(TypeError):
|
||||
TorchVisionWrapper()
|
||||
f = TorchVisionWrapper('Grayscale')
|
||||
with pytest.raises(AssertionError):
|
||||
f({})
|
||||
results = f(x)
|
||||
assert results['img'].shape == (128, 100)
|
||||
assert results['img_shape'] == (128, 100)
|
||||
|
||||
|
||||
@mock.patch('random.choice')
|
||||
def test_oneof(rand_choice):
|
||||
color_jitter = dict(type='TorchVisionWrapper', op='ColorJitter')
|
||||
gray_scale = dict(type='TorchVisionWrapper', op='Grayscale')
|
||||
x = {'img': np.random.randint(0, 256, size=(128, 100, 3), dtype=np.uint8)}
|
||||
f = OneOfWrapper([color_jitter, gray_scale])
|
||||
# Use color_jitter at the first call
|
||||
rand_choice.side_effect = lambda x: x[0]
|
||||
results = f(x)
|
||||
assert results['img'].shape == (128, 100, 3)
|
||||
# Use gray_scale at the second call
|
||||
rand_choice.side_effect = lambda x: x[1]
|
||||
results = f(x)
|
||||
assert results['img'].shape == (128, 100)
|
||||
|
||||
# Passing object
|
||||
f = OneOfWrapper([ColorJitter(), gray_scale])
|
||||
# Use color_jitter at the first call
|
||||
results = f(x)
|
||||
assert results['img'].shape == (128, 100)
|
||||
|
||||
# Test invalid inputs
|
||||
with pytest.raises(AssertionError):
|
||||
f = OneOfWrapper(None)
|
||||
with pytest.raises(AssertionError):
|
||||
f = OneOfWrapper([])
|
||||
with pytest.raises(AssertionError):
|
||||
f = OneOfWrapper({})
|
||||
|
||||
|
||||
@mock.patch('numpy.random.uniform')
|
||||
def test_runwithprob(np_random_uniform):
|
||||
np_random_uniform.side_effect = [0.1, 0.9]
|
||||
f = RandomWrapper([dict(type='TorchVisionWrapper', op='Grayscale')], 0.5)
|
||||
img = np.random.randint(0, 256, size=(128, 100, 3), dtype=np.uint8)
|
||||
results = f({'img': copy.deepcopy(img)})
|
||||
assert results['img'].shape == (128, 100)
|
||||
results = f({'img': copy.deepcopy(img)})
|
||||
assert results['img'].shape == (128, 100, 3)
|
|
@ -1,7 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import unittest.mock as mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torchvision.transforms as TF
|
||||
from mmdet.core import BitmapMasks, PolygonMasks
|
||||
from PIL import Image
|
||||
|
@ -343,3 +345,29 @@ def test_square_resize_pad(mock_sample):
|
|||
target[1::2] *= 8. / 3
|
||||
assert np.allclose(output['gt_masks'].masks[0][0], target)
|
||||
assert output['img'].shape == (40, 40, 3)
|
||||
|
||||
|
||||
def test_pyramid_rescale():
|
||||
img = np.random.randint(0, 256, size=(128, 100, 3), dtype=np.uint8)
|
||||
x = {'img': copy.deepcopy(img)}
|
||||
f = transforms.PyramidRescale()
|
||||
results = f(x)
|
||||
assert results['img'].shape == (128, 100, 3)
|
||||
|
||||
# Test invalid inputs
|
||||
with pytest.raises(AssertionError):
|
||||
transforms.PyramidRescale(base_shape=(128))
|
||||
with pytest.raises(AssertionError):
|
||||
transforms.PyramidRescale(base_shape=128)
|
||||
with pytest.raises(AssertionError):
|
||||
transforms.PyramidRescale(factor=[])
|
||||
with pytest.raises(AssertionError):
|
||||
transforms.PyramidRescale(randomize_factor=[])
|
||||
with pytest.raises(AssertionError):
|
||||
f({})
|
||||
|
||||
# Test factor = 0
|
||||
f_derandomized = transforms.PyramidRescale(
|
||||
factor=0, randomize_factor=False)
|
||||
results = f_derandomized({'img': copy.deepcopy(img)})
|
||||
assert np.all(results['img'] == img)
|
||||
|
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.convertors import AttnConvertor
|
||||
from mmocr.models.textrecog.convertors import ABIConvertor, AttnConvertor
|
||||
|
||||
|
||||
def _create_dummy_dict_file(dict_file):
|
||||
|
@ -76,3 +76,30 @@ def test_attn_label_convertor():
|
|||
assert output_strings[0] == 'hell'
|
||||
|
||||
tmp_dir.cleanup()
|
||||
|
||||
|
||||
def test_abi_label_convertor():
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
# create dummy data
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_dict.txt')
|
||||
_create_dummy_dict_file(dict_file)
|
||||
|
||||
label_convertor = ABIConvertor(dict_file=dict_file, max_seq_len=10)
|
||||
|
||||
label_convertor.end_idx
|
||||
# test encode str to tensor
|
||||
strings = ['hell']
|
||||
targets_dict = label_convertor.str2tensor(strings)
|
||||
assert torch.allclose(targets_dict['targets'][0],
|
||||
torch.LongTensor([0, 1, 2, 2, 8]))
|
||||
assert torch.allclose(targets_dict['padded_targets'][0],
|
||||
torch.LongTensor([8, 0, 1, 2, 2, 8, 9, 9, 9, 9]))
|
||||
|
||||
strings = ['hellhellhell']
|
||||
targets_dict = label_convertor.str2tensor(strings)
|
||||
assert torch.allclose(targets_dict['targets'][0],
|
||||
torch.LongTensor([0, 1, 2, 2, 0, 1, 2, 2, 0, 8]))
|
||||
assert torch.allclose(targets_dict['padded_targets'][0],
|
||||
torch.LongTensor([8, 0, 1, 2, 2, 0, 1, 2, 2, 0]))
|
||||
|
||||
tmp_dir.cleanup()
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.backbones import (ResNet31OCR, ShallowCNN,
|
||||
VeryDeepVgg)
|
||||
from mmocr.models.textrecog.backbones import (ResNet31OCR, ResNetABI,
|
||||
ShallowCNN, VeryDeepVgg)
|
||||
|
||||
|
||||
def test_resnet31_ocr_backbone():
|
||||
|
@ -47,3 +47,26 @@ def test_shallow_cnn_ocr_backbone():
|
|||
imgs = torch.randn(1, 1, 32, 100)
|
||||
feat = model(imgs)
|
||||
assert feat.shape == torch.Size([1, 512, 8, 25])
|
||||
|
||||
|
||||
def test_resnet_abi():
|
||||
"""Test resnet backbone."""
|
||||
with pytest.raises(AssertionError):
|
||||
ResNetABI(2.5)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
ResNetABI(3, arch_settings=5)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
ResNetABI(3, stem_channels=None)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
ResNetABI(arch_settings=[3, 4, 6, 6], strides=[1, 2, 1, 2, 1])
|
||||
|
||||
# Test forwarding
|
||||
model = ResNetABI()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 32, 160)
|
||||
feat = model(imgs)
|
||||
assert feat.shape == torch.Size([1, 512, 8, 40])
|
||||
|
|
|
@ -4,8 +4,9 @@ import math
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.decoders import (BaseDecoder, NRTRDecoder,
|
||||
ParallelSARDecoder,
|
||||
from mmocr.models.textrecog.decoders import (ABILanguageDecoder,
|
||||
ABIVisionDecoder, BaseDecoder,
|
||||
NRTRDecoder, ParallelSARDecoder,
|
||||
ParallelSARDecoderWithBS,
|
||||
SequentialSARDecoder)
|
||||
from mmocr.models.textrecog.decoders.sar_decoder_with_bs import DecodeNode
|
||||
|
@ -112,3 +113,22 @@ def test_transformer_decoder():
|
|||
|
||||
out_test = decoder(None, out_enc, tgt_dict, img_metas, False)
|
||||
assert out_test.shape == torch.Size([1, 5, 36])
|
||||
|
||||
|
||||
def test_abi_language_decoder():
|
||||
decoder = ABILanguageDecoder(max_seq_len=25)
|
||||
logits = torch.randn(2, 25, 90)
|
||||
result = decoder(
|
||||
feat=None, out_enc=logits, targets_dict=None, img_metas=None)
|
||||
assert result['feature'].shape == torch.Size([2, 25, 512])
|
||||
assert result['logits'].shape == torch.Size([2, 25, 90])
|
||||
|
||||
|
||||
def test_abi_vision_decoder():
|
||||
model = ABIVisionDecoder(
|
||||
in_channels=128, num_channels=16, max_seq_len=10, use_result=None)
|
||||
x = torch.randn(2, 128, 8, 32)
|
||||
result = model(x, None)
|
||||
assert result['feature'].shape == torch.Size([2, 10, 128])
|
||||
assert result['logits'].shape == torch.Size([2, 10, 90])
|
||||
assert result['attn_scores'].shape == torch.Size([2, 10, 8, 32])
|
||||
|
|
|
@ -2,8 +2,9 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.encoders import (BaseEncoder, NRTREncoder,
|
||||
SAREncoder, SatrnEncoder)
|
||||
from mmocr.models.textrecog.encoders import (ABIVisionModel, BaseEncoder,
|
||||
NRTREncoder, SAREncoder,
|
||||
SatrnEncoder, TransformerEncoder)
|
||||
|
||||
|
||||
def test_sar_encoder():
|
||||
|
@ -33,7 +34,7 @@ def test_sar_encoder():
|
|||
assert out_enc.shape == torch.Size([1, 512])
|
||||
|
||||
|
||||
def test_transformer_encoder():
|
||||
def test_nrtr_encoder():
|
||||
tf_encoder = NRTREncoder()
|
||||
tf_encoder.init_weights()
|
||||
tf_encoder.train()
|
||||
|
@ -62,3 +63,19 @@ def test_base_encoder():
|
|||
feat = torch.randn(1, 256, 4, 40)
|
||||
out_enc = encoder(feat)
|
||||
assert out_enc.shape == torch.Size([1, 256, 4, 40])
|
||||
|
||||
|
||||
def test_transformer_encoder():
|
||||
model = TransformerEncoder()
|
||||
x = torch.randn(10, 512, 8, 32)
|
||||
assert model(x).shape == torch.Size([10, 512, 8, 32])
|
||||
|
||||
|
||||
def test_abi_vision_model():
|
||||
model = ABIVisionModel(
|
||||
decoder=dict(type='ABIVisionDecoder', max_seq_len=10, use_result=None))
|
||||
x = torch.randn(1, 512, 8, 32)
|
||||
result = model(x)
|
||||
assert result['feature'].shape == torch.Size([1, 10, 512])
|
||||
assert result['logits'].shape == torch.Size([1, 10, 90])
|
||||
assert result['attn_scores'].shape == torch.Size([1, 10, 8, 32])
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.fusers import ABIFuser
|
||||
|
||||
|
||||
def test_base_alignment():
|
||||
model = ABIFuser(d_model=512, num_chars=90, max_seq_len=40)
|
||||
l_feat = torch.randn(1, 40, 512)
|
||||
v_feat = torch.randn(1, 40, 512)
|
||||
result = model(l_feat, v_feat)
|
||||
assert result['logits'].shape == torch.Size([1, 40, 90])
|
|
@ -3,7 +3,8 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmocr.models.common.losses import DiceLoss
|
||||
from mmocr.models.textrecog.losses import CELoss, CTCLoss, SARLoss, TFLoss
|
||||
from mmocr.models.textrecog.losses import (ABILoss, CELoss, CTCLoss, SARLoss,
|
||||
TFLoss)
|
||||
|
||||
|
||||
def test_ctc_loss():
|
||||
|
@ -46,9 +47,17 @@ def test_ce_loss():
|
|||
losses = ce_loss(outputs, targets_dict)
|
||||
assert isinstance(losses, dict)
|
||||
assert 'loss_ce' in losses
|
||||
print(losses['loss_ce'].size())
|
||||
assert losses['loss_ce'].size(1) == 10
|
||||
|
||||
ce_loss = CELoss(ignore_first_char=True)
|
||||
outputs = torch.rand(1, 10, 37)
|
||||
targets_dict = {
|
||||
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
|
||||
}
|
||||
new_output, new_target = ce_loss.format(outputs, targets_dict)
|
||||
assert new_output.shape == torch.Size([1, 37, 9])
|
||||
assert new_target.shape == torch.Size([1, 9])
|
||||
|
||||
|
||||
def test_sar_loss():
|
||||
outputs = torch.rand(1, 10, 37)
|
||||
|
@ -89,3 +98,36 @@ def test_dice_loss():
|
|||
mask = torch.rand(1, 1, 1, 1)
|
||||
loss = dice_loss(pred, gt, mask)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
|
||||
def test_abi_loss():
|
||||
loss = ABILoss(num_classes=90)
|
||||
outputs = dict(
|
||||
out_enc=dict(logits=torch.randn(2, 10, 90)),
|
||||
out_decs=[
|
||||
dict(logits=torch.randn(2, 10, 90)),
|
||||
dict(logits=torch.randn(2, 10, 90))
|
||||
],
|
||||
out_fusers=[
|
||||
dict(logits=torch.randn(2, 10, 90)),
|
||||
dict(logits=torch.randn(2, 10, 90))
|
||||
])
|
||||
targets_dict = {
|
||||
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]]),
|
||||
'targets':
|
||||
[torch.LongTensor([1, 2, 3, 4]),
|
||||
torch.LongTensor([1, 2, 3])]
|
||||
}
|
||||
result = loss(outputs, targets_dict)
|
||||
assert isinstance(result, dict)
|
||||
assert isinstance(result['loss_visual'], torch.Tensor)
|
||||
assert isinstance(result['loss_lang'], torch.Tensor)
|
||||
assert isinstance(result['loss_fusion'], torch.Tensor)
|
||||
|
||||
outputs.pop('out_enc')
|
||||
loss(outputs, targets_dict)
|
||||
outputs.pop('out_decs')
|
||||
loss(outputs, targets_dict)
|
||||
outputs.pop('out_fusers')
|
||||
with pytest.raises(AssertionError):
|
||||
loss(outputs, targets_dict)
|
||||
|
|
Loading…
Reference in New Issue