add sar, seg and other components

pull/2/head
Hongbin Sun 2021-04-02 23:54:57 +08:00
parent af78ffb407
commit 4ecd0cea8a
95 changed files with 7998 additions and 0 deletions

View File

@ -0,0 +1,96 @@
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
gt_label_convertor = dict(
type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomPaddingOCR',
max_ratio=[0.15, 0.2, 0.15, 0.2],
box_type='char_quads'),
dict(type='OpencvToPil'),
dict(
type='RandomRotateImageBox',
min_angle=-17,
max_angle=17,
box_type='char_quads'),
dict(type='PilToOpencv'),
dict(
type='ResizeOCR',
height=64,
min_width=64,
max_width=512,
keep_aspect_ratio=True),
dict(
type='OCRSegTargets',
label_convertor=gt_label_convertor,
box_type='char_quads'),
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
dict(type='ToTensorOCR'),
dict(type='FancyPCA'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='CustomFormatBundle',
keys=['gt_kernels'],
visualize=dict(flag=False, boundary_key=None),
call_super=False),
dict(
type='Collect',
keys=['img', 'gt_kernels'],
meta_keys=['filename', 'ori_shape', 'img_shape'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeOCR',
height=64,
min_width=64,
max_width=None,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(type='CustomFormatBundle', call_super=False),
dict(
type='Collect',
keys=['img'],
meta_keys=['filename', 'ori_shape', 'img_shape'])
]
prefix = 'tests/data/ocr_char_ann_toy_dataset/'
train = dict(
type='OCRSegDataset',
img_prefix=prefix + 'imgs',
ann_file=prefix + 'instances_train.txt',
loader=dict(
type='HardDiskLoader',
repeat=100,
parser=dict(
type='LineJsonParser', keys=['file_name', 'annotations', 'text'])),
pipeline=train_pipeline,
test_mode=True)
test = dict(
type='OCRDataset',
img_prefix=prefix + 'imgs',
ann_file=prefix + 'instances_test.txt',
loader=dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=test_pipeline,
test_mode=True)
data = dict(
samples_per_gpu=8,
workers_per_gpu=1,
train=dict(type='ConcatDataset', datasets=[train]),
val=dict(type='ConcatDataset', datasets=[test]),
test=dict(type='ConcatDataset', datasets=[test]))
evaluation = dict(interval=1, metric='acc')

View File

@ -0,0 +1,99 @@
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeOCR',
height=32,
min_width=32,
max_width=160,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiRotateAugOCR',
rotate_degrees=[0, 90, 270],
transforms=[
dict(
type='ResizeOCR',
height=32,
min_width=32,
max_width=160,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
]),
])
]
dataset_type = 'OCRDataset'
img_prefix = 'tests/data/ocr_toy_dataset/imgs'
train_anno_file1 = 'tests/data/ocr_toy_dataset/label.txt'
train1 = dict(
type=dataset_type,
img_prefix=img_prefix,
ann_file=train_anno_file1,
loader=dict(
type='HardDiskLoader',
repeat=100,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=train_pipeline,
test_mode=False)
train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb'
train2 = dict(
type=dataset_type,
img_prefix=img_prefix,
ann_file=train_anno_file2,
loader=dict(
type='LmdbLoader',
repeat=100,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=train_pipeline,
test_mode=False)
test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb'
test = dict(
type=dataset_type,
img_prefix=img_prefix,
ann_file=test_anno_file1,
loader=dict(
type='LmdbLoader',
repeat=1,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=test_pipeline,
test_mode=True)
data = dict(
samples_per_gpu=16,
workers_per_gpu=2,
train=dict(type='ConcatDataset', datasets=[train1, train2]),
val=dict(type='ConcatDataset', datasets=[test]),
test=dict(type='ConcatDataset', datasets=[test]))
evaluation = dict(interval=1, metric='acc')

View File

@ -0,0 +1,11 @@
label_convertor = dict(
type='CTCConvertor', dict_type='DICT90', with_unknown=False)
model = dict(
type='CRNNNet',
preprocessor=None,
backbone=dict(type='VeryDeepVgg', leakyRelu=False),
encoder=None,
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
loss=dict(type='CTCLoss', flatten=False),
label_convertor=label_convertor)

View File

@ -0,0 +1,24 @@
label_convertor = dict(
type='AttnConvertor', dict_type='DICT90', with_unknown=True)
hybrid_decoder = dict(type='SequenceAttentionDecoder')
position_decoder = dict(type='PositionAttentionDecoder')
model = dict(
type='RobustScanner',
backbone=dict(type='ResNet31OCR'),
encoder=dict(
type='ChannelReductionEncoder',
in_channels=512,
out_channels=128,
),
decoder=dict(
type='RobustScannerDecoder',
dim_input=512,
dim_model=128,
hybrid_decoder=hybrid_decoder,
position_decoder=position_decoder),
loss=dict(type='SARLoss'),
label_convertor=label_convertor,
max_seq_len=30)

View File

@ -0,0 +1,24 @@
label_convertor = dict(
type='AttnConvertor', dict_type='DICT90', with_unknown=True)
model = dict(
type='SARNet',
backbone=dict(type='ResNet31OCR'),
encoder=dict(
type='SAREncoder',
enc_bi_rnn=False,
enc_do_rnn=0.1,
enc_gru=False,
),
decoder=dict(
type='ParallelSARDecoder',
enc_bi_rnn=False,
dec_bi_rnn=False,
dec_do_rnn=0,
dec_gru=False,
pred_dropout=0.1,
d_k=512,
pred_concat=True),
loss=dict(type='SARLoss'),
label_convertor=label_convertor,
max_seq_len=30)

View File

@ -0,0 +1,11 @@
label_convertor = dict(
type='AttnConvertor', dict_type='DICT90', with_unknown=False)
model = dict(
type='TransformerNet',
backbone=dict(type='ResNet31OCR'),
encoder=dict(type='TFEncoder'),
decoder=dict(type='TFDecoder'),
loss=dict(type='TFLoss'),
label_convertor=label_convertor,
max_seq_len=40)

View File

@ -0,0 +1,65 @@
# Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition
## Introduction
[ALGORITHM]
```
@inproceedings{li2019show,
title={Show, attend and read: A simple and strong baseline for irregular text recognition},
author={Li, Hui and Wang, Peng and Shen, Chunhua and Zhang, Guyu},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={33},
number={01},
pages={8610--8617},
year={2019}
}
```
## Dataset
### Train Dataset
| trainset | instance_num | repeat_num | source |
| :--------: | :----------: | :--------: | :----------------------: |
| icdar_2011 | 3567 | 20 | real |
| icdar_2013 | 848 | 20 | real |
| icdar2015 | 4468 | 20 | real |
| coco_text | 42142 | 20 | real |
| IIIT5K | 2000 | 20 | real |
| SynthText | 2400000 | 1 | synth |
| SynthAdd | 1216889 | 1 | synth, 1.6m in [[1]](#1) |
| Syn90k | 2400000 | 1 | synth |
### Test Dataset
| testset | instance_num | type |
| :-----: | :----------: | :-------------------------: |
| IIIT5K | 3000 | regular |
| SVT | 647 | regular |
| IC13 | 1015 | regular |
| IC15 | 2077 | irregular |
| SVTP | 645 | irregular, 639 in [[1]](#1) |
| CT80 | 288 | irregular |
## Results and Models
| Methods|Backbone|Decoder||Regular Text||||Irregular Text||download|
| :-------------: | :-----: | :-----: | :-----: | :------: | :-----: | :----: | :-----: | :-----: | :-----: |:-----: |
||||IIIT5K|SVT|IC13||IC15|SVTP|CT80|
|[SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py)|R31-1/8-1/4|ParallelSARDecoder|95.0|89.6|93.7||79.0|82.2|88.9|[model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth) | [config](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_academic.py) | [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210327_154129.log.json) |
|[SAR](configs/textrecog/sar/sar_r31_sequential_decoder_academic.py)|R31-1/8-1/4|SequentialSARDecoder|95.2|88.7|92.4||78.2|81.9|89.6|[model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_sequential_decoder_academic-d06c9a8e.pth) | [config](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_sequential_decoder_academic.py) | [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210330_105728.log.json)|
**Notes:**
- `R31-1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width.
- We did not use beam search during decoding.
- We implemented two kinds of decoder. Namely, `ParallelSARDecoder` and `SequentialSARDecoder`.
- `ParallelSARDecoder`: Parallel decoding during training with `LSTM` layer. It would be faster.
- `SequentialSARDecoder`: Sequential Decoding during training with `LSTMCell`. It would be easier to understand.
- For train dataset.
- We did not construct distinct data groups (20 groups in [[1]](#1)) to train the model group-by-group since it would render model training too complicated.
- Instead, we randomly selected `2.4m` patches from `Syn90k`, `2.4m` from `SynthText` and `1.2m` from `SynthAdd`, and grouped all data together. See [config](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_academic.py) for details.
- We used 48 GPUs with `total_batch_size = 64 * 48` in the experiment above to speedup training, while keeping the `initial lr = 1e-3` unchanged.
## References
<a id="1">[1]</a> Li, Hui and Wang, Peng and Shen, Chunhua and Zhang, Guyu. Show, attend and read: A simple and strong baseline for irregular text recognition. In AAAI 2019.

View File

@ -0,0 +1,219 @@
_base_ = ['../../_base_/default_runtime.py']
label_convertor = dict(
type='AttnConvertor', dict_type='DICT90', with_unknown=True)
model = dict(
type='SARNet',
backbone=dict(type='ResNet31OCR'),
encoder=dict(
type='SAREncoder',
enc_bi_rnn=False,
enc_do_rnn=0.1,
enc_gru=False,
),
decoder=dict(
type='ParallelSARDecoder',
enc_bi_rnn=False,
dec_bi_rnn=False,
dec_do_rnn=0,
dec_gru=False,
pred_dropout=0.1,
d_k=512,
pred_concat=True),
loss=dict(type='SARLoss'),
label_convertor=label_convertor,
max_seq_len=30)
# optimizer
optimizer = dict(type='Adam', lr=1e-3)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[3, 4])
total_epochs = 5
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeOCR',
height=48,
min_width=48,
max_width=160,
keep_aspect_ratio=True,
width_downsample_ratio=0.25),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiRotateAugOCR',
rotate_degrees=[0, 90, 270],
transforms=[
dict(
type='ResizeOCR',
height=48,
min_width=48,
max_width=160,
keep_aspect_ratio=True,
width_downsample_ratio=0.25),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
]),
])
]
dataset_type = 'OCRDataset'
train_prefix = 'data/mixture/'
train_img_prefix1 = train_prefix + 'icdar_2011'
train_img_prefix2 = train_prefix + 'icdar_2013'
train_img_prefix3 = train_prefix + 'icdar_2015'
train_img_prefix4 = train_prefix + 'coco_text'
train_img_prefix5 = train_prefix + 'III5K'
train_img_prefix6 = train_prefix + 'SynthText_Add'
train_img_prefix7 = train_prefix + 'SynthText'
train_img_prefix8 = train_prefix + 'Syn90k'
train_ann_file1 = train_prefix + 'icdar_2011/train_label.txt',
train_ann_file2 = train_prefix + 'icdar_2013/train_label.txt',
train_ann_file3 = train_prefix + 'icdar_2015/train_label.txt',
train_ann_file4 = train_prefix + 'coco_text/train_label.txt',
train_ann_file5 = train_prefix + 'III5K/train_label.txt',
train_ann_file6 = train_prefix + 'SynthText_Add/label.txt',
train_ann_file7 = train_prefix + 'SynthText/shuffle_labels.txt',
train_ann_file8 = train_prefix + 'Syn90k/shuffle_labels.txt'
train1 = dict(
type=dataset_type,
img_prefix=train_img_prefix1,
ann_file=train_ann_file1,
loader=dict(
type='HardDiskLoader',
repeat=20,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=train_pipeline,
test_mode=False)
train2 = {key: value for key, value in train1.items()}
train2['img_prefix'] = train_img_prefix2
train2['ann_file'] = train_ann_file2
train3 = {key: value for key, value in train1.items()}
train3['img_prefix'] = train_img_prefix3
train3['ann_file'] = train_ann_file3
train4 = {key: value for key, value in train1.items()}
train4['img_prefix'] = train_img_prefix4
train4['ann_file'] = train_ann_file4
train5 = {key: value for key, value in train1.items()}
train5['img_prefix'] = train_img_prefix5
train5['ann_file'] = train_ann_file5
train6 = dict(
type=dataset_type,
img_prefix=train_img_prefix6,
ann_file=train_ann_file6,
loader=dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=train_pipeline,
test_mode=False)
train7 = {key: value for key, value in train6.items()}
train7['img_prefix'] = train_img_prefix7
train7['ann_file'] = train_ann_file7
train8 = {key: value for key, value in train6.items()}
train8['img_prefix'] = train_img_prefix8
train8['ann_file'] = train_ann_file8
test_prefix = 'data/mixture/'
test_img_prefix1 = test_prefix + 'IIIT5K/'
test_img_prefix2 = test_prefix + 'svt/'
test_img_prefix3 = test_prefix + 'icdar_2013/'
test_img_prefix4 = test_prefix + 'icdar_2015/'
test_img_prefix5 = test_prefix + 'svtp/'
test_img_prefix6 = test_prefix + 'ct80/'
test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
test_ann_file2 = test_prefix + 'svt/test_label.txt'
test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt'
test_ann_file5 = test_prefix + 'svtp/test_label.txt'
test_ann_file6 = test_prefix + 'ct80/test_label.txt'
test1 = dict(
type=dataset_type,
img_prefix=test_img_prefix1,
ann_file=test_ann_file1,
loader=dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=test_pipeline,
test_mode=True)
test2 = {key: value for key, value in test1.items()}
test2['img_prefix'] = test_img_prefix2
test2['ann_file'] = test_ann_file2
test3 = {key: value for key, value in test1.items()}
test3['img_prefix'] = test_img_prefix3
test3['ann_file'] = test_ann_file3
test4 = {key: value for key, value in test1.items()}
test4['img_prefix'] = test_img_prefix4
test4['ann_file'] = test_ann_file4
test5 = {key: value for key, value in test1.items()}
test5['img_prefix'] = test_img_prefix5
test5['ann_file'] = test_ann_file5
test6 = {key: value for key, value in test1.items()}
test6['img_prefix'] = test_img_prefix6
test6['ann_file'] = test_ann_file6
data = dict(
samples_per_gpu=64,
workers_per_gpu=2,
train=dict(
type='ConcatDataset',
datasets=[
train1, train2, train3, train4, train5, train6, train7, train8
]),
val=dict(
type='ConcatDataset',
datasets=[test1, test2, test3, test4, test5, test6]),
test=dict(
type='ConcatDataset',
datasets=[test1, test2, test3, test4, test5, test6]))
evaluation = dict(interval=1, metric='acc')

View File

@ -0,0 +1,110 @@
_base_ = [
'../../_base_/default_runtime.py', '../../_base_/recog_models/sar.py'
]
# optimizer
optimizer = dict(type='Adam', lr=1e-3)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[3, 4])
total_epochs = 5
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeOCR',
height=48,
min_width=48,
max_width=160,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiRotateAugOCR',
rotate_degrees=[0, 90, 270],
transforms=[
dict(
type='ResizeOCR',
height=48,
min_width=48,
max_width=160,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
]),
])
]
dataset_type = 'OCRDataset'
img_prefix = 'tests/data/ocr_toy_dataset/imgs'
train_anno_file1 = 'tests/data/ocr_toy_dataset/label.txt'
train1 = dict(
type=dataset_type,
img_prefix=img_prefix,
ann_file=train_anno_file1,
loader=dict(
type='HardDiskLoader',
repeat=100,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=train_pipeline,
test_mode=False)
train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb'
train2 = dict(
type=dataset_type,
img_prefix=img_prefix,
ann_file=train_anno_file2,
loader=dict(
type='LmdbLoader',
repeat=100,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=train_pipeline,
test_mode=False)
test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb'
test = dict(
type=dataset_type,
img_prefix=img_prefix,
ann_file=test_anno_file1,
loader=dict(
type='LmdbLoader',
repeat=1,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=test_pipeline,
test_mode=True)
data = dict(
samples_per_gpu=16,
workers_per_gpu=2,
train=dict(type='ConcatDataset', datasets=[train1, train2]),
val=dict(type='ConcatDataset', datasets=[test]),
test=dict(type='ConcatDataset', datasets=[test]))
evaluation = dict(interval=1, metric='acc')

View File

@ -0,0 +1,219 @@
_base_ = ['../../_base_/default_runtime.py']
label_convertor = dict(
type='AttnConvertor', dict_type='DICT90', with_unknown=True)
model = dict(
type='SARNet',
backbone=dict(type='ResNet31OCR'),
encoder=dict(
type='SAREncoder',
enc_bi_rnn=False,
enc_do_rnn=0.1,
enc_gru=False,
),
decoder=dict(
type='SequentialSARDecoder',
enc_bi_rnn=False,
dec_bi_rnn=False,
dec_do_rnn=0,
dec_gru=False,
pred_dropout=0.1,
d_k=512,
pred_concat=True),
loss=dict(type='SARLoss'),
label_convertor=label_convertor,
max_seq_len=30)
# optimizer
optimizer = dict(type='Adam', lr=1e-3)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[3, 4])
total_epochs = 5
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeOCR',
height=48,
min_width=48,
max_width=160,
keep_aspect_ratio=True,
width_downsample_ratio=0.25),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiRotateAugOCR',
rotate_degrees=[0, 90, 270],
transforms=[
dict(
type='ResizeOCR',
height=48,
min_width=48,
max_width=160,
keep_aspect_ratio=True,
width_downsample_ratio=0.25),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
]),
])
]
dataset_type = 'OCRDataset'
train_prefix = 'data/mixture/'
train_img_prefix1 = train_prefix + 'icdar_2011'
train_img_prefix2 = train_prefix + 'icdar_2013'
train_img_prefix3 = train_prefix + 'icdar_2015'
train_img_prefix4 = train_prefix + 'coco_text'
train_img_prefix5 = train_prefix + 'III5K'
train_img_prefix6 = train_prefix + 'SynthText_Add'
train_img_prefix7 = train_prefix + 'SynthText'
train_img_prefix8 = train_prefix + 'Syn90k'
train_ann_file1 = train_prefix + 'icdar_2011/train_label.txt',
train_ann_file2 = train_prefix + 'icdar_2013/train_label.txt',
train_ann_file3 = train_prefix + 'icdar_2015/train_label.txt',
train_ann_file4 = train_prefix + 'coco_text/train_label.txt',
train_ann_file5 = train_prefix + 'III5K/train_label.txt',
train_ann_file6 = train_prefix + 'SynthText_Add/label.txt',
train_ann_file7 = train_prefix + 'SynthText/shuffle_labels.txt',
train_ann_file8 = train_prefix + 'Syn90k/shuffle_labels.txt'
train1 = dict(
type=dataset_type,
img_prefix=train_img_prefix1,
ann_file=train_ann_file1,
loader=dict(
type='HardDiskLoader',
repeat=20,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=train_pipeline,
test_mode=False)
train2 = {key: value for key, value in train1.items()}
train2['img_prefix'] = train_img_prefix2
train2['ann_file'] = train_ann_file2
train3 = {key: value for key, value in train1.items()}
train3['img_prefix'] = train_img_prefix3
train3['ann_file'] = train_ann_file3
train4 = {key: value for key, value in train1.items()}
train4['img_prefix'] = train_img_prefix4
train4['ann_file'] = train_ann_file4
train5 = {key: value for key, value in train1.items()}
train5['img_prefix'] = train_img_prefix5
train5['ann_file'] = train_ann_file5
train6 = dict(
type=dataset_type,
img_prefix=train_img_prefix6,
ann_file=train_ann_file6,
loader=dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=train_pipeline,
test_mode=False)
train7 = {key: value for key, value in train6.items()}
train7['img_prefix'] = train_img_prefix7
train7['ann_file'] = train_ann_file7
train8 = {key: value for key, value in train6.items()}
train8['img_prefix'] = train_img_prefix8
train8['ann_file'] = train_ann_file8
test_prefix = 'data/mixture/'
test_img_prefix1 = test_prefix + 'IIIT5K/'
test_img_prefix2 = test_prefix + 'svt/'
test_img_prefix3 = test_prefix + 'icdar_2013/'
test_img_prefix4 = test_prefix + 'icdar_2015/'
test_img_prefix5 = test_prefix + 'svtp/'
test_img_prefix6 = test_prefix + 'ct80/'
test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
test_ann_file2 = test_prefix + 'svt/test_label.txt'
test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt'
test_ann_file5 = test_prefix + 'svtp/test_label.txt'
test_ann_file6 = test_prefix + 'ct80/test_label.txt'
test1 = dict(
type=dataset_type,
img_prefix=test_img_prefix1,
ann_file=test_ann_file1,
loader=dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=test_pipeline,
test_mode=True)
test2 = {key: value for key, value in test1.items()}
test2['img_prefix'] = test_img_prefix2
test2['ann_file'] = test_ann_file2
test3 = {key: value for key, value in test1.items()}
test3['img_prefix'] = test_img_prefix3
test3['ann_file'] = test_ann_file3
test4 = {key: value for key, value in test1.items()}
test4['img_prefix'] = test_img_prefix4
test4['ann_file'] = test_ann_file4
test5 = {key: value for key, value in test1.items()}
test5['img_prefix'] = test_img_prefix5
test5['ann_file'] = test_ann_file5
test6 = {key: value for key, value in test1.items()}
test6['img_prefix'] = test_img_prefix6
test6['ann_file'] = test_ann_file6
data = dict(
samples_per_gpu=64,
workers_per_gpu=2,
train=dict(
type='ConcatDataset',
datasets=[
train1, train2, train3, train4, train5, train6, train7, train8
]),
val=dict(
type='ConcatDataset',
datasets=[test1, test2, test3, test4, test5, test6]),
test=dict(
type='ConcatDataset',
datasets=[test1, test2, test3, test4, test5, test6]))
evaluation = dict(interval=1, metric='acc')

View File

@ -0,0 +1,36 @@
# Baseline of segmentation based text recognition method.
## Introduction
A Baseline Method for Segmentation based Text Recognition.
[ALGORITHM]
## Dataset
### Train Dataset
| trainset | instance_num | repeat_num | source |
| :-------: | :----------: | :--------: | :----: |
| SynthText | 7266686 | 1 | synth |
### Test Dataset
| testset | instance_num | type |
| :-----: | :----------: | :-------: |
| IIIT5K | 3000 | regular |
| SVT | 647 | regular |
| IC13 | 1015 | regular |
| CT80 | 288 | irregular |
## Results and Models
|Backbone|Neck|Head|||Regular Text|||Irregular Text|base_lr|batch_size/gpu|gpus|download
| :-------------: | :-----: | :-----: | :------: | :-----: | :----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: |
|||||IIIT5K|SVT|IC13||CT80|
|R31-1/16|FPNOCR|1x||90.9|81.8|90.7||80.9|1e-4|16|4|[model](https://download.openmmlab.com/mmocr/textrecog/seg/seg_r31_1by16_fpnocr_academic-0c50e163.pth) &#124; [config](https://download.openmmlab.com/mmocr/textrecog/seg/seg_r31_1by16_fpnocr_academic.py) &#124; [log](https://download.openmmlab.com/mmocr/textrecog/seg/20210325_112835.log.json) |
**Notes:**
- `R31-1/16` means the size (both height and width ) of feature from backbone is 1/16 of input image.
- `1x` means the size (both height and width) of feature from head is the same with input image.

View File

@ -0,0 +1,160 @@
_base_ = ['../../_base_/default_runtime.py']
# optimizer
optimizer = dict(type='Adam', lr=1e-4)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[3, 4])
total_epochs = 5
label_convertor = dict(
type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True)
model = dict(
type='SegRecognizer',
backbone=dict(
type='ResNet31OCR',
layers=[1, 2, 5, 3],
channels=[32, 64, 128, 256, 512, 512],
out_indices=[0, 1, 2, 3],
stage4_pool_cfg=dict(kernel_size=2, stride=2),
last_stage_pool=True),
neck=dict(
type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256),
head=dict(
type='SegHead',
in_channels=256,
upsample_param=dict(scale_factor=2.0, mode='nearest')),
loss=dict(
type='SegLoss', seg_downsample_ratio=1.0, seg_with_loss_weight=True),
label_convertor=label_convertor)
find_unused_parameters = True
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
gt_label_convertor = dict(
type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomPaddingOCR',
max_ratio=[0.15, 0.2, 0.15, 0.2],
box_type='char_quads'),
dict(type='OpencvToPil'),
dict(
type='RandomRotateImageBox',
min_angle=-17,
max_angle=17,
box_type='char_quads'),
dict(type='PilToOpencv'),
dict(
type='ResizeOCR',
height=64,
min_width=64,
max_width=512,
keep_aspect_ratio=True),
dict(
type='OCRSegTargets',
label_convertor=gt_label_convertor,
box_type='char_quads'),
dict(type='RandomRotateTextDet', rotate_ratio=0.5, max_angle=15),
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
dict(type='ToTensorOCR'),
dict(type='FancyPCA'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='CustomFormatBundle',
keys=['gt_kernels'],
visualize=dict(flag=False, boundary_key=None),
call_super=False),
dict(
type='Collect',
keys=['img', 'gt_kernels'],
meta_keys=['filename', 'ori_shape', 'img_shape'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeOCR',
height=64,
min_width=64,
max_width=None,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(type='CustomFormatBundle', call_super=False),
dict(
type='Collect',
keys=['img'],
meta_keys=['filename', 'ori_shape', 'img_shape'])
]
train_img_root = 'data/mixture/'
train_img_prefix = train_img_root + 'SynthText'
train_ann_file = train_img_root + 'SynthText/instances_train.txt'
train = dict(
type='OCRSegDataset',
img_prefix=train_img_prefix,
ann_file=train_ann_file,
loader=dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineJsonParser', keys=['file_name', 'annotations', 'text'])),
pipeline=train_pipeline,
test_mode=False)
dataset_type = 'OCRDataset'
test_prefix = 'data/mixture/'
test_img_prefix1 = test_prefix + 'IIIT5K/'
test_img_prefix2 = test_prefix + 'svt/'
test_img_prefix3 = test_prefix + 'icdar_2013/'
test_img_prefix4 = test_prefix + 'ct80/'
test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
test_ann_file2 = test_prefix + 'svt/test_label.txt'
test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
test_ann_file4 = test_prefix + 'ct80/test_label.txt'
test1 = dict(
type=dataset_type,
img_prefix=test_img_prefix1,
ann_file=test_ann_file1,
loader=dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=test_pipeline,
test_mode=True)
test2 = {key: value for key, value in test1.items()}
test2['img_prefix'] = test_img_prefix2
test2['ann_file'] = test_ann_file2
test3 = {key: value for key, value in test1.items()}
test3['img_prefix'] = test_img_prefix3
test3['ann_file'] = test_ann_file3
test4 = {key: value for key, value in test1.items()}
test4['img_prefix'] = test_img_prefix4
test4['ann_file'] = test_ann_file4
data = dict(
samples_per_gpu=16,
workers_per_gpu=2,
train=dict(type='ConcatDataset', datasets=[train]),
val=dict(type='ConcatDataset', datasets=[test1, test2, test3, test4]),
test=dict(type='ConcatDataset', datasets=[test1, test2, test3, test4]))
evaluation = dict(interval=1, metric='acc')

View File

@ -0,0 +1,35 @@
_base_ = [
'../../_base_/default_runtime.py',
'../../_base_/recog_datasets/seg_toy_dataset.py'
]
# optimizer
optimizer = dict(type='Adam', lr=1e-4)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[3, 4])
total_epochs = 5
label_convertor = dict(
type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True)
model = dict(
type='SegRecognizer',
backbone=dict(
type='ResNet31OCR',
layers=[1, 2, 5, 3],
channels=[32, 64, 128, 256, 512, 512],
out_indices=[0, 1, 2, 3],
stage4_pool_cfg=dict(kernel_size=2, stride=2),
last_stage_pool=True),
neck=dict(
type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256),
head=dict(
type='SegHead',
in_channels=256,
upsample_param=dict(scale_factor=2.0, mode='nearest')),
loss=dict(
type='SegLoss', seg_downsample_ratio=1.0, seg_with_loss_weight=False),
label_convertor=label_convertor)
find_unused_parameters = True

View File

@ -0,0 +1,30 @@
## Introduction
### Train Dataset
| trainset | instance_num | repeat_num | note |
| :--------: | :----------: | :--------: | :---: |
| icdar_2011 | 3567 | 20 | real |
| icdar_2013 | 848 | 20 | real |
| icdar2015 | 4468 | 20 | real |
| coco_text | 42142 | 20 | real |
| IIIT5K | 2000 | 20 | real |
| SynthText | 2400000 | 1 | synth |
### Test Dataset
| testset | instance_num | note |
| :-----: | :----------: | :-------------------------: |
| IIIT5K | 3000 | regular |
| SVT | 647 | regular |
| IC13 | 1015 | regular |
| IC15 | 2077 | irregular |
| SVTP | 645 | irregular, 639 in [[1]](#1) |
| CT80 | 288 | irregular |
## Results and models
| methods | | Regular Text | | | | Irregular Text | | download |
| :---------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :------: |
| | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
| Transformer | 93.3 | 85.8 | 91.3 | | 73.2 | 76.6 | 87.8 | |

View File

@ -0,0 +1,12 @@
_base_ = [
'../../_base_/default_runtime.py',
'../../_base_/recog_models/transformer.py',
'../../_base_/recog_datasets/toy_dataset.py'
]
# optimizer
optimizer = dict(type='Adadelta', lr=1)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[3, 4])
total_epochs = 5

View File

@ -0,0 +1,3 @@
from .version import __version__, short_version
__all__ = ['__version__', 'short_version']

View File

@ -0,0 +1,133 @@
import re
from difflib import SequenceMatcher
import Levenshtein
def cal_true_positive_char(pred, gt):
"""Calculate correct character number in prediction.
Args:
pred (str): Prediction text.
gt (str): Ground truth text.
Returns:
true_positive_char_num (int): The true positive number.
"""
all_opt = SequenceMatcher(None, pred, gt)
true_positive_char_num = 0
for opt, _, _, s2, e2 in all_opt.get_opcodes():
if opt == 'equal':
true_positive_char_num += (e2 - s2)
else:
pass
return true_positive_char_num
def count_matches(pred_texts, gt_texts):
"""Count the various match number for metric calculation.
Args:
pred_texts (list[str]): Predicted text string.
gt_texts (list[str]): Ground truth text string.
Returns:
match_res: (dict[str: int]): Match number used for
metric calculation.
"""
match_res = {
'gt_char_num': 0,
'pred_char_num': 0,
'true_positive_char_num': 0,
'gt_word_num': 0,
'match_word_num': 0,
'match_word_ignore_case': 0,
'match_word_ignore_case_symbol': 0
}
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
norm_ed_sum = 0.0
for pred_text, gt_text in zip(pred_texts, gt_texts):
if gt_text == pred_text:
match_res['match_word_num'] += 1
gt_text_lower = gt_text.lower()
pred_text_lower = pred_text.lower()
if gt_text_lower == pred_text_lower:
match_res['match_word_ignore_case'] += 1
gt_text_lower_ignore = comp.sub('', gt_text_lower)
pred_text_lower_ignore = comp.sub('', pred_text_lower)
if gt_text_lower_ignore == pred_text_lower_ignore:
match_res['match_word_ignore_case_symbol'] += 1
match_res['gt_word_num'] += 1
# normalized edit distance
edit_dist = Levenshtein.distance(pred_text_lower_ignore,
gt_text_lower_ignore)
norm_ed = float(edit_dist) / max(1, len(gt_text_lower_ignore),
len(pred_text_lower_ignore))
norm_ed_sum += norm_ed
# number to calculate char level recall & precision
match_res['gt_char_num'] += len(gt_text_lower_ignore)
match_res['pred_char_num'] += len(pred_text_lower_ignore)
true_positive_char_num = cal_true_positive_char(
pred_text_lower_ignore, gt_text_lower_ignore)
match_res['true_positive_char_num'] += true_positive_char_num
normalized_edit_distance = norm_ed_sum / max(1, len(gt_texts))
match_res['ned'] = normalized_edit_distance
return match_res
def eval_ocr_metric(pred_texts, gt_texts):
"""Evaluate the text recognition performance with metric: word accuracy and
1-N.E.D. See https://rrc.cvc.uab.es/?ch=14&com=tasks for details.
Args:
pred_texts (list[str]): Text strings of prediction.
gt_texts (list[str]): Text strings of ground truth.
Returns:
eval_res (dict[str: float]): Metric dict for text recognition, include:
- word_acc: Accuracy in word level.
- word_acc_ignore_case: Accuracy in word level, ignore letter case.
- word_acc_ignore_case_symbol: Accuracy in word level, ignore
letter case and symbol. (default metric for
academic evaluation)
- char_recall: Recall in character level, ignore
letter case and symbol.
- char_precision: Precision in character level, ignore
letter case and symbol.
- 1-N.E.D: 1 - normalized_edit_distance.
"""
assert isinstance(pred_texts, list)
assert isinstance(gt_texts, list)
assert len(pred_texts) == len(gt_texts)
match_res = count_matches(pred_texts, gt_texts)
eps = 1e-8
char_recall = 1.0 * match_res['true_positive_char_num'] / (
eps + match_res['gt_char_num'])
char_precision = 1.0 * match_res['true_positive_char_num'] / (
eps + match_res['pred_char_num'])
word_acc = 1.0 * match_res['match_word_num'] / (
eps + match_res['gt_word_num'])
word_acc_ignore_case = 1.0 * match_res['match_word_ignore_case'] / (
eps + match_res['gt_word_num'])
word_acc_ignore_case_symbol = 1.0 * match_res[
'match_word_ignore_case_symbol'] / (
eps + match_res['gt_word_num'])
eval_res = {}
eval_res['word_acc'] = word_acc
eval_res['word_acc_ignore_case'] = word_acc_ignore_case
eval_res['word_acc_ignore_case_symbol'] = word_acc_ignore_case_symbol
eval_res['char_recall'] = char_recall
eval_res['char_precision'] = char_precision
eval_res['1-N.E.D'] = 1.0 - match_res['ned']
for key, value in eval_res.items():
eval_res[key] = float('{:.4f}'.format(value))
return eval_res

View File

@ -0,0 +1,15 @@
from mmdet.datasets.builder import DATASETS, build_dataloader, build_dataset
from .base_dataset import BaseDataset
from .icdar_dataset import IcdarDataset
from .kie_dataset import KIEDataset
from .ocr_dataset import OCRDataset
from .ocr_seg_dataset import OCRSegDataset
from .pipelines import CustomFormatBundle, DBNetTargets, DRRGTargets
from .text_det_dataset import TextDetDataset
from .utils import * # noqa: F401,F403
__all__ = [
'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset',
'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle',
'DBNetTargets', 'OCRSegDataset', 'DRRGTargets', 'KIEDataset'
]

View File

@ -0,0 +1,166 @@
import numpy as np
from mmcv.utils import print_log
from torch.utils.data import Dataset
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.pipelines import Compose
from mmocr.datasets.builder import build_loader
@DATASETS.register_module()
class BaseDataset(Dataset):
"""Custom dataset for text detection, text recognition, and their
downstream tasks.
1. The text detection annotation format is as follows:
The `annotations` field is optional for testing
(this is one line of anno_file, with line-json-str
converted to dict for visualizing only).
{
"file_name": "sample.jpg",
"height": 1080,
"width": 960,
"annotations":
[
{
"iscrowd": 0,
"category_id": 1,
"bbox": [357.0, 667.0, 804.0, 100.0],
"segmentation": [[361, 667, 710, 670,
72, 767, 357, 763]]
}
]
}
2. The two text recognition annotation formats are as follows:
The `x1,y1,x2,y2,x3,y3,x4,y4` field is used for online crop
augmentation during training.
format1: sample.jpg hello
format2: sample.jpg 20 20 100 20 100 40 20 40 hello
Args:
ann_file (str): Annotation file path.
pipeline (list[dict]): Processing pipeline.
loader (dict): Dictionary to construct loader
to load annotation infos.
img_prefix (str, optional): Image prefix to generate full
image path.
test_mode (bool, optional): If set True, try...except will
be turned off in __getitem__.
"""
def __init__(self,
ann_file,
loader,
pipeline,
img_prefix='',
test_mode=False):
super().__init__()
self.test_mode = test_mode
self.img_prefix = img_prefix
self.ann_file = ann_file
# load annotations
loader.update(ann_file=ann_file)
self.data_infos = build_loader(loader)
# processing pipeline
self.pipeline = Compose(pipeline)
# set group flag and class, no meaning
# for text detect and recognize
self._set_group_flag()
self.CLASSES = 0
def __len__(self):
return len(self.data_infos)
def _set_group_flag(self):
"""Set flag."""
self.flag = np.zeros(len(self), dtype=np.uint8)
def pre_pipeline(self, results):
"""Prepare results dict for pipeline."""
results['img_prefix'] = self.img_prefix
def prepare_train_img(self, index):
"""Get training data and annotations from pipeline.
Args:
index (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys \
introduced by pipeline.
"""
img_info = self.data_infos[index]
results = dict(img_info=img_info)
self.pre_pipeline(results)
return self.pipeline(results)
def prepare_test_img(self, img_info):
"""Get testing data from pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Testing data after pipeline with new keys introduced by \
pipeline.
"""
return self.prepare_train_img(img_info)
def _log_error_index(self, index):
"""Logging data info of bad index."""
try:
data_info = self.data_infos[index]
img_prefix = self.img_prefix
print_log(f'Warning: skip broken file {data_info} '
f'with img_prefix {img_prefix}')
except Exception as e:
print_log(f'load index {index} with error {e}')
def _get_next_index(self, index):
"""Get next index from dataset."""
self._log_error_index(index)
index = (index + 1) % len(self)
return index
def __getitem__(self, index):
"""Get training/test data from pipeline.
Args:
index (int): Index of data.
Returns:
dict: Training/test data.
"""
if self.test_mode:
return self.prepare_test_img(index)
while True:
try:
data = self.prepare_train_img(index)
if data is None:
raise Exception('prepared train data empty')
break
except Exception as e:
print_log(f'prepare index {index} with error {e}')
index = self._get_next_index(index)
return data
def format_results(self, results, **kwargs):
"""Placeholder to format result to dataset-specific output."""
pass
def evaluate(self, results, metric=None, logger=None, **kwargs):
"""Evaluate the dataset.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
Returns:
dict[str: float]
"""
raise NotImplementedError

View File

@ -0,0 +1,14 @@
from mmcv.utils import Registry, build_from_cfg
LOADERS = Registry('loader')
PARSERS = Registry('parser')
def build_loader(cfg):
"""Build anno file loader."""
return build_from_cfg(cfg, LOADERS)
def build_parser(cfg):
"""Build anno file parser."""
return build_from_cfg(cfg, PARSERS)

View File

@ -0,0 +1,34 @@
from mmdet.datasets.builder import DATASETS
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
from mmocr.datasets.base_dataset import BaseDataset
@DATASETS.register_module()
class OCRDataset(BaseDataset):
def pre_pipeline(self, results):
results['img_prefix'] = self.img_prefix
results['text'] = results['img_info']['text']
def evaluate(self, results, metric='acc', logger=None, **kwargs):
"""Evaluate the dataset.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
Returns:
dict[str: float]
"""
gt_texts = []
pred_texts = []
for i in range(len(self)):
item_info = self.data_infos[i]
text = item_info['text']
gt_texts.append(text)
pred_texts.append(results[i]['text'])
eval_results = eval_ocr_metric(pred_texts, gt_texts)
return eval_results

View File

@ -0,0 +1,90 @@
import mmocr.utils as utils
from mmdet.datasets.builder import DATASETS
from mmocr.datasets.ocr_dataset import OCRDataset
@DATASETS.register_module()
class OCRSegDataset(OCRDataset):
def pre_pipeline(self, results):
results['img_prefix'] = self.img_prefix
def _parse_anno_info(self, annotations):
"""Parse char boxes annotations.
Args:
annotations (list[dict]): Annotations of one image, where
each dict is for one character.
Returns:
dict: A dict containing the following keys:
- chars (list[str]): List of character strings.
- char_rects (list[list[float]]): List of char box, with each
in style of rectangle: [x_min, y_min, x_max, y_max].
- char_quads (list[list[float]]): List of char box, with each
in style of quadrangle: [x1, y1, x2, y2, x3, y3, x4, y4].
"""
assert utils.is_type_list(annotations, dict)
assert 'char_box' in annotations[0]
assert 'char_text' in annotations[0]
assert len(annotations[0]['char_box']) == 4 or \
len(annotations[0]['char_box']) == 8
chars, char_rects, char_quads = [], [], []
for ann in annotations:
char_box = ann['char_box']
if len(char_box) == 4:
char_box_type = ann.get('char_box_type', 'xyxy')
if char_box_type == 'xyxy':
char_rects.append(char_box)
char_quads.append([
char_box[0], char_box[1], char_box[2], char_box[1],
char_box[2], char_box[3], char_box[0], char_box[3]
])
elif char_box_type == 'xywh':
x1, y1, w, h = char_box
x2 = x1 + w
y2 = y1 + h
char_rects.append([x1, y1, x2, y2])
char_quads.append([x1, y1, x2, y1, x2, y2, x1, y2])
else:
raise ValueError(f'invalid char_box_type {char_box_type}')
elif len(char_box) == 8:
x_list, y_list = [], []
for i in range(4):
x_list.append(char_box[2 * i])
y_list.append(char_box[2 * i + 1])
x_max, x_min = max(x_list), min(x_list)
y_max, y_min = max(y_list), min(y_list)
char_rects.append([x_min, y_min, x_max, y_max])
char_quads.append(char_box)
else:
raise Exception(
f'invalid num in char box: {len(char_box)} not in (4, 8)')
chars.append(ann['char_text'])
ann = dict(chars=chars, char_rects=char_rects, char_quads=char_quads)
return ann
def prepare_train_img(self, index):
"""Get training data and annotations from pipeline.
Args:
index (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys \
introduced by pipeline.
"""
img_ann_info = self.data_infos[index]
img_info = {
'filename': img_ann_info['file_name'],
}
ann_info = self._parse_anno_info(img_ann_info['annotations'])
results = dict(img_info=img_info, ann_info=ann_info)
self.pre_pipeline(results)
return self.pipeline(results)

View File

@ -0,0 +1,185 @@
import cv2
import numpy as np
from shapely.geometry import LineString, Point, Polygon
import mmocr.utils as utils
def sort_vertex(points_x, points_y):
"""Sort box vertices in clockwise order from left-top first.
Args:
points_x (list[float]): x of four vertices.
points_y (list[float]): y of four vertices.
Returns:
sorted_points_x (list[float]): x of sorted four vertices.
sorted_points_y (list[float]): y of sorted four vertices.
"""
assert utils.is_type_list(points_x, float) or utils.is_type_list(
points_x, int)
assert utils.is_type_list(points_y, float) or utils.is_type_list(
points_y, int)
assert len(points_x) == 4
assert len(points_y) == 4
x = np.array(points_x)
y = np.array(points_y)
center_x = np.sum(x) * 0.25
center_y = np.sum(y) * 0.25
x_arr = np.array(x - center_x)
y_arr = np.array(y - center_y)
angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi
sort_idx = np.argsort(angle)
sorted_points_x, sorted_points_y = [], []
for i in range(4):
sorted_points_x.append(points_x[sort_idx[i]])
sorted_points_y.append(points_y[sort_idx[i]])
return convert_canonical(sorted_points_x, sorted_points_y)
def convert_canonical(points_x, points_y):
"""Make left-top be first.
Args:
points_x (list[float]): x of four vertices.
points_y (list[float]): y of four vertices.
Returns:
sorted_points_x (list[float]): x of sorted four vertices.
sorted_points_y (list[float]): y of sorted four vertices.
"""
assert utils.is_type_list(points_x, float) or utils.is_type_list(
points_x, int)
assert utils.is_type_list(points_y, float) or utils.is_type_list(
points_y, int)
assert len(points_x) == 4
assert len(points_y) == 4
points = [Point(points_x[i], points_y[i]) for i in range(4)]
polygon = Polygon([(p.x, p.y) for p in points])
min_x, min_y, _, _ = polygon.bounds
points_to_lefttop = [
LineString([points[i], Point(min_x, min_y)]) for i in range(4)
]
distances = np.array([line.length for line in points_to_lefttop])
sort_dist_idx = np.argsort(distances)
lefttop_idx = sort_dist_idx[0]
if lefttop_idx == 0:
point_orders = [0, 1, 2, 3]
elif lefttop_idx == 1:
point_orders = [1, 2, 3, 0]
elif lefttop_idx == 2:
point_orders = [2, 3, 0, 1]
else:
point_orders = [3, 0, 1, 2]
sorted_points_x = [points_x[i] for i in point_orders]
sorted_points_y = [points_y[j] for j in point_orders]
return sorted_points_x, sorted_points_y
def box_jitter(points_x, points_y, jitter_ratio_x=0.5, jitter_ratio_y=0.1):
"""Jitter on the coordinates of bounding box.
Args:
points_x (list[float | int]): List of y for four vertices.
points_y (list[float | int]): List of x for four vertices.
jitter_ratio_x (float): Horizontal jitter ratio relative to the height.
jitter_ratio_y (float): Vertical jitter ratio relative to the height.
"""
assert len(points_x) == 4
assert len(points_y) == 4
assert isinstance(jitter_ratio_x, float)
assert isinstance(jitter_ratio_y, float)
assert 0 <= jitter_ratio_x < 1
assert 0 <= jitter_ratio_y < 1
points = [Point(points_x[i], points_y[i]) for i in range(4)]
line_list = [
LineString([points[i], points[i + 1 if i < 3 else 0]])
for i in range(4)
]
tmp_h = max(line_list[1].length, line_list[3].length)
for i in range(4):
jitter_pixel_x = (np.random.rand() - 0.5) * 2 * jitter_ratio_x * tmp_h
jitter_pixel_y = (np.random.rand() - 0.5) * 2 * jitter_ratio_y * tmp_h
points_x[i] += jitter_pixel_x
points_y[i] += jitter_pixel_y
def warp_img(src_img,
box,
jitter_flag=False,
jitter_ratio_x=0.5,
jitter_ratio_y=0.1):
"""Crop box area from image using opencv warpPerspective w/o box jitter.
Args:
src_img (np.array): Image before cropping.
box (list[float | int]): Coordinates of quadrangle.
"""
assert utils.is_type_list(box, float) or utils.is_type_list(box, int)
assert len(box) == 8
h, w = src_img.shape[:2]
points_x = [min(max(x, 0), w) for x in box[0:8:2]]
points_y = [min(max(y, 0), h) for y in box[1:9:2]]
points_x, points_y = sort_vertex(points_x, points_y)
if jitter_flag:
box_jitter(
points_x,
points_y,
jitter_ratio_x=jitter_ratio_x,
jitter_ratio_y=jitter_ratio_y)
points = [Point(points_x[i], points_y[i]) for i in range(4)]
edges = [
LineString([points[i], points[i + 1 if i < 3 else 0]])
for i in range(4)
]
pts1 = np.float32([[points[i].x, points[i].y] for i in range(4)])
box_width = max(edges[0].length, edges[2].length)
box_height = max(edges[1].length, edges[3].length)
pts2 = np.float32([[0, 0], [box_width, 0], [box_width, box_height],
[0, box_height]])
M = cv2.getPerspectiveTransform(pts1, pts2)
dst_img = cv2.warpPerspective(src_img, M,
(int(box_width), int(box_height)))
return dst_img
def crop_img(src_img, box):
"""Crop box area to rectangle.
Args:
src_img (np.array): Image before crop.
box (list[float | int]): Points of quadrangle.
"""
assert utils.is_type_list(box, float) or utils.is_type_list(box, int)
assert len(box) == 8
h, w = src_img.shape[:2]
points_x = [min(max(x, 0), w) for x in box[0:8:2]]
points_y = [min(max(y, 0), h) for y in box[1:9:2]]
left = int(min(points_x))
top = int(min(points_y))
right = int(max(points_x))
bottom = int(max(points_y))
dst_img = src_img[top:bottom, left:right]
return dst_img

View File

@ -0,0 +1,201 @@
import cv2
import numpy as np
import mmocr.utils.check_argument as check_argument
from mmdet.core import BitmapMasks
from mmdet.datasets.builder import PIPELINES
from mmocr.models.builder import build_convertor
@PIPELINES.register_module()
class OCRSegTargets:
"""Generate gt shrinked kernels for segmentation based OCR framework.
Args:
label_convertor (dict): Dictionary to construct label_convertor
to convert char to index.
attn_shrink_ratio (float): The area shrinked ratio
between attention kernels and gt text masks.
seg_shrink_ratio (float): The area shrinked ratio
between segmentation kernels and gt text masks.
box_type (str): Character box type, should be either
'char_rects' or 'char_quads', with 'char_rects'
for rectangle with ``xyxy`` style and 'char_quads'
for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
"""
def __init__(self,
label_convertor=None,
attn_shrink_ratio=0.5,
seg_shrink_ratio=0.25,
box_type='char_rects',
pad_val=255):
assert isinstance(attn_shrink_ratio, float)
assert isinstance(seg_shrink_ratio, float)
assert 0. < attn_shrink_ratio < 1.0
assert 0. < seg_shrink_ratio < 1.0
assert label_convertor is not None
assert box_type in ('char_rects', 'char_quads')
self.attn_shrink_ratio = attn_shrink_ratio
self.seg_shrink_ratio = seg_shrink_ratio
self.label_convertor = build_convertor(label_convertor)
self.box_type = box_type
self.pad_val = pad_val
def shrink_char_quad(self, char_quad, shrink_ratio):
"""Shrink char box in style of quadrangle.
Args:
char_quad (list[float]): Char box with format
[x1, y1, x2, y2, x3, y3, x4, y4].
shrink_ratio (float): The area shrinked ratio
between gt kernels and gt text masks.
"""
points = [[char_quad[0], char_quad[1]], [char_quad[2], char_quad[3]],
[char_quad[4], char_quad[5]], [char_quad[6], char_quad[7]]]
shrink_points = []
for p_idx, point in enumerate(points):
p1 = points[(p_idx + 3) % 4]
p2 = points[(p_idx + 1) % 4]
dist1 = self.l2_dist_two_points(p1, point)
dist2 = self.l2_dist_two_points(p2, point)
min_dist = min(dist1, dist2)
v1 = [p1[0] - point[0], p1[1] - point[1]]
v2 = [p2[0] - point[0], p2[1] - point[1]]
temp_dist1 = (shrink_ratio * min_dist /
dist1) if min_dist != 0 else 0.
temp_dist2 = (shrink_ratio * min_dist /
dist2) if min_dist != 0 else 0.
v1 = [temp * temp_dist1 for temp in v1]
v2 = [temp * temp_dist2 for temp in v2]
shrink_point = [
round(point[0] + v1[0] + v2[0]),
round(point[1] + v1[1] + v2[1])
]
shrink_points.append(shrink_point)
poly = np.array(shrink_points)
return poly
def shrink_char_rect(self, char_rect, shrink_ratio):
"""Shrink char box in style of rectangle.
Args:
char_rect (list[float]): Char box with format
[x_min, y_min, x_max, y_max].
shrink_ratio (float): The area shrinked ratio
between gt kernels and gt text masks.
"""
x_min, y_min, x_max, y_max = char_rect
w = x_max - x_min
h = y_max - y_min
x_min_s = round((x_min + x_max - w * shrink_ratio) / 2)
y_min_s = round((y_min + y_max - h * shrink_ratio) / 2)
x_max_s = round((x_min + x_max + w * shrink_ratio) / 2)
y_max_s = round((y_min + y_max + h * shrink_ratio) / 2)
poly = np.array([[x_min_s, y_min_s], [x_max_s, y_min_s],
[x_max_s, y_max_s], [x_min_s, y_max_s]])
return poly
def generate_kernels(self,
resize_shape,
pad_shape,
char_boxes,
char_inds,
shrink_ratio=0.5,
binary=True):
"""Generate char instance kernels for one shrink ratio.
Args:
resize_shape (tuple(int, int)): Image size (height, width)
after resizing.
pad_shape (tuple(int, int)): Image size (height, width)
after padding.
char_boxes (list[list[float]]): The list of char polygons.
char_inds (list[int]): List of char indexes.
shrink_ratio (float): The shrink ratio of kernel.
binary (bool): If True, return binary ndarray
containing 0 & 1 only.
Returns:
char_kernel (ndarray): The text kernel mask of (height, width).
"""
assert isinstance(resize_shape, tuple)
assert isinstance(pad_shape, tuple)
assert check_argument.is_2dlist(char_boxes)
assert check_argument.is_type_list(char_inds, int)
assert isinstance(shrink_ratio, float)
assert isinstance(binary, bool)
char_kernel = np.zeros(pad_shape, dtype=np.int32)
char_kernel[:resize_shape[0], resize_shape[1]:] = self.pad_val
for i, char_box in enumerate(char_boxes):
if self.box_type == 'char_rects':
poly = self.shrink_char_rect(char_box, shrink_ratio)
elif self.box_type == 'char_quads':
poly = self.shrink_char_quad(char_box, shrink_ratio)
fill_value = 1 if binary else char_inds[i]
cv2.fillConvexPoly(char_kernel, poly.astype(np.int32),
(fill_value))
return char_kernel
def l2_dist_two_points(self, p1, p2):
return ((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5
def __call__(self, results):
img_shape = results['img_shape']
resize_shape = results['resize_shape']
h_scale = 1.0 * resize_shape[0] / img_shape[0]
w_scale = 1.0 * resize_shape[1] / img_shape[1]
char_boxes, char_inds = [], []
char_num = len(results['ann_info'][self.box_type])
for i in range(char_num):
char_box = results['ann_info'][self.box_type][i]
num_points = 2 if self.box_type == 'char_rects' else 4
for j in range(num_points):
char_box[j * 2] = round(char_box[j * 2] * w_scale)
char_box[j * 2 + 1] = round(char_box[j * 2 + 1] * h_scale)
char_boxes.append(char_box)
char = results['ann_info']['chars'][i]
char_ind = self.label_convertor.str2idx([char])[0][0]
char_inds.append(char_ind)
resize_shape = tuple(results['resize_shape'][:2])
pad_shape = tuple(results['pad_shape'][:2])
binary_target = self.generate_kernels(
resize_shape,
pad_shape,
char_boxes,
char_inds,
shrink_ratio=self.attn_shrink_ratio,
binary=True)
seg_target = self.generate_kernels(
resize_shape,
pad_shape,
char_boxes,
char_inds,
shrink_ratio=self.seg_shrink_ratio,
binary=False)
mask = np.ones(pad_shape, dtype=np.int32)
mask[:resize_shape[0], resize_shape[1]:] = 0
results['gt_kernels'] = BitmapMasks([binary_target, seg_target, mask],
pad_shape[0], pad_shape[1])
results['mask_fields'] = ['gt_kernels']
return results

View File

@ -0,0 +1,447 @@
import math
import cv2
import mmcv
import numpy as np
import torch
import torchvision.transforms.functional as TF
from mmcv.runner.dist_utils import get_dist_info
from PIL import Image
from shapely.geometry import Polygon
from shapely.geometry import box as shapely_box
import mmocr.utils as utils
from mmdet.datasets.builder import PIPELINES
from mmocr.datasets.pipelines.crop import warp_img
@PIPELINES.register_module()
class ResizeOCR:
"""Image resizing and padding for OCR.
Args:
height (int | tuple(int)): Image height after resizing.
min_width (none | int | tuple(int)): Image minimum width
after resizing.
max_width (none | int | tuple(int)): Image maximum width
after resizing.
keep_aspect_ratio (bool): Keep image aspect ratio if True
during resizing, Otherwise resize to the size height *
max_width.
img_pad_value (int): Scalar to fill padding area.
width_downsample_ratio (float): Downsample ratio in horizontal
direction from input image to output feature.
"""
def __init__(self,
height,
min_width=None,
max_width=None,
keep_aspect_ratio=True,
img_pad_value=0,
width_downsample_ratio=1.0 / 16):
assert isinstance(height, (int, tuple))
assert utils.is_none_or_type(min_width, (int, tuple))
assert utils.is_none_or_type(max_width, (int, tuple))
if not keep_aspect_ratio:
assert max_width is not None, \
'"max_width" must assigned ' + \
'if "keep_aspect_ratio" is False'
assert isinstance(img_pad_value, int)
if isinstance(height, tuple):
assert isinstance(min_width, tuple)
assert isinstance(max_width, tuple)
assert len(height) == len(min_width) == len(max_width)
self.height = height
self.min_width = min_width
self.max_width = max_width
self.keep_aspect_ratio = keep_aspect_ratio
self.img_pad_value = img_pad_value
self.width_downsample_ratio = width_downsample_ratio
def __call__(self, results):
rank, _ = get_dist_info()
if isinstance(self.height, int):
dst_height = self.height
dst_min_width = self.min_width
dst_max_width = self.max_width
else:
"""Multi-scale resize used in distributed training.
Choose one (height, width) pair for one rank id.
"""
idx = rank % len(self.height)
dst_height = self.height[idx]
dst_min_width = self.min_width[idx]
dst_max_width = self.max_width[idx]
img_shape = results['img_shape']
ori_height, ori_width = img_shape[:2]
valid_ratio = 1.0
resize_shape = list(img_shape)
pad_shape = list(img_shape)
if self.keep_aspect_ratio:
new_width = math.ceil(float(dst_height) / ori_height * ori_width)
width_divisor = int(1 / self.width_downsample_ratio)
# make sure new_width is an integral multiple of width_divisor.
if new_width % width_divisor != 0:
new_width = round(new_width / width_divisor) * width_divisor
if dst_min_width is not None:
new_width = max(dst_min_width, new_width)
if dst_max_width is not None:
valid_ratio = min(1.0, 1.0 * new_width / dst_max_width)
resize_width = min(dst_max_width, new_width)
img_resize = cv2.resize(results['img'],
(resize_width, dst_height))
resize_shape = img_resize.shape
pad_shape = img_resize.shape
if new_width < dst_max_width:
img_resize = mmcv.impad(
img_resize,
shape=(dst_height, dst_max_width),
pad_val=self.img_pad_value)
pad_shape = img_resize.shape
else:
img_resize = cv2.resize(results['img'],
(new_width, dst_height))
resize_shape = img_resize.shape
pad_shape = img_resize.shape
else:
img_resize = cv2.resize(results['img'],
(dst_max_width, dst_height))
resize_shape = img_resize.shape
pad_shape = img_resize.shape
results['img'] = img_resize
results['resize_shape'] = resize_shape
results['pad_shape'] = pad_shape
results['valid_ratio'] = valid_ratio
return results
@PIPELINES.register_module()
class ToTensorOCR:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor."""
def __init__(self):
pass
def __call__(self, results):
results['img'] = TF.to_tensor(results['img'].copy())
return results
@PIPELINES.register_module()
class NormalizeOCR:
"""Normalize a tensor image with mean and standard deviation."""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, results):
results['img'] = TF.normalize(results['img'], self.mean, self.std)
return results
@PIPELINES.register_module()
class OnlineCropOCR:
"""Crop text areas from whole image with bounding box jitter. If no bbox is
given, return directly.
Args:
box_keys (list[str]): Keys in results which correspond to RoI bbox.
jitter_prob (float): The probability of box jitter.
max_jitter_ratio_x (float): Maximum horizontal jitter ratio
relative to height.
max_jitter_ratio_y (float): Maximum vertical jitter ratio
relative to height.
"""
def __init__(self,
box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'],
jitter_prob=0.5,
max_jitter_ratio_x=0.05,
max_jitter_ratio_y=0.02):
assert utils.is_type_list(box_keys, str)
assert 0 <= jitter_prob <= 1
assert 0 <= max_jitter_ratio_x <= 1
assert 0 <= max_jitter_ratio_y <= 1
self.box_keys = box_keys
self.jitter_prob = jitter_prob
self.max_jitter_ratio_x = max_jitter_ratio_x
self.max_jitter_ratio_y = max_jitter_ratio_y
def __call__(self, results):
if 'img_info' not in results:
return results
crop_flag = True
box = []
for key in self.box_keys:
if key not in results['img_info']:
crop_flag = False
break
box.append(float(results['img_info'][key]))
if not crop_flag:
return results
jitter_flag = np.random.random() > self.jitter_prob
kwargs = dict(
jitter_flag=jitter_flag,
jitter_ratio_x=self.max_jitter_ratio_x,
jitter_ratio_y=self.max_jitter_ratio_y)
crop_img = warp_img(results['img'], box, **kwargs)
results['img'] = crop_img
results['img_shape'] = crop_img.shape
return results
@PIPELINES.register_module()
class FancyPCA:
"""Implementation of PCA based image augmentation, proposed in the paper
``Imagenet Classification With Deep Convolutional Neural Networks``.
It alters the intensities of RGB values along the principal components of
ImageNet dataset.
"""
def __init__(self, eig_vec=None, eig_val=None):
if eig_vec is None:
eig_vec = torch.Tensor([
[-0.5675, +0.7192, +0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, +0.4203],
]).t()
if eig_val is None:
eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]])
self.eig_val = eig_val # 1*3
self.eig_vec = eig_vec # 3*3
def pca(self, tensor):
assert tensor.size(0) == 3
alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1
reconst = torch.mm(self.eig_val * alpha, self.eig_vec)
tensor = tensor + reconst.view(3, 1, 1)
return tensor
def __call__(self, results):
img = results['img']
tensor = self.pca(img)
results['img'] = tensor
return results
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str
@PIPELINES.register_module()
class RandomPaddingOCR:
"""Pad the given image on all sides, as well as modify the coordinates of
character bounding box in image.
Args:
max_ratio (list[int]): [left, top, right, bottom].
box_type (None|str): Character box type. If not none,
should be either 'char_rects' or 'char_quads', with
'char_rects' for rectangle with ``xyxy`` style and
'char_quads' for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
"""
def __init__(self, max_ratio=None, box_type=None):
if max_ratio is None:
max_ratio = [0.1, 0.2, 0.1, 0.2]
else:
assert utils.is_type_list(max_ratio, float)
assert len(max_ratio) == 4
assert box_type is None or box_type in ('char_rects', 'char_quads')
self.max_ratio = max_ratio
self.box_type = box_type
def __call__(self, results):
img_shape = results['img_shape']
ori_height, ori_width = img_shape[:2]
random_padding_left = round(
np.random.uniform(0, self.max_ratio[0]) * ori_width)
random_padding_top = round(
np.random.uniform(0, self.max_ratio[1]) * ori_height)
random_padding_right = round(
np.random.uniform(0, self.max_ratio[2]) * ori_width)
random_padding_bottom = round(
np.random.uniform(0, self.max_ratio[3]) * ori_height)
img = np.copy(results['img'])
img = cv2.copyMakeBorder(img, random_padding_top,
random_padding_bottom, random_padding_left,
random_padding_right, cv2.BORDER_REPLICATE)
results['img'] = img
results['img_shape'] = img.shape
if self.box_type is not None:
num_points = 2 if self.box_type == 'char_rects' else 4
char_num = len(results['ann_info'][self.box_type])
for i in range(char_num):
for j in range(num_points):
results['ann_info'][self.box_type][i][
j * 2] += random_padding_left
results['ann_info'][self.box_type][i][
j * 2 + 1] += random_padding_top
return results
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str
@PIPELINES.register_module()
class RandomRotateImageBox:
"""Rotate augmentation for segmentation based text recognition.
Args:
min_angle (int): Minimum rotation angle for image and box.
max_angle (int): Maximum rotation angle for image and box.
box_type (str): Character box type, should be either
'char_rects' or 'char_quads', with 'char_rects'
for rectangle with ``xyxy`` style and 'char_quads'
for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
"""
def __init__(self, min_angle=-10, max_angle=10, box_type='char_quads'):
assert box_type in ('char_rects', 'char_quads')
self.min_angle = min_angle
self.max_angle = max_angle
self.box_type = box_type
def __call__(self, results):
in_img = results['img']
in_chars = results['ann_info']['chars']
in_boxes = results['ann_info'][self.box_type]
img_width, img_height = in_img.size
rotate_center = [img_width / 2., img_height / 2.]
tan_temp_max_angle = rotate_center[1] / rotate_center[0]
temp_max_angle = np.arctan(tan_temp_max_angle) * 180. / np.pi
random_angle = np.random.uniform(
max(self.min_angle, -temp_max_angle),
min(self.max_angle, temp_max_angle))
random_angle_radian = random_angle * np.pi / 180.
img_box = shapely_box(0, 0, img_width, img_height)
out_img = TF.rotate(
in_img,
random_angle,
resample=False,
expand=False,
center=rotate_center)
out_boxes, out_chars = self.rotate_bbox(in_boxes, in_chars,
random_angle_radian,
rotate_center, img_box)
results['img'] = out_img
results['ann_info']['chars'] = out_chars
results['ann_info'][self.box_type] = out_boxes
return results
@staticmethod
def rotate_bbox(boxes, chars, angle, center, img_box):
out_boxes = []
out_chars = []
for idx, bbox in enumerate(boxes):
temp_bbox = []
for i in range(len(bbox) // 2):
point = [bbox[2 * i], bbox[2 * i + 1]]
temp_bbox.append(
RandomRotateImageBox.rotate_point(point, angle, center))
poly_temp_bbox = Polygon(temp_bbox).buffer(0)
if poly_temp_bbox.is_valid:
if img_box.intersects(poly_temp_bbox) and (
not img_box.touches(poly_temp_bbox)):
temp_bbox_area = poly_temp_bbox.area
intersect_area = img_box.intersection(poly_temp_bbox).area
intersect_ratio = intersect_area / temp_bbox_area
if intersect_ratio >= 0.7:
out_box = []
for p in temp_bbox:
out_box.extend(p)
out_boxes.append(out_box)
out_chars.append(chars[idx])
return out_boxes, out_chars
@staticmethod
def rotate_point(point, angle, center):
cos_theta = math.cos(-angle)
sin_theta = math.sin(-angle)
c_x = center[0]
c_y = center[1]
new_x = (point[0] - c_x) * cos_theta - (point[1] -
c_y) * sin_theta + c_x
new_y = (point[0] - c_x) * sin_theta + (point[1] -
c_y) * cos_theta + c_y
return [new_x, new_y]
@PIPELINES.register_module()
class OpencvToPil:
"""Convert ``numpy.ndarray`` (bgr) to ``PIL Image`` (rgb)."""
def __init__(self, **kwargs):
pass
def __call__(self, results):
img = results['img'][..., ::-1]
img = Image.fromarray(img)
results['img'] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str
@PIPELINES.register_module()
class PilToOpencv:
"""Convert ``PIL Image`` (rgb) to ``numpy.ndarray`` (bgr)."""
def __init__(self, **kwargs):
pass
def __call__(self, results):
img = np.asarray(results['img'])
img = img[..., ::-1]
results['img'] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str

View File

@ -0,0 +1,108 @@
import mmcv
import numpy as np
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines.compose import Compose
@PIPELINES.register_module()
class MultiRotateAugOCR:
"""Test-time augmentation with multiple rotations in the case that
img_height > img_width.
An example configuration is as follows:
.. code-block::
rotate_degrees=[0, 90, 270],
transforms=[
dict(
type='ResizeOCR',
height=32,
min_width=32,
max_width=160,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
]),
]
After MultiRotateAugOCR with above configuration, the results are wrapped
into lists of the same length as follows:
.. code-block::
dict(
img=[...],
img_shape=[...]
...
)
Args:
transforms (list[dict]): Transformation applied for each augmentation.
rotate_degrees (list[int] | None): Degrees of anti-clockwise rotation.
force_rotate (bool): If True, rotate image by 'rotate_degrees'
while ignore image aspect ratio.
"""
def __init__(self, transforms, rotate_degrees=None, force_rotate=False):
self.transforms = Compose(transforms)
self.force_rotate = force_rotate
if rotate_degrees is not None:
self.rotate_degrees = rotate_degrees if isinstance(
rotate_degrees, list) else [rotate_degrees]
assert mmcv.is_list_of(self.rotate_degrees, int)
for degree in self.rotate_degrees:
assert 0 <= degree < 360
assert degree % 90 == 0
if 0 not in self.rotate_degrees:
self.rotate_degrees.append(0)
else:
self.rotate_degrees = [0]
def __call__(self, results):
"""Call function to apply test time augment transformation to results.
Args:
results (dict): Result dict contains the data to be transformed.
Returns:
dict[str: list]: The augmented data, where each value is wrapped
into a list.
"""
img_shape = results['img_shape']
ori_height, ori_width = img_shape[:2]
if not self.force_rotate and ori_height <= ori_width:
rotate_degrees = [0]
else:
rotate_degrees = self.rotate_degrees
aug_data = []
for degree in set(rotate_degrees):
_results = results.copy()
if degree == 0:
pass
elif degree == 90:
_results['img'] = np.rot90(_results['img'], 1)
elif degree == 180:
_results['img'] = np.rot90(_results['img'], 2)
elif degree == 270:
_results['img'] = np.rot90(_results['img'], 3)
data = self.transforms(_results)
aug_data.append(data)
# list of dict to dict of list
aug_data_dict = {key: [] for key in aug_data[0]}
for data in aug_data:
for key, val in data.items():
aug_data_dict[key].append(val)
return aug_data_dict
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(transforms={self.transforms}, '
repr_str += f'rotate_degrees={self.rotate_degrees})'
return repr_str

View File

@ -0,0 +1,121 @@
import numpy as np
from mmdet.datasets.builder import DATASETS
from mmocr.core.evaluation.hmean import eval_hmean
from mmocr.datasets.base_dataset import BaseDataset
@DATASETS.register_module()
class TextDetDataset(BaseDataset):
def _parse_anno_info(self, annotations):
"""Parse bbox and mask annotation.
Args:
annotations (dict): Annotations of one image.
Returns:
dict: A dict containing the following keys: bboxes, bboxes_ignore,
labels, masks, masks_ignore. "masks" and
"masks_ignore" are represented by polygon boundary
point sequences.
"""
gt_bboxes, gt_bboxes_ignore = [], []
gt_masks, gt_masks_ignore = [], []
gt_labels = []
for ann in annotations:
if ann.get('iscrowd', False):
gt_bboxes_ignore.append(ann['bbox'])
gt_masks_ignore.append(ann.get('segmentation', None))
else:
gt_bboxes.append(ann['bbox'])
gt_labels.append(ann['category_id'])
gt_masks.append(ann.get('segmentation', None))
if gt_bboxes:
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
gt_labels = np.array(gt_labels, dtype=np.int64)
else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
gt_labels = np.array([], dtype=np.int64)
if gt_bboxes_ignore:
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
else:
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
ann = dict(
bboxes=gt_bboxes,
labels=gt_labels,
bboxes_ignore=gt_bboxes_ignore,
masks_ignore=gt_masks_ignore,
masks=gt_masks)
return ann
def prepare_train_img(self, index):
"""Get training data and annotations from pipeline.
Args:
index (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys \
introduced by pipeline.
"""
img_ann_info = self.data_infos[index]
img_info = {
'filename': img_ann_info['file_name'],
'height': img_ann_info['height'],
'width': img_ann_info['width']
}
ann_info = self._parse_anno_info(img_ann_info['annotations'])
results = dict(img_info=img_info, ann_info=ann_info)
results['bbox_fields'] = []
results['mask_fields'] = []
results['seg_fields'] = []
self.pre_pipeline(results)
return self.pipeline(results)
def evaluate(self,
results,
metric='hmean-iou',
score_thr=0.3,
rank_list=None,
logger=None,
**kwargs):
"""Evaluate the dataset.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
score_thr (float): Score threshold for prediction map.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
rank_list (str): json file used to save eval result
of each image after ranking.
Returns:
dict[str: float]
"""
metrics = metric if isinstance(metric, list) else [metric]
allowed_metrics = ['hmean-iou', 'hmean-ic13']
metrics = set(metrics) & set(allowed_metrics)
img_infos = []
ann_infos = []
for i in range(len(self)):
img_ann_info = self.data_infos[i]
img_info = {'filename': img_ann_info['file_name']}
ann_info = self._parse_anno_info(img_ann_info['annotations'])
img_infos.append(img_info)
ann_infos.append(ann_info)
eval_results = eval_hmean(
results,
img_infos,
ann_infos,
metrics=metrics,
score_thr=score_thr,
logger=logger,
rank_list=rank_list)
return eval_results

View File

@ -0,0 +1,4 @@
from .loader import HardDiskLoader, LmdbLoader
from .parser import LineJsonParser, LineStrParser
__all__ = ['HardDiskLoader', 'LmdbLoader', 'LineStrParser', 'LineJsonParser']

View File

@ -0,0 +1,108 @@
import os.path as osp
from mmocr.datasets.builder import LOADERS, build_parser
@LOADERS.register_module()
class Loader:
"""Load annotation from annotation file, and parse instance information to
dict format with parser.
Args:
ann_file (str): Annotation file path.
parser (dict): Dictionary to construct parser
to parse original annotation infos.
repeat (int): Repeated times of annotations.
"""
def __init__(self, ann_file, parser, repeat=1):
assert isinstance(ann_file, str)
assert isinstance(repeat, int)
assert isinstance(parser, dict)
assert repeat > 0
assert osp.exists(ann_file), f'{ann_file} is not exist'
self.ori_data_infos = self._load(ann_file)
self.parser = build_parser(parser)
self.repeat = repeat
def __len__(self):
return len(self.ori_data_infos) * self.repeat
def _load(self, ann_file):
"""Load annotation file."""
raise NotImplementedError
def __getitem__(self, index):
"""Retrieve anno info of one instance with dict format."""
return self.parser.get_item(self.ori_data_infos, index)
@LOADERS.register_module()
class HardDiskLoader(Loader):
"""Load annotation file from hard disk to RAM.
Args:
ann_file (str): Annotation file path.
"""
def _load(self, ann_file):
data_ret = []
with open(ann_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
data_ret.append(line)
return data_ret
@LOADERS.register_module()
class LmdbLoader(Loader):
"""Load annotation file with lmdb storage backend."""
def _load(self, ann_file):
lmdb_anno_obj = LmdbAnnFileBackend(ann_file)
return lmdb_anno_obj
class LmdbAnnFileBackend:
"""Lmdb storage backend for annotation file.
Args:
lmdb_path (str): Lmdb file path.
"""
def __init__(self, lmdb_path, coding='utf8'):
self.lmdb_path = lmdb_path
self.coding = coding
env = self._get_env()
with env.begin(write=False) as txn:
self.total_number = int(
txn.get('total_number'.encode(self.coding)).decode(
self.coding))
def __getitem__(self, index):
"""Retrieval one line from lmdb file by index."""
# only attach env to self when __getitem__ is called
# because env object cannot be pickle
if not hasattr(self, 'env'):
self.env = self._get_env()
with self.env.begin(write=False) as txn:
line = txn.get(str(index).encode(self.coding)).decode(self.coding)
return line
def __len__(self):
return self.total_number
def _get_env(self):
import lmdb
return lmdb.open(
self.lmdb_path,
max_readers=1,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)

View File

@ -0,0 +1,69 @@
import json
from mmocr.datasets.builder import PARSERS
@PARSERS.register_module()
class LineStrParser:
"""Parse string of one line in annotation file to dict format.
Args:
keys (list[str]): Keys in result dict.
keys_idx (list[int]): Value index in sub-string list
for each key above.
separator (str): Separator to separate string to list of sub-string.
"""
def __init__(self,
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' '):
assert isinstance(keys, list)
assert isinstance(keys_idx, list)
assert isinstance(separator, str)
assert len(keys) > 0
assert len(keys) == len(keys_idx)
self.keys = keys
self.keys_idx = keys_idx
self.separator = separator
def get_item(self, data_ret, index):
map_index = index % len(data_ret)
line_str = data_ret[map_index]
for split_key in self.separator:
if split_key != ' ':
line_str = line_str.replace(split_key, ' ')
line_str = line_str.split()
if len(line_str) <= max(self.keys_idx):
raise Exception(
f'key index: {max(self.keys_idx)} out of range: {line_str}')
line_info = {}
for i, key in enumerate(self.keys):
line_info[key] = line_str[self.keys_idx[i]]
return line_info
@PARSERS.register_module()
class LineJsonParser:
"""Parse json-string of one line in annotation file to dict format.
Args:
keys (list[str]): Keys in both json-string and result dict.
"""
def __init__(self, keys=[], **kwargs):
assert isinstance(keys, list)
assert len(keys) > 0
self.keys = keys
def get_item(self, data_ret, index):
map_index = index % len(data_ret)
line_json_obj = json.loads(data_ret[map_index])
line_info = {}
for key in self.keys:
if key not in line_json_obj:
raise Exception(f'key {key} not in line json {line_json_obj}')
line_info[key] = line_json_obj[key]
return line_info

View File

@ -0,0 +1,8 @@
from .backbones import * # noqa: F401,F403
from .convertors import * # noqa: F401,F403
from .decoders import * # noqa: F401,F403
from .encoders import * # noqa: F401,F403
from .heads import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .recognizer import * # noqa: F401,F403

View File

@ -0,0 +1,4 @@
from .resnet31_ocr import ResNet31OCR
from .very_deep_vgg import VeryDeepVgg
__all__ = ['ResNet31OCR', 'VeryDeepVgg']

View File

@ -0,0 +1,149 @@
import torch.nn as nn
from mmcv.cnn import kaiming_init, uniform_init
import mmocr.utils as utils
from mmdet.models.builder import BACKBONES
from mmocr.models.textrecog.layers import BasicBlock
@BACKBONES.register_module()
class ResNet31OCR(nn.Module):
"""Implement ResNet backbone for text recognition, modified from
`ResNet <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
base_channels (int): Number of channels of input image tensor.
layers (list[int]): List of BasicBlock number for each stage.
channels (list[int]): List of out_channels of Conv2d layer.
out_indices (None | Sequence[int]): Indicdes of output stages.
stage4_pool_cfg (dict): Dictionary to construct and configure
pooling layer in stage 4.
last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
"""
def __init__(self,
base_channels=3,
layers=[1, 2, 5, 3],
channels=[64, 128, 256, 256, 512, 512, 512],
out_indices=None,
stage4_pool_cfg=dict(kernel_size=(2, 1), stride=(2, 1)),
last_stage_pool=False):
super().__init__()
assert isinstance(base_channels, int)
assert utils.is_type_list(layers, int)
assert utils.is_type_list(channels, int)
assert out_indices is None or (isinstance(out_indices, list)
or isinstance(out_indices, tuple))
assert isinstance(last_stage_pool, bool)
self.out_indices = out_indices
self.last_stage_pool = last_stage_pool
# conv 1 (Conv, Conv)
self.conv1_1 = nn.Conv2d(
base_channels, channels[0], kernel_size=3, stride=1, padding=1)
self.bn1_1 = nn.BatchNorm2d(channels[0])
self.relu1_1 = nn.ReLU(inplace=True)
self.conv1_2 = nn.Conv2d(
channels[0], channels[1], kernel_size=3, stride=1, padding=1)
self.bn1_2 = nn.BatchNorm2d(channels[1])
self.relu1_2 = nn.ReLU(inplace=True)
# conv 2 (Max-pooling, Residual block, Conv)
self.pool2 = nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block2 = self._make_layer(channels[1], channels[2], layers[0])
self.conv2 = nn.Conv2d(
channels[2], channels[2], kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(channels[2])
self.relu2 = nn.ReLU(inplace=True)
# conv 3 (Max-pooling, Residual block, Conv)
self.pool3 = nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block3 = self._make_layer(channels[2], channels[3], layers[1])
self.conv3 = nn.Conv2d(
channels[3], channels[3], kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(channels[3])
self.relu3 = nn.ReLU(inplace=True)
# conv 4 (Max-pooling, Residual block, Conv)
self.pool4 = nn.MaxPool2d(padding=0, ceil_mode=True, **stage4_pool_cfg)
self.block4 = self._make_layer(channels[3], channels[4], layers[2])
self.conv4 = nn.Conv2d(
channels[4], channels[4], kernel_size=3, stride=1, padding=1)
self.bn4 = nn.BatchNorm2d(channels[4])
self.relu4 = nn.ReLU(inplace=True)
# conv 5 ((Max-pooling), Residual block, Conv)
self.pool5 = None
if self.last_stage_pool:
self.pool5 = nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True) # 1/16
self.block5 = self._make_layer(channels[4], channels[5], layers[3])
self.conv5 = nn.Conv2d(
channels[5], channels[5], kernel_size=3, stride=1, padding=1)
self.bn5 = nn.BatchNorm2d(channels[5])
self.relu5 = nn.ReLU(inplace=True)
def init_weights(self, pretrained=None):
# initialize weight and bias
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
uniform_init(m)
def _make_layer(self, input_channels, output_channels, blocks):
layers = []
for _ in range(blocks):
downsample = None
if input_channels != output_channels:
downsample = nn.Sequential(
nn.Conv2d(
input_channels,
output_channels,
kernel_size=1,
stride=1,
bias=False),
nn.BatchNorm2d(output_channels),
)
layers.append(
BasicBlock(
input_channels, output_channels, downsample=downsample))
input_channels = output_channels
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1_1(x)
x = self.bn1_1(x)
x = self.relu1_1(x)
x = self.conv1_2(x)
x = self.bn1_2(x)
x = self.relu1_2(x)
outs = []
for i in range(4):
layer_index = i + 2
pool_layer = getattr(self, f'pool{layer_index}')
block_layer = getattr(self, f'block{layer_index}')
conv_layer = getattr(self, f'conv{layer_index}')
bn_layer = getattr(self, f'bn{layer_index}')
relu_layer = getattr(self, f'relu{layer_index}')
if pool_layer is not None:
x = pool_layer(x)
x = block_layer(x)
x = conv_layer(x)
x = bn_layer(x)
x = relu_layer(x)
outs.append(x)
if self.out_indices is not None:
return tuple([outs[i] for i in self.out_indices])
return x

View File

@ -0,0 +1,6 @@
from .attn import AttnConvertor
from .base import BaseConvertor
from .ctc import CTCConvertor
from .seg import SegConvertor
__all__ = ['BaseConvertor', 'CTCConvertor', 'AttnConvertor', 'SegConvertor']

View File

@ -0,0 +1,140 @@
import torch
import mmocr.utils as utils
from mmocr.models.builder import CONVERTORS
from .base import BaseConvertor
@CONVERTORS.register_module()
class AttnConvertor(BaseConvertor):
"""Convert between text, index and tensor for encoder-decoder based
pipeline.
Args:
dict_type (str): Type of dict, should be one of {'DICT36', 'DICT90'}.
dict_file (None|str): Character dict file path. If not none,
higher priority than dict_type.
dict_list (None|list[str]): Character list. If not none, higher
priority than dict_type, but lower than dict_file.
with_unknown (bool): If True, add `UKN` token to class.
max_seq_len (int): Maximum sequence length of label.
lower (bool): If True, convert original string to lower case.
start_end_same (bool): Whether use the same index for
start and end token or not. Default: True.
"""
def __init__(self,
dict_type='DICT90',
dict_file=None,
dict_list=None,
with_unknown=True,
max_seq_len=40,
lower=False,
start_end_same=True,
**kwargs):
super().__init__(dict_type, dict_file, dict_list)
assert isinstance(with_unknown, bool)
assert isinstance(max_seq_len, int)
assert isinstance(lower, bool)
self.with_unknown = with_unknown
self.max_seq_len = max_seq_len
self.lower = lower
self.start_end_same = start_end_same
self.update_dict()
def update_dict(self):
start_end_token = '<BOS/EOS>'
unknown_token = '<UKN>'
padding_token = '<PAD>'
# unknown
self.unknown_idx = None
if self.with_unknown:
self.idx2char.append(unknown_token)
self.unknown_idx = len(self.idx2char) - 1
# BOS/EOS
self.idx2char.append(start_end_token)
self.start_idx = len(self.idx2char) - 1
if not self.start_end_same:
self.idx2char.append(start_end_token)
self.end_idx = len(self.idx2char) - 1
# padding
self.idx2char.append(padding_token)
self.padding_idx = len(self.idx2char) - 1
# update char2idx
self.char2idx = {}
for idx, char in enumerate(self.idx2char):
self.char2idx[char] = idx
def str2tensor(self, strings):
"""
Convert text-string into tensor.
Args:
strings (list[str]): ['hello', 'world']
Returns:
dict (str: Tensor | list[tensor]):
tensors (list[Tensor]): [torch.Tensor([1,2,3,3,4]),
torch.Tensor([5,4,6,3,7])]
padded_targets (Tensor(bsz * max_seq_len))
"""
assert utils.is_type_list(strings, str)
tensors, padded_targets = [], []
indexes = self.str2idx(strings)
for index in indexes:
tensor = torch.LongTensor(index)
tensors.append(tensor)
# target tensor for loss
src_target = torch.LongTensor(tensor.size(0) + 2).fill_(0)
src_target[-1] = self.end_idx
src_target[0] = self.start_idx
src_target[1:-1] = tensor
padded_target = (torch.ones(self.max_seq_len) *
self.padding_idx).long()
char_num = src_target.size(0)
if char_num > self.max_seq_len:
padded_target = src_target[:self.max_seq_len]
else:
padded_target[:char_num] = src_target
padded_targets.append(padded_target)
padded_targets = torch.stack(padded_targets, 0).long()
return {'targets': tensors, 'padded_targets': padded_targets}
def tensor2idx(self, outputs, img_metas=None):
"""
Convert output tensor to text-index
Args:
outputs (tensor): model outputs with size: N * T * C
img_metas (list[dict]): Each dict contains one image info.
Returns:
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]]
scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94],
[0.9,0.9,0.98,0.97,0.96]]
"""
batch_size = outputs.size(0)
ignore_indexes = [self.padding_idx]
indexes, scores = [], []
for idx in range(batch_size):
seq = outputs[idx, :, :]
max_value, max_idx = torch.max(seq, -1)
str_index, str_score = [], []
output_index = max_idx.cpu().detach().numpy().tolist()
output_score = max_value.cpu().detach().numpy().tolist()
for char_index, char_score in zip(output_index, output_score):
if char_index in ignore_indexes:
continue
if char_index == self.end_idx:
break
str_index.append(char_index)
str_score.append(char_score)
indexes.append(str_index)
scores.append(str_score)
return indexes, scores

View File

@ -0,0 +1,115 @@
from mmocr.models.builder import CONVERTORS
@CONVERTORS.register_module()
class BaseConvertor:
"""Convert between text, index and tensor for text recognize pipeline.
Args:
dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'.
dict_file (None|str): Character dict file path. If not none,
the dict_file is of higher priority than dict_type.
dict_list (None|list[str]): Character list. If not none, the list
is of higher priority than dict_type, but lower than dict_file.
"""
start_idx = end_idx = padding_idx = 0
unknown_idx = None
lower = False
DICT36 = tuple('0123456789abcdefghijklmnopqrstuvwxyz')
DICT90 = tuple('0123456789abcdefghijklmnopqrstuvwxyz'
'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()'
'*+,-./:;<=>?@[\\]_`~')
def __init__(self, dict_type='DICT90', dict_file=None, dict_list=None):
assert dict_type in ('DICT36', 'DICT90')
assert dict_file is None or isinstance(dict_file, str)
assert dict_list is None or isinstance(dict_list, list)
self.idx2char = []
if dict_file is not None:
with open(dict_file, encoding='utf-8') as fr:
for line in fr:
line = line.strip()
if line != '':
self.idx2char.append(line)
elif dict_list is not None:
self.idx2char = dict_list
else:
if dict_type == 'DICT36':
self.idx2char = list(self.DICT36)
else:
self.idx2char = list(self.DICT90)
self.char2idx = {}
for idx, char in enumerate(self.idx2char):
self.char2idx[char] = idx
def num_classes(self):
"""Number of output classes."""
return len(self.idx2char)
def str2idx(self, strings):
"""Convert strings to indexes.
Args:
strings (list[str]): ['hello', 'world'].
Returns:
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
"""
assert isinstance(strings, list)
indexes = []
for string in strings:
if self.lower:
string = string.lower()
index = []
for char in string:
char_idx = self.char2idx.get(char, self.unknown_idx)
if char_idx is None:
raise Exception(f'Chararcter: {char} not in dict,'
f' please check gt_label and use'
f' custom dict file,'
f' or set "with_unknown=True"')
index.append(char_idx)
indexes.append(index)
return indexes
def str2tensor(self, strings):
"""Convert text-string to input tensor.
Args:
strings (list[str]): ['hello', 'world'].
Returns:
tensors (list[torch.Tensor]): [torch.Tensor([1,2,3,3,4]),
torch.Tensor([5,4,6,3,7])].
"""
raise NotImplementedError
def idx2str(self, indexes):
"""Convert indexes to text strings.
Args:
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
Returns:
strings (list[str]): ['hello', 'world'].
"""
assert isinstance(indexes, list)
strings = []
for index in indexes:
string = [self.idx2char[i] for i in index]
strings.append(''.join(string))
return strings
def tensor2idx(self, output):
"""Convert model output tensor to character indexes and scores.
Args:
output (tensor): The model outputs with size: N * T * C
Returns:
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94],
[0.9,0.9,0.98,0.97,0.96]].
"""
raise NotImplementedError

View File

@ -0,0 +1,144 @@
import math
import torch
import torch.nn.functional as F
import mmocr.utils as utils
from mmocr.models.builder import CONVERTORS
from .base import BaseConvertor
@CONVERTORS.register_module()
class CTCConvertor(BaseConvertor):
"""Convert between text, index and tensor for CTC loss-based pipeline.
Args:
dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'.
dict_file (None|str): Character dict file path. If not none, the file
is of higher priority than dict_type.
dict_list (None|list[str]): Character list. If not none, the list
is of higher priority than dict_type, but lower than dict_file.
with_unknown (bool): If True, add `UKN` token to class.
lower (bool): If True, convert original string to lower case.
"""
def __init__(self,
dict_type='DICT90',
dict_file=None,
dict_list=None,
with_unknown=True,
lower=False,
**kwargs):
super().__init__(dict_type, dict_file, dict_list)
assert isinstance(with_unknown, bool)
assert isinstance(lower, bool)
self.with_unknown = with_unknown
self.lower = lower
self.update_dict()
def update_dict(self):
# CTC-blank
blank_token = '<BLK>'
self.blank_idx = 0
self.idx2char.insert(0, blank_token)
# unknown
self.unknown_idx = None
if self.with_unknown:
self.idx2char.append('<UKN>')
self.unknown_idx = len(self.idx2char) - 1
# update char2idx
self.char2idx = {}
for idx, char in enumerate(self.idx2char):
self.char2idx[char] = idx
def str2tensor(self, strings):
"""Convert text-string to ctc-loss input tensor.
Args:
strings (list[str]): ['hello', 'world'].
Returns:
dict (str: tensor | list[tensor]):
tensors (list[tensor]): [torch.Tensor([1,2,3,3,4]),
torch.Tensor([5,4,6,3,7])].
flatten_targets (tensor): torch.Tensor([1,2,3,3,4,5,4,6,3,7]).
target_lengths (tensor): torch.IntTensot([5,5]).
"""
assert utils.is_type_list(strings, str)
tensors = []
indexes = self.str2idx(strings)
for index in indexes:
tensor = torch.IntTensor(index)
tensors.append(tensor)
target_lengths = torch.IntTensor([len(t) for t in tensors])
flatten_target = torch.cat(tensors)
return {
'targets': tensors,
'flatten_targets': flatten_target,
'target_lengths': target_lengths
}
def tensor2idx(self, output, img_metas, topk=1, return_topk=False):
"""Convert model output tensor to index-list.
Args:
output (tensor): The model outputs with size: N * T * C.
img_metas (list[dict]): Each dict contains one image info.
topk (int): The highest k classes to be returned.
return_topk (bool): Whether to return topk or just top1.
Returns:
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94],
[0.9,0.9,0.98,0.97,0.96]]
(
indexes_topk (list[list[list[int]->len=topk]]):
scores_topk (list[list[list[float]->len=topk]])
).
"""
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == output.size(0)
assert isinstance(topk, int)
assert topk >= 1
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
]
batch_size = output.size(0)
output = F.softmax(output, dim=2)
output = output.cpu().detach()
batch_topk_value, batch_topk_idx = output.topk(topk, dim=2)
batch_max_idx = batch_topk_idx[:, :, 0]
scores_topk, indexes_topk = [], []
scores, indexes = [], []
feat_len = output.size(1)
for b in range(batch_size):
valid_ratio = valid_ratios[b]
decode_len = min(feat_len, math.ceil(feat_len * valid_ratio))
pred = batch_max_idx[b, :]
select_idx = []
prev_idx = self.blank_idx
for t in range(decode_len):
tmp_value = pred[t].item()
if tmp_value not in (prev_idx, self.blank_idx):
select_idx.append(t)
prev_idx = tmp_value
select_idx = torch.LongTensor(select_idx)
topk_value = torch.index_select(batch_topk_value[b, :, :], 0,
select_idx) # valid_seqlen * topk
topk_idx = torch.index_select(batch_topk_idx[b, :, :], 0,
select_idx)
topk_idx_list, topk_value_list = topk_idx.numpy().tolist(
), topk_value.numpy().tolist()
indexes_topk.append(topk_idx_list)
scores_topk.append(topk_value_list)
indexes.append([x[0] for x in topk_idx_list])
scores.append([x[0] for x in topk_value_list])
if return_topk:
return indexes_topk, scores_topk
return indexes, scores

View File

@ -0,0 +1,123 @@
import cv2
import numpy as np
import torch
import mmocr.utils as utils
from mmocr.models.builder import CONVERTORS
from .base import BaseConvertor
@CONVERTORS.register_module()
class SegConvertor(BaseConvertor):
"""Convert between text, index and tensor for segmentation based pipeline.
Args:
dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'.
dict_file (None|str): Character dict file path. If not none, the
file is of higher priority than dict_type.
dict_list (None|list[str]): Character list. If not none, the list
is of higher priority than dict_type, but lower than dict_file.
with_unknown (bool): If True, add `UKN` token to class.
lower (bool): If True, convert original string to lower case.
"""
def __init__(self,
dict_type='DICT36',
dict_file=None,
dict_list=None,
with_unknown=True,
lower=False,
**kwargs):
super().__init__(dict_type, dict_file, dict_list)
assert isinstance(with_unknown, bool)
assert isinstance(lower, bool)
self.with_unknown = with_unknown
self.lower = lower
self.update_dict()
def update_dict(self):
# background
self.idx2char.insert(0, '<BG>')
# unknown
self.unknown_idx = None
if self.with_unknown:
self.idx2char.append('<UKN>')
self.unknown_idx = len(self.idx2char) - 1
# update char2idx
self.char2idx = {}
for idx, char in enumerate(self.idx2char):
self.char2idx[char] = idx
def tensor2str(self, output, img_metas=None):
"""Convert model output tensor to string labels.
Args:
output (tensor): Model outputs with size: N * C * H * W
img_metas (list[dict]): Each dict contains one image info.
Returns:
texts (list[str]): Decoded text labels.
scores (list[list[float]]): Decoded chars scores.
"""
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == output.size(0)
texts, scores = [], []
for b in range(output.size(0)):
seg_pred = output[b].detach()
seg_res = torch.argmax(
seg_pred, dim=0).cpu().numpy().astype(np.int32)
seg_thr = np.where(seg_res == 0, 0, 255).astype(np.uint8)
_, labels, stats, centroids = cv2.connectedComponentsWithStats(
seg_thr)
component_num = stats.shape[0]
all_res = []
for i in range(component_num):
temp_loc = (labels == i)
temp_value = seg_res[temp_loc]
temp_center = centroids[i]
temp_max_num = 0
temp_max_cls = -1
temp_total_num = 0
for c in range(len(self.idx2char)):
c_num = np.sum(temp_value == c)
temp_total_num += c_num
if c_num > temp_max_num:
temp_max_num = c_num
temp_max_cls = c
if temp_max_cls == 0:
continue
temp_max_score = 1.0 * temp_max_num / temp_total_num
all_res.append(
[temp_max_cls, temp_center, temp_max_num, temp_max_score])
all_res = sorted(all_res, key=lambda s: s[1][0])
chars, char_scores = [], []
for res in all_res:
temp_area = res[2]
if temp_area < 20:
continue
temp_char_index = res[0]
if temp_char_index >= len(self.idx2char):
temp_char = ''
elif temp_char_index <= 0:
temp_char = ''
elif temp_char_index == self.unknown_idx:
temp_char = ''
else:
temp_char = self.idx2char[temp_char_index]
chars.append(temp_char)
char_scores.append(res[3])
text = ''.join(chars)
texts.append(text)
scores.append(char_scores)
return texts, scores

View File

@ -0,0 +1,15 @@
from .base_decoder import BaseDecoder
from .crnn_decoder import CRNNDecoder
from .position_attention_decoder import PositionAttentionDecoder
from .robust_scanner_decoder import RobustScannerDecoder
from .sar_decoder import ParallelSARDecoder, SequentialSARDecoder
from .sar_decoder_with_bs import ParallelSARDecoderWithBS
from .sequence_attention_decoder import SequenceAttentionDecoder
from .transformer_decoder import TFDecoder
__all__ = [
'CRNNDecoder', 'ParallelSARDecoder', 'SequentialSARDecoder',
'ParallelSARDecoderWithBS', 'TFDecoder', 'BaseDecoder',
'SequenceAttentionDecoder', 'PositionAttentionDecoder',
'RobustScannerDecoder'
]

View File

@ -0,0 +1,32 @@
import torch.nn as nn
from mmocr.models.builder import DECODERS
@DECODERS.register_module()
class BaseDecoder(nn.Module):
"""Base decoder class for text recognition."""
def __init__(self, **kwargs):
super().__init__()
def init_weights(self):
pass
def forward_train(self, feat, out_enc, targets_dict, img_metas):
raise NotImplementedError
def forward_test(self, feat, out_enc, img_metas):
raise NotImplementedError
def forward(self,
feat,
out_enc,
targets_dict=None,
img_metas=None,
train_mode=True):
self.train_mode = train_mode
if train_mode:
return self.forward_train(feat, out_enc, targets_dict, img_metas)
return self.forward_test(feat, out_enc, img_metas)

View File

@ -0,0 +1,49 @@
import torch.nn as nn
from mmcv.cnn import xavier_init
from mmocr.models.builder import DECODERS
from mmocr.models.textrecog.layers import BidirectionalLSTM
from .base_decoder import BaseDecoder
@DECODERS.register_module()
class CRNNDecoder(BaseDecoder):
def __init__(self,
in_channels=None,
num_classes=None,
rnn_flag=False,
**kwargs):
super().__init__()
self.num_classes = num_classes
self.rnn_flag = rnn_flag
if rnn_flag:
self.decoder = nn.Sequential(
BidirectionalLSTM(in_channels, 256, 256),
BidirectionalLSTM(256, 256, num_classes))
else:
self.decoder = nn.Conv2d(
in_channels, num_classes, kernel_size=1, stride=1)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m)
def forward_train(self, feat, out_enc, targets_dict, img_metas):
assert feat.size(2) == 1, 'feature height must be 1'
if self.rnn_flag:
x = feat.squeeze(2) # [N, C, W]
x = x.permute(2, 0, 1) # [W, N, C]
x = self.decoder(x) # [W, N, C]
outputs = x.permute(1, 0, 2).contiguous()
else:
x = self.decoder(feat)
x = x.permute(0, 3, 1, 2).contiguous()
n, w, c, h = x.size()
outputs = x.view(n, w, c * h)
return outputs
def forward_test(self, feat, out_enc, img_metas):
return self.forward_train(feat, out_enc, None, img_metas)

View File

@ -0,0 +1,407 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import mmocr.utils as utils
from mmocr.models.builder import DECODERS
from .base_decoder import BaseDecoder
@DECODERS.register_module()
class ParallelSARDecoder(BaseDecoder):
"""Implementation Parallel Decoder module in `SAR.
<https://arxiv.org/abs/1811.00751>`_
Args:
number_classes (int): Output class number.
channels (list[int]): Network layer channels.
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
dec_do_rnn (float): Dropout of RNN layer in decoder.
dec_gru (bool): If True, use GRU, else LSTM in decoder.
d_model (int): Dim of channels from backbone.
d_enc (int): Dim of encoder RNN layer.
d_k (int): Dim of channels of attention module.
pred_dropout (float): Dropout probability of prediction layer.
max_seq_len (int): Maximum sequence length for decoding.
mask (bool): If True, mask padding in feature map.
start_idx (int): Index of start token.
padding_idx (int): Index of padding token.
pred_concat (bool): If True, concat glimpse feature from
attention with holistic feature and hidden state.
"""
def __init__(self,
num_classes=37,
enc_bi_rnn=False,
dec_bi_rnn=False,
dec_do_rnn=0.0,
dec_gru=False,
d_model=512,
d_enc=512,
d_k=64,
pred_dropout=0.0,
max_seq_len=40,
mask=True,
start_idx=0,
padding_idx=92,
pred_concat=False,
**kwargs):
super().__init__()
self.num_classes = num_classes
self.enc_bi_rnn = enc_bi_rnn
self.d_k = d_k
self.start_idx = start_idx
self.max_seq_len = max_seq_len
self.mask = mask
self.pred_concat = pred_concat
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
# 2D attention layer
self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
self.conv3x3_1 = nn.Conv2d(
d_model, d_k, kernel_size=3, stride=1, padding=1)
self.conv1x1_2 = nn.Linear(d_k, 1)
# Decoder RNN layer
kwargs = dict(
input_size=encoder_rnn_out_size,
hidden_size=encoder_rnn_out_size,
num_layers=2,
batch_first=True,
dropout=dec_do_rnn,
bidirectional=dec_bi_rnn)
if dec_gru:
self.rnn_decoder = nn.GRU(**kwargs)
else:
self.rnn_decoder = nn.LSTM(**kwargs)
# Decoder input embedding
self.embedding = nn.Embedding(
self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx)
# Prediction layer
self.pred_dropout = nn.Dropout(pred_dropout)
pred_num_classes = num_classes - 1 # ignore padding_idx in prediction
if pred_concat:
fc_in_channel = decoder_rnn_out_size + d_model + d_enc
else:
fc_in_channel = d_model
self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
def _2d_attention(self,
decoder_input,
feat,
holistic_feat,
valid_ratios=None):
y = self.rnn_decoder(decoder_input)[0]
# y: bsz * (seq_len + 1) * hidden_size
attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
bsz, seq_len, attn_size = attn_query.size()
attn_query = attn_query.view(bsz, seq_len, attn_size, 1, 1)
attn_key = self.conv3x3_1(feat)
# bsz * attn_size * h * w
attn_key = attn_key.unsqueeze(1)
# bsz * 1 * attn_size * h * w
attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1))
# bsz * (seq_len + 1) * attn_size * h * w
attn_weight = attn_weight.permute(0, 1, 3, 4, 2).contiguous()
# bsz * (seq_len + 1) * h * w * attn_size
attn_weight = self.conv1x1_2(attn_weight)
# bsz * (seq_len + 1) * h * w * 1
bsz, T, h, w, c = attn_weight.size()
assert c == 1
if valid_ratios is not None:
# cal mask of attention weight
attn_mask = torch.zeros_like(attn_weight)
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, math.ceil(w * valid_ratio))
attn_mask[i, :, :, valid_width:, :] = 1
attn_weight = attn_weight.masked_fill(attn_mask.bool(),
float('-inf'))
attn_weight = attn_weight.view(bsz, T, -1)
attn_weight = F.softmax(attn_weight, dim=-1)
attn_weight = attn_weight.view(bsz, T, h, w,
c).permute(0, 1, 4, 2, 3).contiguous()
attn_feat = torch.sum(
torch.mul(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False)
# bsz * (seq_len + 1) * C
# linear transformation
if self.pred_concat:
hf_c = holistic_feat.size(-1)
holistic_feat = holistic_feat.expand(bsz, seq_len, hf_c)
y = self.prediction(torch.cat((y, attn_feat, holistic_feat), 2))
else:
y = self.prediction(attn_feat)
# bsz * (seq_len + 1) * num_classes
if self.train_mode:
y = self.pred_dropout(y)
return y
def forward_train(self, feat, out_enc, targets_dict, img_metas):
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
] if self.mask else None
targets = targets_dict['padded_targets'].to(feat.device)
tgt_embedding = self.embedding(targets)
# bsz * seq_len * emb_dim
out_enc = out_enc.unsqueeze(1)
# bsz * 1 * emb_dim
in_dec = torch.cat((out_enc, tgt_embedding), dim=1)
# bsz * (seq_len + 1) * C
out_dec = self._2d_attention(
in_dec, feat, out_enc, valid_ratios=valid_ratios)
# bsz * (seq_len + 1) * num_classes
return out_dec[:, 1:, :] # bsz * seq_len * num_classes
def forward_test(self, feat, out_enc, img_metas):
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
] if self.mask else None
seq_len = self.max_seq_len
bsz = feat.size(0)
start_token = torch.full((bsz, ),
self.start_idx,
device=feat.device,
dtype=torch.long)
# bsz
start_token = self.embedding(start_token)
# bsz * emb_dim
start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1)
# bsz * seq_len * emb_dim
out_enc = out_enc.unsqueeze(1)
# bsz * 1 * emb_dim
decoder_input = torch.cat((out_enc, start_token), dim=1)
# bsz * (seq_len + 1) * emb_dim
outputs = []
for i in range(1, seq_len + 1):
decoder_output = self._2d_attention(
decoder_input, feat, out_enc, valid_ratios=valid_ratios)
char_output = decoder_output[:, i, :] # bsz * num_classes
char_output = F.softmax(char_output, -1)
outputs.append(char_output)
_, max_idx = torch.max(char_output, dim=1, keepdim=False)
char_embedding = self.embedding(max_idx) # bsz * emb_dim
if i < seq_len:
decoder_input[:, i + 1, :] = char_embedding
outputs = torch.stack(outputs, 1) # bsz * seq_len * num_classes
return outputs
@DECODERS.register_module()
class SequentialSARDecoder(BaseDecoder):
"""Implementation Sequential Decoder module in `SAR.
<https://arxiv.org/abs/1811.00751>`_.
Args:
number_classes (int): Number of output class.
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
dec_do_rnn (float): Dropout of RNN layer in decoder.
dec_gru (bool): If True, use GRU, else LSTM in decoder.
d_k (int): Dim of conv layers in attention module.
d_model (int): Dim of channels from backbone.
d_enc (int): Dim of encoder RNN layer.
pred_dropout (float): Dropout probability of prediction layer.
max_seq_len (int): Maximum sequence length during decoding.
mask (bool): If True, mask padding in feature map.
start_idx (int): Index of start token.
padding_idx (int): Index of padding token.
pred_concat (bool): If True, concat glimpse feature from
attention with holistic feature and hidden state.
"""
def __init__(self,
num_classes=37,
enc_bi_rnn=False,
dec_bi_rnn=False,
dec_gru=False,
d_k=64,
d_model=512,
d_enc=512,
pred_dropout=0.0,
mask=True,
max_seq_len=40,
start_idx=0,
padding_idx=92,
pred_concat=False,
**kwargs):
super().__init__()
self.num_classes = num_classes
self.enc_bi_rnn = enc_bi_rnn
self.d_k = d_k
self.start_idx = start_idx
self.dec_gru = dec_gru
self.max_seq_len = max_seq_len
self.mask = mask
self.pred_concat = pred_concat
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
# 2D attention layer
self.conv1x1_1 = nn.Conv2d(
decoder_rnn_out_size, d_k, kernel_size=1, stride=1)
self.conv3x3_1 = nn.Conv2d(
d_model, d_k, kernel_size=3, stride=1, padding=1)
self.conv1x1_2 = nn.Conv2d(d_k, 1, kernel_size=1, stride=1)
# Decoder rnn layer
if dec_gru:
self.rnn_decoder_layer1 = nn.GRUCell(encoder_rnn_out_size,
encoder_rnn_out_size)
self.rnn_decoder_layer2 = nn.GRUCell(encoder_rnn_out_size,
encoder_rnn_out_size)
else:
self.rnn_decoder_layer1 = nn.LSTMCell(encoder_rnn_out_size,
encoder_rnn_out_size)
self.rnn_decoder_layer2 = nn.LSTMCell(encoder_rnn_out_size,
encoder_rnn_out_size)
# Decoder input embedding
self.embedding = nn.Embedding(
self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx)
# Prediction layer
self.pred_dropout = nn.Dropout(pred_dropout)
pred_num_class = num_classes - 1 # ignore padding index
if pred_concat:
fc_in_channel = decoder_rnn_out_size + d_model + d_enc
else:
fc_in_channel = d_model
self.prediction = nn.Linear(fc_in_channel, pred_num_class)
def _2d_attention(self,
y_prev,
feat,
holistic_feat,
hx1,
cx1,
hx2,
cx2,
valid_ratios=None):
_, _, h_feat, w_feat = feat.size()
if self.dec_gru:
hx1 = cx1 = self.rnn_decoder_layer1(y_prev, hx1)
hx2 = cx2 = self.rnn_decoder_layer2(hx1, hx2)
else:
hx1, cx1 = self.rnn_decoder_layer1(y_prev, (hx1, cx1))
hx2, cx2 = self.rnn_decoder_layer2(hx1, (hx2, cx2))
tile_hx2 = hx2.view(hx2.size(0), hx2.size(1), 1, 1)
attn_query = self.conv1x1_1(tile_hx2) # bsz * attn_size * 1 * 1
attn_query = attn_query.expand(-1, -1, h_feat, w_feat)
attn_key = self.conv3x3_1(feat)
attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1))
attn_weight = self.conv1x1_2(attn_weight)
bsz, c, h, w = attn_weight.size()
assert c == 1
if valid_ratios is not None:
# cal mask of attention weight
attn_mask = torch.zeros_like(attn_weight)
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, math.ceil(w * valid_ratio))
attn_mask[i, :, :, valid_width:] = 1
attn_weight = attn_weight.masked_fill(attn_mask.bool(),
float('-inf'))
attn_weight = F.softmax(attn_weight.view(bsz, -1), dim=-1)
attn_weight = attn_weight.view(bsz, c, h, w)
attn_feat = torch.sum(
torch.mul(feat, attn_weight), (2, 3), keepdim=False) # n * c
# linear transformation
if self.pred_concat:
y = self.prediction(torch.cat((hx2, attn_feat, holistic_feat), 1))
else:
y = self.prediction(attn_feat)
return y, hx1, hx1, hx2, hx2
def forward_train(self, feat, out_enc, targets_dict, img_metas=None):
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
] if self.mask else None
if self.train_mode:
targets = targets_dict['padded_targets'].to(feat.device)
tgt_embedding = self.embedding(targets)
outputs = []
start_token = torch.full((feat.size(0), ),
self.start_idx,
device=feat.device,
dtype=torch.long)
start_token = self.embedding(start_token)
for i in range(-1, self.max_seq_len):
if i == -1:
if self.dec_gru:
hx1 = cx1 = self.rnn_decoder_layer1(out_enc)
hx2 = cx2 = self.rnn_decoder_layer2(hx1)
else:
hx1, cx1 = self.rnn_decoder_layer1(out_enc)
hx2, cx2 = self.rnn_decoder_layer2(hx1)
if not self.train_mode:
y_prev = start_token
else:
if self.train_mode:
y_prev = tgt_embedding[:, i, :]
y, hx1, cx1, hx2, cx2 = self._2d_attention(
y_prev,
feat,
out_enc,
hx1,
cx1,
hx2,
cx2,
valid_ratios=valid_ratios)
if self.train_mode:
y = self.pred_dropout(y)
else:
y = F.softmax(y, -1)
_, max_idx = torch.max(y, dim=1, keepdim=False)
char_embedding = self.embedding(max_idx)
y_prev = char_embedding
outputs.append(y)
outputs = torch.stack(outputs, 1)
return outputs
def forward_test(self, feat, out_enc, img_metas):
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
return self.forward_train(feat, out_enc, None, img_metas)

View File

@ -0,0 +1,148 @@
from queue import PriorityQueue
import torch
import torch.nn.functional as F
import mmocr.utils as utils
from mmocr.models.builder import DECODERS
from . import ParallelSARDecoder
class DecodeNode:
"""Node class to save decoded char indices and scores.
Args:
indexes (list[int]): Char indices that decoded yes.
scores (list[float]): Char scores that decoded yes.
"""
def __init__(self, indexes=[1], scores=[0.9]):
assert utils.is_type_list(indexes, int)
assert utils.is_type_list(scores, float)
assert utils.equal_len(indexes, scores)
self.indexes = indexes
self.scores = scores
def eval(self):
"""Calculate accumulated score."""
accu_score = sum(self.scores)
return accu_score
@DECODERS.register_module()
class ParallelSARDecoderWithBS(ParallelSARDecoder):
"""Parallel Decoder module with beam-search in SAR.
Args:
beam_width (int): Width for beam search.
"""
def __init__(self,
beam_width=5,
num_classes=37,
enc_bi_rnn=False,
dec_bi_rnn=False,
dec_do_rnn=0,
dec_gru=False,
d_model=512,
d_enc=512,
d_k=64,
pred_dropout=0.0,
max_seq_len=40,
mask=True,
start_idx=0,
padding_idx=0,
pred_concat=False,
**kwargs):
super().__init__(num_classes, enc_bi_rnn, dec_bi_rnn, dec_do_rnn,
dec_gru, d_model, d_enc, d_k, pred_dropout,
max_seq_len, mask, start_idx, padding_idx,
pred_concat)
assert isinstance(beam_width, int)
assert beam_width > 0
self.beam_width = beam_width
def forward_test(self, feat, out_enc, img_metas):
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
] if self.mask else None
seq_len = self.max_seq_len
bsz = feat.size(0)
assert bsz == 1, 'batch size must be 1 for beam search.'
start_token = torch.full((bsz, ),
self.start_idx,
device=feat.device,
dtype=torch.long)
# bsz
start_token = self.embedding(start_token)
# bsz * emb_dim
start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1)
# bsz * seq_len * emb_dim
out_enc = out_enc.unsqueeze(1)
# bsz * 1 * emb_dim
decoder_input = torch.cat((out_enc, start_token), dim=1)
# bsz * (seq_len + 1) * emb_dim
# Initialize beam-search queue
q = PriorityQueue()
init_node = DecodeNode([self.start_idx], [0.0])
q.put((-init_node.eval(), init_node))
for i in range(1, seq_len + 1):
next_nodes = []
beam_width = self.beam_width if i > 1 else 1
for _ in range(beam_width):
_, node = q.get()
input_seq = torch.clone(decoder_input) # bsz * T * emb_dim
# fill previous input tokens (step 1...i) in input_seq
for t, index in enumerate(node.indexes):
input_token = torch.full((bsz, ),
index,
device=input_seq.device,
dtype=torch.long)
input_token = self.embedding(input_token) # bsz * emb_dim
input_seq[:, t + 1, :] = input_token
output_seq = self._2d_attention(
input_seq, feat, out_enc, valid_ratios=valid_ratios)
output_char = output_seq[:, i, :] # bsz * num_classes
output_char = F.softmax(output_char, -1)
topk_value, topk_idx = output_char.topk(self.beam_width, dim=1)
topk_value, topk_idx = topk_value.squeeze(0), topk_idx.squeeze(
0)
for k in range(self.beam_width):
kth_score = topk_value[k].item()
kth_idx = topk_idx[k].item()
next_node = DecodeNode(node.indexes + [kth_idx],
node.scores + [kth_score])
delta = k * 1e-6
next_nodes.append(
(-node.eval() - kth_score - delta, next_node))
# Use minus since priority queue sort
# with ascending order
while not q.empty():
q.get()
# Put all candidates to queue
for next_node in next_nodes:
q.put(next_node)
best_node = q.get()
num_classes = self.num_classes - 1 # ignore padding index
outputs = torch.zeros(bsz, seq_len, num_classes)
for i in range(seq_len):
idx = best_node[1].indexes[i + 1]
outputs[0, i, idx] = best_node[1].scores[i + 1]
return outputs

View File

@ -0,0 +1,99 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmocr.models.builder import DECODERS
from mmocr.models.textrecog.layers import (DecoderLayer, PositionalEncoding,
get_pad_mask, get_subsequent_mask)
from .base_decoder import BaseDecoder
@DECODERS.register_module()
class TFDecoder(BaseDecoder):
"""Transformer Decoder block with self attention mechanism."""
def __init__(self,
n_layers=6,
d_embedding=512,
n_head=8,
d_k=64,
d_v=64,
d_model=512,
d_inner=256,
n_position=200,
dropout=0.1,
num_classes=93,
max_seq_len=40,
start_idx=1,
padding_idx=92,
**kwargs):
super().__init__()
self.padding_idx = padding_idx
self.start_idx = start_idx
self.max_seq_len = max_seq_len
self.trg_word_emb = nn.Embedding(
num_classes, d_embedding, padding_idx=padding_idx)
self.position_enc = PositionalEncoding(
d_embedding, n_position=n_position)
self.dropout = nn.Dropout(p=dropout)
self.layer_stack = nn.ModuleList([
DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)
])
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
pred_num_class = num_classes - 1 # ignore padding_idx
self.classifier = nn.Linear(d_model, pred_num_class)
def _attention(self, trg_seq, src, src_mask=None):
trg_embedding = self.trg_word_emb(trg_seq)
trg_pos_encoded = self.position_enc(trg_embedding)
tgt = self.dropout(trg_pos_encoded)
trg_mask = get_pad_mask(
trg_seq, pad_idx=self.padding_idx) & get_subsequent_mask(trg_seq)
output = tgt
for dec_layer in self.layer_stack:
output = dec_layer(
output,
src,
slf_attn_mask=trg_mask,
dec_enc_attn_mask=src_mask)
output = self.layer_norm(output)
return output
def forward_train(self, feat, out_enc, targets_dict, img_metas):
targets = targets_dict['padded_targets'].to(out_enc.device)
attn_output = self._attention(targets, out_enc, src_mask=None)
outputs = self.classifier(attn_output)
return outputs
def forward_test(self, feat, out_enc, img_metas):
bsz = out_enc.size(0)
init_target_seq = torch.full((bsz, self.max_seq_len + 1),
self.padding_idx,
device=out_enc.device,
dtype=torch.long)
# bsz * seq_len
init_target_seq[:, 0] = self.start_idx
outputs = []
for step in range(0, self.max_seq_len):
decoder_output = self._attention(
init_target_seq, out_enc, src_mask=None)
# bsz * seq_len * 512
step_result = F.softmax(
self.classifier(decoder_output[:, step, :]), dim=-1)
# bsz * num_classes
outputs.append(step_result)
_, step_max_index = torch.max(step_result, dim=-1)
init_target_seq[:, step + 1] = step_max_index
outputs = torch.stack(outputs, dim=1)
return outputs

View File

@ -0,0 +1,6 @@
from .base_encoder import BaseEncoder
from .channel_reduction_encoder import ChannelReductionEncoder
from .sar_encoder import SAREncoder
from .transformer_encoder import TFEncoder
__all__ = ['SAREncoder', 'TFEncoder', 'BaseEncoder', 'ChannelReductionEncoder']

View File

@ -0,0 +1,14 @@
import torch.nn as nn
from mmocr.models.builder import ENCODERS
@ENCODERS.register_module()
class BaseEncoder(nn.Module):
"""Base Encoder class for text recognition."""
def init_weights(self):
pass
def forward(self, feat, **kwargs):
return feat

View File

@ -0,0 +1,102 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import uniform_init, xavier_init
import mmocr.utils as utils
from mmocr.models.builder import ENCODERS
from .base_encoder import BaseEncoder
@ENCODERS.register_module()
class SAREncoder(BaseEncoder):
"""Implementation of encoder module in `SAR.
<https://arxiv.org/abs/1811.00751>`_
Args:
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
enc_do_rnn (float): Dropout probability of RNN layer in encoder.
enc_gru (bool): If True, use GRU, else LSTM in encoder.
d_model (int): Dim of channels from backbone.
d_enc (int): Dim of encoder RNN layer.
mask (bool): If True, mask padding in RNN sequence.
"""
def __init__(self,
enc_bi_rnn=False,
enc_do_rnn=0.0,
enc_gru=False,
d_model=512,
d_enc=512,
mask=True,
**kwargs):
super().__init__()
assert isinstance(enc_bi_rnn, bool)
assert isinstance(enc_do_rnn, (int, float))
assert 0 <= enc_do_rnn < 1.0
assert isinstance(enc_gru, bool)
assert isinstance(d_model, int)
assert isinstance(d_enc, int)
assert isinstance(mask, bool)
self.enc_bi_rnn = enc_bi_rnn
self.enc_do_rnn = enc_do_rnn
self.mask = mask
# LSTM Encoder
kwargs = dict(
input_size=d_model,
hidden_size=d_enc,
num_layers=2,
batch_first=True,
dropout=enc_do_rnn,
bidirectional=enc_bi_rnn)
if enc_gru:
self.rnn_encoder = nn.GRU(**kwargs)
else:
self.rnn_encoder = nn.LSTM(**kwargs)
# global feature transformation
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
def init_weights(self):
# initialize weight and bias
for m in self.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m)
elif isinstance(m, nn.BatchNorm2d):
uniform_init(m)
def forward(self, feat, img_metas=None):
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
] if self.mask else None
h_feat = feat.size(2)
feat_v = F.max_pool2d(
feat, kernel_size=(h_feat, 1), stride=1, padding=0)
feat_v = feat_v.squeeze(2) # bsz * C * W
feat_v = feat_v.permute(0, 2, 1).contiguous() # bsz * W * C
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
if valid_ratios is not None:
valid_hf = []
T = holistic_feat.size(1)
for i, valid_ratio in enumerate(valid_ratios):
valid_step = min(T, math.ceil(T * valid_ratio)) - 1
valid_hf.append(holistic_feat[i, valid_step, :])
valid_hf = torch.stack(valid_hf, dim=0)
else:
valid_hf = holistic_feat[:, -1, :] # bsz * C
holistic_feat = self.linear(valid_hf) # bsz * C
return holistic_feat

View File

@ -0,0 +1,16 @@
from mmocr.models.builder import ENCODERS
from .base_encoder import BaseEncoder
@ENCODERS.register_module()
class TFEncoder(BaseEncoder):
"""Encode 2d feature map to 1d sequence."""
def __init__(self, **kwargs):
super().__init__()
def forward(self, feat, img_metas=None):
n, c, _, _ = feat.size()
enc_output = feat.view(n, c, -1).transpose(2, 1).contiguous()
return enc_output

View File

@ -0,0 +1,3 @@
from .seg_head import SegHead
__all__ = ['SegHead']

View File

@ -0,0 +1,50 @@
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from torch import nn
from mmdet.models.builder import HEADS
@HEADS.register_module()
class SegHead(nn.Module):
"""Head for segmentation based text recognition.
Args:
in_channels (int): Number of input channels.
num_classes (int): Number of output classes.
upsample_param (dict | None): Config dict for interpolation layer.
Default: `dict(scale_factor=1.0, mode='nearest')`
"""
def __init__(self, in_channels=128, num_classes=37, upsample_param=None):
super().__init__()
assert isinstance(num_classes, int)
assert num_classes > 0
assert upsample_param is None or isinstance(upsample_param, dict)
self.upsample_param = upsample_param
self.seg_conv = ConvModule(
in_channels,
in_channels,
3,
stride=1,
padding=1,
norm_cfg=dict(type='BN'))
# prediction
self.pred_conv = nn.Conv2d(
in_channels, num_classes, kernel_size=1, stride=1, padding=0)
def init_weights(self):
pass
def forward(self, out_neck):
seg_map = self.seg_conv(out_neck[-1])
seg_map = self.pred_conv(seg_map)
if self.upsample_param is not None:
seg_map = F.interpolate(seg_map, **self.upsample_param)
return seg_map

View File

@ -0,0 +1,15 @@
from .conv_layer import BasicBlock, Bottleneck
from .dot_product_attention_layer import DotProductAttentionLayer
from .lstm_layer import BidirectionalLSTM
from .position_aware_layer import PositionAwareLayer
from .robust_scanner_fusion_layer import RobustScannerFusionLayer
from .transformer_layer import (DecoderLayer, MultiHeadAttention,
PositionalEncoding, PositionwiseFeedForward,
get_pad_mask, get_subsequent_mask)
__all__ = [
'BidirectionalLSTM', 'MultiHeadAttention', 'PositionalEncoding',
'PositionwiseFeedForward', 'BasicBlock', 'Bottleneck',
'RobustScannerFusionLayer', 'DotProductAttentionLayer',
'PositionAwareLayer', 'DecoderLayer', 'get_pad_mask', 'get_subsequent_mask'
]

View File

@ -0,0 +1,93 @@
import torch.nn as nn
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=False):
super().__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
if downsample:
self.downsample = nn.Sequential(
nn.Conv2d(
inplanes, planes * self.expansion, 1, stride, bias=False),
nn.BatchNorm2d(planes * self.expansion),
)
else:
self.downsample = nn.Sequential()
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=False):
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
if downsample:
self.downsample = nn.Sequential(
nn.Conv2d(
inplanes, planes * self.expansion, 1, stride, bias=False),
nn.BatchNorm2d(planes * self.expansion),
)
else:
self.downsample = nn.Sequential()
def forward(self, x):
residual = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += residual
out = self.relu(out)
return out

View File

@ -0,0 +1,20 @@
import torch.nn as nn
class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut):
super().__init__()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, _ = self.rnn(input)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = self.embedding(t_rec) # [T * b, nOut]
output = output.view(T, b, -1)
return output

View File

@ -0,0 +1,172 @@
"""This code is from https://github.com/jadore801120/attention-is-all-you-need-
pytorch."""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class DecoderLayer(nn.Module):
"""Compose with three layers
1. MultiHeadSelfAttn
2. MultiHeadEncoderDecoderAttn
3. PositionwiseFeedForward
"""
def __init__(self,
d_model=512,
d_inner=256,
n_head=8,
d_k=64,
d_v=64,
dropout=0.1):
super().__init__()
self.slf_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout)
self.enc_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(
d_model, d_inner, dropout=dropout)
def forward(self,
dec_input,
enc_output,
slf_attn_mask=None,
dec_enc_attn_mask=None):
dec_output = self.slf_attn(
dec_input, dec_input, dec_input, mask=slf_attn_mask)
dec_output = self.enc_attn(
dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
dec_output = self.pos_ffn(dec_output)
return dec_output
class ScaledDotProductAttention(nn.Module):
"""Scaled Dot-Product Attention."""
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
def forward(self, q, k, v, mask=None):
attn = torch.matmul(q / self.temperature, k.transpose(1, 2))
if mask is not None:
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = self.dropout(F.softmax(attn, dim=-1))
output = torch.matmul(attn, v)
return output, attn
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention module."""
def __init__(self, n_head=8, d_model=512, d_k=64, d_v=64, dropout=0.1):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.w_qs_list = nn.ModuleList(
[nn.Linear(d_model, d_k, bias=False) for _ in range(n_head)])
self.w_ks_list = nn.ModuleList(
[nn.Linear(d_model, d_k, bias=False) for _ in range(n_head)])
self.w_vs_list = nn.ModuleList(
[nn.Linear(d_model, d_v, bias=False) for _ in range(n_head)])
self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
self.attention = ScaledDotProductAttention(temperature=d_k**0.5)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(self, q, k, v, mask=None):
residual = q
q = self.layer_norm(q)
attention_q_list = []
for head_index in range(self.n_head):
q_each = self.w_qs_list[head_index](q) # bsz * seq_len * d_k
k_each = self.w_ks_list[head_index](k) # bsz * seq_len * d_k
v_each = self.w_vs_list[head_index](v) # bsz * seq_len * d_v
attention_q_each, _ = self.attention(
q_each, k_each, v_each, mask=mask)
attention_q_list.append(attention_q_each)
q = torch.cat(attention_q_list, dim=-1)
q = self.dropout(self.fc(q))
q += residual
return q
class PositionwiseFeedForward(nn.Module):
"""A two-feed-forward-layer module."""
def __init__(self, d_in, d_hid, dropout=0.1):
super().__init__()
self.w_1 = nn.Linear(d_in, d_hid) # position-wise
self.w_2 = nn.Linear(d_hid, d_in) # position-wise
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = self.layer_norm(x)
x = self.w_2(F.relu(self.w_1(x)))
x = self.dropout(x)
x += residual
return x
class PositionalEncoding(nn.Module):
def __init__(self, d_hid=512, n_position=200):
super().__init__()
# Not a parameter
self.register_buffer(
'position_table',
self._get_sinusoid_encoding_table(n_position, d_hid))
def _get_sinusoid_encoding_table(self, n_position, d_hid):
"""Sinusoid position encoding table."""
denominator = torch.Tensor([
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
])
denominator = denominator.view(1, -1)
pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
sinusoid_table = pos_tensor * denominator
sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
return sinusoid_table.unsqueeze(0)
def forward(self, x):
self.device = x.device
return x + self.position_table[:, :x.size(1)].clone().detach()
def get_pad_mask(seq, pad_idx):
return (seq != pad_idx).unsqueeze(-2)
def get_subsequent_mask(seq):
"""For masking out the subsequent info."""
len_s = seq.size(1)
subsequent_mask = 1 - torch.triu(
torch.ones((len_s, len_s), device=seq.device), diagonal=1)
subsequent_mask = subsequent_mask.unsqueeze(0).bool()
return subsequent_mask

View File

@ -0,0 +1,5 @@
from .ce_loss import CELoss, SARLoss, TFLoss
from .ctc_loss import CTCLoss
from .seg_loss import CAFCNLoss, SegLoss
__all__ = ['CELoss', 'SARLoss', 'CTCLoss', 'TFLoss', 'SegLoss', 'CAFCNLoss']

View File

@ -0,0 +1,94 @@
import torch.nn as nn
from mmdet.models.builder import LOSSES
@LOSSES.register_module()
class CELoss(nn.Module):
"""Implementation of loss module for encoder-decoder based text recognition
method with CrossEntropy loss.
Args:
ignore_index (int): Specifies a target value that is
ignored and does not contribute to the input gradient.
reduction (str): Specifies the reduction to apply to the output,
should be one of the following: ('none', 'mean', 'sum').
"""
def __init__(self, ignore_index=-1, reduction='none'):
super().__init__()
assert isinstance(ignore_index, int)
assert isinstance(reduction, str)
assert reduction in ['none', 'mean', 'sum']
self.loss_ce = nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction=reduction)
def format(self, outputs, targets_dict):
targets = targets_dict['padded_targets']
return outputs.permute(0, 2, 1).contiguous(), targets
def forward(self, outputs, targets_dict):
outputs, targets = self.format(outputs, targets_dict)
loss_ce = self.loss_ce(outputs, targets.to(outputs.device))
losses = dict(loss_ce=loss_ce)
return losses
@LOSSES.register_module()
class SARLoss(CELoss):
"""Implementation of loss module in `SAR.
<https://arxiv.org/abs/1811.00751>`_.
Args:
ignore_index (int): Specifies a target value that is
ignored and does not contribute to the input gradient.
reduction (str): Specifies the reduction to apply to the output,
should be one of the following: ('none', 'mean', 'sum').
"""
def __init__(self, ignore_index=0, reduction='mean', **kwargs):
super().__init__(ignore_index, reduction)
def format(self, outputs, targets_dict):
targets = targets_dict['padded_targets']
# targets[0, :], [start_idx, idx1, idx2, ..., end_idx, pad_idx...]
# outputs[0, :, 0], [idx1, idx2, ..., end_idx, ...]
# ignore first index of target in loss calculation
targets = targets[:, 1:].contiguous()
# ignore last index of outputs to be in same seq_len with targets
outputs = outputs[:, :-1, :].permute(0, 2, 1).contiguous()
return outputs, targets
@LOSSES.register_module()
class TFLoss(CELoss):
"""Implementation of loss module for transformer."""
def __init__(self,
ignore_index=-1,
reduction='none',
flatten=True,
**kwargs):
super().__init__(ignore_index, reduction)
assert isinstance(flatten, bool)
self.flatten = flatten
def format(self, outputs, targets_dict):
outputs = outputs[:, :-1, :].contiguous()
targets = targets_dict['padded_targets']
targets = targets[:, 1:].contiguous()
if self.flatten:
outputs = outputs.view(-1, outputs.size(-1))
targets = targets.view(-1)
else:
outputs = outputs.permute(0, 2, 1).contiguous()
return outputs, targets

View File

@ -0,0 +1,63 @@
import torch
import torch.nn as nn
from mmdet.models.builder import LOSSES
@LOSSES.register_module()
class CTCLoss(nn.Module):
"""Implementation of loss module for CTC-loss based text recognition.
Args:
flatten (bool): If True, use flattened targets, else padded targets.
blank (int): Blank label. Default 0.
reduction (str): Specifies the reduction to apply to the output,
should be one of the following: ('none', 'mean', 'sum').
zero_infinity (bool): Whether to zero infinite losses and
the associated gradients. Default: False.
Infinite losses mainly occur when the inputs
are too short to be aligned to the targets.
"""
def __init__(self,
flatten=True,
blank=0,
reduction='mean',
zero_infinity=False,
**kwargs):
super().__init__()
assert isinstance(flatten, bool)
assert isinstance(blank, int)
assert isinstance(reduction, str)
assert isinstance(zero_infinity, bool)
self.flatten = flatten
self.blank = blank
self.ctc_loss = nn.CTCLoss(
blank=blank, reduction=reduction, zero_infinity=zero_infinity)
def forward(self, outputs, targets_dict):
outputs = torch.log_softmax(outputs, dim=2)
bsz, seq_len = outputs.size(0), outputs.size(1)
input_lengths = torch.full(
size=(bsz, ), fill_value=seq_len, dtype=torch.long)
outputs_for_loss = outputs.permute(1, 0, 2).contiguous() # T * N * C
if self.flatten:
targets = targets_dict['flatten_targets']
else:
targets = torch.full(
size=(bsz, seq_len), fill_value=self.blank, dtype=torch.long)
for idx, tensor in enumerate(targets_dict['targets']):
valid_len = min(tensor.size(0), seq_len)
targets[idx, :valid_len] = tensor[:valid_len]
target_lengths = targets_dict['target_lengths']
loss_ctc = self.ctc_loss(outputs_for_loss, targets, input_lengths,
target_lengths)
losses = dict(loss_ctc=loss_ctc)
return losses

View File

@ -0,0 +1,176 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models.builder import LOSSES
from mmocr.models.common.losses import DiceLoss
@LOSSES.register_module()
class SegLoss(nn.Module):
"""Implementation of loss module for segmentation based text recognition
method.
Args:
seg_downsample_ratio (float): Downsample ratio of
segmentation map.
seg_with_loss_weight (bool): If True, set weight for
segmentation loss.
ignore_index (int): Specifies a target value that is ignored
and does not contribute to the input gradient.
"""
def __init__(self,
seg_downsample_ratio=0.5,
seg_with_loss_weight=True,
ignore_index=255,
**kwargs):
super().__init__()
assert isinstance(seg_downsample_ratio, (int, float))
assert 0 < seg_downsample_ratio <= 1
assert isinstance(ignore_index, int)
self.seg_downsample_ratio = seg_downsample_ratio
self.seg_with_loss_weight = seg_with_loss_weight
self.ignore_index = ignore_index
def seg_loss(self, out_head, gt_kernels):
seg_map = out_head # bsz * num_classes * H/2 * W/2
seg_target = [
item[1].rescale(self.seg_downsample_ratio).to_tensor(
torch.long, seg_map.device) for item in gt_kernels
]
seg_target = torch.stack(seg_target).squeeze(1)
loss_weight = None
if self.seg_with_loss_weight:
N = torch.sum(seg_target != self.ignore_index)
N_neg = torch.sum(seg_target == 0)
weight_val = 1.0 * N_neg / (N - N_neg)
loss_weight = torch.ones(seg_map.size(1), device=seg_map.device)
loss_weight[1:] = weight_val
loss_seg = F.cross_entropy(
seg_map,
seg_target,
weight=loss_weight,
ignore_index=self.ignore_index)
return loss_seg
def forward(self, out_neck, out_head, gt_kernels):
losses = {}
loss_seg = self.seg_loss(out_head, gt_kernels)
losses['loss_seg'] = loss_seg
return losses
@LOSSES.register_module()
class CAFCNLoss(SegLoss):
"""Implementation of loss module in `CA-FCN.
<https://arxiv.org/pdf/1809.06508.pdf>`_
Args:
alpha (float): Weight ratio of attention loss.
attn_s2_downsample_ratio (float): Downsample ratio
of attention map from output stage 2.
attn_s3_downsample_ratio (float): Downsample ratio
of attention map from output stage 3.
seg_downsample_ratio (float): Downsample ratio of
segmentation map.
attn_with_dice_loss (bool): If True, use dice_loss for attention,
else BCELoss.
with_attn (bool): If True, include attention loss, else
segmentation loss only.
seg_with_loss_weight (bool): If True, set weight for
segmentation loss.
ignore_index (int): Specifies a target value that is ignored
and does not contribute to the input gradient.
"""
def __init__(self,
alpha=1.0,
attn_s2_downsample_ratio=0.25,
attn_s3_downsample_ratio=0.125,
seg_downsample_ratio=0.5,
attn_with_dice_loss=False,
with_attn=True,
seg_with_loss_weight=True,
ignore_index=255):
super().__init__(seg_downsample_ratio, seg_with_loss_weight,
ignore_index)
assert isinstance(alpha, (int, float))
assert isinstance(attn_s2_downsample_ratio, (int, float))
assert isinstance(attn_s3_downsample_ratio, (int, float))
assert 0 < attn_s2_downsample_ratio <= 1
assert 0 < attn_s3_downsample_ratio <= 1
self.alpha = alpha
self.attn_s2_downsample_ratio = attn_s2_downsample_ratio
self.attn_s3_downsample_ratio = attn_s3_downsample_ratio
self.with_attn = with_attn
self.attn_with_dice_loss = attn_with_dice_loss
# attention loss
if with_attn:
if attn_with_dice_loss:
self.criterion_attn = DiceLoss()
else:
self.criterion_attn = nn.BCELoss(reduction='none')
def attn_loss(self, out_neck, gt_kernels):
attn_map_s2 = out_neck[0] # bsz * 2 * H/4 * W/4
mask_s2 = torch.stack([
item[2].rescale(self.attn_s2_downsample_ratio).to_tensor(
torch.float, attn_map_s2.device) for item in gt_kernels
])
attn_target_s2 = torch.stack([
item[0].rescale(self.attn_s2_downsample_ratio).to_tensor(
torch.float, attn_map_s2.device) for item in gt_kernels
])
mask_s3 = torch.stack([
item[2].rescale(self.attn_s3_downsample_ratio).to_tensor(
torch.float, attn_map_s2.device) for item in gt_kernels
])
attn_target_s3 = torch.stack([
item[0].rescale(self.attn_s3_downsample_ratio).to_tensor(
torch.float, attn_map_s2.device) for item in gt_kernels
])
targets = [
attn_target_s2, attn_target_s3, attn_target_s3, attn_target_s3
]
masks = [mask_s2, mask_s3, mask_s3, mask_s3]
loss_attn = 0.
for i in range(len(out_neck) - 1):
pred = out_neck[i]
if self.attn_with_dice_loss:
loss_attn += self.criterion_attn(pred, targets[i], masks[i])
else:
loss_attn += torch.sum(
self.criterion_attn(pred, targets[i]) *
masks[i]) / torch.sum(masks[i])
return loss_attn
def forward(self, out_neck, out_head, gt_kernels):
losses = super().forward(out_neck, out_head, gt_kernels)
if self.with_attn:
loss_attn = self.attn_loss(out_neck, gt_kernels)
losses['loss_attn'] = loss_attn
return losses

View File

@ -0,0 +1,5 @@
from .cafcn_neck import CAFCNNeck
from .fpn_ocr import FPNOCR
from .fpn_seg import FPNSeg
__all__ = ['CAFCNNeck', 'FPNSeg', 'FPNOCR']

View File

@ -0,0 +1,223 @@
import torch
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.ops import DeformConv2dPack
from torch import nn
from mmdet.models.builder import NECKS
class CharAttn(nn.Module):
"""Implementation of Character attention module in `CA-FCN.
<https://arxiv.org/pdf/1809.06508.pdf>`_
"""
def __init__(self, in_channels=128, out_channels=128, deformable=False):
super().__init__()
assert isinstance(in_channels, int)
assert isinstance(deformable, bool)
self.in_channels = in_channels
self.out_channels = out_channels
self.deformable = deformable
# attention layers
self.attn_layer = nn.Sequential(
ConvModule(
in_channels,
in_channels,
3,
stride=1,
padding=1,
norm_cfg=dict(type='BN')),
ConvModule(
in_channels,
1,
3,
stride=1,
padding=1,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='Sigmoid')))
conv_kwargs = dict(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=1,
padding=1)
if self.deformable:
self.conv = DeformConv2dPack(**conv_kwargs)
else:
self.conv = nn.Conv2d(**conv_kwargs)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, in_feat):
# Calculate attn map
attn_map = self.attn_layer(in_feat) # N * 1 * H * W
in_feat = self.relu(self.bn(self.conv(in_feat)))
out_feat_map = self._upsample_mul(in_feat, 1 + attn_map)
return out_feat_map, attn_map
def _upsample_add(self, x, y):
return F.interpolate(x, size=y.size()[2:]) + y
def _upsample_mul(self, x, y):
return F.interpolate(x, size=y.size()[2:]) * y
class FeatGenerator(nn.Module):
"""Generate attention-augmented stage feature from backbone stage
feature."""
def __init__(self,
in_channels=512,
out_channels=128,
deformable=True,
concat=False,
upsample=False,
with_attn=True):
super().__init__()
self.concat = concat
self.upsample = upsample
self.with_attn = with_attn
if with_attn:
self.char_attn = CharAttn(
in_channels=in_channels,
out_channels=out_channels,
deformable=deformable)
else:
self.char_attn = ConvModule(
in_channels,
out_channels,
3,
stride=1,
padding=1,
norm_cfg=dict(type='BN'))
if concat:
self.conv_to_concat = ConvModule(
out_channels,
out_channels,
3,
stride=1,
padding=1,
norm_cfg=dict(type='BN'))
kernel_size = (3, 1) if deformable else 3
padding = (1, 0) if deformable else 1
tmp_in_channels = out_channels * 2 if concat else out_channels
self.conv_after_concat = ConvModule(
tmp_in_channels,
out_channels,
kernel_size,
stride=1,
padding=padding,
norm_cfg=dict(type='BN'))
def forward(self, x, y=None, size=None):
if self.with_attn:
feat_map, attn_map = self.char_attn(x)
else:
feat_map = self.char_attn(x)
attn_map = feat_map
if self.concat:
y = self.conv_to_concat(y)
feat_map = torch.cat((y, feat_map), dim=1)
feat_map = self.conv_after_concat(feat_map)
if self.upsample:
feat_map = F.interpolate(feat_map, size)
return attn_map, feat_map
@NECKS.register_module()
class CAFCNNeck(nn.Module):
"""Implementation of neck module in `CA-FCN.
<https://arxiv.org/pdf/1809.06508.pdf>`_
Args:
in_channels (list[int]): Number of input channels for each scale.
out_channels (int): Number of output channels for each scale.
deformable (bool): If True, use deformable conv.
with_attn (bool): If True, add attention for each output feature map.
"""
def __init__(self,
in_channels=[128, 256, 512, 512],
out_channels=128,
deformable=True,
with_attn=True):
super().__init__()
self.deformable = deformable
self.with_attn = with_attn
# stage_in5_to_out5
self.s5 = FeatGenerator(
in_channels=in_channels[-1],
out_channels=out_channels,
deformable=deformable,
concat=False,
with_attn=with_attn)
# stage_in4_to_out4
self.s4 = FeatGenerator(
in_channels=in_channels[-2],
out_channels=out_channels,
deformable=deformable,
concat=True,
with_attn=with_attn)
# stage_in3_to_out3
self.s3 = FeatGenerator(
in_channels=in_channels[-3],
out_channels=out_channels,
deformable=False,
concat=True,
upsample=True,
with_attn=with_attn)
# stage_in2_to_out2
self.s2 = FeatGenerator(
in_channels=in_channels[-4],
out_channels=out_channels,
deformable=False,
concat=True,
upsample=True,
with_attn=with_attn)
def init_weights(self):
pass
def forward(self, feats):
in_stage1 = feats[0]
in_stage2, in_stage3 = feats[1], feats[2]
in_stage4, in_stage5 = feats[3], feats[4]
# out stage 5
out_s5_attn_map, out_s5 = self.s5(in_stage5)
# out stage 4
out_s4_attn_map, out_s4 = self.s4(in_stage4, out_s5)
# out stage 3
out_s3_attn_map, out_s3 = self.s3(in_stage3, out_s4,
in_stage2.size()[2:])
# out stage 2
out_s2_attn_map, out_s2 = self.s2(in_stage2, out_s3,
in_stage1.size()[2:])
return (out_s2_attn_map, out_s3_attn_map, out_s4_attn_map,
out_s5_attn_map, out_s2)

View File

@ -0,0 +1,70 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmdet.models.builder import NECKS
@NECKS.register_module()
class FPNOCR(nn.Module):
"""FPN-like Network for segmentation based text recognition.
Args:
in_channels (list[int]): Number of input channels for each scale.
out_channels (int): Number of output channels for each scale.
last_stage_only (bool): If True, output last stage only.
"""
def __init__(self, in_channels, out_channels, last_stage_only=True):
super(FPNOCR, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.last_stage_only = last_stage_only
self.lateral_convs = nn.ModuleList()
self.smooth_convs_1x1 = nn.ModuleList()
self.smooth_convs_3x3 = nn.ModuleList()
for i in range(self.num_ins):
l_conv = ConvModule(
in_channels[i], out_channels, 1, norm_cfg=dict(type='BN'))
self.lateral_convs.append(l_conv)
for i in range(self.num_ins - 1):
s_conv_1x1 = ConvModule(
out_channels * 2, out_channels, 1, norm_cfg=dict(type='BN'))
s_conv_3x3 = ConvModule(
out_channels,
out_channels,
3,
padding=1,
norm_cfg=dict(type='BN'))
self.smooth_convs_1x1.append(s_conv_1x1)
self.smooth_convs_3x3.append(s_conv_3x3)
def init_weights(self):
pass
def _upsample_x2(self, x):
return F.interpolate(x, scale_factor=2, mode='bilinear')
def forward(self, inputs):
lateral_features = [
l_conv(inputs[i]) for i, l_conv in enumerate(self.lateral_convs)
]
outs = []
for i in range(len(self.smooth_convs_3x3), 0, -1): # 3, 2, 1
last_out = lateral_features[-1] if len(outs) == 0 else outs[-1]
upsample = self._upsample_x2(last_out)
upsample_cat = torch.cat((upsample, lateral_features[i - 1]),
dim=1)
smooth_1x1 = self.smooth_convs_1x1[i - 1](upsample_cat)
smooth_3x3 = self.smooth_convs_3x3[i - 1](smooth_1x1)
outs.append(smooth_3x3)
return tuple(outs[-1:]) if self.last_stage_only else tuple(outs)

View File

@ -0,0 +1,43 @@
import torch.nn.functional as F
from mmcv.runner import auto_fp16
from mmdet.models.builder import NECKS
from mmdet.models.necks import FPN
@NECKS.register_module()
class FPNSeg(FPN):
"""Feature Pyramid Network for segmentation based text recognition.
Args:
in_channels (list[int]): Number of input channels for each scale.
out_channels (int): Number of output channels for each scale.
num_outs (int): Number of output scales.
upsample_param (dict | None): Config dict for interpolate layer.
Default: `dict(scale_factor=1.0, mode='nearest')`
last_stage_only (bool): If True, output last stage of FPN only.
"""
def __init__(self,
in_channels,
out_channels,
num_outs,
upsample_param=None,
last_stage_only=True):
super().__init__(in_channels, out_channels, num_outs)
self.upsample_param = upsample_param
self.last_stage_only = last_stage_only
@auto_fp16()
def forward(self, inputs):
outs = super().forward(inputs)
outs = list(outs)
if self.upsample_param is not None:
outs[0] = F.interpolate(outs[0], **self.upsample_param)
if self.last_stage_only:
return tuple(outs[0:1])
return tuple(outs[::-1])

View File

@ -0,0 +1,13 @@
from .base import BaseRecognizer
from .cafcn import CAFCNNet
from .crnn import CRNNNet
from .encode_decode_recognizer import EncodeDecodeRecognizer
from .robust_scanner import RobustScanner
from .sar import SARNet
from .seg_recognizer import SegRecognizer
from .transformer import TransformerNet
__all__ = [
'BaseRecognizer', 'EncodeDecodeRecognizer', 'CRNNNet', 'SARNet',
'TransformerNet', 'SegRecognizer', 'RobustScanner', 'CAFCNNet'
]

View File

@ -0,0 +1,236 @@
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
import mmcv
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.runner import auto_fp16
from mmcv.utils import print_log
from mmdet.utils import get_root_logger
from mmocr.core import imshow_text_label
class BaseRecognizer(nn.Module, metaclass=ABCMeta):
"""Base class for text recognition."""
def __init__(self):
super().__init__()
self.fp16_enabled = False
@abstractmethod
def extract_feat(self, imgs):
"""Extract features from images."""
pass
@abstractmethod
def forward_train(self, imgs, img_metas, **kwargs):
"""
Args:
img (tensor): tensors with shape (N, C, H, W).
Typically should be mean centered and std scaled.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details of the values of these keys, see
:class:`mmdet.datasets.pipelines.Collect`.
kwargs (keyword arguments): Specific to concrete implementation.
"""
pass
@abstractmethod
def simple_test(self, img, img_metas, **kwargs):
pass
@abstractmethod
def aug_test(self, imgs, img_metas, **kwargs):
"""Test function with test time augmentation.
Args:
imgs (list[tensor]): Tensor should have shape NxCxHxW,
which contains all images in the batch.
img_metas (list[list[dict]]): The metadata of images.
"""
pass
def init_weights(self, pretrained=None):
"""Initialize the weights for detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if pretrained is not None:
logger = get_root_logger()
print_log(f'load model from: {pretrained}', logger=logger)
def forward_test(self, imgs, img_metas, **kwargs):
"""
Args:
imgs (tensor | list[tensor]): Tensor should have shape NxCxHxW,
which contains all images in the batch.
img_metas (list[dict] | list[list[dict]]):
The outer list indicates images in a batch.
"""
if isinstance(imgs, list):
assert len(imgs) == len(img_metas)
assert len(imgs) > 0
assert imgs[0].size(0) == 1, 'aug test does not support ' \
'inference with batch size ' \
f'{imgs[0].size(0)}'
return self.aug_test(imgs, img_metas, **kwargs)
return self.simple_test(imgs, img_metas, **kwargs)
@auto_fp16(apply_to=('img', ))
def forward(self, img, img_metas, return_loss=True, **kwargs):
"""Calls either :func:`forward_train` or :func:`forward_test` depending
on whether ``return_loss`` is ``True``.
Note that img and img_meta are single-nested (i.e. tensor and
list[dict]).
"""
if return_loss:
return self.forward_train(img, img_metas, **kwargs)
return self.forward_test(img, img_metas, **kwargs)
def _parse_losses(self, losses):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw outputs of the network, which usually contain
losses and other necessary infomation.
Returns:
tuple[tensor, dict]: (loss, log_vars), loss is the loss tensor \
which may be a weighted sum of all losses, log_vars contains \
all the variables to be sent to the logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
def train_step(self, data, optimizer):
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer update, which are done by an optimizer
hook. Note that in some complicated cases or models (e.g. GAN),
the whole process (including the back propagation and optimizer update)
is also defined by this method.
Args:
data (dict): The outputs of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
``num_samples``.
- ``loss`` is a tensor for back propagation, which is a \
weighted sum of multiple losses.
- ``log_vars`` contains all the variables to be sent to the
logger.
- ``num_samples`` indicates the batch size used for \
averaging the logs (Note: for the \
DDP model, num_samples refers to the batch size for each GPU).
"""
losses = self(**data)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
return outputs
def val_step(self, data, optimizer):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but is
used during val epochs. Note that the evaluation after training epochs
is not implemented by this method, but by an evaluation hook.
"""
losses = self(**data)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
return outputs
def show_result(self,
img,
result,
gt_label='',
win_name='',
show=False,
wait_time=0,
out_file=None,
**kwargs):
"""Draw `result` on `img`.
Args:
img (str or tensor): The image to be displayed.
result (dict): The results to draw on `img`.
gt_label (str): Ground truth label of img.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The output filename.
Default: None.
Returns:
img (tensor): Only if not `show` or `out_file`.
"""
img = mmcv.imread(img)
img = img.copy()
pred_label = None
if 'text' in result.keys():
pred_label = result['text']
# if out_file specified, do not show image in window
if out_file is not None:
show = False
# draw text label
if pred_label is not None:
img = imshow_text_label(
img,
pred_label,
gt_label,
show=show,
win_name=win_name,
wait_time=wait_time,
out_file=out_file)
if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img
return img

View File

@ -0,0 +1,7 @@
from mmdet.models.builder import DETECTORS
from .seg_recognizer import SegRecognizer
@DETECTORS.register_module()
class CAFCNNet(SegRecognizer):
"""Implementation of `CAFCN <https://arxiv.org/pdf/1809.06508.pdf>`_"""

View File

@ -0,0 +1,18 @@
import torch
import torch.nn.functional as F
from mmdet.models.builder import DETECTORS
from .encode_decode_recognizer import EncodeDecodeRecognizer
@DETECTORS.register_module()
class CRNNNet(EncodeDecodeRecognizer):
"""CTC-loss based recognizer."""
def forward_conversion(self, params, img):
x = self.extract_feat(img)
x = self.encoder(x)
outs = self.decoder(x)
outs = F.softmax(outs, dim=2)
params = torch.pow(params, 1)
return outs, params

View File

@ -0,0 +1,173 @@
from mmdet.models.builder import DETECTORS, build_backbone, build_loss
from mmocr.models.builder import (build_convertor, build_decoder,
build_encoder, build_preprocessor)
from .base import BaseRecognizer
@DETECTORS.register_module()
class EncodeDecodeRecognizer(BaseRecognizer):
"""Base class for encode-decode recognizer."""
def __init__(self,
preprocessor=None,
backbone=None,
encoder=None,
decoder=None,
loss=None,
label_convertor=None,
train_cfg=None,
test_cfg=None,
max_seq_len=40,
pretrained=None):
super().__init__()
# Label convertor (str2tensor, tensor2str)
assert label_convertor is not None
label_convertor.update(max_seq_len=max_seq_len)
self.label_convertor = build_convertor(label_convertor)
# Preprocessor module, e.g., TPS
self.preprocessor = None
if preprocessor is not None:
self.preprocessor = build_preprocessor(preprocessor)
# Backbone
assert backbone is not None
self.backbone = build_backbone(backbone)
# Encoder module
self.encoder = None
if encoder is not None:
self.encoder = build_encoder(encoder)
# Decoder module
assert decoder is not None
decoder.update(num_classes=self.label_convertor.num_classes())
decoder.update(start_idx=self.label_convertor.start_idx)
decoder.update(padding_idx=self.label_convertor.padding_idx)
decoder.update(max_seq_len=max_seq_len)
self.decoder = build_decoder(decoder)
# Loss
assert loss is not None
loss.update(ignore_index=self.label_convertor.padding_idx)
self.loss = build_loss(loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.max_seq_len = max_seq_len
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
"""Initialize the weights of recognizer."""
super().init_weights(pretrained)
if self.preprocessor is not None:
self.preprocessor.init_weights()
self.backbone.init_weights()
if self.encoder is not None:
self.encoder.init_weights()
self.decoder.init_weights()
def extract_feat(self, img):
"""Directly extract features from the backbone."""
if self.preprocessor is not None:
img = self.preprocessor(img)
x = self.backbone(img)
return x
def forward_train(self, img, img_metas):
"""
Args:
img (tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A list of image info dict where each dict
contains: 'img_shape', 'filename', and may also contain
'ori_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
Returns:
dict[str, tensor]: A dictionary of loss components.
"""
feat = self.extract_feat(img)
gt_labels = [img_meta['text'] for img_meta in img_metas]
targets_dict = self.label_convertor.str2tensor(gt_labels)
out_enc = None
if self.encoder is not None:
out_enc = self.encoder(feat, img_metas)
out_dec = self.decoder(
feat, out_enc, targets_dict, img_metas, train_mode=True)
loss_inputs = (
out_dec,
targets_dict,
)
losses = self.loss(*loss_inputs)
return losses
def simple_test(self, img, img_metas, **kwargs):
"""Test function with test time augmentation.
Args:
imgs (torch.Tensor): Image input tensor.
img_metas (list[dict]): List of image information.
Returns:
list[str]: Text label result of each image.
"""
feat = self.extract_feat(img)
out_enc = None
if self.encoder is not None:
out_enc = self.encoder(feat, img_metas)
out_dec = self.decoder(
feat, out_enc, None, img_metas, train_mode=False)
label_indexes, label_scores = \
self.label_convertor.tensor2idx(out_dec, img_metas)
label_strings = self.label_convertor.idx2str(label_indexes)
# flatten batch results
results = []
for string, score in zip(label_strings, label_scores):
results.append(dict(text=string, score=score))
return results
def merge_aug_results(self, aug_results):
out_text, out_score = '', -1
for result in aug_results:
text = result[0]['text']
score = sum(result[0]['score']) / max(1, len(text))
if score > out_score:
out_text = text
out_score = score
out_results = [dict(text=out_text, score=out_score)]
return out_results
def aug_test(self, imgs, img_metas, **kwargs):
"""Test function as well as time augmentation.
Args:
imgs (list[tensor]): Tensor should have shape NxCxHxW,
which contains all images in the batch.
img_metas (list[list[dict]]): The metadata of images.
"""
aug_results = []
for img, img_meta in zip(imgs, img_metas):
result = self.simple_test(img, img_meta, **kwargs)
aug_results.append(result)
return self.merge_aug_results(aug_results)

View File

@ -0,0 +1,7 @@
from mmdet.models.builder import DETECTORS
from .encode_decode_recognizer import EncodeDecodeRecognizer
@DETECTORS.register_module()
class SARNet(EncodeDecodeRecognizer):
"""Implementation of `SAR <https://arxiv.org/abs/1811.00751>`_"""

View File

@ -0,0 +1,153 @@
from mmdet.models.builder import (DETECTORS, build_backbone, build_head,
build_loss, build_neck)
from mmocr.models.builder import build_convertor, build_preprocessor
from .base import BaseRecognizer
@DETECTORS.register_module()
class SegRecognizer(BaseRecognizer):
"""Base class for segmentation based recognizer."""
def __init__(self,
preprocessor=None,
backbone=None,
neck=None,
head=None,
loss=None,
label_convertor=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super().__init__()
# Label_convertor
assert label_convertor is not None
self.label_convertor = build_convertor(label_convertor)
# Preprocessor module, e.g., TPS
self.preprocessor = None
if preprocessor is not None:
self.preprocessor = build_preprocessor(preprocessor)
# Backbone
assert backbone is not None
self.backbone = build_backbone(backbone)
# Neck
assert neck is not None
self.neck = build_neck(neck)
# Head
assert head is not None
head.update(num_classes=self.label_convertor.num_classes())
self.head = build_head(head)
# Loss
assert loss is not None
self.loss = build_loss(loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
"""Initialize the weights of recognizer."""
super().init_weights(pretrained)
if self.preprocessor is not None:
self.preprocessor.init_weights()
self.backbone.init_weights(pretrained=pretrained)
if self.neck is not None:
self.neck.init_weights()
self.head.init_weights()
def extract_feat(self, img):
"""Directly extract features from the backbone."""
if self.preprocessor is not None:
img = self.preprocessor(img)
x = self.backbone(img)
return x
def forward_train(self, img, img_metas, gt_kernels=None):
"""
Args:
img (tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A list of image info dict where each dict
contains: 'img_shape', 'filename', and may also contain
'ori_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
Returns:
dict[str, tensor]: A dictionary of loss components.
"""
feats = self.extract_feat(img)
out_neck = self.neck(feats)
out_head = self.head(out_neck)
loss_inputs = (out_neck, out_head, gt_kernels)
losses = self.loss(*loss_inputs)
return losses
def simple_test(self, img, img_metas, **kwargs):
"""Test function without test time augmentation.
Args:
imgs (torch.Tensor): Image input tensor.
img_metas (list[dict]): List of image information.
Returns:
list[str]: Text label result of each image.
"""
feat = self.extract_feat(img)
out_neck = self.neck(feat)
out_head = self.head(out_neck)
texts, scores = self.label_convertor.tensor2str(out_head, img_metas)
# flatten batch results
results = []
for text, score in zip(texts, scores):
results.append(dict(text=text, score=score))
return results
def merge_aug_results(self, aug_results):
out_text, out_score = '', -1
for result in aug_results:
text = result[0]['text']
score = sum(result[0]['score']) / max(1, len(text))
if score > out_score:
out_text = text
out_score = score
out_results = [dict(text=out_text, score=out_score)]
return out_results
def aug_test(self, imgs, img_metas, **kwargs):
"""Test function with test time augmentation.
Args:
imgs (list[tensor]): Tensor should have shape NxCxHxW,
which contains all images in the batch.
img_metas (list[list[dict]]): The metadata of images.
"""
aug_results = []
for img, img_meta in zip(imgs, img_metas):
result = self.simple_test(img, img_meta, **kwargs)
aug_results.append(result)
return self.merge_aug_results(aug_results)

View File

@ -0,0 +1,7 @@
from mmdet.models.builder import DETECTORS
from .encode_decode_recognizer import EncodeDecodeRecognizer
@DETECTORS.register_module()
class TransformerNet(EncodeDecodeRecognizer):
"""Implementation of Transformer based OCR."""

View File

@ -0,0 +1,74 @@
import os.path as osp
import tempfile
import numpy as np
import pytest
from mmocr.datasets.base_dataset import BaseDataset
def _create_dummy_ann_file(ann_file):
ann_info1 = 'sample1.jpg hello'
ann_info2 = 'sample2.jpg world'
with open(ann_file, 'w') as fw:
for ann_info in [ann_info1, ann_info2]:
fw.write(ann_info + '\n')
def _create_dummy_loader():
loader = dict(
type='HardDiskLoader',
repeat=1,
parser=dict(type='LineStrParser', keys=['file_name', 'text']))
return loader
def test_custom_dataset():
tmp_dir = tempfile.TemporaryDirectory()
# create dummy data
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
_create_dummy_ann_file(ann_file)
loader = _create_dummy_loader()
for mode in [True, False]:
dataset = BaseDataset(ann_file, loader, pipeline=[], test_mode=mode)
# test len
assert len(dataset) == len(dataset.data_infos)
# test set group flag
assert np.allclose(dataset.flag, [0, 0])
# test prepare_train_img
expect_results = {
'img_info': {
'file_name': 'sample1.jpg',
'text': 'hello'
},
'img_prefix': ''
}
assert dataset.prepare_train_img(0) == expect_results
# test prepare_test_img
assert dataset.prepare_test_img(0) == expect_results
# test __getitem__
assert dataset[0] == expect_results
# test get_next_index
assert dataset._get_next_index(0) == 1
# test format_resuls
expect_results_copy = {
key: value
for key, value in expect_results.items()
}
dataset.format_results(expect_results)
assert expect_results_copy == expect_results
# test evaluate
with pytest.raises(NotImplementedError):
dataset.evaluate(expect_results)
tmp_dir.cleanup()

View File

@ -0,0 +1,96 @@
import math
import numpy as np
import pytest
from mmocr.datasets.pipelines.crop import (box_jitter, convert_canonical,
crop_img, sort_vertex, warp_img)
def test_order_vertex():
dummy_points_x = [20, 20, 120, 120]
dummy_points_y = [20, 40, 40, 20]
with pytest.raises(AssertionError):
sort_vertex([], dummy_points_y)
with pytest.raises(AssertionError):
sort_vertex(dummy_points_x, [])
ordered_points_x, ordered_points_y = sort_vertex(dummy_points_x,
dummy_points_y)
expect_points_x = [20, 120, 120, 20]
expect_points_y = [20, 20, 40, 40]
assert np.allclose(ordered_points_x, expect_points_x)
assert np.allclose(ordered_points_y, expect_points_y)
def test_convert_canonical():
dummy_points_x = [120, 120, 20, 20]
dummy_points_y = [20, 40, 40, 20]
with pytest.raises(AssertionError):
convert_canonical([], dummy_points_y)
with pytest.raises(AssertionError):
convert_canonical(dummy_points_x, [])
ordered_points_x, ordered_points_y = convert_canonical(
dummy_points_x, dummy_points_y)
expect_points_x = [20, 120, 120, 20]
expect_points_y = [20, 20, 40, 40]
assert np.allclose(ordered_points_x, expect_points_x)
assert np.allclose(ordered_points_y, expect_points_y)
def test_box_jitter():
dummy_points_x = [20, 120, 120, 20]
dummy_points_y = [20, 20, 40, 40]
kwargs = dict(jitter_ratio_x=0.0, jitter_ratio_y=0.0)
with pytest.raises(AssertionError):
box_jitter([], dummy_points_y)
with pytest.raises(AssertionError):
box_jitter(dummy_points_x, [])
with pytest.raises(AssertionError):
box_jitter(dummy_points_x, dummy_points_y, jitter_ratio_x=1.)
with pytest.raises(AssertionError):
box_jitter(dummy_points_x, dummy_points_y, jitter_ratio_y=1.)
box_jitter(dummy_points_x, dummy_points_y, **kwargs)
assert np.allclose(dummy_points_x, [20, 120, 120, 20])
assert np.allclose(dummy_points_y, [20, 20, 40, 40])
def test_opencv_crop():
dummy_img = np.ones((600, 600, 3), dtype=np.uint8)
dummy_box = [20, 20, 120, 20, 120, 40, 20, 40]
cropped_img = warp_img(dummy_img, dummy_box)
with pytest.raises(AssertionError):
warp_img(dummy_img, [])
with pytest.raises(AssertionError):
warp_img(dummy_img, [20, 40, 40, 20])
assert math.isclose(cropped_img.shape[0], 20)
assert math.isclose(cropped_img.shape[1], 100)
def test_min_rect_crop():
dummy_img = np.ones((600, 600, 3), dtype=np.uint8)
dummy_box = [20, 20, 120, 20, 120, 40, 20, 40]
cropped_img = crop_img(dummy_img, dummy_box)
with pytest.raises(AssertionError):
crop_img(dummy_img, [])
with pytest.raises(AssertionError):
crop_img(dummy_img, [20, 40, 40, 20])
assert math.isclose(cropped_img.shape[0], 20)
assert math.isclose(cropped_img.shape[1], 100)

View File

@ -0,0 +1,83 @@
import json
import os.path as osp
import tempfile
import numpy as np
from mmocr.datasets.text_det_dataset import TextDetDataset
def _create_dummy_ann_file(ann_file):
ann_info1 = {
'file_name':
'sample1.jpg',
'height':
640,
'width':
640,
'annotations': [{
'iscrowd': 0,
'category_id': 1,
'bbox': [50, 70, 80, 100],
'segmentation': [[50, 70, 80, 70, 80, 100, 50, 100]]
}, {
'iscrowd':
1,
'category_id':
1,
'bbox': [120, 140, 200, 200],
'segmentation': [[120, 140, 200, 140, 200, 200, 120, 200]]
}]
}
with open(ann_file, 'w') as fw:
fw.write(json.dumps(ann_info1) + '\n')
def _create_dummy_loader():
loader = dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineJsonParser',
keys=['file_name', 'height', 'width', 'annotations']))
return loader
def test_detect_dataset():
tmp_dir = tempfile.TemporaryDirectory()
# create dummy data
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
_create_dummy_ann_file(ann_file)
# test initialization
loader = _create_dummy_loader()
dataset = TextDetDataset(ann_file, loader, pipeline=[])
# test _parse_ann_info
img_ann_info = dataset.data_infos[0]
ann = dataset._parse_anno_info(img_ann_info['annotations'])
print(ann['bboxes'])
assert np.allclose(ann['bboxes'], [[50., 70., 80., 100.]])
assert np.allclose(ann['labels'], [1])
assert np.allclose(ann['bboxes_ignore'], [[120, 140, 200, 200]])
assert np.allclose(ann['masks'], [[[50, 70, 80, 70, 80, 100, 50, 100]]])
assert np.allclose(ann['masks_ignore'],
[[[120, 140, 200, 140, 200, 200, 120, 200]]])
tmp_dir.cleanup()
# test prepare_train_img
pipeline_results = dataset.prepare_train_img(0)
assert np.allclose(pipeline_results['bbox_fields'], [])
assert np.allclose(pipeline_results['mask_fields'], [])
assert np.allclose(pipeline_results['seg_fields'], [])
expect_img_info = {'filename': 'sample1.jpg', 'height': 640, 'width': 640}
assert pipeline_results['img_info'] == expect_img_info
# test evluation
metrics = 'hmean-iou'
results = [{'boundary_result': [[50, 70, 80, 70, 80, 100, 50, 100, 1]]}]
eval_res = dataset.evaluate(results, metrics)
assert eval_res['hmean-iou:hmean'] == 1

View File

@ -0,0 +1,71 @@
import json
import os.path as osp
import tempfile
import pytest
from tools.data.utils.txt2lmdb import converter
from mmocr.datasets.utils.loader import HardDiskLoader, LmdbLoader, Loader
def _create_dummy_line_str_file(ann_file):
ann_info1 = 'sample1.jpg hello'
ann_info2 = 'sample2.jpg world'
with open(ann_file, 'w') as fw:
for ann_info in [ann_info1, ann_info2]:
fw.write(ann_info + '\n')
def _create_dummy_line_json_file(ann_file):
ann_info1 = {'filename': 'sample1.jpg', 'text': 'hello'}
ann_info2 = {'filename': 'sample2.jpg', 'text': 'world'}
with open(ann_file, 'w') as fw:
for ann_info in [ann_info1, ann_info2]:
fw.write(json.dumps(ann_info) + '\n')
def test_loader():
tmp_dir = tempfile.TemporaryDirectory()
# create dummy data
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
_create_dummy_line_str_file(ann_file)
parser = dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')
with pytest.raises(AssertionError):
Loader(ann_file, parser, repeat=0)
with pytest.raises(AssertionError):
Loader(ann_file, [], repeat=1)
with pytest.raises(AssertionError):
Loader('sample.txt', parser, repeat=1)
with pytest.raises(NotImplementedError):
loader = Loader(ann_file, parser, repeat=1)
print(loader)
# test text loader and line str parser
text_loader = HardDiskLoader(ann_file, parser, repeat=1)
assert len(text_loader) == 2
assert text_loader.ori_data_infos[0] == 'sample1.jpg hello'
assert text_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'}
# test text loader and linedict parser
_create_dummy_line_json_file(ann_file)
json_parser = dict(type='LineJsonParser', keys=['filename', 'text'])
text_loader = HardDiskLoader(ann_file, json_parser, repeat=1)
assert text_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'}
# test lmdb loader and line str parser
_create_dummy_line_str_file(ann_file)
lmdb_file = osp.join(tmp_dir.name, 'fake_data.lmdb')
converter(ann_file, lmdb_file)
lmdb_loader = LmdbLoader(lmdb_file, parser, repeat=1)
assert lmdb_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'}
tmp_dir.cleanup()

View File

@ -0,0 +1,51 @@
import math
import os.path as osp
import tempfile
from mmocr.datasets.ocr_dataset import OCRDataset
def _create_dummy_ann_file(ann_file):
ann_info1 = 'sample1.jpg hello'
ann_info2 = 'sample2.jpg world'
with open(ann_file, 'w') as fw:
for ann_info in [ann_info1, ann_info2]:
fw.write(ann_info + '\n')
def _create_dummy_loader():
loader = dict(
type='HardDiskLoader',
repeat=1,
parser=dict(type='LineStrParser', keys=['file_name', 'text']))
return loader
def test_detect_dataset():
tmp_dir = tempfile.TemporaryDirectory()
# create dummy data
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
_create_dummy_ann_file(ann_file)
# test initialization
loader = _create_dummy_loader()
dataset = OCRDataset(ann_file, loader, pipeline=[])
tmp_dir.cleanup()
# test pre_pipeline
img_info = dataset.data_infos[0]
results = dict(img_info=img_info)
dataset.pre_pipeline(results)
assert results['img_prefix'] == dataset.img_prefix
assert results['text'] == img_info['text']
# test evluation
metric = 'acc'
results = [{'text': 'hello'}, {'text': 'worl'}]
eval_res = dataset.evaluate(results, metric)
assert math.isclose(eval_res['word_acc'], 0.5, abs_tol=1e-4)
assert math.isclose(eval_res['char_precision'], 1.0, abs_tol=1e-4)
assert math.isclose(eval_res['char_recall'], 0.9, abs_tol=1e-4)

View File

@ -0,0 +1,127 @@
import json
import math
import os.path as osp
import tempfile
import pytest
from mmocr.datasets.ocr_seg_dataset import OCRSegDataset
def _create_dummy_ann_file(ann_file):
ann_info1 = {
'file_name':
'sample1.png',
'annotations': [{
'char_text':
'F',
'char_box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0]
}, {
'char_text':
'r',
'char_box': [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0]
}, {
'char_text':
'o',
'char_box': [33.0, 2.0, 43.0, 2.0, 36.0, 12.0, 25.0, 12.0]
}, {
'char_text':
'm',
'char_box': [46.0, 2.0, 61.0, 2.0, 53.0, 12.0, 39.0, 12.0]
}, {
'char_text':
':',
'char_box': [61.0, 2.0, 69.0, 2.0, 63.0, 12.0, 55.0, 12.0]
}],
'text':
'From:'
}
ann_info2 = {
'file_name':
'sample2.png',
'annotations': [{
'char_text': 'o',
'char_box': [0.0, 5.0, 7.0, 5.0, 9.0, 15.0, 2.0, 15.0]
}, {
'char_text':
'u',
'char_box': [7.0, 4.0, 14.0, 4.0, 18.0, 18.0, 11.0, 18.0]
}, {
'char_text':
't',
'char_box': [13.0, 1.0, 19.0, 2.0, 24.0, 18.0, 17.0, 18.0]
}],
'text':
'out'
}
with open(ann_file, 'w') as fw:
for ann_info in [ann_info1, ann_info2]:
fw.write(json.dumps(ann_info) + '\n')
return ann_info1, ann_info2
def _create_dummy_loader():
loader = dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineJsonParser', keys=['file_name', 'text', 'annotations']))
return loader
def test_ocr_seg_dataset():
tmp_dir = tempfile.TemporaryDirectory()
# create dummy data
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
ann_info1, ann_info2 = _create_dummy_ann_file(ann_file)
# test initialization
loader = _create_dummy_loader()
dataset = OCRSegDataset(ann_file, loader, pipeline=[])
tmp_dir.cleanup()
# test pre_pipeline
img_info = dataset.data_infos[0]
results = dict(img_info=img_info)
dataset.pre_pipeline(results)
assert results['img_prefix'] == dataset.img_prefix
# test _parse_anno_info
annos = ann_info1['annotations']
with pytest.raises(AssertionError):
dataset._parse_anno_info(annos[0])
annos2 = ann_info2['annotations']
with pytest.raises(AssertionError):
dataset._parse_anno_info([{'char_text': 'i'}])
with pytest.raises(AssertionError):
dataset._parse_anno_info([{'char_box': [1, 2, 3, 4, 5, 6, 7, 8]}])
annos2[0]['char_box'] = [1, 2, 3]
with pytest.raises(AssertionError):
dataset._parse_anno_info(annos2)
return_anno = dataset._parse_anno_info(annos)
assert return_anno['chars'] == ['F', 'r', 'o', 'm', ':']
assert len(return_anno['char_rects']) == 5
# test prepare_train_img
expect_results = {
'img_info': {
'filename': 'sample1.png'
},
'img_prefix': '',
'ann_info': return_anno
}
data = dataset.prepare_train_img(0)
assert data == expect_results
# test evluation
metric = 'acc'
results = [{'text': 'From:'}, {'text': 'ou'}]
eval_res = dataset.evaluate(results, metric)
assert math.isclose(eval_res['word_acc'], 0.5, abs_tol=1e-4)
assert math.isclose(eval_res['char_precision'], 1.0, abs_tol=1e-4)
assert math.isclose(eval_res['char_recall'], 0.857, abs_tol=1e-4)

View File

@ -0,0 +1,93 @@
import os.path as osp
import tempfile
import numpy as np
import pytest
from mmocr.datasets.pipelines.ocr_seg_targets import OCRSegTargets
def _create_dummy_dict_file(dict_file):
chars = list('0123456789')
with open(dict_file, 'w') as fw:
for char in chars:
fw.write(char + '\n')
def test_ocr_segm_targets():
tmp_dir = tempfile.TemporaryDirectory()
# create dummy dict file
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
_create_dummy_dict_file(dict_file)
# dummy label convertor
label_convertor = dict(
type='SegConvertor',
dict_file=dict_file,
with_unknown=True,
lower=True)
# test init
with pytest.raises(AssertionError):
OCRSegTargets(None, 0.5, 0.5)
with pytest.raises(AssertionError):
OCRSegTargets(label_convertor, '1by2', 0.5)
with pytest.raises(AssertionError):
OCRSegTargets(label_convertor, 0.5, 2)
ocr_seg_tgt = OCRSegTargets(label_convertor, 0.5, 0.5)
# test generate kernels
img_size = (8, 8)
pad_size = (8, 10)
char_boxes = [[2, 2, 6, 6]]
char_idxs = [2]
with pytest.raises(AssertionError):
ocr_seg_tgt.generate_kernels(8, pad_size, char_boxes, char_idxs, 0.5,
True)
with pytest.raises(AssertionError):
ocr_seg_tgt.generate_kernels(img_size, pad_size, [2, 2, 6, 6],
char_idxs, 0.5, True)
with pytest.raises(AssertionError):
ocr_seg_tgt.generate_kernels(img_size, pad_size, char_boxes, 2, 0.5,
True)
attn_tgt = ocr_seg_tgt.generate_kernels(
img_size, pad_size, char_boxes, char_idxs, 0.5, binary=True)
expect_attn_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255]]
assert np.allclose(attn_tgt, np.array(expect_attn_tgt, dtype=np.int32))
segm_tgt = ocr_seg_tgt.generate_kernels(
img_size, pad_size, char_boxes, char_idxs, 0.5, binary=False)
expect_segm_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255]]
assert np.allclose(segm_tgt, np.array(expect_segm_tgt, dtype=np.int32))
# test __call__
results = {}
results['img_shape'] = (4, 4, 3)
results['resize_shape'] = (8, 8, 3)
results['pad_shape'] = (8, 10)
results['ann_info'] = {}
results['ann_info']['char_rects'] = [[1, 1, 3, 3]]
results['ann_info']['chars'] = ['1']
results = ocr_seg_tgt(results)
assert results['mask_fields'] == ['gt_kernels']
assert np.allclose(results['gt_kernels'].masks[0],
np.array(expect_attn_tgt, dtype=np.int32))
assert np.allclose(results['gt_kernels'].masks[1],
np.array(expect_segm_tgt, dtype=np.int32))
tmp_dir.cleanup()

View File

@ -0,0 +1,94 @@
import math
import unittest.mock as mock
import numpy as np
import torch
import torchvision.transforms.functional as TF
import mmocr.datasets.pipelines.ocr_transforms as transforms
def test_resize_ocr():
input_img = np.ones((64, 256, 3), dtype=np.uint8)
rci = transforms.ResizeOCR(
32, min_width=32, max_width=160, keep_aspect_ratio=True)
results = {'img_shape': input_img.shape, 'img': input_img}
# test call
results = rci(results)
assert np.allclose([32, 160, 3], results['pad_shape'])
assert np.allclose([32, 160, 3], results['img'].shape)
assert 'valid_ratio' in results
assert math.isclose(results['valid_ratio'], 0.8)
assert math.isclose(np.sum(results['img'][:, 129:, :]), 0)
rci = transforms.ResizeOCR(
32, min_width=32, max_width=160, keep_aspect_ratio=False)
results = {'img_shape': input_img.shape, 'img': input_img}
results = rci(results)
assert math.isclose(results['valid_ratio'], 1)
def test_to_tensor():
input_img = np.ones((64, 256, 3), dtype=np.uint8)
expect_output = TF.to_tensor(input_img)
rci = transforms.ToTensorOCR()
results = {'img': input_img}
results = rci(results)
assert np.allclose(results['img'].numpy(), expect_output.numpy())
def test_normalize():
inputs = torch.zeros(3, 10, 10)
expect_output = torch.ones_like(inputs) * (-1)
rci = transforms.NormalizeOCR(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
results = {'img': inputs}
results = rci(results)
assert np.allclose(results['img'].numpy(), expect_output.numpy())
@mock.patch('%s.transforms.np.random.random' % __name__)
def test_online_crop(mock_random):
kwargs = dict(
box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'],
jitter_prob=0.5,
max_jitter_ratio_x=0.05,
max_jitter_ratio_y=0.02)
mock_random.side_effect = [0.1, 1, 1, 1]
src_img = np.ones((100, 100, 3), dtype=np.uint8)
results = {
'img': src_img,
'img_info': {
'x1': '20',
'y1': '20',
'x2': '40',
'y2': '20',
'x3': '40',
'y3': '40',
'x4': '20',
'y4': '40'
}
}
rci = transforms.OnlineCropOCR(**kwargs)
results = rci(results)
assert np.allclose(results['img_shape'], [20, 20, 3])
# test not crop
mock_random.side_effect = [0.1, 1, 1, 1]
results['img_info'] = {}
results['img'] = src_img
results = rci(results)
assert np.allclose(results['img'].shape, [100, 100, 3])

View File

@ -0,0 +1,59 @@
import json
import pytest
from mmocr.datasets.utils.parser import LineJsonParser, LineStrParser
def test_line_str_parser():
data_ret = ['sample1.jpg hello', 'sample2.jpg world']
keys = ['filename', 'text']
keys_idx = [0, 1]
separator = ' '
# test init
with pytest.raises(AssertionError):
parser = LineStrParser('filename', keys_idx, separator)
with pytest.raises(AssertionError):
parser = LineStrParser(keys, keys_idx, [' '])
with pytest.raises(AssertionError):
parser = LineStrParser(keys, [0], separator)
# test get_item
parser = LineStrParser(keys, keys_idx, separator)
assert parser.get_item(data_ret, 0) == \
{'filename': 'sample1.jpg', 'text': 'hello'}
with pytest.raises(Exception):
parser = LineStrParser(['filename', 'text', 'ignore'], [0, 1, 2],
separator)
parser.get_item(data_ret, 0)
def test_line_dict_parser():
data_ret = [
json.dumps({
'filename': 'sample1.jpg',
'text': 'hello'
}),
json.dumps({
'filename': 'sample2.jpg',
'text': 'world'
})
]
keys = ['filename', 'text']
# test init
with pytest.raises(AssertionError):
parser = LineJsonParser('filename')
with pytest.raises(AssertionError):
parser = LineJsonParser([])
# test get_item
parser = LineJsonParser(keys)
assert parser.get_item(data_ret, 0) == \
{'filename': 'sample1.jpg', 'text': 'hello'}
with pytest.raises(Exception):
parser = LineJsonParser(['img_name', 'text'])
parser.get_item(data_ret, 0)

View File

@ -0,0 +1,33 @@
import numpy as np
import pytest
from mmocr.datasets.pipelines.test_time_aug import MultiRotateAugOCR
def test_resize_ocr():
input_img1 = np.ones((64, 256, 3), dtype=np.uint8)
input_img2 = np.ones((64, 32, 3), dtype=np.uint8)
rci = MultiRotateAugOCR(transforms=[], rotate_degrees=[0, 90, 270])
# test invalid arguments
with pytest.raises(AssertionError):
MultiRotateAugOCR(transforms=[], rotate_degrees=[45])
with pytest.raises(AssertionError):
MultiRotateAugOCR(transforms=[], rotate_degrees=[20.5])
# test call with input_img1
results = {'img_shape': input_img1.shape, 'img': input_img1}
results = rci(results)
assert np.allclose([64, 256, 3], results['img_shape'])
assert len(results['img']) == 1
assert len(results['img_shape']) == 1
assert np.allclose([64, 256, 3], results['img_shape'][0])
# test call with input_img2
results = {'img_shape': input_img2.shape, 'img': input_img2}
results = rci(results)
assert np.allclose([64, 32, 3], results['img_shape'])
assert len(results['img']) == 3
assert len(results['img_shape']) == 3
assert np.allclose([64, 32, 3], results['img_shape'][0])

View File

@ -0,0 +1,77 @@
import os.path as osp
import tempfile
import numpy as np
import pytest
import torch
from mmocr.models.textrecog.convertors import AttnConvertor
def _create_dummy_dict_file(dict_file):
characters = list('helowrd')
with open(dict_file, 'w') as fw:
for char in characters:
fw.write(char + '\n')
def test_attn_label_convertor():
tmp_dir = tempfile.TemporaryDirectory()
# create dummy data
dict_file = osp.join(tmp_dir.name, 'fake_dict.txt')
_create_dummy_dict_file(dict_file)
# test invalid arguments
with pytest.raises(AssertionError):
AttnConvertor(5)
with pytest.raises(AssertionError):
AttnConvertor('DICT90', dict_file, '1')
with pytest.raises(AssertionError):
AttnConvertor('DICT90', dict_file, True, '1')
label_convertor = AttnConvertor(dict_file=dict_file, max_seq_len=10)
# test init and parse_dict
assert label_convertor.num_classes() == 10
assert len(label_convertor.idx2char) == 10
assert label_convertor.idx2char[0] == 'h'
assert label_convertor.idx2char[1] == 'e'
assert label_convertor.idx2char[-3] == '<UKN>'
assert label_convertor.char2idx['h'] == 0
assert label_convertor.unknown_idx == 7
# test encode str to tensor
strings = ['hell']
targets_dict = label_convertor.str2tensor(strings)
assert torch.allclose(targets_dict['targets'][0],
torch.LongTensor([0, 1, 2, 2]))
assert torch.allclose(targets_dict['padded_targets'][0],
torch.LongTensor([8, 0, 1, 2, 2, 8, 9, 9, 9, 9]))
# test decode output to index
dummy_output = torch.Tensor([[[100, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 100, 3, 4, 5, 6, 7, 8, 9],
[1, 2, 100, 4, 5, 6, 7, 8, 9],
[1, 2, 100, 4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 5, 6, 7, 8, 100],
[1, 2, 3, 4, 5, 6, 7, 100, 9],
[1, 2, 3, 4, 5, 6, 7, 100, 9],
[1, 2, 3, 4, 5, 6, 7, 100, 9],
[1, 2, 3, 4, 5, 6, 7, 100, 9],
[1, 2, 3, 4, 5, 6, 7, 100, 9]]])
indexes, scores = label_convertor.tensor2idx(dummy_output)
assert np.allclose(indexes, [[0, 1, 2, 2]])
# test encode_str_label_to_index
with pytest.raises(AssertionError):
label_convertor.str2idx('hell')
tmp_indexes = label_convertor.str2idx(strings)
assert np.allclose(tmp_indexes, [[0, 1, 2, 2]])
# test decode_index to str_label
input_indexes = [[0, 1, 2, 2]]
with pytest.raises(AssertionError):
label_convertor.idx2str('hell')
output_strings = label_convertor.idx2str(input_indexes)
assert output_strings[0] == 'hell'
tmp_dir.cleanup()

View File

@ -0,0 +1,79 @@
import os.path as osp
import tempfile
import numpy as np
import pytest
import torch
from mmocr.models.textrecog.convertors import BaseConvertor, CTCConvertor
def _create_dummy_dict_file(dict_file):
chars = list('helowrd')
with open(dict_file, 'w') as fw:
for char in chars:
fw.write(char + '\n')
def test_ctc_label_convertor():
tmp_dir = tempfile.TemporaryDirectory()
# create dummy data
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
_create_dummy_dict_file(dict_file)
# test invalid arguments
with pytest.raises(AssertionError):
CTCConvertor(5)
label_convertor = CTCConvertor(dict_file=dict_file, with_unknown=False)
# test init and parse_chars
assert label_convertor.num_classes() == 8
assert len(label_convertor.idx2char) == 8
assert label_convertor.idx2char[0] == '<BLK>'
assert label_convertor.char2idx['h'] == 1
assert label_convertor.unknown_idx is None
# test encode str to tensor
strings = ['hell']
expect_tensor = torch.IntTensor([1, 2, 3, 3])
targets_dict = label_convertor.str2tensor(strings)
assert torch.allclose(targets_dict['targets'][0], expect_tensor)
assert torch.allclose(targets_dict['flatten_targets'], expect_tensor)
assert torch.allclose(targets_dict['target_lengths'], torch.IntTensor([4]))
# test decode output to index
dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8],
[100, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 100, 4, 5, 6, 7, 8],
[1, 2, 100, 4, 5, 6, 7, 8],
[100, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 100, 5, 6, 7, 8],
[100, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 100, 5, 6, 7, 8]]])
indexes, scores = label_convertor.tensor2idx(
dummy_output, img_metas=[{
'valid_ratio': 1.0
}])
assert np.allclose(indexes, [[1, 2, 3, 3]])
# test encode_str_label_to_index
with pytest.raises(AssertionError):
label_convertor.str2idx('hell')
tmp_indexes = label_convertor.str2idx(strings)
assert np.allclose(tmp_indexes, [[1, 2, 3, 3]])
# test deocde_index_to_str_label
input_indexes = [[1, 2, 3, 3]]
with pytest.raises(AssertionError):
label_convertor.idx2str('hell')
output_strings = label_convertor.idx2str(input_indexes)
assert output_strings[0] == 'hell'
tmp_dir.cleanup()
def test_base_label_convertor():
with pytest.raises(NotImplementedError):
label_convertor = BaseConvertor()
label_convertor.str2tensor(None)
label_convertor.tensor2idx(None)

View File

@ -0,0 +1,36 @@
import pytest
import torch
from mmocr.models.textrecog.backbones import ResNet31OCR, VeryDeepVgg
def test_resnet31_ocr_backbone():
"""Test resnet backbone."""
with pytest.raises(AssertionError):
ResNet31OCR(2.5)
with pytest.raises(AssertionError):
ResNet31OCR(3, layers=5)
with pytest.raises(AssertionError):
ResNet31OCR(3, channels=5)
# Test ResNet18 forward
model = ResNet31OCR()
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 32, 160)
feat = model(imgs)
assert feat.shape == torch.Size([1, 512, 4, 40])
def test_vgg_deep_vgg_ocr_backbone():
model = VeryDeepVgg()
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 32, 160)
feats = model(imgs)
assert feats.shape == torch.Size([1, 512, 1, 41])

View File

@ -0,0 +1,112 @@
import math
import pytest
import torch
from mmocr.models.textrecog.decoders import (BaseDecoder, ParallelSARDecoder,
ParallelSARDecoderWithBS,
SequentialSARDecoder, TFDecoder)
from mmocr.models.textrecog.decoders.sar_decoder_with_bs import DecodeNode
def _create_dummy_input():
feat = torch.rand(1, 512, 4, 40)
out_enc = torch.rand(1, 512)
tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])}
img_metas = [{'valid_ratio': 1.0}]
return feat, out_enc, tgt_dict, img_metas
def test_base_decoder():
decoder = BaseDecoder()
with pytest.raises(NotImplementedError):
decoder.forward_train(None, None, None, None)
with pytest.raises(NotImplementedError):
decoder.forward_test(None, None, None)
def test_parallel_sar_decoder():
# test parallel sar decoder
decoder = ParallelSARDecoder(num_classes=37, padding_idx=36, max_seq_len=5)
decoder.init_weights()
decoder.train()
feat, out_enc, tgt_dict, img_metas = _create_dummy_input()
with pytest.raises(AssertionError):
decoder(feat, out_enc, tgt_dict, [], True)
with pytest.raises(AssertionError):
decoder(feat, out_enc, tgt_dict, img_metas * 2, True)
out_train = decoder(feat, out_enc, tgt_dict, img_metas, True)
assert out_train.shape == torch.Size([1, 5, 36])
out_test = decoder(feat, out_enc, tgt_dict, img_metas, False)
assert out_test.shape == torch.Size([1, 5, 36])
def test_sequential_sar_decoder():
# test parallel sar decoder
decoder = SequentialSARDecoder(
num_classes=37, padding_idx=36, max_seq_len=5)
decoder.init_weights()
decoder.train()
feat, out_enc, tgt_dict, img_metas = _create_dummy_input()
with pytest.raises(AssertionError):
decoder(feat, out_enc, tgt_dict, [])
with pytest.raises(AssertionError):
decoder(feat, out_enc, tgt_dict, img_metas * 2)
out_train = decoder(feat, out_enc, tgt_dict, img_metas, True)
assert out_train.shape == torch.Size([1, 5, 36])
out_test = decoder(feat, out_enc, tgt_dict, img_metas, False)
assert out_test.shape == torch.Size([1, 5, 36])
def test_parallel_sar_decoder_with_beam_search():
with pytest.raises(AssertionError):
ParallelSARDecoderWithBS(beam_width='beam')
with pytest.raises(AssertionError):
ParallelSARDecoderWithBS(beam_width=0)
feat, out_enc, tgt_dict, img_metas = _create_dummy_input()
decoder = ParallelSARDecoderWithBS(
beam_width=1, num_classes=37, padding_idx=36, max_seq_len=5)
decoder.init_weights()
decoder.train()
with pytest.raises(AssertionError):
decoder(feat, out_enc, tgt_dict, [])
with pytest.raises(AssertionError):
decoder(feat, out_enc, tgt_dict, img_metas * 2)
out_test = decoder(feat, out_enc, tgt_dict, img_metas, train_mode=False)
assert out_test.shape == torch.Size([1, 5, 36])
# test decodenode
with pytest.raises(AssertionError):
DecodeNode(1, 1)
with pytest.raises(AssertionError):
DecodeNode([1, 2], ['4', '3'])
with pytest.raises(AssertionError):
DecodeNode([1, 2], [0.5])
decode_node = DecodeNode([1, 2], [0.7, 0.8])
assert math.isclose(decode_node.eval(), 1.5)
def test_transformer_decoder():
decoder = TFDecoder(num_classes=37, padding_idx=36, max_seq_len=5)
decoder.init_weights()
decoder.train()
out_enc = torch.rand(1, 128, 512)
tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])}
img_metas = [{'valid_ratio': 1.0}]
tgt_dict['padded_targets'] = tgt_dict['padded_targets']
out_train = decoder(None, out_enc, tgt_dict, img_metas, True)
assert out_train.shape == torch.Size([1, 5, 36])
out_test = decoder(None, out_enc, tgt_dict, img_metas, False)
assert out_test.shape == torch.Size([1, 5, 36])

View File

@ -0,0 +1,53 @@
import pytest
import torch
from mmocr.models.textrecog.encoders import BaseEncoder, SAREncoder, TFEncoder
def test_sar_encoder():
with pytest.raises(AssertionError):
SAREncoder(enc_bi_rnn='bi')
with pytest.raises(AssertionError):
SAREncoder(enc_do_rnn=2)
with pytest.raises(AssertionError):
SAREncoder(enc_gru='gru')
with pytest.raises(AssertionError):
SAREncoder(d_model=512.5)
with pytest.raises(AssertionError):
SAREncoder(d_enc=200.5)
with pytest.raises(AssertionError):
SAREncoder(mask='mask')
encoder = SAREncoder()
encoder.init_weights()
encoder.train()
feat = torch.randn(1, 512, 4, 40)
with pytest.raises(AssertionError):
encoder(feat)
img_metas = [{'valid_ratio': 1.0}]
with pytest.raises(AssertionError):
encoder(feat, img_metas * 2)
out_enc = encoder(feat, img_metas)
assert out_enc.shape == torch.Size([1, 512])
def test_transformer_encoder():
tf_encoder = TFEncoder()
tf_encoder.init_weights()
tf_encoder.train()
feat = torch.randn(1, 512, 4, 40)
out_enc = tf_encoder(feat)
assert out_enc.shape == torch.Size([1, 160, 512])
def test_base_encoder():
encoder = BaseEncoder()
encoder.init_weights()
encoder.train()
feat = torch.randn(1, 256, 4, 40)
out_enc = encoder(feat)
assert out_enc.shape == torch.Size([1, 256, 4, 40])

View File

@ -0,0 +1,16 @@
import pytest
import torch
from mmocr.models.textrecog import SegHead
def test_cafcn_head():
with pytest.raises(AssertionError):
SegHead(num_classes='100')
with pytest.raises(AssertionError):
SegHead(num_classes=-1)
cafcn_head = SegHead(num_classes=37)
out_neck = (torch.rand(1, 128, 32, 32), )
out_head = cafcn_head(out_neck)
assert out_head.shape == torch.Size([1, 37, 32, 32])

View File

@ -0,0 +1,55 @@
import torch
from mmocr.models.textrecog.layers import (BasicBlock, Bottleneck,
DecoderLayer, PositionalEncoding,
get_pad_mask, get_subsequent_mask)
from mmocr.models.textrecog.layers.conv_layer import conv3x3
def test_conv_layer():
conv3by3 = conv3x3(3, 6)
assert conv3by3.in_channels == 3
assert conv3by3.out_channels == 6
assert conv3by3.kernel_size == (3, 3)
x = torch.rand(1, 64, 224, 224)
# test basic block
basic_block = BasicBlock(64, 64)
assert basic_block.expansion == 1
out = basic_block(x)
assert out.shape == torch.Size([1, 64, 224, 224])
# test bottle neck
bottle_neck = Bottleneck(64, 64, downsample=True)
assert bottle_neck.expansion == 4
out = bottle_neck(x)
assert out.shape == torch.Size([1, 256, 224, 224])
def test_transformer_layer():
# test decoder_layer
decoder_layer = DecoderLayer()
in_dec = torch.rand(1, 30, 512)
out_enc = torch.rand(1, 128, 512)
out_dec = decoder_layer(in_dec, out_enc)
assert out_dec.shape == torch.Size([1, 30, 512])
# test positional_encoding
pos_encoder = PositionalEncoding()
x = torch.rand(1, 30, 512)
out = pos_encoder(x)
assert out.size() == x.size()
# test get pad mask
seq = torch.rand(1, 30)
pad_idx = 0
out = get_pad_mask(seq, pad_idx)
assert out.shape == torch.Size([1, 1, 30])
# test get_subsequent_mask
out_mask = get_subsequent_mask(seq)
assert out_mask.shape == torch.Size([1, 30, 30])

View File

@ -0,0 +1,123 @@
import numpy as np
import pytest
import torch
from mmdet.core import BitmapMasks
from mmocr.models.common.losses import DiceLoss
from mmocr.models.textrecog.losses import (CAFCNLoss, CELoss, CTCLoss, SARLoss,
TFLoss)
def test_ctc_loss():
# test CTCLoss
ctc_loss = CTCLoss()
outputs = torch.zeros(2, 40, 37)
targets_dict = {
'flatten_targets': torch.IntTensor([1, 2, 3, 4, 5]),
'target_lengths': torch.LongTensor([2, 3])
}
losses = ctc_loss(outputs, targets_dict)
assert isinstance(losses, dict)
assert 'loss_ctc' in losses
assert torch.allclose(losses['loss_ctc'],
torch.tensor(losses['loss_ctc'].item()).float())
def test_ce_loss():
with pytest.raises(AssertionError):
CELoss(ignore_index='ignore')
with pytest.raises(AssertionError):
CELoss(reduction=1)
with pytest.raises(AssertionError):
CELoss(reduction='avg')
ce_loss = CELoss(ignore_index=0)
outputs = torch.rand(1, 10, 37)
targets_dict = {
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
}
losses = ce_loss(outputs, targets_dict)
assert isinstance(losses, dict)
assert 'loss_ce' in losses
print(losses['loss_ce'].size())
assert losses['loss_ce'].size(1) == 10
def test_sar_loss():
outputs = torch.rand(1, 10, 37)
targets_dict = {
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
}
sar_loss = SARLoss()
new_output, new_target = sar_loss.format(outputs, targets_dict)
assert new_output.shape == torch.Size([1, 37, 9])
assert new_target.shape == torch.Size([1, 9])
def test_tf_loss():
with pytest.raises(AssertionError):
TFLoss(flatten=1.0)
outputs = torch.rand(1, 10, 37)
targets_dict = {
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
}
tf_loss = TFLoss(flatten=False)
new_output, new_target = tf_loss.format(outputs, targets_dict)
assert new_output.shape == torch.Size([1, 37, 9])
assert new_target.shape == torch.Size([1, 9])
def test_cafcn_loss():
with pytest.raises(AssertionError):
CAFCNLoss(alpha='1')
with pytest.raises(AssertionError):
CAFCNLoss(attn_s2_downsample_ratio='2')
with pytest.raises(AssertionError):
CAFCNLoss(attn_s3_downsample_ratio='1.5')
with pytest.raises(AssertionError):
CAFCNLoss(seg_downsample_ratio='1.5')
with pytest.raises(AssertionError):
CAFCNLoss(attn_s2_downsample_ratio=2)
with pytest.raises(AssertionError):
CAFCNLoss(attn_s3_downsample_ratio=1.5)
with pytest.raises(AssertionError):
CAFCNLoss(seg_downsample_ratio=1.5)
bsz = 1
H = W = 64
out_neck = (torch.ones(bsz, 1, H // 4, W // 4) * 0.5,
torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
torch.ones(bsz, 1, H // 2, W // 2) * 0.5)
out_head = torch.rand(bsz, 37, H // 2, W // 2)
attn_tgt = np.zeros((H, W), dtype=np.float32)
segm_tgt = np.zeros((H, W), dtype=np.float32)
mask = np.ones((H, W), dtype=np.float32)
gt_kernels = BitmapMasks([attn_tgt, segm_tgt, mask], H, W)
cafcn_loss = CAFCNLoss()
losses = cafcn_loss(out_neck, out_head, [gt_kernels])
assert isinstance(losses, dict)
assert 'loss_seg' in losses
assert torch.allclose(losses['loss_seg'],
torch.tensor(losses['loss_seg'].item()).float())
def test_dice_loss():
with pytest.raises(AssertionError):
DiceLoss(eps='1')
dice_loss = DiceLoss()
pred = torch.rand(1, 1, 32, 32)
gt = torch.rand(1, 1, 32, 32)
loss = dice_loss(pred, gt, None)
assert isinstance(loss, torch.Tensor)
mask = torch.rand(1, 1, 1, 1)
loss = dice_loss(pred, gt, mask)
assert isinstance(loss, torch.Tensor)

View File

@ -0,0 +1,48 @@
import pytest
import torch
from mmocr.models.textrecog.necks.cafcn_neck import (CAFCNNeck, CharAttn,
FeatGenerator)
def test_char_attn():
with pytest.raises(AssertionError):
CharAttn(in_channels=5.0)
with pytest.raises(AssertionError):
CharAttn(deformable='deformabel')
in_feat = torch.rand(1, 128, 32, 32)
char_attn = CharAttn()
out_feat_map, attn_map = char_attn(in_feat)
assert attn_map.shape == torch.Size([1, 1, 32, 32])
assert out_feat_map.shape == torch.Size([1, 128, 32, 32])
@pytest.mark.skip(reason='TODO: re-enable after CI support pytorch>1.4')
def test_feat_generator():
in_feat = torch.rand(1, 128, 32, 32)
feat_generator = FeatGenerator(in_channels=128, out_channels=128)
attn_map, feat_map = feat_generator(in_feat)
assert attn_map.shape == torch.Size([1, 1, 32, 32])
assert feat_map.shape == torch.Size([1, 128, 32, 32])
@pytest.mark.skip(reason='TODO: re-enable after CI support pytorch>1.4')
def test_cafcn_neck():
in_s1 = torch.rand(1, 64, 64, 64)
in_s2 = torch.rand(1, 128, 32, 32)
in_s3 = torch.rand(1, 256, 16, 16)
in_s4 = torch.rand(1, 512, 16, 16)
in_s5 = torch.rand(1, 512, 16, 16)
cafcn_neck = CAFCNNeck()
cafcn_neck.init_weights()
cafcn_neck.train()
out_neck = cafcn_neck((in_s1, in_s2, in_s3, in_s4, in_s5))
assert out_neck[0].shape == torch.Size([1, 1, 32, 32])
assert out_neck[1].shape == torch.Size([1, 1, 16, 16])
assert out_neck[2].shape == torch.Size([1, 1, 16, 16])
assert out_neck[3].shape == torch.Size([1, 1, 16, 16])
assert out_neck[4].shape == torch.Size([1, 128, 64, 64])

View File

@ -0,0 +1,156 @@
import os.path as osp
import tempfile
import numpy as np
import pytest
import torch
from mmdet.core import BitmapMasks
from mmocr.models.textrecog.recognizer import (EncodeDecodeRecognizer,
SegRecognizer)
def _create_dummy_dict_file(dict_file):
chars = list('helowrd')
with open(dict_file, 'w') as fw:
for char in chars:
fw.write(char + '\n')
def test_base_recognizer():
tmp_dir = tempfile.TemporaryDirectory()
# create dummy data
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
_create_dummy_dict_file(dict_file)
label_convertor = dict(
type='CTCConvertor', dict_file=dict_file, with_unknown=False)
preprocessor = None
backbone = dict(type='VeryDeepVgg', leakyRelu=False)
encoder = None
decoder = dict(type='CRNNDecoder', in_channels=512, rnn_flag=True)
loss = dict(type='CTCLoss')
with pytest.raises(AssertionError):
EncodeDecodeRecognizer(backbone=None)
with pytest.raises(AssertionError):
EncodeDecodeRecognizer(decoder=None)
with pytest.raises(AssertionError):
EncodeDecodeRecognizer(loss=None)
with pytest.raises(AssertionError):
EncodeDecodeRecognizer(label_convertor=None)
recognizer = EncodeDecodeRecognizer(
preprocessor=preprocessor,
backbone=backbone,
encoder=encoder,
decoder=decoder,
loss=loss,
label_convertor=label_convertor)
recognizer.init_weights()
recognizer.train()
imgs = torch.rand(1, 3, 32, 160)
# test extract feat
feat = recognizer.extract_feat(imgs)
assert feat.shape == torch.Size([1, 512, 1, 41])
# test forward train
img_metas = [{'text': 'hello', 'valid_ratio': 1.0}]
losses = recognizer.forward_train(imgs, img_metas)
assert isinstance(losses, dict)
assert 'loss_ctc' in losses
# test simple test
results = recognizer.simple_test(imgs, img_metas)
assert isinstance(results, list)
assert isinstance(results[0], dict)
assert 'text' in results[0]
assert 'score' in results[0]
# test aug_test
aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas])
assert isinstance(aug_results, list)
assert isinstance(aug_results[0], dict)
assert 'text' in aug_results[0]
assert 'score' in aug_results[0]
tmp_dir.cleanup()
@pytest.mark.skip(reason='TODO: re-enable after CI support pytorch>1.4')
def test_seg_recognizer():
tmp_dir = tempfile.TemporaryDirectory()
# create dummy data
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
_create_dummy_dict_file(dict_file)
label_convertor = dict(
type='SegConvertor', dict_file=dict_file, with_unknown=False)
preprocessor = None
backbone = dict(type='ResNet31OCR')
neck = dict(type='FPNOCR')
head = dict(type='SegHead')
loss = dict(type='SegLoss')
with pytest.raises(AssertionError):
SegRecognizer(backbone=None)
with pytest.raises(AssertionError):
SegRecognizer(neck=None)
with pytest.raises(AssertionError):
SegRecognizer(head=None)
with pytest.raises(AssertionError):
SegRecognizer(loss=None)
with pytest.raises(AssertionError):
SegRecognizer(label_convertor=None)
recognizer = SegRecognizer(
preprocessor=preprocessor,
backbone=backbone,
neck=neck,
head=head,
loss=loss,
label_convertor=label_convertor)
recognizer.init_weights()
recognizer.train()
imgs = torch.rand(1, 3, 64, 256)
# test extract feat
feats = recognizer.extract_feat(imgs)
assert len(feats) == 5
assert feats[0].shape == torch.Size([1, 64, 32, 128])
assert feats[1].shape == torch.Size([1, 128, 16, 64])
assert feats[2].shape == torch.Size([1, 256, 8, 32])
assert feats[3].shape == torch.Size([1, 512, 8, 32])
assert feats[4].shape == torch.Size([1, 512, 8, 32])
attn_tgt = np.zeros((64, 256), dtype=np.float32)
segm_tgt = np.zeros((64, 256), dtype=np.float32)
gt_kernels = BitmapMasks([attn_tgt, segm_tgt], 64, 256)
# test forward train
img_metas = [{'text': 'hello', 'valid_ratio': 1.0}]
losses = recognizer.forward_train(imgs, img_metas, gt_kernels=[gt_kernels])
assert isinstance(losses, dict)
# test simple test
results = recognizer.simple_test(imgs, img_metas)
assert isinstance(results, list)
assert isinstance(results[0], dict)
assert 'text' in results[0]
assert 'score' in results[0]
# test aug_test
aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas])
assert isinstance(aug_results, list)
assert isinstance(aug_results[0], dict)
assert 'text' in aug_results[0]
assert 'score' in aug_results[0]
tmp_dir.cleanup()

View File

@ -0,0 +1,66 @@
"""Test text label visualize."""
import os.path as osp
import random
import tempfile
from unittest import mock
import numpy as np
import pytest
import mmocr.core.visualize as visualize_utils
def test_tile_image():
dummp_imgs, heights, widths = [], [], []
for _ in range(3):
h = random.randint(100, 300)
w = random.randint(100, 300)
heights.append(h)
widths.append(w)
# dummy_img = Image.new('RGB', (w, h), Image.ANTIALIAS)
dummy_img = np.ones((h, w, 3), dtype=np.uint8)
dummp_imgs.append(dummy_img)
joint_img = visualize_utils.tile_image(dummp_imgs)
assert joint_img.shape[0] == sum(heights)
assert joint_img.shape[1] == max(widths)
# test invalid arguments
with pytest.raises(AssertionError):
visualize_utils.tile_image(dummp_imgs[0])
with pytest.raises(AssertionError):
visualize_utils.tile_image([])
@mock.patch('%s.visualize_utils.mmcv.imread' % __name__)
@mock.patch('%s.visualize_utils.mmcv.imshow' % __name__)
@mock.patch('%s.visualize_utils.mmcv.imwrite' % __name__)
def test_show_text_label(mock_imwrite, mock_imshow, mock_imread):
img = np.ones((32, 160), dtype=np.uint8)
pred_label = 'hello'
gt_label = 'world'
tmp_dir = tempfile.TemporaryDirectory()
out_file = osp.join(tmp_dir.name, 'tmp.jpg')
# test invalid arguments
with pytest.raises(AssertionError):
visualize_utils.imshow_text_label(5, pred_label, gt_label)
with pytest.raises(AssertionError):
visualize_utils.imshow_text_label(img, pred_label, 4)
with pytest.raises(AssertionError):
visualize_utils.imshow_text_label(img, 3, gt_label)
with pytest.raises(AssertionError):
visualize_utils.imshow_text_label(
img, pred_label, gt_label, show=True, wait_time=0.1)
mock_imread.side_effect = [img, img]
visualize_utils.imshow_text_label(
img, pred_label, gt_label, out_file=out_file)
visualize_utils.imshow_text_label(
img, pred_label, gt_label, out_file=None, show=True)
# test showing img
mock_imshow.assert_called_once()
mock_imwrite.assert_called_once()
tmp_dir.cleanup()

View File

@ -0,0 +1,93 @@
import argparse
import codecs
import json
import os.path as osp
import cv2
def read_json(fpath):
with codecs.open(fpath, 'r', 'utf-8') as f:
obj = json.load(f)
return obj
def parse_old_label(img_prefix, in_path):
imgid2imgname = {}
imgid2anno = {}
idx = 0
with open(in_path, 'r') as fr:
for line in fr:
line = line.strip().split()
img_full_path = osp.join(img_prefix, line[0])
if not osp.exists(img_full_path):
continue
img = cv2.imread(img_full_path)
h, w = img.shape[:2]
img_info = {}
img_info['file_name'] = line[0]
img_info['height'] = h
img_info['width'] = w
imgid2imgname[idx] = img_info
imgid2anno[idx] = []
for i in range(len(line[1:]) // 8):
seg = [int(x) for x in line[(1 + i * 8):(1 + (i + 1) * 8)]]
points_x = seg[0:2:8]
points_y = seg[1:2:9]
box = [
min(points_x),
min(points_y),
max(points_x),
max(points_y)
]
new_anno = {}
new_anno['iscrowd'] = 0
new_anno['category_id'] = 1
new_anno['bbox'] = box
new_anno['segmentation'] = [seg]
imgid2anno[idx].append(new_anno)
idx += 1
return imgid2imgname, imgid2anno
def gen_line_dict_file(out_path, imgid2imgname, imgid2anno):
# import pdb; pdb.set_trace()
with codecs.open(out_path, 'w', 'utf-8') as fw:
for key, value in imgid2imgname.items():
if key in imgid2anno:
anno = imgid2anno[key]
line_dict = {}
line_dict['file_name'] = value['file_name']
line_dict['height'] = value['height']
line_dict['width'] = value['width']
line_dict['annotations'] = anno
line_dict_str = json.dumps(line_dict)
fw.write(line_dict_str + '\n')
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--img-prefix',
help='image prefix, to generate full image path with "image_name"')
parser.add_argument(
'--in-path',
help='mapping file of image_name and ann_file,'
' "image_name ann_file" in each line')
parser.add_argument(
'--out-path', help='output txt path with line-json format')
args = parser.parse_args()
return args
def main():
args = parse_args()
imgid2imgname, imgid2anno = parse_old_label(args.img_prefix, args.in_path)
gen_line_dict_file(args.out_path, imgid2imgname, imgid2anno)
print('finish')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,74 @@
import argparse
import shutil
import sys
import time
from pathlib import Path
import lmdb
def converter(imglist, output, batch_size=1000, coding='utf-8'):
# read imglist
with open(imglist) as f:
lines = f.readlines()
# create lmdb database
if Path(output).is_dir():
while True:
print('%s already exist, delete or not? [Y/n]' % output)
Yn = input().strip()
if Yn in ['Y', 'y']:
shutil.rmtree(output)
break
elif Yn in ['N', 'n']:
return
print('create database %s' % output)
Path(output).mkdir(parents=True, exist_ok=False)
env = lmdb.open(output, map_size=1099511627776)
# build lmdb
beg_time = time.strftime('%H:%M:%S')
for beg_index in range(0, len(lines), batch_size):
end_index = min(beg_index + batch_size, len(lines))
sys.stdout.write('\r[%s-%s], processing [%d-%d] / %d' %
(beg_time, time.strftime('%H:%M:%S'), beg_index,
end_index, len(lines)))
sys.stdout.flush()
batch = [(str(index).encode(coding), lines[index].encode(coding))
for index in range(beg_index, end_index)]
with env.begin(write=True) as txn:
cursor = txn.cursor()
cursor.putmulti(batch, dupdata=False, overwrite=True)
sys.stdout.write('\n')
with env.begin(write=True) as txn:
key = 'total_number'.encode(coding)
value = str(len(lines)).encode(coding)
txn.put(key, value)
print('done', flush=True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--imglist', '-i', required=True, help='input imglist path')
parser.add_argument(
'--output', '-o', required=True, help='output lmdb path')
parser.add_argument(
'--batch_size',
'-b',
type=int,
default=10000,
help='processing batch size, default 10000')
parser.add_argument(
'--coding',
'-c',
default='utf8',
help='bytes coding scheme, default utf8')
opt = parser.parse_args()
converter(
opt.imglist, opt.output, batch_size=opt.batch_size, coding=opt.coding)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,132 @@
import os.path as osp
import shutil
import time
from argparse import ArgumentParser
import mmcv
import torch
from mmcv.utils import ProgressBar
from mmdet.apis import init_detector
from mmdet.utils import get_root_logger
from mmocr.apis import model_inference
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
from mmocr.datasets import build_dataset # noqa: F401
from mmocr.models import build_detector # noqa: F401
def save_results(img_paths, pred_labels, gt_labels, res_dir):
"""Save predicted results to txt file.
Args:
img_paths (list[str])
pred_labels (list[str])
gt_labels (list[str])
res_dir (str)
"""
assert len(img_paths) == len(pred_labels) == len(gt_labels)
res_file = osp.join(res_dir, 'results.txt')
correct_file = osp.join(res_dir, 'correct.txt')
wrong_file = osp.join(res_dir, 'wrong.txt')
with open(res_file, 'w') as fw, \
open(correct_file, 'w') as fw_correct, \
open(wrong_file, 'w') as fw_wrong:
for img_path, pred_label, gt_label in zip(img_paths, pred_labels,
gt_labels):
fw.write(img_path + ' ' + pred_label + ' ' + gt_label + '\n')
if pred_label == gt_label:
fw_correct.write(img_path + ' ' + pred_label + ' ' + gt_label +
'\n')
else:
fw_wrong.write(img_path + ' ' + pred_label + ' ' + gt_label +
'\n')
def main():
parser = ArgumentParser()
parser.add_argument('--img_root_path', type=str, help='Image root path')
parser.add_argument('--img_list', type=str, help='Image path list file')
parser.add_argument('--config', type=str, help='Config file')
parser.add_argument('--checkpoint', type=str, help='Checkpoint file')
parser.add_argument(
'--out_dir', type=str, default='./results', help='Dir to save results')
parser.add_argument(
'--show', action='store_true', help='show image or save')
args = parser.parse_args()
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(args.out_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level='INFO')
# build the model from a config file and a checkpoint file
device = 'cuda:' + str(torch.cuda.current_device())
model = init_detector(args.config, args.checkpoint, device=device)
if hasattr(model, 'module'):
model = model.module
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = \
model.cfg.data.test['datasets'][0].pipeline
# Start Inference
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
mmcv.mkdir_or_exist(out_vis_dir)
correct_vis_dir = osp.join(args.out_dir, 'correct')
mmcv.mkdir_or_exist(correct_vis_dir)
wrong_vis_dir = osp.join(args.out_dir, 'wrong')
mmcv.mkdir_or_exist(wrong_vis_dir)
img_paths, pred_labels, gt_labels = [], [], []
total_img_num = sum([1 for _ in open(args.img_list)])
progressbar = ProgressBar(task_num=total_img_num)
num_gt_label = 0
with open(args.img_list, 'r') as fr:
for line in fr:
progressbar.update()
item_list = line.strip().split()
img_file = item_list[0]
gt_label = ''
if len(item_list) >= 2:
gt_label = item_list[1]
num_gt_label += 1
img_path = osp.join(args.img_root_path, img_file)
if not osp.exists(img_path):
raise FileNotFoundError(img_path)
# Test a single image
result = model_inference(model, img_path)
pred_label = result['text']
out_img_name = '_'.join(img_file.split('/'))
out_file = osp.join(out_vis_dir, out_img_name)
kwargs_dict = {
'gt_label': gt_label,
'show': args.show,
'out_file': '' if args.show else out_file
}
model.show_result(img_path, result, **kwargs_dict)
if gt_label != '':
if gt_label == pred_label:
dst_file = osp.join(correct_vis_dir, out_img_name)
else:
dst_file = osp.join(wrong_vis_dir, out_img_name)
shutil.copy(out_file, dst_file)
img_paths.append(img_path)
gt_labels.append(gt_label)
pred_labels.append(pred_label)
# Save results
save_results(img_paths, pred_labels, gt_labels, args.out_dir)
if num_gt_label == len(pred_labels):
# eval
eval_results = eval_ocr_metric(pred_labels, gt_labels)
logger.info('\n' + '-' * 100)
info = 'eval on testset with img_root_path ' + \
f'{args.img_root_path} and img_list {args.img_list}\n'
logger.info(info)
logger.info(eval_results)
print(f'\nInference done, and results saved in {args.out_dir}\n')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,25 @@
#!/bin/bash
DATE=`date +%Y-%m-%d`
TIME=`date +"%H-%M-%S"`
if [ $# -lt 5 ]
then
echo "Usage: bash $0 CONFIG CHECKPOINT IMG_PREFIX IMG_LIST RESULTS_DIR"
exit
fi
CONFIG_FILE=$1
CHECKPOINT=$2
IMG_ROOT_PATH=$3
IMG_LIST=$4
OUT_DIR=$5_${DATE}_${TIME}
mkdir ${OUT_DIR} -p &&
python tools/ocr_test_imgs.py \
--img_root_path ${IMG_ROOT_PATH} \
--img_list ${IMG_LIST} \
--config ${CONFIG_FILE} \
--checkpoint ${CHECKPOINT} \
--out_dir ${OUT_DIR}