diff --git a/configs/_base_/recog_datasets/seg_toy_dataset.py b/configs/_base_/recog_datasets/seg_toy_dataset.py
new file mode 100644
index 00000000..d6c49dbf
--- /dev/null
+++ b/configs/_base_/recog_datasets/seg_toy_dataset.py
@@ -0,0 +1,96 @@
+img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+gt_label_convertor = dict(
+ type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='RandomPaddingOCR',
+ max_ratio=[0.15, 0.2, 0.15, 0.2],
+ box_type='char_quads'),
+ dict(type='OpencvToPil'),
+ dict(
+ type='RandomRotateImageBox',
+ min_angle=-17,
+ max_angle=17,
+ box_type='char_quads'),
+ dict(type='PilToOpencv'),
+ dict(
+ type='ResizeOCR',
+ height=64,
+ min_width=64,
+ max_width=512,
+ keep_aspect_ratio=True),
+ dict(
+ type='OCRSegTargets',
+ label_convertor=gt_label_convertor,
+ box_type='char_quads'),
+ dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
+ dict(type='ToTensorOCR'),
+ dict(type='FancyPCA'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='CustomFormatBundle',
+ keys=['gt_kernels'],
+ visualize=dict(flag=False, boundary_key=None),
+ call_super=False),
+ dict(
+ type='Collect',
+ keys=['img', 'gt_kernels'],
+ meta_keys=['filename', 'ori_shape', 'img_shape'])
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeOCR',
+ height=64,
+ min_width=64,
+ max_width=None,
+ keep_aspect_ratio=True),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(type='CustomFormatBundle', call_super=False),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=['filename', 'ori_shape', 'img_shape'])
+]
+
+prefix = 'tests/data/ocr_char_ann_toy_dataset/'
+train = dict(
+ type='OCRSegDataset',
+ img_prefix=prefix + 'imgs',
+ ann_file=prefix + 'instances_train.txt',
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=100,
+ parser=dict(
+ type='LineJsonParser', keys=['file_name', 'annotations', 'text'])),
+ pipeline=train_pipeline,
+ test_mode=True)
+
+test = dict(
+ type='OCRDataset',
+ img_prefix=prefix + 'imgs',
+ ann_file=prefix + 'instances_test.txt',
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=1,
+ train=dict(type='ConcatDataset', datasets=[train]),
+ val=dict(type='ConcatDataset', datasets=[test]),
+ test=dict(type='ConcatDataset', datasets=[test]))
+
+evaluation = dict(interval=1, metric='acc')
diff --git a/configs/_base_/recog_datasets/toy_dataset.py b/configs/_base_/recog_datasets/toy_dataset.py
new file mode 100755
index 00000000..83848863
--- /dev/null
+++ b/configs/_base_/recog_datasets/toy_dataset.py
@@ -0,0 +1,99 @@
+img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeOCR',
+ height=32,
+ min_width=32,
+ max_width=160,
+ keep_aspect_ratio=True),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
+ ]),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiRotateAugOCR',
+ rotate_degrees=[0, 90, 270],
+ transforms=[
+ dict(
+ type='ResizeOCR',
+ height=32,
+ min_width=32,
+ max_width=160,
+ keep_aspect_ratio=True),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'valid_ratio'
+ ]),
+ ])
+]
+
+dataset_type = 'OCRDataset'
+img_prefix = 'tests/data/ocr_toy_dataset/imgs'
+train_anno_file1 = 'tests/data/ocr_toy_dataset/label.txt'
+train1 = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=train_anno_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=100,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb'
+train2 = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=train_anno_file2,
+ loader=dict(
+ type='LmdbLoader',
+ repeat=100,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb'
+test = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=test_anno_file1,
+ loader=dict(
+ type='LmdbLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=2,
+ train=dict(type='ConcatDataset', datasets=[train1, train2]),
+ val=dict(type='ConcatDataset', datasets=[test]),
+ test=dict(type='ConcatDataset', datasets=[test]))
+
+evaluation = dict(interval=1, metric='acc')
diff --git a/configs/_base_/recog_models/crnn.py b/configs/_base_/recog_models/crnn.py
new file mode 100644
index 00000000..6b98c3d9
--- /dev/null
+++ b/configs/_base_/recog_models/crnn.py
@@ -0,0 +1,11 @@
+label_convertor = dict(
+ type='CTCConvertor', dict_type='DICT90', with_unknown=False)
+
+model = dict(
+ type='CRNNNet',
+ preprocessor=None,
+ backbone=dict(type='VeryDeepVgg', leakyRelu=False),
+ encoder=None,
+ decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
+ loss=dict(type='CTCLoss', flatten=False),
+ label_convertor=label_convertor)
diff --git a/configs/_base_/recog_models/robust_scanner.py b/configs/_base_/recog_models/robust_scanner.py
new file mode 100644
index 00000000..4cc2fa10
--- /dev/null
+++ b/configs/_base_/recog_models/robust_scanner.py
@@ -0,0 +1,24 @@
+label_convertor = dict(
+ type='AttnConvertor', dict_type='DICT90', with_unknown=True)
+
+hybrid_decoder = dict(type='SequenceAttentionDecoder')
+
+position_decoder = dict(type='PositionAttentionDecoder')
+
+model = dict(
+ type='RobustScanner',
+ backbone=dict(type='ResNet31OCR'),
+ encoder=dict(
+ type='ChannelReductionEncoder',
+ in_channels=512,
+ out_channels=128,
+ ),
+ decoder=dict(
+ type='RobustScannerDecoder',
+ dim_input=512,
+ dim_model=128,
+ hybrid_decoder=hybrid_decoder,
+ position_decoder=position_decoder),
+ loss=dict(type='SARLoss'),
+ label_convertor=label_convertor,
+ max_seq_len=30)
diff --git a/configs/_base_/recog_models/sar.py b/configs/_base_/recog_models/sar.py
new file mode 100755
index 00000000..8438d9b9
--- /dev/null
+++ b/configs/_base_/recog_models/sar.py
@@ -0,0 +1,24 @@
+label_convertor = dict(
+ type='AttnConvertor', dict_type='DICT90', with_unknown=True)
+
+model = dict(
+ type='SARNet',
+ backbone=dict(type='ResNet31OCR'),
+ encoder=dict(
+ type='SAREncoder',
+ enc_bi_rnn=False,
+ enc_do_rnn=0.1,
+ enc_gru=False,
+ ),
+ decoder=dict(
+ type='ParallelSARDecoder',
+ enc_bi_rnn=False,
+ dec_bi_rnn=False,
+ dec_do_rnn=0,
+ dec_gru=False,
+ pred_dropout=0.1,
+ d_k=512,
+ pred_concat=True),
+ loss=dict(type='SARLoss'),
+ label_convertor=label_convertor,
+ max_seq_len=30)
diff --git a/configs/_base_/recog_models/transformer.py b/configs/_base_/recog_models/transformer.py
new file mode 100644
index 00000000..476643fa
--- /dev/null
+++ b/configs/_base_/recog_models/transformer.py
@@ -0,0 +1,11 @@
+label_convertor = dict(
+ type='AttnConvertor', dict_type='DICT90', with_unknown=False)
+
+model = dict(
+ type='TransformerNet',
+ backbone=dict(type='ResNet31OCR'),
+ encoder=dict(type='TFEncoder'),
+ decoder=dict(type='TFDecoder'),
+ loss=dict(type='TFLoss'),
+ label_convertor=label_convertor,
+ max_seq_len=40)
diff --git a/configs/textrecog/sar/README.md b/configs/textrecog/sar/README.md
new file mode 100644
index 00000000..517c6985
--- /dev/null
+++ b/configs/textrecog/sar/README.md
@@ -0,0 +1,65 @@
+# Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition
+
+## Introduction
+
+[ALGORITHM]
+
+```
+@inproceedings{li2019show,
+ title={Show, attend and read: A simple and strong baseline for irregular text recognition},
+ author={Li, Hui and Wang, Peng and Shen, Chunhua and Zhang, Guyu},
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
+ volume={33},
+ number={01},
+ pages={8610--8617},
+ year={2019}
+}
+```
+
+## Dataset
+
+### Train Dataset
+
+| trainset | instance_num | repeat_num | source |
+| :--------: | :----------: | :--------: | :----------------------: |
+| icdar_2011 | 3567 | 20 | real |
+| icdar_2013 | 848 | 20 | real |
+| icdar2015 | 4468 | 20 | real |
+| coco_text | 42142 | 20 | real |
+| IIIT5K | 2000 | 20 | real |
+| SynthText | 2400000 | 1 | synth |
+| SynthAdd | 1216889 | 1 | synth, 1.6m in [[1]](#1) |
+| Syn90k | 2400000 | 1 | synth |
+
+### Test Dataset
+
+| testset | instance_num | type |
+| :-----: | :----------: | :-------------------------: |
+| IIIT5K | 3000 | regular |
+| SVT | 647 | regular |
+| IC13 | 1015 | regular |
+| IC15 | 2077 | irregular |
+| SVTP | 645 | irregular, 639 in [[1]](#1) |
+| CT80 | 288 | irregular |
+
+## Results and Models
+| Methods|Backbone|Decoder||Regular Text||||Irregular Text||download|
+| :-------------: | :-----: | :-----: | :-----: | :------: | :-----: | :----: | :-----: | :-----: | :-----: |:-----: |
+||||IIIT5K|SVT|IC13||IC15|SVTP|CT80|
+|[SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py)|R31-1/8-1/4|ParallelSARDecoder|95.0|89.6|93.7||79.0|82.2|88.9|[model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth) | [config](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_academic.py) | [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210327_154129.log.json) |
+|[SAR](configs/textrecog/sar/sar_r31_sequential_decoder_academic.py)|R31-1/8-1/4|SequentialSARDecoder|95.2|88.7|92.4||78.2|81.9|89.6|[model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_sequential_decoder_academic-d06c9a8e.pth) | [config](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_sequential_decoder_academic.py) | [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210330_105728.log.json)|
+
+**Notes:**
+- `R31-1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width.
+- We did not use beam search during decoding.
+- We implemented two kinds of decoder. Namely, `ParallelSARDecoder` and `SequentialSARDecoder`.
+ - `ParallelSARDecoder`: Parallel decoding during training with `LSTM` layer. It would be faster.
+ - `SequentialSARDecoder`: Sequential Decoding during training with `LSTMCell`. It would be easier to understand.
+- For train dataset.
+ - We did not construct distinct data groups (20 groups in [[1]](#1)) to train the model group-by-group since it would render model training too complicated.
+ - Instead, we randomly selected `2.4m` patches from `Syn90k`, `2.4m` from `SynthText` and `1.2m` from `SynthAdd`, and grouped all data together. See [config](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_academic.py) for details.
+- We used 48 GPUs with `total_batch_size = 64 * 48` in the experiment above to speedup training, while keeping the `initial lr = 1e-3` unchanged.
+
+## References
+
+[1] Li, Hui and Wang, Peng and Shen, Chunhua and Zhang, Guyu. Show, attend and read: A simple and strong baseline for irregular text recognition. In AAAI 2019.
diff --git a/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py b/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py
new file mode 100644
index 00000000..4e405227
--- /dev/null
+++ b/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py
@@ -0,0 +1,219 @@
+_base_ = ['../../_base_/default_runtime.py']
+
+label_convertor = dict(
+ type='AttnConvertor', dict_type='DICT90', with_unknown=True)
+
+model = dict(
+ type='SARNet',
+ backbone=dict(type='ResNet31OCR'),
+ encoder=dict(
+ type='SAREncoder',
+ enc_bi_rnn=False,
+ enc_do_rnn=0.1,
+ enc_gru=False,
+ ),
+ decoder=dict(
+ type='ParallelSARDecoder',
+ enc_bi_rnn=False,
+ dec_bi_rnn=False,
+ dec_do_rnn=0,
+ dec_gru=False,
+ pred_dropout=0.1,
+ d_k=512,
+ pred_concat=True),
+ loss=dict(type='SARLoss'),
+ label_convertor=label_convertor,
+ max_seq_len=30)
+
+# optimizer
+optimizer = dict(type='Adam', lr=1e-3)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[3, 4])
+total_epochs = 5
+
+img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeOCR',
+ height=48,
+ min_width=48,
+ max_width=160,
+ keep_aspect_ratio=True,
+ width_downsample_ratio=0.25),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
+ ]),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiRotateAugOCR',
+ rotate_degrees=[0, 90, 270],
+ transforms=[
+ dict(
+ type='ResizeOCR',
+ height=48,
+ min_width=48,
+ max_width=160,
+ keep_aspect_ratio=True,
+ width_downsample_ratio=0.25),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'valid_ratio'
+ ]),
+ ])
+]
+
+dataset_type = 'OCRDataset'
+
+train_prefix = 'data/mixture/'
+
+train_img_prefix1 = train_prefix + 'icdar_2011'
+train_img_prefix2 = train_prefix + 'icdar_2013'
+train_img_prefix3 = train_prefix + 'icdar_2015'
+train_img_prefix4 = train_prefix + 'coco_text'
+train_img_prefix5 = train_prefix + 'III5K'
+train_img_prefix6 = train_prefix + 'SynthText_Add'
+train_img_prefix7 = train_prefix + 'SynthText'
+train_img_prefix8 = train_prefix + 'Syn90k'
+
+train_ann_file1 = train_prefix + 'icdar_2011/train_label.txt',
+train_ann_file2 = train_prefix + 'icdar_2013/train_label.txt',
+train_ann_file3 = train_prefix + 'icdar_2015/train_label.txt',
+train_ann_file4 = train_prefix + 'coco_text/train_label.txt',
+train_ann_file5 = train_prefix + 'III5K/train_label.txt',
+train_ann_file6 = train_prefix + 'SynthText_Add/label.txt',
+train_ann_file7 = train_prefix + 'SynthText/shuffle_labels.txt',
+train_ann_file8 = train_prefix + 'Syn90k/shuffle_labels.txt'
+
+train1 = dict(
+ type=dataset_type,
+ img_prefix=train_img_prefix1,
+ ann_file=train_ann_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=20,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+train2 = {key: value for key, value in train1.items()}
+train2['img_prefix'] = train_img_prefix2
+train2['ann_file'] = train_ann_file2
+
+train3 = {key: value for key, value in train1.items()}
+train3['img_prefix'] = train_img_prefix3
+train3['ann_file'] = train_ann_file3
+
+train4 = {key: value for key, value in train1.items()}
+train4['img_prefix'] = train_img_prefix4
+train4['ann_file'] = train_ann_file4
+
+train5 = {key: value for key, value in train1.items()}
+train5['img_prefix'] = train_img_prefix5
+train5['ann_file'] = train_ann_file5
+
+train6 = dict(
+ type=dataset_type,
+ img_prefix=train_img_prefix6,
+ ann_file=train_ann_file6,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+train7 = {key: value for key, value in train6.items()}
+train7['img_prefix'] = train_img_prefix7
+train7['ann_file'] = train_ann_file7
+
+train8 = {key: value for key, value in train6.items()}
+train8['img_prefix'] = train_img_prefix8
+train8['ann_file'] = train_ann_file8
+
+test_prefix = 'data/mixture/'
+test_img_prefix1 = test_prefix + 'IIIT5K/'
+test_img_prefix2 = test_prefix + 'svt/'
+test_img_prefix3 = test_prefix + 'icdar_2013/'
+test_img_prefix4 = test_prefix + 'icdar_2015/'
+test_img_prefix5 = test_prefix + 'svtp/'
+test_img_prefix6 = test_prefix + 'ct80/'
+
+test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
+test_ann_file2 = test_prefix + 'svt/test_label.txt'
+test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
+test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt'
+test_ann_file5 = test_prefix + 'svtp/test_label.txt'
+test_ann_file6 = test_prefix + 'ct80/test_label.txt'
+
+test1 = dict(
+ type=dataset_type,
+ img_prefix=test_img_prefix1,
+ ann_file=test_ann_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+test2 = {key: value for key, value in test1.items()}
+test2['img_prefix'] = test_img_prefix2
+test2['ann_file'] = test_ann_file2
+
+test3 = {key: value for key, value in test1.items()}
+test3['img_prefix'] = test_img_prefix3
+test3['ann_file'] = test_ann_file3
+
+test4 = {key: value for key, value in test1.items()}
+test4['img_prefix'] = test_img_prefix4
+test4['ann_file'] = test_ann_file4
+
+test5 = {key: value for key, value in test1.items()}
+test5['img_prefix'] = test_img_prefix5
+test5['ann_file'] = test_ann_file5
+
+test6 = {key: value for key, value in test1.items()}
+test6['img_prefix'] = test_img_prefix6
+test6['ann_file'] = test_ann_file6
+
+data = dict(
+ samples_per_gpu=64,
+ workers_per_gpu=2,
+ train=dict(
+ type='ConcatDataset',
+ datasets=[
+ train1, train2, train3, train4, train5, train6, train7, train8
+ ]),
+ val=dict(
+ type='ConcatDataset',
+ datasets=[test1, test2, test3, test4, test5, test6]),
+ test=dict(
+ type='ConcatDataset',
+ datasets=[test1, test2, test3, test4, test5, test6]))
+
+evaluation = dict(interval=1, metric='acc')
diff --git a/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py b/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py
new file mode 100755
index 00000000..0c8b53e2
--- /dev/null
+++ b/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py
@@ -0,0 +1,110 @@
+_base_ = [
+ '../../_base_/default_runtime.py', '../../_base_/recog_models/sar.py'
+]
+
+# optimizer
+optimizer = dict(type='Adam', lr=1e-3)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[3, 4])
+total_epochs = 5
+
+img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeOCR',
+ height=48,
+ min_width=48,
+ max_width=160,
+ keep_aspect_ratio=True),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
+ ]),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiRotateAugOCR',
+ rotate_degrees=[0, 90, 270],
+ transforms=[
+ dict(
+ type='ResizeOCR',
+ height=48,
+ min_width=48,
+ max_width=160,
+ keep_aspect_ratio=True),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'valid_ratio'
+ ]),
+ ])
+]
+
+dataset_type = 'OCRDataset'
+img_prefix = 'tests/data/ocr_toy_dataset/imgs'
+train_anno_file1 = 'tests/data/ocr_toy_dataset/label.txt'
+train1 = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=train_anno_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=100,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb'
+train2 = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=train_anno_file2,
+ loader=dict(
+ type='LmdbLoader',
+ repeat=100,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb'
+test = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=test_anno_file1,
+ loader=dict(
+ type='LmdbLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=2,
+ train=dict(type='ConcatDataset', datasets=[train1, train2]),
+ val=dict(type='ConcatDataset', datasets=[test]),
+ test=dict(type='ConcatDataset', datasets=[test]))
+
+evaluation = dict(interval=1, metric='acc')
diff --git a/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py b/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py
new file mode 100644
index 00000000..6fa00dd7
--- /dev/null
+++ b/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py
@@ -0,0 +1,219 @@
+_base_ = ['../../_base_/default_runtime.py']
+
+label_convertor = dict(
+ type='AttnConvertor', dict_type='DICT90', with_unknown=True)
+
+model = dict(
+ type='SARNet',
+ backbone=dict(type='ResNet31OCR'),
+ encoder=dict(
+ type='SAREncoder',
+ enc_bi_rnn=False,
+ enc_do_rnn=0.1,
+ enc_gru=False,
+ ),
+ decoder=dict(
+ type='SequentialSARDecoder',
+ enc_bi_rnn=False,
+ dec_bi_rnn=False,
+ dec_do_rnn=0,
+ dec_gru=False,
+ pred_dropout=0.1,
+ d_k=512,
+ pred_concat=True),
+ loss=dict(type='SARLoss'),
+ label_convertor=label_convertor,
+ max_seq_len=30)
+
+# optimizer
+optimizer = dict(type='Adam', lr=1e-3)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[3, 4])
+total_epochs = 5
+
+img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeOCR',
+ height=48,
+ min_width=48,
+ max_width=160,
+ keep_aspect_ratio=True,
+ width_downsample_ratio=0.25),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
+ ]),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiRotateAugOCR',
+ rotate_degrees=[0, 90, 270],
+ transforms=[
+ dict(
+ type='ResizeOCR',
+ height=48,
+ min_width=48,
+ max_width=160,
+ keep_aspect_ratio=True,
+ width_downsample_ratio=0.25),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'valid_ratio'
+ ]),
+ ])
+]
+
+dataset_type = 'OCRDataset'
+
+train_prefix = 'data/mixture/'
+
+train_img_prefix1 = train_prefix + 'icdar_2011'
+train_img_prefix2 = train_prefix + 'icdar_2013'
+train_img_prefix3 = train_prefix + 'icdar_2015'
+train_img_prefix4 = train_prefix + 'coco_text'
+train_img_prefix5 = train_prefix + 'III5K'
+train_img_prefix6 = train_prefix + 'SynthText_Add'
+train_img_prefix7 = train_prefix + 'SynthText'
+train_img_prefix8 = train_prefix + 'Syn90k'
+
+train_ann_file1 = train_prefix + 'icdar_2011/train_label.txt',
+train_ann_file2 = train_prefix + 'icdar_2013/train_label.txt',
+train_ann_file3 = train_prefix + 'icdar_2015/train_label.txt',
+train_ann_file4 = train_prefix + 'coco_text/train_label.txt',
+train_ann_file5 = train_prefix + 'III5K/train_label.txt',
+train_ann_file6 = train_prefix + 'SynthText_Add/label.txt',
+train_ann_file7 = train_prefix + 'SynthText/shuffle_labels.txt',
+train_ann_file8 = train_prefix + 'Syn90k/shuffle_labels.txt'
+
+train1 = dict(
+ type=dataset_type,
+ img_prefix=train_img_prefix1,
+ ann_file=train_ann_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=20,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+train2 = {key: value for key, value in train1.items()}
+train2['img_prefix'] = train_img_prefix2
+train2['ann_file'] = train_ann_file2
+
+train3 = {key: value for key, value in train1.items()}
+train3['img_prefix'] = train_img_prefix3
+train3['ann_file'] = train_ann_file3
+
+train4 = {key: value for key, value in train1.items()}
+train4['img_prefix'] = train_img_prefix4
+train4['ann_file'] = train_ann_file4
+
+train5 = {key: value for key, value in train1.items()}
+train5['img_prefix'] = train_img_prefix5
+train5['ann_file'] = train_ann_file5
+
+train6 = dict(
+ type=dataset_type,
+ img_prefix=train_img_prefix6,
+ ann_file=train_ann_file6,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+train7 = {key: value for key, value in train6.items()}
+train7['img_prefix'] = train_img_prefix7
+train7['ann_file'] = train_ann_file7
+
+train8 = {key: value for key, value in train6.items()}
+train8['img_prefix'] = train_img_prefix8
+train8['ann_file'] = train_ann_file8
+
+test_prefix = 'data/mixture/'
+test_img_prefix1 = test_prefix + 'IIIT5K/'
+test_img_prefix2 = test_prefix + 'svt/'
+test_img_prefix3 = test_prefix + 'icdar_2013/'
+test_img_prefix4 = test_prefix + 'icdar_2015/'
+test_img_prefix5 = test_prefix + 'svtp/'
+test_img_prefix6 = test_prefix + 'ct80/'
+
+test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
+test_ann_file2 = test_prefix + 'svt/test_label.txt'
+test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
+test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt'
+test_ann_file5 = test_prefix + 'svtp/test_label.txt'
+test_ann_file6 = test_prefix + 'ct80/test_label.txt'
+
+test1 = dict(
+ type=dataset_type,
+ img_prefix=test_img_prefix1,
+ ann_file=test_ann_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+test2 = {key: value for key, value in test1.items()}
+test2['img_prefix'] = test_img_prefix2
+test2['ann_file'] = test_ann_file2
+
+test3 = {key: value for key, value in test1.items()}
+test3['img_prefix'] = test_img_prefix3
+test3['ann_file'] = test_ann_file3
+
+test4 = {key: value for key, value in test1.items()}
+test4['img_prefix'] = test_img_prefix4
+test4['ann_file'] = test_ann_file4
+
+test5 = {key: value for key, value in test1.items()}
+test5['img_prefix'] = test_img_prefix5
+test5['ann_file'] = test_ann_file5
+
+test6 = {key: value for key, value in test1.items()}
+test6['img_prefix'] = test_img_prefix6
+test6['ann_file'] = test_ann_file6
+
+data = dict(
+ samples_per_gpu=64,
+ workers_per_gpu=2,
+ train=dict(
+ type='ConcatDataset',
+ datasets=[
+ train1, train2, train3, train4, train5, train6, train7, train8
+ ]),
+ val=dict(
+ type='ConcatDataset',
+ datasets=[test1, test2, test3, test4, test5, test6]),
+ test=dict(
+ type='ConcatDataset',
+ datasets=[test1, test2, test3, test4, test5, test6]))
+
+evaluation = dict(interval=1, metric='acc')
diff --git a/configs/textrecog/seg/README.md b/configs/textrecog/seg/README.md
new file mode 100644
index 00000000..49ca2014
--- /dev/null
+++ b/configs/textrecog/seg/README.md
@@ -0,0 +1,36 @@
+# Baseline of segmentation based text recognition method.
+
+## Introduction
+
+A Baseline Method for Segmentation based Text Recognition.
+
+[ALGORITHM]
+
+## Dataset
+
+### Train Dataset
+
+| trainset | instance_num | repeat_num | source |
+| :-------: | :----------: | :--------: | :----: |
+| SynthText | 7266686 | 1 | synth |
+
+### Test Dataset
+
+| testset | instance_num | type |
+| :-----: | :----------: | :-------: |
+| IIIT5K | 3000 | regular |
+| SVT | 647 | regular |
+| IC13 | 1015 | regular |
+| CT80 | 288 | irregular |
+
+## Results and Models
+|Backbone|Neck|Head|||Regular Text|||Irregular Text|base_lr|batch_size/gpu|gpus|download
+| :-------------: | :-----: | :-----: | :------: | :-----: | :----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: |
+|||||IIIT5K|SVT|IC13||CT80|
+|R31-1/16|FPNOCR|1x||90.9|81.8|90.7||80.9|1e-4|16|4|[model](https://download.openmmlab.com/mmocr/textrecog/seg/seg_r31_1by16_fpnocr_academic-0c50e163.pth) | [config](https://download.openmmlab.com/mmocr/textrecog/seg/seg_r31_1by16_fpnocr_academic.py) | [log](https://download.openmmlab.com/mmocr/textrecog/seg/20210325_112835.log.json) |
+
+
+**Notes:**
+
+- `R31-1/16` means the size (both height and width ) of feature from backbone is 1/16 of input image.
+- `1x` means the size (both height and width) of feature from head is the same with input image.
diff --git a/configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py b/configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py
new file mode 100644
index 00000000..8a568f8e
--- /dev/null
+++ b/configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py
@@ -0,0 +1,160 @@
+_base_ = ['../../_base_/default_runtime.py']
+
+# optimizer
+optimizer = dict(type='Adam', lr=1e-4)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[3, 4])
+total_epochs = 5
+
+label_convertor = dict(
+ type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True)
+
+model = dict(
+ type='SegRecognizer',
+ backbone=dict(
+ type='ResNet31OCR',
+ layers=[1, 2, 5, 3],
+ channels=[32, 64, 128, 256, 512, 512],
+ out_indices=[0, 1, 2, 3],
+ stage4_pool_cfg=dict(kernel_size=2, stride=2),
+ last_stage_pool=True),
+ neck=dict(
+ type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256),
+ head=dict(
+ type='SegHead',
+ in_channels=256,
+ upsample_param=dict(scale_factor=2.0, mode='nearest')),
+ loss=dict(
+ type='SegLoss', seg_downsample_ratio=1.0, seg_with_loss_weight=True),
+ label_convertor=label_convertor)
+
+find_unused_parameters = True
+
+img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+gt_label_convertor = dict(
+ type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='RandomPaddingOCR',
+ max_ratio=[0.15, 0.2, 0.15, 0.2],
+ box_type='char_quads'),
+ dict(type='OpencvToPil'),
+ dict(
+ type='RandomRotateImageBox',
+ min_angle=-17,
+ max_angle=17,
+ box_type='char_quads'),
+ dict(type='PilToOpencv'),
+ dict(
+ type='ResizeOCR',
+ height=64,
+ min_width=64,
+ max_width=512,
+ keep_aspect_ratio=True),
+ dict(
+ type='OCRSegTargets',
+ label_convertor=gt_label_convertor,
+ box_type='char_quads'),
+ dict(type='RandomRotateTextDet', rotate_ratio=0.5, max_angle=15),
+ dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
+ dict(type='ToTensorOCR'),
+ dict(type='FancyPCA'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='CustomFormatBundle',
+ keys=['gt_kernels'],
+ visualize=dict(flag=False, boundary_key=None),
+ call_super=False),
+ dict(
+ type='Collect',
+ keys=['img', 'gt_kernels'],
+ meta_keys=['filename', 'ori_shape', 'img_shape'])
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeOCR',
+ height=64,
+ min_width=64,
+ max_width=None,
+ keep_aspect_ratio=True),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(type='CustomFormatBundle', call_super=False),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=['filename', 'ori_shape', 'img_shape'])
+]
+
+train_img_root = 'data/mixture/'
+
+train_img_prefix = train_img_root + 'SynthText'
+
+train_ann_file = train_img_root + 'SynthText/instances_train.txt'
+
+train = dict(
+ type='OCRSegDataset',
+ img_prefix=train_img_prefix,
+ ann_file=train_ann_file,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineJsonParser', keys=['file_name', 'annotations', 'text'])),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+dataset_type = 'OCRDataset'
+test_prefix = 'data/mixture/'
+
+test_img_prefix1 = test_prefix + 'IIIT5K/'
+test_img_prefix2 = test_prefix + 'svt/'
+test_img_prefix3 = test_prefix + 'icdar_2013/'
+test_img_prefix4 = test_prefix + 'ct80/'
+
+test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
+test_ann_file2 = test_prefix + 'svt/test_label.txt'
+test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
+test_ann_file4 = test_prefix + 'ct80/test_label.txt'
+
+test1 = dict(
+ type=dataset_type,
+ img_prefix=test_img_prefix1,
+ ann_file=test_ann_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+test2 = {key: value for key, value in test1.items()}
+test2['img_prefix'] = test_img_prefix2
+test2['ann_file'] = test_ann_file2
+
+test3 = {key: value for key, value in test1.items()}
+test3['img_prefix'] = test_img_prefix3
+test3['ann_file'] = test_ann_file3
+
+test4 = {key: value for key, value in test1.items()}
+test4['img_prefix'] = test_img_prefix4
+test4['ann_file'] = test_ann_file4
+
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=2,
+ train=dict(type='ConcatDataset', datasets=[train]),
+ val=dict(type='ConcatDataset', datasets=[test1, test2, test3, test4]),
+ test=dict(type='ConcatDataset', datasets=[test1, test2, test3, test4]))
+
+evaluation = dict(interval=1, metric='acc')
diff --git a/configs/textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py b/configs/textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py
new file mode 100644
index 00000000..63b3d08c
--- /dev/null
+++ b/configs/textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py
@@ -0,0 +1,35 @@
+_base_ = [
+ '../../_base_/default_runtime.py',
+ '../../_base_/recog_datasets/seg_toy_dataset.py'
+]
+
+# optimizer
+optimizer = dict(type='Adam', lr=1e-4)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[3, 4])
+total_epochs = 5
+
+label_convertor = dict(
+ type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True)
+
+model = dict(
+ type='SegRecognizer',
+ backbone=dict(
+ type='ResNet31OCR',
+ layers=[1, 2, 5, 3],
+ channels=[32, 64, 128, 256, 512, 512],
+ out_indices=[0, 1, 2, 3],
+ stage4_pool_cfg=dict(kernel_size=2, stride=2),
+ last_stage_pool=True),
+ neck=dict(
+ type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256),
+ head=dict(
+ type='SegHead',
+ in_channels=256,
+ upsample_param=dict(scale_factor=2.0, mode='nearest')),
+ loss=dict(
+ type='SegLoss', seg_downsample_ratio=1.0, seg_with_loss_weight=False),
+ label_convertor=label_convertor)
+
+find_unused_parameters = True
diff --git a/configs/textrecog/transformer/README.md b/configs/textrecog/transformer/README.md
new file mode 100644
index 00000000..46a132ad
--- /dev/null
+++ b/configs/textrecog/transformer/README.md
@@ -0,0 +1,30 @@
+## Introduction
+
+### Train Dataset
+
+| trainset | instance_num | repeat_num | note |
+| :--------: | :----------: | :--------: | :---: |
+| icdar_2011 | 3567 | 20 | real |
+| icdar_2013 | 848 | 20 | real |
+| icdar2015 | 4468 | 20 | real |
+| coco_text | 42142 | 20 | real |
+| IIIT5K | 2000 | 20 | real |
+| SynthText | 2400000 | 1 | synth |
+
+### Test Dataset
+
+| testset | instance_num | note |
+| :-----: | :----------: | :-------------------------: |
+| IIIT5K | 3000 | regular |
+| SVT | 647 | regular |
+| IC13 | 1015 | regular |
+| IC15 | 2077 | irregular |
+| SVTP | 645 | irregular, 639 in [[1]](#1) |
+| CT80 | 288 | irregular |
+
+## Results and models
+
+| methods | | Regular Text | | | | Irregular Text | | download |
+| :---------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :------: |
+| | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
+| Transformer | 93.3 | 85.8 | 91.3 | | 73.2 | 76.6 | 87.8 | |
diff --git a/configs/textrecog/transformer/transformer_r31_toy_dataset.py b/configs/textrecog/transformer/transformer_r31_toy_dataset.py
new file mode 100755
index 00000000..9fc4a3ab
--- /dev/null
+++ b/configs/textrecog/transformer/transformer_r31_toy_dataset.py
@@ -0,0 +1,12 @@
+_base_ = [
+ '../../_base_/default_runtime.py',
+ '../../_base_/recog_models/transformer.py',
+ '../../_base_/recog_datasets/toy_dataset.py'
+]
+
+# optimizer
+optimizer = dict(type='Adadelta', lr=1)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[3, 4])
+total_epochs = 5
diff --git a/mmocr/__init__.py b/mmocr/__init__.py
new file mode 100644
index 00000000..1c4f7e8f
--- /dev/null
+++ b/mmocr/__init__.py
@@ -0,0 +1,3 @@
+from .version import __version__, short_version
+
+__all__ = ['__version__', 'short_version']
diff --git a/mmocr/core/evaluation/ocr_metric.py b/mmocr/core/evaluation/ocr_metric.py
new file mode 100644
index 00000000..5c5124f0
--- /dev/null
+++ b/mmocr/core/evaluation/ocr_metric.py
@@ -0,0 +1,133 @@
+import re
+from difflib import SequenceMatcher
+
+import Levenshtein
+
+
+def cal_true_positive_char(pred, gt):
+ """Calculate correct character number in prediction.
+
+ Args:
+ pred (str): Prediction text.
+ gt (str): Ground truth text.
+
+ Returns:
+ true_positive_char_num (int): The true positive number.
+ """
+
+ all_opt = SequenceMatcher(None, pred, gt)
+ true_positive_char_num = 0
+ for opt, _, _, s2, e2 in all_opt.get_opcodes():
+ if opt == 'equal':
+ true_positive_char_num += (e2 - s2)
+ else:
+ pass
+ return true_positive_char_num
+
+
+def count_matches(pred_texts, gt_texts):
+ """Count the various match number for metric calculation.
+
+ Args:
+ pred_texts (list[str]): Predicted text string.
+ gt_texts (list[str]): Ground truth text string.
+
+ Returns:
+ match_res: (dict[str: int]): Match number used for
+ metric calculation.
+ """
+ match_res = {
+ 'gt_char_num': 0,
+ 'pred_char_num': 0,
+ 'true_positive_char_num': 0,
+ 'gt_word_num': 0,
+ 'match_word_num': 0,
+ 'match_word_ignore_case': 0,
+ 'match_word_ignore_case_symbol': 0
+ }
+ comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
+ norm_ed_sum = 0.0
+ for pred_text, gt_text in zip(pred_texts, gt_texts):
+ if gt_text == pred_text:
+ match_res['match_word_num'] += 1
+ gt_text_lower = gt_text.lower()
+ pred_text_lower = pred_text.lower()
+ if gt_text_lower == pred_text_lower:
+ match_res['match_word_ignore_case'] += 1
+ gt_text_lower_ignore = comp.sub('', gt_text_lower)
+ pred_text_lower_ignore = comp.sub('', pred_text_lower)
+ if gt_text_lower_ignore == pred_text_lower_ignore:
+ match_res['match_word_ignore_case_symbol'] += 1
+ match_res['gt_word_num'] += 1
+
+ # normalized edit distance
+ edit_dist = Levenshtein.distance(pred_text_lower_ignore,
+ gt_text_lower_ignore)
+ norm_ed = float(edit_dist) / max(1, len(gt_text_lower_ignore),
+ len(pred_text_lower_ignore))
+ norm_ed_sum += norm_ed
+
+ # number to calculate char level recall & precision
+ match_res['gt_char_num'] += len(gt_text_lower_ignore)
+ match_res['pred_char_num'] += len(pred_text_lower_ignore)
+ true_positive_char_num = cal_true_positive_char(
+ pred_text_lower_ignore, gt_text_lower_ignore)
+ match_res['true_positive_char_num'] += true_positive_char_num
+
+ normalized_edit_distance = norm_ed_sum / max(1, len(gt_texts))
+ match_res['ned'] = normalized_edit_distance
+
+ return match_res
+
+
+def eval_ocr_metric(pred_texts, gt_texts):
+ """Evaluate the text recognition performance with metric: word accuracy and
+ 1-N.E.D. See https://rrc.cvc.uab.es/?ch=14&com=tasks for details.
+
+ Args:
+ pred_texts (list[str]): Text strings of prediction.
+ gt_texts (list[str]): Text strings of ground truth.
+
+ Returns:
+ eval_res (dict[str: float]): Metric dict for text recognition, include:
+ - word_acc: Accuracy in word level.
+ - word_acc_ignore_case: Accuracy in word level, ignore letter case.
+ - word_acc_ignore_case_symbol: Accuracy in word level, ignore
+ letter case and symbol. (default metric for
+ academic evaluation)
+ - char_recall: Recall in character level, ignore
+ letter case and symbol.
+ - char_precision: Precision in character level, ignore
+ letter case and symbol.
+ - 1-N.E.D: 1 - normalized_edit_distance.
+ """
+ assert isinstance(pred_texts, list)
+ assert isinstance(gt_texts, list)
+ assert len(pred_texts) == len(gt_texts)
+
+ match_res = count_matches(pred_texts, gt_texts)
+ eps = 1e-8
+ char_recall = 1.0 * match_res['true_positive_char_num'] / (
+ eps + match_res['gt_char_num'])
+ char_precision = 1.0 * match_res['true_positive_char_num'] / (
+ eps + match_res['pred_char_num'])
+ word_acc = 1.0 * match_res['match_word_num'] / (
+ eps + match_res['gt_word_num'])
+ word_acc_ignore_case = 1.0 * match_res['match_word_ignore_case'] / (
+ eps + match_res['gt_word_num'])
+ word_acc_ignore_case_symbol = 1.0 * match_res[
+ 'match_word_ignore_case_symbol'] / (
+ eps + match_res['gt_word_num'])
+
+ eval_res = {}
+ eval_res['word_acc'] = word_acc
+ eval_res['word_acc_ignore_case'] = word_acc_ignore_case
+ eval_res['word_acc_ignore_case_symbol'] = word_acc_ignore_case_symbol
+ eval_res['char_recall'] = char_recall
+ eval_res['char_precision'] = char_precision
+ eval_res['1-N.E.D'] = 1.0 - match_res['ned']
+
+ for key, value in eval_res.items():
+ eval_res[key] = float('{:.4f}'.format(value))
+
+ return eval_res
diff --git a/mmocr/datasets/__init__.py b/mmocr/datasets/__init__.py
new file mode 100644
index 00000000..dd219226
--- /dev/null
+++ b/mmocr/datasets/__init__.py
@@ -0,0 +1,15 @@
+from mmdet.datasets.builder import DATASETS, build_dataloader, build_dataset
+from .base_dataset import BaseDataset
+from .icdar_dataset import IcdarDataset
+from .kie_dataset import KIEDataset
+from .ocr_dataset import OCRDataset
+from .ocr_seg_dataset import OCRSegDataset
+from .pipelines import CustomFormatBundle, DBNetTargets, DRRGTargets
+from .text_det_dataset import TextDetDataset
+from .utils import * # noqa: F401,F403
+
+__all__ = [
+ 'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset',
+ 'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle',
+ 'DBNetTargets', 'OCRSegDataset', 'DRRGTargets', 'KIEDataset'
+]
diff --git a/mmocr/datasets/base_dataset.py b/mmocr/datasets/base_dataset.py
new file mode 100644
index 00000000..c91497d9
--- /dev/null
+++ b/mmocr/datasets/base_dataset.py
@@ -0,0 +1,166 @@
+import numpy as np
+from mmcv.utils import print_log
+from torch.utils.data import Dataset
+
+from mmdet.datasets.builder import DATASETS
+from mmdet.datasets.pipelines import Compose
+from mmocr.datasets.builder import build_loader
+
+
+@DATASETS.register_module()
+class BaseDataset(Dataset):
+ """Custom dataset for text detection, text recognition, and their
+ downstream tasks.
+
+ 1. The text detection annotation format is as follows:
+ The `annotations` field is optional for testing
+ (this is one line of anno_file, with line-json-str
+ converted to dict for visualizing only).
+
+ {
+ "file_name": "sample.jpg",
+ "height": 1080,
+ "width": 960,
+ "annotations":
+ [
+ {
+ "iscrowd": 0,
+ "category_id": 1,
+ "bbox": [357.0, 667.0, 804.0, 100.0],
+ "segmentation": [[361, 667, 710, 670,
+ 72, 767, 357, 763]]
+ }
+ ]
+ }
+
+ 2. The two text recognition annotation formats are as follows:
+ The `x1,y1,x2,y2,x3,y3,x4,y4` field is used for online crop
+ augmentation during training.
+
+ format1: sample.jpg hello
+ format2: sample.jpg 20 20 100 20 100 40 20 40 hello
+
+ Args:
+ ann_file (str): Annotation file path.
+ pipeline (list[dict]): Processing pipeline.
+ loader (dict): Dictionary to construct loader
+ to load annotation infos.
+ img_prefix (str, optional): Image prefix to generate full
+ image path.
+ test_mode (bool, optional): If set True, try...except will
+ be turned off in __getitem__.
+ """
+
+ def __init__(self,
+ ann_file,
+ loader,
+ pipeline,
+ img_prefix='',
+ test_mode=False):
+ super().__init__()
+ self.test_mode = test_mode
+ self.img_prefix = img_prefix
+ self.ann_file = ann_file
+ # load annotations
+ loader.update(ann_file=ann_file)
+ self.data_infos = build_loader(loader)
+ # processing pipeline
+ self.pipeline = Compose(pipeline)
+ # set group flag and class, no meaning
+ # for text detect and recognize
+ self._set_group_flag()
+ self.CLASSES = 0
+
+ def __len__(self):
+ return len(self.data_infos)
+
+ def _set_group_flag(self):
+ """Set flag."""
+ self.flag = np.zeros(len(self), dtype=np.uint8)
+
+ def pre_pipeline(self, results):
+ """Prepare results dict for pipeline."""
+ results['img_prefix'] = self.img_prefix
+
+ def prepare_train_img(self, index):
+ """Get training data and annotations from pipeline.
+
+ Args:
+ index (int): Index of data.
+
+ Returns:
+ dict: Training data and annotation after pipeline with new keys \
+ introduced by pipeline.
+ """
+ img_info = self.data_infos[index]
+ results = dict(img_info=img_info)
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+
+ def prepare_test_img(self, img_info):
+ """Get testing data from pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Testing data after pipeline with new keys introduced by \
+ pipeline.
+ """
+ return self.prepare_train_img(img_info)
+
+ def _log_error_index(self, index):
+ """Logging data info of bad index."""
+ try:
+ data_info = self.data_infos[index]
+ img_prefix = self.img_prefix
+ print_log(f'Warning: skip broken file {data_info} '
+ f'with img_prefix {img_prefix}')
+ except Exception as e:
+ print_log(f'load index {index} with error {e}')
+
+ def _get_next_index(self, index):
+ """Get next index from dataset."""
+ self._log_error_index(index)
+ index = (index + 1) % len(self)
+ return index
+
+ def __getitem__(self, index):
+ """Get training/test data from pipeline.
+
+ Args:
+ index (int): Index of data.
+
+ Returns:
+ dict: Training/test data.
+ """
+ if self.test_mode:
+ return self.prepare_test_img(index)
+
+ while True:
+ try:
+ data = self.prepare_train_img(index)
+ if data is None:
+ raise Exception('prepared train data empty')
+ break
+ except Exception as e:
+ print_log(f'prepare index {index} with error {e}')
+ index = self._get_next_index(index)
+ return data
+
+ def format_results(self, results, **kwargs):
+ """Placeholder to format result to dataset-specific output."""
+ pass
+
+ def evaluate(self, results, metric=None, logger=None, **kwargs):
+ """Evaluate the dataset.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ Returns:
+ dict[str: float]
+ """
+ raise NotImplementedError
diff --git a/mmocr/datasets/builder.py b/mmocr/datasets/builder.py
new file mode 100644
index 00000000..e7bcf423
--- /dev/null
+++ b/mmocr/datasets/builder.py
@@ -0,0 +1,14 @@
+from mmcv.utils import Registry, build_from_cfg
+
+LOADERS = Registry('loader')
+PARSERS = Registry('parser')
+
+
+def build_loader(cfg):
+ """Build anno file loader."""
+ return build_from_cfg(cfg, LOADERS)
+
+
+def build_parser(cfg):
+ """Build anno file parser."""
+ return build_from_cfg(cfg, PARSERS)
diff --git a/mmocr/datasets/ocr_dataset.py b/mmocr/datasets/ocr_dataset.py
new file mode 100644
index 00000000..4ec6d962
--- /dev/null
+++ b/mmocr/datasets/ocr_dataset.py
@@ -0,0 +1,34 @@
+from mmdet.datasets.builder import DATASETS
+from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
+from mmocr.datasets.base_dataset import BaseDataset
+
+
+@DATASETS.register_module()
+class OCRDataset(BaseDataset):
+
+ def pre_pipeline(self, results):
+ results['img_prefix'] = self.img_prefix
+ results['text'] = results['img_info']['text']
+
+ def evaluate(self, results, metric='acc', logger=None, **kwargs):
+ """Evaluate the dataset.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ Returns:
+ dict[str: float]
+ """
+ gt_texts = []
+ pred_texts = []
+ for i in range(len(self)):
+ item_info = self.data_infos[i]
+ text = item_info['text']
+ gt_texts.append(text)
+ pred_texts.append(results[i]['text'])
+
+ eval_results = eval_ocr_metric(pred_texts, gt_texts)
+
+ return eval_results
diff --git a/mmocr/datasets/ocr_seg_dataset.py b/mmocr/datasets/ocr_seg_dataset.py
new file mode 100644
index 00000000..0b149af1
--- /dev/null
+++ b/mmocr/datasets/ocr_seg_dataset.py
@@ -0,0 +1,90 @@
+import mmocr.utils as utils
+from mmdet.datasets.builder import DATASETS
+from mmocr.datasets.ocr_dataset import OCRDataset
+
+
+@DATASETS.register_module()
+class OCRSegDataset(OCRDataset):
+
+ def pre_pipeline(self, results):
+ results['img_prefix'] = self.img_prefix
+
+ def _parse_anno_info(self, annotations):
+ """Parse char boxes annotations.
+ Args:
+ annotations (list[dict]): Annotations of one image, where
+ each dict is for one character.
+
+ Returns:
+ dict: A dict containing the following keys:
+
+ - chars (list[str]): List of character strings.
+ - char_rects (list[list[float]]): List of char box, with each
+ in style of rectangle: [x_min, y_min, x_max, y_max].
+ - char_quads (list[list[float]]): List of char box, with each
+ in style of quadrangle: [x1, y1, x2, y2, x3, y3, x4, y4].
+ """
+
+ assert utils.is_type_list(annotations, dict)
+ assert 'char_box' in annotations[0]
+ assert 'char_text' in annotations[0]
+ assert len(annotations[0]['char_box']) == 4 or \
+ len(annotations[0]['char_box']) == 8
+
+ chars, char_rects, char_quads = [], [], []
+ for ann in annotations:
+ char_box = ann['char_box']
+ if len(char_box) == 4:
+ char_box_type = ann.get('char_box_type', 'xyxy')
+ if char_box_type == 'xyxy':
+ char_rects.append(char_box)
+ char_quads.append([
+ char_box[0], char_box[1], char_box[2], char_box[1],
+ char_box[2], char_box[3], char_box[0], char_box[3]
+ ])
+ elif char_box_type == 'xywh':
+ x1, y1, w, h = char_box
+ x2 = x1 + w
+ y2 = y1 + h
+ char_rects.append([x1, y1, x2, y2])
+ char_quads.append([x1, y1, x2, y1, x2, y2, x1, y2])
+ else:
+ raise ValueError(f'invalid char_box_type {char_box_type}')
+ elif len(char_box) == 8:
+ x_list, y_list = [], []
+ for i in range(4):
+ x_list.append(char_box[2 * i])
+ y_list.append(char_box[2 * i + 1])
+ x_max, x_min = max(x_list), min(x_list)
+ y_max, y_min = max(y_list), min(y_list)
+ char_rects.append([x_min, y_min, x_max, y_max])
+ char_quads.append(char_box)
+ else:
+ raise Exception(
+ f'invalid num in char box: {len(char_box)} not in (4, 8)')
+ chars.append(ann['char_text'])
+
+ ann = dict(chars=chars, char_rects=char_rects, char_quads=char_quads)
+
+ return ann
+
+ def prepare_train_img(self, index):
+ """Get training data and annotations from pipeline.
+
+ Args:
+ index (int): Index of data.
+
+ Returns:
+ dict: Training data and annotation after pipeline with new keys \
+ introduced by pipeline.
+ """
+ img_ann_info = self.data_infos[index]
+ img_info = {
+ 'filename': img_ann_info['file_name'],
+ }
+ ann_info = self._parse_anno_info(img_ann_info['annotations'])
+ results = dict(img_info=img_info, ann_info=ann_info)
+
+ self.pre_pipeline(results)
+
+ return self.pipeline(results)
diff --git a/mmocr/datasets/pipelines/crop.py b/mmocr/datasets/pipelines/crop.py
new file mode 100644
index 00000000..c2e57bdd
--- /dev/null
+++ b/mmocr/datasets/pipelines/crop.py
@@ -0,0 +1,185 @@
+import cv2
+import numpy as np
+from shapely.geometry import LineString, Point, Polygon
+
+import mmocr.utils as utils
+
+
+def sort_vertex(points_x, points_y):
+ """Sort box vertices in clockwise order from left-top first.
+
+ Args:
+ points_x (list[float]): x of four vertices.
+ points_y (list[float]): y of four vertices.
+ Returns:
+ sorted_points_x (list[float]): x of sorted four vertices.
+ sorted_points_y (list[float]): y of sorted four vertices.
+ """
+ assert utils.is_type_list(points_x, float) or utils.is_type_list(
+ points_x, int)
+ assert utils.is_type_list(points_y, float) or utils.is_type_list(
+ points_y, int)
+ assert len(points_x) == 4
+ assert len(points_y) == 4
+
+ x = np.array(points_x)
+ y = np.array(points_y)
+ center_x = np.sum(x) * 0.25
+ center_y = np.sum(y) * 0.25
+
+ x_arr = np.array(x - center_x)
+ y_arr = np.array(y - center_y)
+
+ angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi
+ sort_idx = np.argsort(angle)
+
+ sorted_points_x, sorted_points_y = [], []
+ for i in range(4):
+ sorted_points_x.append(points_x[sort_idx[i]])
+ sorted_points_y.append(points_y[sort_idx[i]])
+
+ return convert_canonical(sorted_points_x, sorted_points_y)
+
+
+def convert_canonical(points_x, points_y):
+ """Make left-top be first.
+
+ Args:
+ points_x (list[float]): x of four vertices.
+ points_y (list[float]): y of four vertices.
+ Returns:
+ sorted_points_x (list[float]): x of sorted four vertices.
+ sorted_points_y (list[float]): y of sorted four vertices.
+ """
+ assert utils.is_type_list(points_x, float) or utils.is_type_list(
+ points_x, int)
+ assert utils.is_type_list(points_y, float) or utils.is_type_list(
+ points_y, int)
+ assert len(points_x) == 4
+ assert len(points_y) == 4
+
+ points = [Point(points_x[i], points_y[i]) for i in range(4)]
+
+ polygon = Polygon([(p.x, p.y) for p in points])
+ min_x, min_y, _, _ = polygon.bounds
+ points_to_lefttop = [
+ LineString([points[i], Point(min_x, min_y)]) for i in range(4)
+ ]
+ distances = np.array([line.length for line in points_to_lefttop])
+ sort_dist_idx = np.argsort(distances)
+ lefttop_idx = sort_dist_idx[0]
+
+ if lefttop_idx == 0:
+ point_orders = [0, 1, 2, 3]
+ elif lefttop_idx == 1:
+ point_orders = [1, 2, 3, 0]
+ elif lefttop_idx == 2:
+ point_orders = [2, 3, 0, 1]
+ else:
+ point_orders = [3, 0, 1, 2]
+
+ sorted_points_x = [points_x[i] for i in point_orders]
+ sorted_points_y = [points_y[j] for j in point_orders]
+
+ return sorted_points_x, sorted_points_y
+
+
+def box_jitter(points_x, points_y, jitter_ratio_x=0.5, jitter_ratio_y=0.1):
+ """Jitter on the coordinates of bounding box.
+
+ Args:
+ points_x (list[float | int]): List of y for four vertices.
+ points_y (list[float | int]): List of x for four vertices.
+ jitter_ratio_x (float): Horizontal jitter ratio relative to the height.
+ jitter_ratio_y (float): Vertical jitter ratio relative to the height.
+ """
+ assert len(points_x) == 4
+ assert len(points_y) == 4
+ assert isinstance(jitter_ratio_x, float)
+ assert isinstance(jitter_ratio_y, float)
+ assert 0 <= jitter_ratio_x < 1
+ assert 0 <= jitter_ratio_y < 1
+
+ points = [Point(points_x[i], points_y[i]) for i in range(4)]
+ line_list = [
+ LineString([points[i], points[i + 1 if i < 3 else 0]])
+ for i in range(4)
+ ]
+
+ tmp_h = max(line_list[1].length, line_list[3].length)
+
+ for i in range(4):
+ jitter_pixel_x = (np.random.rand() - 0.5) * 2 * jitter_ratio_x * tmp_h
+ jitter_pixel_y = (np.random.rand() - 0.5) * 2 * jitter_ratio_y * tmp_h
+ points_x[i] += jitter_pixel_x
+ points_y[i] += jitter_pixel_y
+
+
+def warp_img(src_img,
+ box,
+ jitter_flag=False,
+ jitter_ratio_x=0.5,
+ jitter_ratio_y=0.1):
+ """Crop box area from image using opencv warpPerspective w/o box jitter.
+
+ Args:
+ src_img (np.array): Image before cropping.
+ box (list[float | int]): Coordinates of quadrangle.
+ """
+ assert utils.is_type_list(box, float) or utils.is_type_list(box, int)
+ assert len(box) == 8
+
+ h, w = src_img.shape[:2]
+ points_x = [min(max(x, 0), w) for x in box[0:8:2]]
+ points_y = [min(max(y, 0), h) for y in box[1:9:2]]
+
+ points_x, points_y = sort_vertex(points_x, points_y)
+
+ if jitter_flag:
+ box_jitter(
+ points_x,
+ points_y,
+ jitter_ratio_x=jitter_ratio_x,
+ jitter_ratio_y=jitter_ratio_y)
+
+ points = [Point(points_x[i], points_y[i]) for i in range(4)]
+ edges = [
+ LineString([points[i], points[i + 1 if i < 3 else 0]])
+ for i in range(4)
+ ]
+
+ pts1 = np.float32([[points[i].x, points[i].y] for i in range(4)])
+ box_width = max(edges[0].length, edges[2].length)
+ box_height = max(edges[1].length, edges[3].length)
+
+ pts2 = np.float32([[0, 0], [box_width, 0], [box_width, box_height],
+ [0, box_height]])
+ M = cv2.getPerspectiveTransform(pts1, pts2)
+ dst_img = cv2.warpPerspective(src_img, M,
+ (int(box_width), int(box_height)))
+
+ return dst_img
+
+
+def crop_img(src_img, box):
+ """Crop box area to rectangle.
+
+ Args:
+ src_img (np.array): Image before crop.
+ box (list[float | int]): Points of quadrangle.
+ """
+ assert utils.is_type_list(box, float) or utils.is_type_list(box, int)
+ assert len(box) == 8
+
+ h, w = src_img.shape[:2]
+ points_x = [min(max(x, 0), w) for x in box[0:8:2]]
+ points_y = [min(max(y, 0), h) for y in box[1:9:2]]
+
+ left = int(min(points_x))
+ top = int(min(points_y))
+ right = int(max(points_x))
+ bottom = int(max(points_y))
+
+ dst_img = src_img[top:bottom, left:right]
+
+ return dst_img
diff --git a/mmocr/datasets/pipelines/ocr_seg_targets.py b/mmocr/datasets/pipelines/ocr_seg_targets.py
new file mode 100644
index 00000000..cbd9b869
--- /dev/null
+++ b/mmocr/datasets/pipelines/ocr_seg_targets.py
@@ -0,0 +1,201 @@
+import cv2
+import numpy as np
+
+import mmocr.utils.check_argument as check_argument
+from mmdet.core import BitmapMasks
+from mmdet.datasets.builder import PIPELINES
+from mmocr.models.builder import build_convertor
+
+
+@PIPELINES.register_module()
+class OCRSegTargets:
+ """Generate gt shrinked kernels for segmentation based OCR framework.
+
+ Args:
+ label_convertor (dict): Dictionary to construct label_convertor
+ to convert char to index.
+ attn_shrink_ratio (float): The area shrinked ratio
+ between attention kernels and gt text masks.
+ seg_shrink_ratio (float): The area shrinked ratio
+ between segmentation kernels and gt text masks.
+ box_type (str): Character box type, should be either
+ 'char_rects' or 'char_quads', with 'char_rects'
+ for rectangle with ``xyxy`` style and 'char_quads'
+ for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
+ """
+
+ def __init__(self,
+ label_convertor=None,
+ attn_shrink_ratio=0.5,
+ seg_shrink_ratio=0.25,
+ box_type='char_rects',
+ pad_val=255):
+
+ assert isinstance(attn_shrink_ratio, float)
+ assert isinstance(seg_shrink_ratio, float)
+ assert 0. < attn_shrink_ratio < 1.0
+ assert 0. < seg_shrink_ratio < 1.0
+ assert label_convertor is not None
+ assert box_type in ('char_rects', 'char_quads')
+
+ self.attn_shrink_ratio = attn_shrink_ratio
+ self.seg_shrink_ratio = seg_shrink_ratio
+ self.label_convertor = build_convertor(label_convertor)
+ self.box_type = box_type
+ self.pad_val = pad_val
+
+ def shrink_char_quad(self, char_quad, shrink_ratio):
+ """Shrink char box in style of quadrangle.
+
+ Args:
+ char_quad (list[float]): Char box with format
+ [x1, y1, x2, y2, x3, y3, x4, y4].
+ shrink_ratio (float): The area shrinked ratio
+ between gt kernels and gt text masks.
+ """
+ points = [[char_quad[0], char_quad[1]], [char_quad[2], char_quad[3]],
+ [char_quad[4], char_quad[5]], [char_quad[6], char_quad[7]]]
+ shrink_points = []
+ for p_idx, point in enumerate(points):
+ p1 = points[(p_idx + 3) % 4]
+ p2 = points[(p_idx + 1) % 4]
+
+ dist1 = self.l2_dist_two_points(p1, point)
+ dist2 = self.l2_dist_two_points(p2, point)
+ min_dist = min(dist1, dist2)
+
+ v1 = [p1[0] - point[0], p1[1] - point[1]]
+ v2 = [p2[0] - point[0], p2[1] - point[1]]
+
+ temp_dist1 = (shrink_ratio * min_dist /
+ dist1) if min_dist != 0 else 0.
+ temp_dist2 = (shrink_ratio * min_dist /
+ dist2) if min_dist != 0 else 0.
+
+ v1 = [temp * temp_dist1 for temp in v1]
+ v2 = [temp * temp_dist2 for temp in v2]
+
+ shrink_point = [
+ round(point[0] + v1[0] + v2[0]),
+ round(point[1] + v1[1] + v2[1])
+ ]
+ shrink_points.append(shrink_point)
+
+ poly = np.array(shrink_points)
+
+ return poly
+
+ def shrink_char_rect(self, char_rect, shrink_ratio):
+ """Shrink char box in style of rectangle.
+
+ Args:
+ char_rect (list[float]): Char box with format
+ [x_min, y_min, x_max, y_max].
+ shrink_ratio (float): The area shrinked ratio
+ between gt kernels and gt text masks.
+ """
+ x_min, y_min, x_max, y_max = char_rect
+ w = x_max - x_min
+ h = y_max - y_min
+ x_min_s = round((x_min + x_max - w * shrink_ratio) / 2)
+ y_min_s = round((y_min + y_max - h * shrink_ratio) / 2)
+ x_max_s = round((x_min + x_max + w * shrink_ratio) / 2)
+ y_max_s = round((y_min + y_max + h * shrink_ratio) / 2)
+ poly = np.array([[x_min_s, y_min_s], [x_max_s, y_min_s],
+ [x_max_s, y_max_s], [x_min_s, y_max_s]])
+
+ return poly
+
+ def generate_kernels(self,
+ resize_shape,
+ pad_shape,
+ char_boxes,
+ char_inds,
+ shrink_ratio=0.5,
+ binary=True):
+ """Generate char instance kernels for one shrink ratio.
+
+ Args:
+ resize_shape (tuple(int, int)): Image size (height, width)
+ after resizing.
+ pad_shape (tuple(int, int)): Image size (height, width)
+ after padding.
+ char_boxes (list[list[float]]): The list of char polygons.
+ char_inds (list[int]): List of char indexes.
+ shrink_ratio (float): The shrink ratio of kernel.
+ binary (bool): If True, return binary ndarray
+ containing 0 & 1 only.
+ Returns:
+ char_kernel (ndarray): The text kernel mask of (height, width).
+ """
+ assert isinstance(resize_shape, tuple)
+ assert isinstance(pad_shape, tuple)
+ assert check_argument.is_2dlist(char_boxes)
+ assert check_argument.is_type_list(char_inds, int)
+ assert isinstance(shrink_ratio, float)
+ assert isinstance(binary, bool)
+
+ char_kernel = np.zeros(pad_shape, dtype=np.int32)
+ char_kernel[:resize_shape[0], resize_shape[1]:] = self.pad_val
+
+ for i, char_box in enumerate(char_boxes):
+ if self.box_type == 'char_rects':
+ poly = self.shrink_char_rect(char_box, shrink_ratio)
+ elif self.box_type == 'char_quads':
+ poly = self.shrink_char_quad(char_box, shrink_ratio)
+
+ fill_value = 1 if binary else char_inds[i]
+ cv2.fillConvexPoly(char_kernel, poly.astype(np.int32),
+ (fill_value))
+
+ return char_kernel
+
+ def l2_dist_two_points(self, p1, p2):
+ return ((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5
+
+ def __call__(self, results):
+ img_shape = results['img_shape']
+ resize_shape = results['resize_shape']
+
+ h_scale = 1.0 * resize_shape[0] / img_shape[0]
+ w_scale = 1.0 * resize_shape[1] / img_shape[1]
+
+ char_boxes, char_inds = [], []
+ char_num = len(results['ann_info'][self.box_type])
+ for i in range(char_num):
+ char_box = results['ann_info'][self.box_type][i]
+ num_points = 2 if self.box_type == 'char_rects' else 4
+ for j in range(num_points):
+ char_box[j * 2] = round(char_box[j * 2] * w_scale)
+ char_box[j * 2 + 1] = round(char_box[j * 2 + 1] * h_scale)
+ char_boxes.append(char_box)
+ char = results['ann_info']['chars'][i]
+ char_ind = self.label_convertor.str2idx([char])[0][0]
+ char_inds.append(char_ind)
+
+ resize_shape = tuple(results['resize_shape'][:2])
+ pad_shape = tuple(results['pad_shape'][:2])
+ binary_target = self.generate_kernels(
+ resize_shape,
+ pad_shape,
+ char_boxes,
+ char_inds,
+ shrink_ratio=self.attn_shrink_ratio,
+ binary=True)
+
+ seg_target = self.generate_kernels(
+ resize_shape,
+ pad_shape,
+ char_boxes,
+ char_inds,
+ shrink_ratio=self.seg_shrink_ratio,
+ binary=False)
+
+ mask = np.ones(pad_shape, dtype=np.int32)
+ mask[:resize_shape[0], resize_shape[1]:] = 0
+
+ results['gt_kernels'] = BitmapMasks([binary_target, seg_target, mask],
+ pad_shape[0], pad_shape[1])
+ results['mask_fields'] = ['gt_kernels']
+
+ return results
diff --git a/mmocr/datasets/pipelines/ocr_transforms.py b/mmocr/datasets/pipelines/ocr_transforms.py
new file mode 100644
index 00000000..afddd40e
--- /dev/null
+++ b/mmocr/datasets/pipelines/ocr_transforms.py
@@ -0,0 +1,447 @@
+import math
+
+import cv2
+import mmcv
+import numpy as np
+import torch
+import torchvision.transforms.functional as TF
+from mmcv.runner.dist_utils import get_dist_info
+from PIL import Image
+from shapely.geometry import Polygon
+from shapely.geometry import box as shapely_box
+
+import mmocr.utils as utils
+from mmdet.datasets.builder import PIPELINES
+from mmocr.datasets.pipelines.crop import warp_img
+
+
+@PIPELINES.register_module()
+class ResizeOCR:
+ """Image resizing and padding for OCR.
+
+ Args:
+ height (int | tuple(int)): Image height after resizing.
+ min_width (none | int | tuple(int)): Image minimum width
+ after resizing.
+ max_width (none | int | tuple(int)): Image maximum width
+ after resizing.
+ keep_aspect_ratio (bool): Keep image aspect ratio if True
+ during resizing, Otherwise resize to the size height *
+ max_width.
+ img_pad_value (int): Scalar to fill padding area.
+ width_downsample_ratio (float): Downsample ratio in horizontal
+ direction from input image to output feature.
+ """
+
+ def __init__(self,
+ height,
+ min_width=None,
+ max_width=None,
+ keep_aspect_ratio=True,
+ img_pad_value=0,
+ width_downsample_ratio=1.0 / 16):
+ assert isinstance(height, (int, tuple))
+ assert utils.is_none_or_type(min_width, (int, tuple))
+ assert utils.is_none_or_type(max_width, (int, tuple))
+ if not keep_aspect_ratio:
+ assert max_width is not None, \
+ '"max_width" must assigned ' + \
+ 'if "keep_aspect_ratio" is False'
+ assert isinstance(img_pad_value, int)
+ if isinstance(height, tuple):
+ assert isinstance(min_width, tuple)
+ assert isinstance(max_width, tuple)
+ assert len(height) == len(min_width) == len(max_width)
+
+ self.height = height
+ self.min_width = min_width
+ self.max_width = max_width
+ self.keep_aspect_ratio = keep_aspect_ratio
+ self.img_pad_value = img_pad_value
+ self.width_downsample_ratio = width_downsample_ratio
+
+ def __call__(self, results):
+ rank, _ = get_dist_info()
+ if isinstance(self.height, int):
+ dst_height = self.height
+ dst_min_width = self.min_width
+ dst_max_width = self.max_width
+ else:
+ """Multi-scale resize used in distributed training.
+
+ Choose one (height, width) pair for one rank id.
+ """
+ idx = rank % len(self.height)
+ dst_height = self.height[idx]
+ dst_min_width = self.min_width[idx]
+ dst_max_width = self.max_width[idx]
+
+ img_shape = results['img_shape']
+ ori_height, ori_width = img_shape[:2]
+ valid_ratio = 1.0
+ resize_shape = list(img_shape)
+ pad_shape = list(img_shape)
+
+ if self.keep_aspect_ratio:
+ new_width = math.ceil(float(dst_height) / ori_height * ori_width)
+ width_divisor = int(1 / self.width_downsample_ratio)
+ # make sure new_width is an integral multiple of width_divisor.
+ if new_width % width_divisor != 0:
+ new_width = round(new_width / width_divisor) * width_divisor
+ if dst_min_width is not None:
+ new_width = max(dst_min_width, new_width)
+ if dst_max_width is not None:
+ valid_ratio = min(1.0, 1.0 * new_width / dst_max_width)
+ resize_width = min(dst_max_width, new_width)
+ img_resize = cv2.resize(results['img'],
+ (resize_width, dst_height))
+ resize_shape = img_resize.shape
+ pad_shape = img_resize.shape
+ if new_width < dst_max_width:
+ img_resize = mmcv.impad(
+ img_resize,
+ shape=(dst_height, dst_max_width),
+ pad_val=self.img_pad_value)
+ pad_shape = img_resize.shape
+ else:
+ img_resize = cv2.resize(results['img'],
+ (new_width, dst_height))
+ resize_shape = img_resize.shape
+ pad_shape = img_resize.shape
+ else:
+ img_resize = cv2.resize(results['img'],
+ (dst_max_width, dst_height))
+ resize_shape = img_resize.shape
+ pad_shape = img_resize.shape
+
+ results['img'] = img_resize
+ results['resize_shape'] = resize_shape
+ results['pad_shape'] = pad_shape
+ results['valid_ratio'] = valid_ratio
+
+ return results
+
+
+@PIPELINES.register_module()
+class ToTensorOCR:
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor."""
+
+ def __init__(self):
+ pass
+
+ def __call__(self, results):
+ results['img'] = TF.to_tensor(results['img'].copy())
+
+ return results
+
+
+@PIPELINES.register_module()
+class NormalizeOCR:
+ """Normalize a tensor image with mean and standard deviation."""
+
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, results):
+ results['img'] = TF.normalize(results['img'], self.mean, self.std)
+
+ return results
+
+
+@PIPELINES.register_module()
+class OnlineCropOCR:
+ """Crop text areas from whole image with bounding box jitter. If no bbox is
+ given, return directly.
+
+ Args:
+ box_keys (list[str]): Keys in results which correspond to RoI bbox.
+ jitter_prob (float): The probability of box jitter.
+ max_jitter_ratio_x (float): Maximum horizontal jitter ratio
+ relative to height.
+ max_jitter_ratio_y (float): Maximum vertical jitter ratio
+ relative to height.
+ """
+
+ def __init__(self,
+ box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'],
+ jitter_prob=0.5,
+ max_jitter_ratio_x=0.05,
+ max_jitter_ratio_y=0.02):
+ assert utils.is_type_list(box_keys, str)
+ assert 0 <= jitter_prob <= 1
+ assert 0 <= max_jitter_ratio_x <= 1
+ assert 0 <= max_jitter_ratio_y <= 1
+
+ self.box_keys = box_keys
+ self.jitter_prob = jitter_prob
+ self.max_jitter_ratio_x = max_jitter_ratio_x
+ self.max_jitter_ratio_y = max_jitter_ratio_y
+
+ def __call__(self, results):
+
+ if 'img_info' not in results:
+ return results
+
+ crop_flag = True
+ box = []
+ for key in self.box_keys:
+ if key not in results['img_info']:
+ crop_flag = False
+ break
+
+ box.append(float(results['img_info'][key]))
+
+ if not crop_flag:
+ return results
+
+ jitter_flag = np.random.random() > self.jitter_prob
+
+ kwargs = dict(
+ jitter_flag=jitter_flag,
+ jitter_ratio_x=self.max_jitter_ratio_x,
+ jitter_ratio_y=self.max_jitter_ratio_y)
+ crop_img = warp_img(results['img'], box, **kwargs)
+
+ results['img'] = crop_img
+ results['img_shape'] = crop_img.shape
+
+ return results
+
+
+@PIPELINES.register_module()
+class FancyPCA:
+ """Implementation of PCA based image augmentation, proposed in the paper
+ ``Imagenet Classification With Deep Convolutional Neural Networks``.
+
+ It alters the intensities of RGB values along the principal components of
+ ImageNet dataset.
+ """
+
+ def __init__(self, eig_vec=None, eig_val=None):
+ if eig_vec is None:
+ eig_vec = torch.Tensor([
+ [-0.5675, +0.7192, +0.4009],
+ [-0.5808, -0.0045, -0.8140],
+ [-0.5836, -0.6948, +0.4203],
+ ]).t()
+ if eig_val is None:
+ eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]])
+ self.eig_val = eig_val # 1*3
+ self.eig_vec = eig_vec # 3*3
+
+ def pca(self, tensor):
+ assert tensor.size(0) == 3
+ alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1
+ reconst = torch.mm(self.eig_val * alpha, self.eig_vec)
+ tensor = tensor + reconst.view(3, 1, 1)
+
+ return tensor
+
+ def __call__(self, results):
+ img = results['img']
+ tensor = self.pca(img)
+ results['img'] = tensor
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomPaddingOCR:
+ """Pad the given image on all sides, as well as modify the coordinates of
+ character bounding box in image.
+
+ Args:
+ max_ratio (list[int]): [left, top, right, bottom].
+ box_type (None|str): Character box type. If not none,
+ should be either 'char_rects' or 'char_quads', with
+ 'char_rects' for rectangle with ``xyxy`` style and
+ 'char_quads' for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
+ """
+
+ def __init__(self, max_ratio=None, box_type=None):
+ if max_ratio is None:
+ max_ratio = [0.1, 0.2, 0.1, 0.2]
+ else:
+ assert utils.is_type_list(max_ratio, float)
+ assert len(max_ratio) == 4
+ assert box_type is None or box_type in ('char_rects', 'char_quads')
+
+ self.max_ratio = max_ratio
+ self.box_type = box_type
+
+ def __call__(self, results):
+
+ img_shape = results['img_shape']
+ ori_height, ori_width = img_shape[:2]
+
+ random_padding_left = round(
+ np.random.uniform(0, self.max_ratio[0]) * ori_width)
+ random_padding_top = round(
+ np.random.uniform(0, self.max_ratio[1]) * ori_height)
+ random_padding_right = round(
+ np.random.uniform(0, self.max_ratio[2]) * ori_width)
+ random_padding_bottom = round(
+ np.random.uniform(0, self.max_ratio[3]) * ori_height)
+
+ img = np.copy(results['img'])
+ img = cv2.copyMakeBorder(img, random_padding_top,
+ random_padding_bottom, random_padding_left,
+ random_padding_right, cv2.BORDER_REPLICATE)
+ results['img'] = img
+ results['img_shape'] = img.shape
+
+ if self.box_type is not None:
+ num_points = 2 if self.box_type == 'char_rects' else 4
+ char_num = len(results['ann_info'][self.box_type])
+ for i in range(char_num):
+ for j in range(num_points):
+ results['ann_info'][self.box_type][i][
+ j * 2] += random_padding_left
+ results['ann_info'][self.box_type][i][
+ j * 2 + 1] += random_padding_top
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomRotateImageBox:
+ """Rotate augmentation for segmentation based text recognition.
+
+ Args:
+ min_angle (int): Minimum rotation angle for image and box.
+ max_angle (int): Maximum rotation angle for image and box.
+ box_type (str): Character box type, should be either
+ 'char_rects' or 'char_quads', with 'char_rects'
+ for rectangle with ``xyxy`` style and 'char_quads'
+ for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
+ """
+
+ def __init__(self, min_angle=-10, max_angle=10, box_type='char_quads'):
+ assert box_type in ('char_rects', 'char_quads')
+
+ self.min_angle = min_angle
+ self.max_angle = max_angle
+ self.box_type = box_type
+
+ def __call__(self, results):
+ in_img = results['img']
+ in_chars = results['ann_info']['chars']
+ in_boxes = results['ann_info'][self.box_type]
+
+ img_width, img_height = in_img.size
+ rotate_center = [img_width / 2., img_height / 2.]
+
+ tan_temp_max_angle = rotate_center[1] / rotate_center[0]
+ temp_max_angle = np.arctan(tan_temp_max_angle) * 180. / np.pi
+
+ random_angle = np.random.uniform(
+ max(self.min_angle, -temp_max_angle),
+ min(self.max_angle, temp_max_angle))
+ random_angle_radian = random_angle * np.pi / 180.
+
+ img_box = shapely_box(0, 0, img_width, img_height)
+
+ out_img = TF.rotate(
+ in_img,
+ random_angle,
+ resample=False,
+ expand=False,
+ center=rotate_center)
+
+ out_boxes, out_chars = self.rotate_bbox(in_boxes, in_chars,
+ random_angle_radian,
+ rotate_center, img_box)
+
+ results['img'] = out_img
+ results['ann_info']['chars'] = out_chars
+ results['ann_info'][self.box_type] = out_boxes
+
+ return results
+
+ @staticmethod
+ def rotate_bbox(boxes, chars, angle, center, img_box):
+ out_boxes = []
+ out_chars = []
+ for idx, bbox in enumerate(boxes):
+ temp_bbox = []
+ for i in range(len(bbox) // 2):
+ point = [bbox[2 * i], bbox[2 * i + 1]]
+ temp_bbox.append(
+ RandomRotateImageBox.rotate_point(point, angle, center))
+ poly_temp_bbox = Polygon(temp_bbox).buffer(0)
+ if poly_temp_bbox.is_valid:
+ if img_box.intersects(poly_temp_bbox) and (
+ not img_box.touches(poly_temp_bbox)):
+ temp_bbox_area = poly_temp_bbox.area
+
+ intersect_area = img_box.intersection(poly_temp_bbox).area
+ intersect_ratio = intersect_area / temp_bbox_area
+
+ if intersect_ratio >= 0.7:
+ out_box = []
+ for p in temp_bbox:
+ out_box.extend(p)
+ out_boxes.append(out_box)
+ out_chars.append(chars[idx])
+
+ return out_boxes, out_chars
+
+ @staticmethod
+ def rotate_point(point, angle, center):
+ cos_theta = math.cos(-angle)
+ sin_theta = math.sin(-angle)
+ c_x = center[0]
+ c_y = center[1]
+ new_x = (point[0] - c_x) * cos_theta - (point[1] -
+ c_y) * sin_theta + c_x
+ new_y = (point[0] - c_x) * sin_theta + (point[1] -
+ c_y) * cos_theta + c_y
+
+ return [new_x, new_y]
+
+
+@PIPELINES.register_module()
+class OpencvToPil:
+ """Convert ``numpy.ndarray`` (bgr) to ``PIL Image`` (rgb)."""
+
+ def __init__(self, **kwargs):
+ pass
+
+ def __call__(self, results):
+ img = results['img'][..., ::-1]
+ img = Image.fromarray(img)
+ results['img'] = img
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class PilToOpencv:
+ """Convert ``PIL Image`` (rgb) to ``numpy.ndarray`` (bgr)."""
+
+ def __init__(self, **kwargs):
+ pass
+
+ def __call__(self, results):
+ img = np.asarray(results['img'])
+ img = img[..., ::-1]
+ results['img'] = img
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
diff --git a/mmocr/datasets/pipelines/test_time_aug.py b/mmocr/datasets/pipelines/test_time_aug.py
new file mode 100644
index 00000000..5c8c1a60
--- /dev/null
+++ b/mmocr/datasets/pipelines/test_time_aug.py
@@ -0,0 +1,108 @@
+import mmcv
+import numpy as np
+
+from mmdet.datasets.builder import PIPELINES
+from mmdet.datasets.pipelines.compose import Compose
+
+
+@PIPELINES.register_module()
+class MultiRotateAugOCR:
+ """Test-time augmentation with multiple rotations in the case that
+ img_height > img_width.
+
+ An example configuration is as follows:
+
+ .. code-block::
+
+ rotate_degrees=[0, 90, 270],
+ transforms=[
+ dict(
+ type='ResizeOCR',
+ height=32,
+ min_width=32,
+ max_width=160,
+ keep_aspect_ratio=True),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'valid_ratio'
+ ]),
+ ]
+
+ After MultiRotateAugOCR with above configuration, the results are wrapped
+ into lists of the same length as follows:
+
+ .. code-block::
+
+ dict(
+ img=[...],
+ img_shape=[...]
+ ...
+ )
+
+ Args:
+ transforms (list[dict]): Transformation applied for each augmentation.
+ rotate_degrees (list[int] | None): Degrees of anti-clockwise rotation.
+ force_rotate (bool): If True, rotate image by 'rotate_degrees'
+ while ignore image aspect ratio.
+ """
+
+ def __init__(self, transforms, rotate_degrees=None, force_rotate=False):
+ self.transforms = Compose(transforms)
+ self.force_rotate = force_rotate
+ if rotate_degrees is not None:
+ self.rotate_degrees = rotate_degrees if isinstance(
+ rotate_degrees, list) else [rotate_degrees]
+ assert mmcv.is_list_of(self.rotate_degrees, int)
+ for degree in self.rotate_degrees:
+ assert 0 <= degree < 360
+ assert degree % 90 == 0
+ if 0 not in self.rotate_degrees:
+ self.rotate_degrees.append(0)
+ else:
+ self.rotate_degrees = [0]
+
+ def __call__(self, results):
+ """Call function to apply test time augment transformation to results.
+
+ Args:
+ results (dict): Result dict contains the data to be transformed.
+
+ Returns:
+ dict[str: list]: The augmented data, where each value is wrapped
+ into a list.
+ """
+ img_shape = results['img_shape']
+ ori_height, ori_width = img_shape[:2]
+ if not self.force_rotate and ori_height <= ori_width:
+ rotate_degrees = [0]
+ else:
+ rotate_degrees = self.rotate_degrees
+ aug_data = []
+ for degree in set(rotate_degrees):
+ _results = results.copy()
+ if degree == 0:
+ pass
+ elif degree == 90:
+ _results['img'] = np.rot90(_results['img'], 1)
+ elif degree == 180:
+ _results['img'] = np.rot90(_results['img'], 2)
+ elif degree == 270:
+ _results['img'] = np.rot90(_results['img'], 3)
+ data = self.transforms(_results)
+ aug_data.append(data)
+ # list of dict to dict of list
+ aug_data_dict = {key: [] for key in aug_data[0]}
+ for data in aug_data:
+ for key, val in data.items():
+ aug_data_dict[key].append(val)
+ return aug_data_dict
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(transforms={self.transforms}, '
+ repr_str += f'rotate_degrees={self.rotate_degrees})'
+ return repr_str
diff --git a/mmocr/datasets/text_det_dataset.py b/mmocr/datasets/text_det_dataset.py
new file mode 100644
index 00000000..fb5d3fa4
--- /dev/null
+++ b/mmocr/datasets/text_det_dataset.py
@@ -0,0 +1,121 @@
+import numpy as np
+
+from mmdet.datasets.builder import DATASETS
+from mmocr.core.evaluation.hmean import eval_hmean
+from mmocr.datasets.base_dataset import BaseDataset
+
+
+@DATASETS.register_module()
+class TextDetDataset(BaseDataset):
+
+ def _parse_anno_info(self, annotations):
+ """Parse bbox and mask annotation.
+ Args:
+ annotations (dict): Annotations of one image.
+
+ Returns:
+ dict: A dict containing the following keys: bboxes, bboxes_ignore,
+ labels, masks, masks_ignore. "masks" and
+ "masks_ignore" are represented by polygon boundary
+ point sequences.
+ """
+ gt_bboxes, gt_bboxes_ignore = [], []
+ gt_masks, gt_masks_ignore = [], []
+ gt_labels = []
+ for ann in annotations:
+ if ann.get('iscrowd', False):
+ gt_bboxes_ignore.append(ann['bbox'])
+ gt_masks_ignore.append(ann.get('segmentation', None))
+ else:
+ gt_bboxes.append(ann['bbox'])
+ gt_labels.append(ann['category_id'])
+ gt_masks.append(ann.get('segmentation', None))
+ if gt_bboxes:
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ else:
+ gt_bboxes = np.zeros((0, 4), dtype=np.float32)
+ gt_labels = np.array([], dtype=np.int64)
+
+ if gt_bboxes_ignore:
+ gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
+ else:
+ gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
+
+ ann = dict(
+ bboxes=gt_bboxes,
+ labels=gt_labels,
+ bboxes_ignore=gt_bboxes_ignore,
+ masks_ignore=gt_masks_ignore,
+ masks=gt_masks)
+
+ return ann
+
+ def prepare_train_img(self, index):
+ """Get training data and annotations from pipeline.
+
+ Args:
+ index (int): Index of data.
+
+ Returns:
+ dict: Training data and annotation after pipeline with new keys \
+ introduced by pipeline.
+ """
+ img_ann_info = self.data_infos[index]
+ img_info = {
+ 'filename': img_ann_info['file_name'],
+ 'height': img_ann_info['height'],
+ 'width': img_ann_info['width']
+ }
+ ann_info = self._parse_anno_info(img_ann_info['annotations'])
+ results = dict(img_info=img_info, ann_info=ann_info)
+ results['bbox_fields'] = []
+ results['mask_fields'] = []
+ results['seg_fields'] = []
+ self.pre_pipeline(results)
+
+ return self.pipeline(results)
+
+ def evaluate(self,
+ results,
+ metric='hmean-iou',
+ score_thr=0.3,
+ rank_list=None,
+ logger=None,
+ **kwargs):
+ """Evaluate the dataset.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ score_thr (float): Score threshold for prediction map.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ rank_list (str): json file used to save eval result
+ of each image after ranking.
+ Returns:
+ dict[str: float]
+ """
+ metrics = metric if isinstance(metric, list) else [metric]
+ allowed_metrics = ['hmean-iou', 'hmean-ic13']
+ metrics = set(metrics) & set(allowed_metrics)
+
+ img_infos = []
+ ann_infos = []
+ for i in range(len(self)):
+ img_ann_info = self.data_infos[i]
+ img_info = {'filename': img_ann_info['file_name']}
+ ann_info = self._parse_anno_info(img_ann_info['annotations'])
+ img_infos.append(img_info)
+ ann_infos.append(ann_info)
+
+ eval_results = eval_hmean(
+ results,
+ img_infos,
+ ann_infos,
+ metrics=metrics,
+ score_thr=score_thr,
+ logger=logger,
+ rank_list=rank_list)
+
+ return eval_results
diff --git a/mmocr/datasets/utils/__init__.py b/mmocr/datasets/utils/__init__.py
new file mode 100644
index 00000000..f014de7e
--- /dev/null
+++ b/mmocr/datasets/utils/__init__.py
@@ -0,0 +1,4 @@
+from .loader import HardDiskLoader, LmdbLoader
+from .parser import LineJsonParser, LineStrParser
+
+__all__ = ['HardDiskLoader', 'LmdbLoader', 'LineStrParser', 'LineJsonParser']
diff --git a/mmocr/datasets/utils/loader.py b/mmocr/datasets/utils/loader.py
new file mode 100644
index 00000000..55a4d075
--- /dev/null
+++ b/mmocr/datasets/utils/loader.py
@@ -0,0 +1,108 @@
+import os.path as osp
+
+from mmocr.datasets.builder import LOADERS, build_parser
+
+
+@LOADERS.register_module()
+class Loader:
+ """Load annotation from annotation file, and parse instance information to
+ dict format with parser.
+
+ Args:
+ ann_file (str): Annotation file path.
+ parser (dict): Dictionary to construct parser
+ to parse original annotation infos.
+ repeat (int): Repeated times of annotations.
+ """
+
+ def __init__(self, ann_file, parser, repeat=1):
+ assert isinstance(ann_file, str)
+ assert isinstance(repeat, int)
+ assert isinstance(parser, dict)
+ assert repeat > 0
+ assert osp.exists(ann_file), f'{ann_file} is not exist'
+
+ self.ori_data_infos = self._load(ann_file)
+ self.parser = build_parser(parser)
+ self.repeat = repeat
+
+ def __len__(self):
+ return len(self.ori_data_infos) * self.repeat
+
+ def _load(self, ann_file):
+ """Load annotation file."""
+ raise NotImplementedError
+
+ def __getitem__(self, index):
+ """Retrieve anno info of one instance with dict format."""
+ return self.parser.get_item(self.ori_data_infos, index)
+
+
+@LOADERS.register_module()
+class HardDiskLoader(Loader):
+ """Load annotation file from hard disk to RAM.
+
+ Args:
+ ann_file (str): Annotation file path.
+ """
+
+ def _load(self, ann_file):
+ data_ret = []
+ with open(ann_file, 'r', encoding='utf-8') as f:
+ for line in f:
+ line = line.strip()
+ data_ret.append(line)
+
+ return data_ret
+
+
+@LOADERS.register_module()
+class LmdbLoader(Loader):
+ """Load annotation file with lmdb storage backend."""
+
+ def _load(self, ann_file):
+ lmdb_anno_obj = LmdbAnnFileBackend(ann_file)
+
+ return lmdb_anno_obj
+
+
+class LmdbAnnFileBackend:
+ """Lmdb storage backend for annotation file.
+
+ Args:
+ lmdb_path (str): Lmdb file path.
+ """
+
+ def __init__(self, lmdb_path, coding='utf8'):
+ self.lmdb_path = lmdb_path
+ self.coding = coding
+ env = self._get_env()
+ with env.begin(write=False) as txn:
+ self.total_number = int(
+ txn.get('total_number'.encode(self.coding)).decode(
+ self.coding))
+
+ def __getitem__(self, index):
+ """Retrieval one line from lmdb file by index."""
+ # only attach env to self when __getitem__ is called
+ # because env object cannot be pickle
+ if not hasattr(self, 'env'):
+ self.env = self._get_env()
+
+ with self.env.begin(write=False) as txn:
+ line = txn.get(str(index).encode(self.coding)).decode(self.coding)
+ return line
+
+ def __len__(self):
+ return self.total_number
+
+ def _get_env(self):
+ import lmdb
+ return lmdb.open(
+ self.lmdb_path,
+ max_readers=1,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False,
+ )
diff --git a/mmocr/datasets/utils/parser.py b/mmocr/datasets/utils/parser.py
new file mode 100644
index 00000000..a895e217
--- /dev/null
+++ b/mmocr/datasets/utils/parser.py
@@ -0,0 +1,69 @@
+import json
+
+from mmocr.datasets.builder import PARSERS
+
+
+@PARSERS.register_module()
+class LineStrParser:
+ """Parse string of one line in annotation file to dict format.
+
+ Args:
+ keys (list[str]): Keys in result dict.
+ keys_idx (list[int]): Value index in sub-string list
+ for each key above.
+ separator (str): Separator to separate string to list of sub-string.
+ """
+
+ def __init__(self,
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' '):
+ assert isinstance(keys, list)
+ assert isinstance(keys_idx, list)
+ assert isinstance(separator, str)
+ assert len(keys) > 0
+ assert len(keys) == len(keys_idx)
+ self.keys = keys
+ self.keys_idx = keys_idx
+ self.separator = separator
+
+ def get_item(self, data_ret, index):
+ map_index = index % len(data_ret)
+ line_str = data_ret[map_index]
+ for split_key in self.separator:
+ if split_key != ' ':
+ line_str = line_str.replace(split_key, ' ')
+ line_str = line_str.split()
+ if len(line_str) <= max(self.keys_idx):
+ raise Exception(
+ f'key index: {max(self.keys_idx)} out of range: {line_str}')
+
+ line_info = {}
+ for i, key in enumerate(self.keys):
+ line_info[key] = line_str[self.keys_idx[i]]
+ return line_info
+
+
+@PARSERS.register_module()
+class LineJsonParser:
+ """Parse json-string of one line in annotation file to dict format.
+
+ Args:
+ keys (list[str]): Keys in both json-string and result dict.
+ """
+
+ def __init__(self, keys=[], **kwargs):
+ assert isinstance(keys, list)
+ assert len(keys) > 0
+ self.keys = keys
+
+ def get_item(self, data_ret, index):
+ map_index = index % len(data_ret)
+ line_json_obj = json.loads(data_ret[map_index])
+ line_info = {}
+ for key in self.keys:
+ if key not in line_json_obj:
+ raise Exception(f'key {key} not in line json {line_json_obj}')
+ line_info[key] = line_json_obj[key]
+
+ return line_info
diff --git a/mmocr/models/textrecog/__init__.py b/mmocr/models/textrecog/__init__.py
new file mode 100644
index 00000000..76bfa419
--- /dev/null
+++ b/mmocr/models/textrecog/__init__.py
@@ -0,0 +1,8 @@
+from .backbones import * # noqa: F401,F403
+from .convertors import * # noqa: F401,F403
+from .decoders import * # noqa: F401,F403
+from .encoders import * # noqa: F401,F403
+from .heads import * # noqa: F401,F403
+from .losses import * # noqa: F401,F403
+from .necks import * # noqa: F401,F403
+from .recognizer import * # noqa: F401,F403
diff --git a/mmocr/models/textrecog/backbones/__init__.py b/mmocr/models/textrecog/backbones/__init__.py
new file mode 100644
index 00000000..5cfb1c30
--- /dev/null
+++ b/mmocr/models/textrecog/backbones/__init__.py
@@ -0,0 +1,4 @@
+from .resnet31_ocr import ResNet31OCR
+from .very_deep_vgg import VeryDeepVgg
+
+__all__ = ['ResNet31OCR', 'VeryDeepVgg']
diff --git a/mmocr/models/textrecog/backbones/resnet31_ocr.py b/mmocr/models/textrecog/backbones/resnet31_ocr.py
new file mode 100644
index 00000000..e2787556
--- /dev/null
+++ b/mmocr/models/textrecog/backbones/resnet31_ocr.py
@@ -0,0 +1,149 @@
+import torch.nn as nn
+from mmcv.cnn import kaiming_init, uniform_init
+
+import mmocr.utils as utils
+from mmdet.models.builder import BACKBONES
+from mmocr.models.textrecog.layers import BasicBlock
+
+
+@BACKBONES.register_module()
+class ResNet31OCR(nn.Module):
+ """Implement ResNet backbone for text recognition, modified from
+ `ResNet `_
+ Args:
+ base_channels (int): Number of channels of input image tensor.
+ layers (list[int]): List of BasicBlock number for each stage.
+ channels (list[int]): List of out_channels of Conv2d layer.
+ out_indices (None | Sequence[int]): Indicdes of output stages.
+ stage4_pool_cfg (dict): Dictionary to construct and configure
+ pooling layer in stage 4.
+ last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
+ """
+
+ def __init__(self,
+ base_channels=3,
+ layers=[1, 2, 5, 3],
+ channels=[64, 128, 256, 256, 512, 512, 512],
+ out_indices=None,
+ stage4_pool_cfg=dict(kernel_size=(2, 1), stride=(2, 1)),
+ last_stage_pool=False):
+ super().__init__()
+ assert isinstance(base_channels, int)
+ assert utils.is_type_list(layers, int)
+ assert utils.is_type_list(channels, int)
+ assert out_indices is None or (isinstance(out_indices, list)
+ or isinstance(out_indices, tuple))
+ assert isinstance(last_stage_pool, bool)
+
+ self.out_indices = out_indices
+ self.last_stage_pool = last_stage_pool
+
+ # conv 1 (Conv, Conv)
+ self.conv1_1 = nn.Conv2d(
+ base_channels, channels[0], kernel_size=3, stride=1, padding=1)
+ self.bn1_1 = nn.BatchNorm2d(channels[0])
+ self.relu1_1 = nn.ReLU(inplace=True)
+
+ self.conv1_2 = nn.Conv2d(
+ channels[0], channels[1], kernel_size=3, stride=1, padding=1)
+ self.bn1_2 = nn.BatchNorm2d(channels[1])
+ self.relu1_2 = nn.ReLU(inplace=True)
+
+ # conv 2 (Max-pooling, Residual block, Conv)
+ self.pool2 = nn.MaxPool2d(
+ kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self.block2 = self._make_layer(channels[1], channels[2], layers[0])
+ self.conv2 = nn.Conv2d(
+ channels[2], channels[2], kernel_size=3, stride=1, padding=1)
+ self.bn2 = nn.BatchNorm2d(channels[2])
+ self.relu2 = nn.ReLU(inplace=True)
+
+ # conv 3 (Max-pooling, Residual block, Conv)
+ self.pool3 = nn.MaxPool2d(
+ kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self.block3 = self._make_layer(channels[2], channels[3], layers[1])
+ self.conv3 = nn.Conv2d(
+ channels[3], channels[3], kernel_size=3, stride=1, padding=1)
+ self.bn3 = nn.BatchNorm2d(channels[3])
+ self.relu3 = nn.ReLU(inplace=True)
+
+ # conv 4 (Max-pooling, Residual block, Conv)
+ self.pool4 = nn.MaxPool2d(padding=0, ceil_mode=True, **stage4_pool_cfg)
+ self.block4 = self._make_layer(channels[3], channels[4], layers[2])
+ self.conv4 = nn.Conv2d(
+ channels[4], channels[4], kernel_size=3, stride=1, padding=1)
+ self.bn4 = nn.BatchNorm2d(channels[4])
+ self.relu4 = nn.ReLU(inplace=True)
+
+ # conv 5 ((Max-pooling), Residual block, Conv)
+ self.pool5 = None
+ if self.last_stage_pool:
+ self.pool5 = nn.MaxPool2d(
+ kernel_size=2, stride=2, padding=0, ceil_mode=True) # 1/16
+ self.block5 = self._make_layer(channels[4], channels[5], layers[3])
+ self.conv5 = nn.Conv2d(
+ channels[5], channels[5], kernel_size=3, stride=1, padding=1)
+ self.bn5 = nn.BatchNorm2d(channels[5])
+ self.relu5 = nn.ReLU(inplace=True)
+
+ def init_weights(self, pretrained=None):
+ # initialize weight and bias
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ uniform_init(m)
+
+ def _make_layer(self, input_channels, output_channels, blocks):
+ layers = []
+ for _ in range(blocks):
+ downsample = None
+ if input_channels != output_channels:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ input_channels,
+ output_channels,
+ kernel_size=1,
+ stride=1,
+ bias=False),
+ nn.BatchNorm2d(output_channels),
+ )
+ layers.append(
+ BasicBlock(
+ input_channels, output_channels, downsample=downsample))
+ input_channels = output_channels
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+
+ x = self.conv1_1(x)
+ x = self.bn1_1(x)
+ x = self.relu1_1(x)
+
+ x = self.conv1_2(x)
+ x = self.bn1_2(x)
+ x = self.relu1_2(x)
+
+ outs = []
+ for i in range(4):
+ layer_index = i + 2
+ pool_layer = getattr(self, f'pool{layer_index}')
+ block_layer = getattr(self, f'block{layer_index}')
+ conv_layer = getattr(self, f'conv{layer_index}')
+ bn_layer = getattr(self, f'bn{layer_index}')
+ relu_layer = getattr(self, f'relu{layer_index}')
+
+ if pool_layer is not None:
+ x = pool_layer(x)
+ x = block_layer(x)
+ x = conv_layer(x)
+ x = bn_layer(x)
+ x = relu_layer(x)
+
+ outs.append(x)
+
+ if self.out_indices is not None:
+ return tuple([outs[i] for i in self.out_indices])
+
+ return x
diff --git a/mmocr/models/textrecog/convertors/__init__.py b/mmocr/models/textrecog/convertors/__init__.py
new file mode 100644
index 00000000..60fc6300
--- /dev/null
+++ b/mmocr/models/textrecog/convertors/__init__.py
@@ -0,0 +1,6 @@
+from .attn import AttnConvertor
+from .base import BaseConvertor
+from .ctc import CTCConvertor
+from .seg import SegConvertor
+
+__all__ = ['BaseConvertor', 'CTCConvertor', 'AttnConvertor', 'SegConvertor']
diff --git a/mmocr/models/textrecog/convertors/attn.py b/mmocr/models/textrecog/convertors/attn.py
new file mode 100644
index 00000000..a80282e8
--- /dev/null
+++ b/mmocr/models/textrecog/convertors/attn.py
@@ -0,0 +1,140 @@
+import torch
+
+import mmocr.utils as utils
+from mmocr.models.builder import CONVERTORS
+from .base import BaseConvertor
+
+
+@CONVERTORS.register_module()
+class AttnConvertor(BaseConvertor):
+ """Convert between text, index and tensor for encoder-decoder based
+ pipeline.
+
+ Args:
+ dict_type (str): Type of dict, should be one of {'DICT36', 'DICT90'}.
+ dict_file (None|str): Character dict file path. If not none,
+ higher priority than dict_type.
+ dict_list (None|list[str]): Character list. If not none, higher
+ priority than dict_type, but lower than dict_file.
+ with_unknown (bool): If True, add `UKN` token to class.
+ max_seq_len (int): Maximum sequence length of label.
+ lower (bool): If True, convert original string to lower case.
+ start_end_same (bool): Whether use the same index for
+ start and end token or not. Default: True.
+ """
+
+ def __init__(self,
+ dict_type='DICT90',
+ dict_file=None,
+ dict_list=None,
+ with_unknown=True,
+ max_seq_len=40,
+ lower=False,
+ start_end_same=True,
+ **kwargs):
+ super().__init__(dict_type, dict_file, dict_list)
+ assert isinstance(with_unknown, bool)
+ assert isinstance(max_seq_len, int)
+ assert isinstance(lower, bool)
+
+ self.with_unknown = with_unknown
+ self.max_seq_len = max_seq_len
+ self.lower = lower
+ self.start_end_same = start_end_same
+
+ self.update_dict()
+
+ def update_dict(self):
+ start_end_token = ''
+ unknown_token = ''
+ padding_token = ''
+
+ # unknown
+ self.unknown_idx = None
+ if self.with_unknown:
+ self.idx2char.append(unknown_token)
+ self.unknown_idx = len(self.idx2char) - 1
+
+ # BOS/EOS
+ self.idx2char.append(start_end_token)
+ self.start_idx = len(self.idx2char) - 1
+ if not self.start_end_same:
+ self.idx2char.append(start_end_token)
+ self.end_idx = len(self.idx2char) - 1
+
+ # padding
+ self.idx2char.append(padding_token)
+ self.padding_idx = len(self.idx2char) - 1
+
+ # update char2idx
+ self.char2idx = {}
+ for idx, char in enumerate(self.idx2char):
+ self.char2idx[char] = idx
+
+ def str2tensor(self, strings):
+ """
+ Convert text-string into tensor.
+ Args:
+ strings (list[str]): ['hello', 'world']
+ Returns:
+ dict (str: Tensor | list[tensor]):
+ tensors (list[Tensor]): [torch.Tensor([1,2,3,3,4]),
+ torch.Tensor([5,4,6,3,7])]
+ padded_targets (Tensor(bsz * max_seq_len))
+ """
+ assert utils.is_type_list(strings, str)
+
+ tensors, padded_targets = [], []
+ indexes = self.str2idx(strings)
+ for index in indexes:
+ tensor = torch.LongTensor(index)
+ tensors.append(tensor)
+ # target tensor for loss
+ src_target = torch.LongTensor(tensor.size(0) + 2).fill_(0)
+ src_target[-1] = self.end_idx
+ src_target[0] = self.start_idx
+ src_target[1:-1] = tensor
+ padded_target = (torch.ones(self.max_seq_len) *
+ self.padding_idx).long()
+ char_num = src_target.size(0)
+ if char_num > self.max_seq_len:
+ padded_target = src_target[:self.max_seq_len]
+ else:
+ padded_target[:char_num] = src_target
+ padded_targets.append(padded_target)
+ padded_targets = torch.stack(padded_targets, 0).long()
+
+ return {'targets': tensors, 'padded_targets': padded_targets}
+
+ def tensor2idx(self, outputs, img_metas=None):
+ """
+ Convert output tensor to text-index
+ Args:
+ outputs (tensor): model outputs with size: N * T * C
+ img_metas (list[dict]): Each dict contains one image info.
+ Returns:
+ indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]]
+ scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94],
+ [0.9,0.9,0.98,0.97,0.96]]
+ """
+ batch_size = outputs.size(0)
+ ignore_indexes = [self.padding_idx]
+ indexes, scores = [], []
+ for idx in range(batch_size):
+ seq = outputs[idx, :, :]
+ max_value, max_idx = torch.max(seq, -1)
+ str_index, str_score = [], []
+ output_index = max_idx.cpu().detach().numpy().tolist()
+ output_score = max_value.cpu().detach().numpy().tolist()
+ for char_index, char_score in zip(output_index, output_score):
+ if char_index in ignore_indexes:
+ continue
+ if char_index == self.end_idx:
+ break
+ str_index.append(char_index)
+ str_score.append(char_score)
+
+ indexes.append(str_index)
+ scores.append(str_score)
+
+ return indexes, scores
diff --git a/mmocr/models/textrecog/convertors/base.py b/mmocr/models/textrecog/convertors/base.py
new file mode 100644
index 00000000..9002d4fe
--- /dev/null
+++ b/mmocr/models/textrecog/convertors/base.py
@@ -0,0 +1,115 @@
+from mmocr.models.builder import CONVERTORS
+
+
+@CONVERTORS.register_module()
+class BaseConvertor:
+ """Convert between text, index and tensor for text recognize pipeline.
+
+ Args:
+ dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'.
+ dict_file (None|str): Character dict file path. If not none,
+ the dict_file is of higher priority than dict_type.
+ dict_list (None|list[str]): Character list. If not none, the list
+ is of higher priority than dict_type, but lower than dict_file.
+ """
+ start_idx = end_idx = padding_idx = 0
+ unknown_idx = None
+ lower = False
+
+ DICT36 = tuple('0123456789abcdefghijklmnopqrstuvwxyz')
+ DICT90 = tuple('0123456789abcdefghijklmnopqrstuvwxyz'
+ 'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()'
+ '*+,-./:;<=>?@[\\]_`~')
+
+ def __init__(self, dict_type='DICT90', dict_file=None, dict_list=None):
+ assert dict_type in ('DICT36', 'DICT90')
+ assert dict_file is None or isinstance(dict_file, str)
+ assert dict_list is None or isinstance(dict_list, list)
+ self.idx2char = []
+ if dict_file is not None:
+ with open(dict_file, encoding='utf-8') as fr:
+ for line in fr:
+ line = line.strip()
+ if line != '':
+ self.idx2char.append(line)
+ elif dict_list is not None:
+ self.idx2char = dict_list
+ else:
+ if dict_type == 'DICT36':
+ self.idx2char = list(self.DICT36)
+ else:
+ self.idx2char = list(self.DICT90)
+
+ self.char2idx = {}
+ for idx, char in enumerate(self.idx2char):
+ self.char2idx[char] = idx
+
+ def num_classes(self):
+ """Number of output classes."""
+ return len(self.idx2char)
+
+ def str2idx(self, strings):
+ """Convert strings to indexes.
+
+ Args:
+ strings (list[str]): ['hello', 'world'].
+ Returns:
+ indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
+ """
+ assert isinstance(strings, list)
+
+ indexes = []
+ for string in strings:
+ if self.lower:
+ string = string.lower()
+ index = []
+ for char in string:
+ char_idx = self.char2idx.get(char, self.unknown_idx)
+ if char_idx is None:
+ raise Exception(f'Chararcter: {char} not in dict,'
+ f' please check gt_label and use'
+ f' custom dict file,'
+ f' or set "with_unknown=True"')
+ index.append(char_idx)
+ indexes.append(index)
+
+ return indexes
+
+ def str2tensor(self, strings):
+ """Convert text-string to input tensor.
+
+ Args:
+ strings (list[str]): ['hello', 'world'].
+ Returns:
+ tensors (list[torch.Tensor]): [torch.Tensor([1,2,3,3,4]),
+ torch.Tensor([5,4,6,3,7])].
+ """
+ raise NotImplementedError
+
+ def idx2str(self, indexes):
+ """Convert indexes to text strings.
+
+ Args:
+ indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
+ Returns:
+ strings (list[str]): ['hello', 'world'].
+ """
+ assert isinstance(indexes, list)
+
+ strings = []
+ for index in indexes:
+ string = [self.idx2char[i] for i in index]
+ strings.append(''.join(string))
+
+ return strings
+
+ def tensor2idx(self, output):
+ """Convert model output tensor to character indexes and scores.
+ Args:
+ output (tensor): The model outputs with size: N * T * C
+ Returns:
+ indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
+ scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94],
+ [0.9,0.9,0.98,0.97,0.96]].
+ """
+ raise NotImplementedError
diff --git a/mmocr/models/textrecog/convertors/ctc.py b/mmocr/models/textrecog/convertors/ctc.py
new file mode 100644
index 00000000..c14fc23f
--- /dev/null
+++ b/mmocr/models/textrecog/convertors/ctc.py
@@ -0,0 +1,144 @@
+import math
+
+import torch
+import torch.nn.functional as F
+
+import mmocr.utils as utils
+from mmocr.models.builder import CONVERTORS
+from .base import BaseConvertor
+
+
+@CONVERTORS.register_module()
+class CTCConvertor(BaseConvertor):
+ """Convert between text, index and tensor for CTC loss-based pipeline.
+
+ Args:
+ dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'.
+ dict_file (None|str): Character dict file path. If not none, the file
+ is of higher priority than dict_type.
+ dict_list (None|list[str]): Character list. If not none, the list
+ is of higher priority than dict_type, but lower than dict_file.
+ with_unknown (bool): If True, add `UKN` token to class.
+ lower (bool): If True, convert original string to lower case.
+ """
+
+ def __init__(self,
+ dict_type='DICT90',
+ dict_file=None,
+ dict_list=None,
+ with_unknown=True,
+ lower=False,
+ **kwargs):
+ super().__init__(dict_type, dict_file, dict_list)
+ assert isinstance(with_unknown, bool)
+ assert isinstance(lower, bool)
+
+ self.with_unknown = with_unknown
+ self.lower = lower
+ self.update_dict()
+
+ def update_dict(self):
+ # CTC-blank
+ blank_token = ''
+ self.blank_idx = 0
+ self.idx2char.insert(0, blank_token)
+
+ # unknown
+ self.unknown_idx = None
+ if self.with_unknown:
+ self.idx2char.append('')
+ self.unknown_idx = len(self.idx2char) - 1
+
+ # update char2idx
+ self.char2idx = {}
+ for idx, char in enumerate(self.idx2char):
+ self.char2idx[char] = idx
+
+ def str2tensor(self, strings):
+ """Convert text-string to ctc-loss input tensor.
+
+ Args:
+ strings (list[str]): ['hello', 'world'].
+ Returns:
+ dict (str: tensor | list[tensor]):
+ tensors (list[tensor]): [torch.Tensor([1,2,3,3,4]),
+ torch.Tensor([5,4,6,3,7])].
+ flatten_targets (tensor): torch.Tensor([1,2,3,3,4,5,4,6,3,7]).
+ target_lengths (tensor): torch.IntTensot([5,5]).
+ """
+ assert utils.is_type_list(strings, str)
+
+ tensors = []
+ indexes = self.str2idx(strings)
+ for index in indexes:
+ tensor = torch.IntTensor(index)
+ tensors.append(tensor)
+ target_lengths = torch.IntTensor([len(t) for t in tensors])
+ flatten_target = torch.cat(tensors)
+
+ return {
+ 'targets': tensors,
+ 'flatten_targets': flatten_target,
+ 'target_lengths': target_lengths
+ }
+
+ def tensor2idx(self, output, img_metas, topk=1, return_topk=False):
+ """Convert model output tensor to index-list.
+ Args:
+ output (tensor): The model outputs with size: N * T * C.
+ img_metas (list[dict]): Each dict contains one image info.
+ topk (int): The highest k classes to be returned.
+ return_topk (bool): Whether to return topk or just top1.
+ Returns:
+ indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
+ scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94],
+ [0.9,0.9,0.98,0.97,0.96]]
+ (
+ indexes_topk (list[list[list[int]->len=topk]]):
+ scores_topk (list[list[list[float]->len=topk]])
+ ).
+ """
+ assert utils.is_type_list(img_metas, dict)
+ assert len(img_metas) == output.size(0)
+ assert isinstance(topk, int)
+ assert topk >= 1
+
+ valid_ratios = [
+ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
+ ]
+
+ batch_size = output.size(0)
+ output = F.softmax(output, dim=2)
+ output = output.cpu().detach()
+ batch_topk_value, batch_topk_idx = output.topk(topk, dim=2)
+ batch_max_idx = batch_topk_idx[:, :, 0]
+ scores_topk, indexes_topk = [], []
+ scores, indexes = [], []
+ feat_len = output.size(1)
+ for b in range(batch_size):
+ valid_ratio = valid_ratios[b]
+ decode_len = min(feat_len, math.ceil(feat_len * valid_ratio))
+ pred = batch_max_idx[b, :]
+ select_idx = []
+ prev_idx = self.blank_idx
+ for t in range(decode_len):
+ tmp_value = pred[t].item()
+ if tmp_value not in (prev_idx, self.blank_idx):
+ select_idx.append(t)
+ prev_idx = tmp_value
+ select_idx = torch.LongTensor(select_idx)
+ topk_value = torch.index_select(batch_topk_value[b, :, :], 0,
+ select_idx) # valid_seqlen * topk
+ topk_idx = torch.index_select(batch_topk_idx[b, :, :], 0,
+ select_idx)
+ topk_idx_list, topk_value_list = topk_idx.numpy().tolist(
+ ), topk_value.numpy().tolist()
+ indexes_topk.append(topk_idx_list)
+ scores_topk.append(topk_value_list)
+ indexes.append([x[0] for x in topk_idx_list])
+ scores.append([x[0] for x in topk_value_list])
+
+ if return_topk:
+ return indexes_topk, scores_topk
+
+ return indexes, scores
diff --git a/mmocr/models/textrecog/convertors/seg.py b/mmocr/models/textrecog/convertors/seg.py
new file mode 100644
index 00000000..0a626dac
--- /dev/null
+++ b/mmocr/models/textrecog/convertors/seg.py
@@ -0,0 +1,123 @@
+import cv2
+import numpy as np
+import torch
+
+import mmocr.utils as utils
+from mmocr.models.builder import CONVERTORS
+from .base import BaseConvertor
+
+
+@CONVERTORS.register_module()
+class SegConvertor(BaseConvertor):
+ """Convert between text, index and tensor for segmentation based pipeline.
+
+ Args:
+ dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'.
+ dict_file (None|str): Character dict file path. If not none, the
+ file is of higher priority than dict_type.
+ dict_list (None|list[str]): Character list. If not none, the list
+ is of higher priority than dict_type, but lower than dict_file.
+ with_unknown (bool): If True, add `UKN` token to class.
+ lower (bool): If True, convert original string to lower case.
+ """
+
+ def __init__(self,
+ dict_type='DICT36',
+ dict_file=None,
+ dict_list=None,
+ with_unknown=True,
+ lower=False,
+ **kwargs):
+ super().__init__(dict_type, dict_file, dict_list)
+ assert isinstance(with_unknown, bool)
+ assert isinstance(lower, bool)
+
+ self.with_unknown = with_unknown
+ self.lower = lower
+ self.update_dict()
+
+ def update_dict(self):
+ # background
+ self.idx2char.insert(0, '')
+
+ # unknown
+ self.unknown_idx = None
+ if self.with_unknown:
+ self.idx2char.append('')
+ self.unknown_idx = len(self.idx2char) - 1
+
+ # update char2idx
+ self.char2idx = {}
+ for idx, char in enumerate(self.idx2char):
+ self.char2idx[char] = idx
+
+ def tensor2str(self, output, img_metas=None):
+ """Convert model output tensor to string labels.
+ Args:
+ output (tensor): Model outputs with size: N * C * H * W
+ img_metas (list[dict]): Each dict contains one image info.
+ Returns:
+ texts (list[str]): Decoded text labels.
+ scores (list[list[float]]): Decoded chars scores.
+ """
+ assert utils.is_type_list(img_metas, dict)
+ assert len(img_metas) == output.size(0)
+
+ texts, scores = [], []
+ for b in range(output.size(0)):
+ seg_pred = output[b].detach()
+ seg_res = torch.argmax(
+ seg_pred, dim=0).cpu().numpy().astype(np.int32)
+
+ seg_thr = np.where(seg_res == 0, 0, 255).astype(np.uint8)
+ _, labels, stats, centroids = cv2.connectedComponentsWithStats(
+ seg_thr)
+
+ component_num = stats.shape[0]
+
+ all_res = []
+ for i in range(component_num):
+ temp_loc = (labels == i)
+ temp_value = seg_res[temp_loc]
+ temp_center = centroids[i]
+
+ temp_max_num = 0
+ temp_max_cls = -1
+ temp_total_num = 0
+ for c in range(len(self.idx2char)):
+ c_num = np.sum(temp_value == c)
+ temp_total_num += c_num
+ if c_num > temp_max_num:
+ temp_max_num = c_num
+ temp_max_cls = c
+
+ if temp_max_cls == 0:
+ continue
+ temp_max_score = 1.0 * temp_max_num / temp_total_num
+ all_res.append(
+ [temp_max_cls, temp_center, temp_max_num, temp_max_score])
+
+ all_res = sorted(all_res, key=lambda s: s[1][0])
+ chars, char_scores = [], []
+ for res in all_res:
+ temp_area = res[2]
+ if temp_area < 20:
+ continue
+ temp_char_index = res[0]
+ if temp_char_index >= len(self.idx2char):
+ temp_char = ''
+ elif temp_char_index <= 0:
+ temp_char = ''
+ elif temp_char_index == self.unknown_idx:
+ temp_char = ''
+ else:
+ temp_char = self.idx2char[temp_char_index]
+ chars.append(temp_char)
+ char_scores.append(res[3])
+
+ text = ''.join(chars)
+
+ texts.append(text)
+ scores.append(char_scores)
+
+ return texts, scores
diff --git a/mmocr/models/textrecog/decoders/__init__.py b/mmocr/models/textrecog/decoders/__init__.py
new file mode 100755
index 00000000..8b374733
--- /dev/null
+++ b/mmocr/models/textrecog/decoders/__init__.py
@@ -0,0 +1,15 @@
+from .base_decoder import BaseDecoder
+from .crnn_decoder import CRNNDecoder
+from .position_attention_decoder import PositionAttentionDecoder
+from .robust_scanner_decoder import RobustScannerDecoder
+from .sar_decoder import ParallelSARDecoder, SequentialSARDecoder
+from .sar_decoder_with_bs import ParallelSARDecoderWithBS
+from .sequence_attention_decoder import SequenceAttentionDecoder
+from .transformer_decoder import TFDecoder
+
+__all__ = [
+ 'CRNNDecoder', 'ParallelSARDecoder', 'SequentialSARDecoder',
+ 'ParallelSARDecoderWithBS', 'TFDecoder', 'BaseDecoder',
+ 'SequenceAttentionDecoder', 'PositionAttentionDecoder',
+ 'RobustScannerDecoder'
+]
diff --git a/mmocr/models/textrecog/decoders/base_decoder.py b/mmocr/models/textrecog/decoders/base_decoder.py
new file mode 100644
index 00000000..543cd528
--- /dev/null
+++ b/mmocr/models/textrecog/decoders/base_decoder.py
@@ -0,0 +1,32 @@
+import torch.nn as nn
+
+from mmocr.models.builder import DECODERS
+
+
+@DECODERS.register_module()
+class BaseDecoder(nn.Module):
+ """Base decoder class for text recognition."""
+
+ def __init__(self, **kwargs):
+ super().__init__()
+
+ def init_weights(self):
+ pass
+
+ def forward_train(self, feat, out_enc, targets_dict, img_metas):
+ raise NotImplementedError
+
+ def forward_test(self, feat, out_enc, img_metas):
+ raise NotImplementedError
+
+ def forward(self,
+ feat,
+ out_enc,
+ targets_dict=None,
+ img_metas=None,
+ train_mode=True):
+ self.train_mode = train_mode
+ if train_mode:
+ return self.forward_train(feat, out_enc, targets_dict, img_metas)
+
+ return self.forward_test(feat, out_enc, img_metas)
diff --git a/mmocr/models/textrecog/decoders/crnn_decoder.py b/mmocr/models/textrecog/decoders/crnn_decoder.py
new file mode 100644
index 00000000..1ce5226a
--- /dev/null
+++ b/mmocr/models/textrecog/decoders/crnn_decoder.py
@@ -0,0 +1,49 @@
+import torch.nn as nn
+from mmcv.cnn import xavier_init
+
+from mmocr.models.builder import DECODERS
+from mmocr.models.textrecog.layers import BidirectionalLSTM
+from .base_decoder import BaseDecoder
+
+
+@DECODERS.register_module()
+class CRNNDecoder(BaseDecoder):
+
+ def __init__(self,
+ in_channels=None,
+ num_classes=None,
+ rnn_flag=False,
+ **kwargs):
+ super().__init__()
+ self.num_classes = num_classes
+ self.rnn_flag = rnn_flag
+
+ if rnn_flag:
+ self.decoder = nn.Sequential(
+ BidirectionalLSTM(in_channels, 256, 256),
+ BidirectionalLSTM(256, 256, num_classes))
+ else:
+ self.decoder = nn.Conv2d(
+ in_channels, num_classes, kernel_size=1, stride=1)
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m)
+
+ def forward_train(self, feat, out_enc, targets_dict, img_metas):
+ assert feat.size(2) == 1, 'feature height must be 1'
+ if self.rnn_flag:
+ x = feat.squeeze(2) # [N, C, W]
+ x = x.permute(2, 0, 1) # [W, N, C]
+ x = self.decoder(x) # [W, N, C]
+ outputs = x.permute(1, 0, 2).contiguous()
+ else:
+ x = self.decoder(feat)
+ x = x.permute(0, 3, 1, 2).contiguous()
+ n, w, c, h = x.size()
+ outputs = x.view(n, w, c * h)
+ return outputs
+
+ def forward_test(self, feat, out_enc, img_metas):
+ return self.forward_train(feat, out_enc, None, img_metas)
diff --git a/mmocr/models/textrecog/decoders/sar_decoder.py b/mmocr/models/textrecog/decoders/sar_decoder.py
new file mode 100755
index 00000000..110fa8e2
--- /dev/null
+++ b/mmocr/models/textrecog/decoders/sar_decoder.py
@@ -0,0 +1,407 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import mmocr.utils as utils
+from mmocr.models.builder import DECODERS
+from .base_decoder import BaseDecoder
+
+
+@DECODERS.register_module()
+class ParallelSARDecoder(BaseDecoder):
+ """Implementation Parallel Decoder module in `SAR.
+
+ `_
+
+ Args:
+ number_classes (int): Output class number.
+ channels (list[int]): Network layer channels.
+ enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
+ dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
+ dec_do_rnn (float): Dropout of RNN layer in decoder.
+ dec_gru (bool): If True, use GRU, else LSTM in decoder.
+ d_model (int): Dim of channels from backbone.
+ d_enc (int): Dim of encoder RNN layer.
+ d_k (int): Dim of channels of attention module.
+ pred_dropout (float): Dropout probability of prediction layer.
+ max_seq_len (int): Maximum sequence length for decoding.
+ mask (bool): If True, mask padding in feature map.
+ start_idx (int): Index of start token.
+ padding_idx (int): Index of padding token.
+ pred_concat (bool): If True, concat glimpse feature from
+ attention with holistic feature and hidden state.
+ """
+
+ def __init__(self,
+ num_classes=37,
+ enc_bi_rnn=False,
+ dec_bi_rnn=False,
+ dec_do_rnn=0.0,
+ dec_gru=False,
+ d_model=512,
+ d_enc=512,
+ d_k=64,
+ pred_dropout=0.0,
+ max_seq_len=40,
+ mask=True,
+ start_idx=0,
+ padding_idx=92,
+ pred_concat=False,
+ **kwargs):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.enc_bi_rnn = enc_bi_rnn
+ self.d_k = d_k
+ self.start_idx = start_idx
+ self.max_seq_len = max_seq_len
+ self.mask = mask
+ self.pred_concat = pred_concat
+
+ encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
+ decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
+ # 2D attention layer
+ self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
+ self.conv3x3_1 = nn.Conv2d(
+ d_model, d_k, kernel_size=3, stride=1, padding=1)
+ self.conv1x1_2 = nn.Linear(d_k, 1)
+
+ # Decoder RNN layer
+ kwargs = dict(
+ input_size=encoder_rnn_out_size,
+ hidden_size=encoder_rnn_out_size,
+ num_layers=2,
+ batch_first=True,
+ dropout=dec_do_rnn,
+ bidirectional=dec_bi_rnn)
+ if dec_gru:
+ self.rnn_decoder = nn.GRU(**kwargs)
+ else:
+ self.rnn_decoder = nn.LSTM(**kwargs)
+
+ # Decoder input embedding
+ self.embedding = nn.Embedding(
+ self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx)
+
+ # Prediction layer
+ self.pred_dropout = nn.Dropout(pred_dropout)
+ pred_num_classes = num_classes - 1 # ignore padding_idx in prediction
+ if pred_concat:
+ fc_in_channel = decoder_rnn_out_size + d_model + d_enc
+ else:
+ fc_in_channel = d_model
+ self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
+
+ def _2d_attention(self,
+ decoder_input,
+ feat,
+ holistic_feat,
+ valid_ratios=None):
+ y = self.rnn_decoder(decoder_input)[0]
+ # y: bsz * (seq_len + 1) * hidden_size
+
+ attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
+ bsz, seq_len, attn_size = attn_query.size()
+ attn_query = attn_query.view(bsz, seq_len, attn_size, 1, 1)
+
+ attn_key = self.conv3x3_1(feat)
+ # bsz * attn_size * h * w
+ attn_key = attn_key.unsqueeze(1)
+ # bsz * 1 * attn_size * h * w
+
+ attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1))
+ # bsz * (seq_len + 1) * attn_size * h * w
+ attn_weight = attn_weight.permute(0, 1, 3, 4, 2).contiguous()
+ # bsz * (seq_len + 1) * h * w * attn_size
+ attn_weight = self.conv1x1_2(attn_weight)
+ # bsz * (seq_len + 1) * h * w * 1
+ bsz, T, h, w, c = attn_weight.size()
+ assert c == 1
+
+ if valid_ratios is not None:
+ # cal mask of attention weight
+ attn_mask = torch.zeros_like(attn_weight)
+ for i, valid_ratio in enumerate(valid_ratios):
+ valid_width = min(w, math.ceil(w * valid_ratio))
+ attn_mask[i, :, :, valid_width:, :] = 1
+ attn_weight = attn_weight.masked_fill(attn_mask.bool(),
+ float('-inf'))
+
+ attn_weight = attn_weight.view(bsz, T, -1)
+ attn_weight = F.softmax(attn_weight, dim=-1)
+ attn_weight = attn_weight.view(bsz, T, h, w,
+ c).permute(0, 1, 4, 2, 3).contiguous()
+
+ attn_feat = torch.sum(
+ torch.mul(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False)
+ # bsz * (seq_len + 1) * C
+
+ # linear transformation
+ if self.pred_concat:
+ hf_c = holistic_feat.size(-1)
+ holistic_feat = holistic_feat.expand(bsz, seq_len, hf_c)
+ y = self.prediction(torch.cat((y, attn_feat, holistic_feat), 2))
+ else:
+ y = self.prediction(attn_feat)
+ # bsz * (seq_len + 1) * num_classes
+ if self.train_mode:
+ y = self.pred_dropout(y)
+
+ return y
+
+ def forward_train(self, feat, out_enc, targets_dict, img_metas):
+ assert utils.is_type_list(img_metas, dict)
+ assert len(img_metas) == feat.size(0)
+
+ valid_ratios = [
+ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
+ ] if self.mask else None
+
+ targets = targets_dict['padded_targets'].to(feat.device)
+ tgt_embedding = self.embedding(targets)
+ # bsz * seq_len * emb_dim
+ out_enc = out_enc.unsqueeze(1)
+ # bsz * 1 * emb_dim
+ in_dec = torch.cat((out_enc, tgt_embedding), dim=1)
+ # bsz * (seq_len + 1) * C
+ out_dec = self._2d_attention(
+ in_dec, feat, out_enc, valid_ratios=valid_ratios)
+ # bsz * (seq_len + 1) * num_classes
+
+ return out_dec[:, 1:, :] # bsz * seq_len * num_classes
+
+ def forward_test(self, feat, out_enc, img_metas):
+ assert utils.is_type_list(img_metas, dict)
+ assert len(img_metas) == feat.size(0)
+
+ valid_ratios = [
+ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
+ ] if self.mask else None
+
+ seq_len = self.max_seq_len
+
+ bsz = feat.size(0)
+ start_token = torch.full((bsz, ),
+ self.start_idx,
+ device=feat.device,
+ dtype=torch.long)
+ # bsz
+ start_token = self.embedding(start_token)
+ # bsz * emb_dim
+ start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1)
+ # bsz * seq_len * emb_dim
+ out_enc = out_enc.unsqueeze(1)
+ # bsz * 1 * emb_dim
+ decoder_input = torch.cat((out_enc, start_token), dim=1)
+ # bsz * (seq_len + 1) * emb_dim
+
+ outputs = []
+ for i in range(1, seq_len + 1):
+ decoder_output = self._2d_attention(
+ decoder_input, feat, out_enc, valid_ratios=valid_ratios)
+ char_output = decoder_output[:, i, :] # bsz * num_classes
+ char_output = F.softmax(char_output, -1)
+ outputs.append(char_output)
+ _, max_idx = torch.max(char_output, dim=1, keepdim=False)
+ char_embedding = self.embedding(max_idx) # bsz * emb_dim
+ if i < seq_len:
+ decoder_input[:, i + 1, :] = char_embedding
+
+ outputs = torch.stack(outputs, 1) # bsz * seq_len * num_classes
+
+ return outputs
+
+
+@DECODERS.register_module()
+class SequentialSARDecoder(BaseDecoder):
+ """Implementation Sequential Decoder module in `SAR.
+
+ `_.
+
+ Args:
+ number_classes (int): Number of output class.
+ enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
+ dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
+ dec_do_rnn (float): Dropout of RNN layer in decoder.
+ dec_gru (bool): If True, use GRU, else LSTM in decoder.
+ d_k (int): Dim of conv layers in attention module.
+ d_model (int): Dim of channels from backbone.
+ d_enc (int): Dim of encoder RNN layer.
+ pred_dropout (float): Dropout probability of prediction layer.
+ max_seq_len (int): Maximum sequence length during decoding.
+ mask (bool): If True, mask padding in feature map.
+ start_idx (int): Index of start token.
+ padding_idx (int): Index of padding token.
+ pred_concat (bool): If True, concat glimpse feature from
+ attention with holistic feature and hidden state.
+ """
+
+ def __init__(self,
+ num_classes=37,
+ enc_bi_rnn=False,
+ dec_bi_rnn=False,
+ dec_gru=False,
+ d_k=64,
+ d_model=512,
+ d_enc=512,
+ pred_dropout=0.0,
+ mask=True,
+ max_seq_len=40,
+ start_idx=0,
+ padding_idx=92,
+ pred_concat=False,
+ **kwargs):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.enc_bi_rnn = enc_bi_rnn
+ self.d_k = d_k
+ self.start_idx = start_idx
+ self.dec_gru = dec_gru
+ self.max_seq_len = max_seq_len
+ self.mask = mask
+ self.pred_concat = pred_concat
+
+ encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
+ decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
+ # 2D attention layer
+ self.conv1x1_1 = nn.Conv2d(
+ decoder_rnn_out_size, d_k, kernel_size=1, stride=1)
+ self.conv3x3_1 = nn.Conv2d(
+ d_model, d_k, kernel_size=3, stride=1, padding=1)
+ self.conv1x1_2 = nn.Conv2d(d_k, 1, kernel_size=1, stride=1)
+
+ # Decoder rnn layer
+ if dec_gru:
+ self.rnn_decoder_layer1 = nn.GRUCell(encoder_rnn_out_size,
+ encoder_rnn_out_size)
+ self.rnn_decoder_layer2 = nn.GRUCell(encoder_rnn_out_size,
+ encoder_rnn_out_size)
+ else:
+ self.rnn_decoder_layer1 = nn.LSTMCell(encoder_rnn_out_size,
+ encoder_rnn_out_size)
+ self.rnn_decoder_layer2 = nn.LSTMCell(encoder_rnn_out_size,
+ encoder_rnn_out_size)
+
+ # Decoder input embedding
+ self.embedding = nn.Embedding(
+ self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx)
+
+ # Prediction layer
+ self.pred_dropout = nn.Dropout(pred_dropout)
+ pred_num_class = num_classes - 1 # ignore padding index
+ if pred_concat:
+ fc_in_channel = decoder_rnn_out_size + d_model + d_enc
+ else:
+ fc_in_channel = d_model
+ self.prediction = nn.Linear(fc_in_channel, pred_num_class)
+
+ def _2d_attention(self,
+ y_prev,
+ feat,
+ holistic_feat,
+ hx1,
+ cx1,
+ hx2,
+ cx2,
+ valid_ratios=None):
+ _, _, h_feat, w_feat = feat.size()
+ if self.dec_gru:
+ hx1 = cx1 = self.rnn_decoder_layer1(y_prev, hx1)
+ hx2 = cx2 = self.rnn_decoder_layer2(hx1, hx2)
+ else:
+ hx1, cx1 = self.rnn_decoder_layer1(y_prev, (hx1, cx1))
+ hx2, cx2 = self.rnn_decoder_layer2(hx1, (hx2, cx2))
+
+ tile_hx2 = hx2.view(hx2.size(0), hx2.size(1), 1, 1)
+ attn_query = self.conv1x1_1(tile_hx2) # bsz * attn_size * 1 * 1
+ attn_query = attn_query.expand(-1, -1, h_feat, w_feat)
+ attn_key = self.conv3x3_1(feat)
+ attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1))
+ attn_weight = self.conv1x1_2(attn_weight)
+ bsz, c, h, w = attn_weight.size()
+ assert c == 1
+
+ if valid_ratios is not None:
+ # cal mask of attention weight
+ attn_mask = torch.zeros_like(attn_weight)
+ for i, valid_ratio in enumerate(valid_ratios):
+ valid_width = min(w, math.ceil(w * valid_ratio))
+ attn_mask[i, :, :, valid_width:] = 1
+ attn_weight = attn_weight.masked_fill(attn_mask.bool(),
+ float('-inf'))
+
+ attn_weight = F.softmax(attn_weight.view(bsz, -1), dim=-1)
+ attn_weight = attn_weight.view(bsz, c, h, w)
+
+ attn_feat = torch.sum(
+ torch.mul(feat, attn_weight), (2, 3), keepdim=False) # n * c
+
+ # linear transformation
+ if self.pred_concat:
+ y = self.prediction(torch.cat((hx2, attn_feat, holistic_feat), 1))
+ else:
+ y = self.prediction(attn_feat)
+
+ return y, hx1, hx1, hx2, hx2
+
+ def forward_train(self, feat, out_enc, targets_dict, img_metas=None):
+ assert utils.is_type_list(img_metas, dict)
+ assert len(img_metas) == feat.size(0)
+
+ valid_ratios = [
+ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
+ ] if self.mask else None
+
+ if self.train_mode:
+ targets = targets_dict['padded_targets'].to(feat.device)
+ tgt_embedding = self.embedding(targets)
+
+ outputs = []
+ start_token = torch.full((feat.size(0), ),
+ self.start_idx,
+ device=feat.device,
+ dtype=torch.long)
+ start_token = self.embedding(start_token)
+ for i in range(-1, self.max_seq_len):
+ if i == -1:
+ if self.dec_gru:
+ hx1 = cx1 = self.rnn_decoder_layer1(out_enc)
+ hx2 = cx2 = self.rnn_decoder_layer2(hx1)
+ else:
+ hx1, cx1 = self.rnn_decoder_layer1(out_enc)
+ hx2, cx2 = self.rnn_decoder_layer2(hx1)
+ if not self.train_mode:
+ y_prev = start_token
+ else:
+ if self.train_mode:
+ y_prev = tgt_embedding[:, i, :]
+ y, hx1, cx1, hx2, cx2 = self._2d_attention(
+ y_prev,
+ feat,
+ out_enc,
+ hx1,
+ cx1,
+ hx2,
+ cx2,
+ valid_ratios=valid_ratios)
+ if self.train_mode:
+ y = self.pred_dropout(y)
+ else:
+ y = F.softmax(y, -1)
+ _, max_idx = torch.max(y, dim=1, keepdim=False)
+ char_embedding = self.embedding(max_idx)
+ y_prev = char_embedding
+ outputs.append(y)
+
+ outputs = torch.stack(outputs, 1)
+
+ return outputs
+
+ def forward_test(self, feat, out_enc, img_metas):
+ assert utils.is_type_list(img_metas, dict)
+ assert len(img_metas) == feat.size(0)
+
+ return self.forward_train(feat, out_enc, None, img_metas)
diff --git a/mmocr/models/textrecog/decoders/sar_decoder_with_bs.py b/mmocr/models/textrecog/decoders/sar_decoder_with_bs.py
new file mode 100755
index 00000000..98094dd4
--- /dev/null
+++ b/mmocr/models/textrecog/decoders/sar_decoder_with_bs.py
@@ -0,0 +1,148 @@
+from queue import PriorityQueue
+
+import torch
+import torch.nn.functional as F
+
+import mmocr.utils as utils
+from mmocr.models.builder import DECODERS
+from . import ParallelSARDecoder
+
+
+class DecodeNode:
+ """Node class to save decoded char indices and scores.
+
+ Args:
+ indexes (list[int]): Char indices that decoded yes.
+ scores (list[float]): Char scores that decoded yes.
+ """
+
+ def __init__(self, indexes=[1], scores=[0.9]):
+ assert utils.is_type_list(indexes, int)
+ assert utils.is_type_list(scores, float)
+ assert utils.equal_len(indexes, scores)
+
+ self.indexes = indexes
+ self.scores = scores
+
+ def eval(self):
+ """Calculate accumulated score."""
+ accu_score = sum(self.scores)
+ return accu_score
+
+
+@DECODERS.register_module()
+class ParallelSARDecoderWithBS(ParallelSARDecoder):
+ """Parallel Decoder module with beam-search in SAR.
+
+ Args:
+ beam_width (int): Width for beam search.
+ """
+
+ def __init__(self,
+ beam_width=5,
+ num_classes=37,
+ enc_bi_rnn=False,
+ dec_bi_rnn=False,
+ dec_do_rnn=0,
+ dec_gru=False,
+ d_model=512,
+ d_enc=512,
+ d_k=64,
+ pred_dropout=0.0,
+ max_seq_len=40,
+ mask=True,
+ start_idx=0,
+ padding_idx=0,
+ pred_concat=False,
+ **kwargs):
+ super().__init__(num_classes, enc_bi_rnn, dec_bi_rnn, dec_do_rnn,
+ dec_gru, d_model, d_enc, d_k, pred_dropout,
+ max_seq_len, mask, start_idx, padding_idx,
+ pred_concat)
+ assert isinstance(beam_width, int)
+ assert beam_width > 0
+
+ self.beam_width = beam_width
+
+ def forward_test(self, feat, out_enc, img_metas):
+ assert utils.is_type_list(img_metas, dict)
+ assert len(img_metas) == feat.size(0)
+
+ valid_ratios = [
+ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
+ ] if self.mask else None
+
+ seq_len = self.max_seq_len
+ bsz = feat.size(0)
+ assert bsz == 1, 'batch size must be 1 for beam search.'
+
+ start_token = torch.full((bsz, ),
+ self.start_idx,
+ device=feat.device,
+ dtype=torch.long)
+ # bsz
+ start_token = self.embedding(start_token)
+ # bsz * emb_dim
+ start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1)
+ # bsz * seq_len * emb_dim
+ out_enc = out_enc.unsqueeze(1)
+ # bsz * 1 * emb_dim
+ decoder_input = torch.cat((out_enc, start_token), dim=1)
+ # bsz * (seq_len + 1) * emb_dim
+
+ # Initialize beam-search queue
+ q = PriorityQueue()
+ init_node = DecodeNode([self.start_idx], [0.0])
+ q.put((-init_node.eval(), init_node))
+
+ for i in range(1, seq_len + 1):
+ next_nodes = []
+ beam_width = self.beam_width if i > 1 else 1
+ for _ in range(beam_width):
+ _, node = q.get()
+
+ input_seq = torch.clone(decoder_input) # bsz * T * emb_dim
+ # fill previous input tokens (step 1...i) in input_seq
+ for t, index in enumerate(node.indexes):
+ input_token = torch.full((bsz, ),
+ index,
+ device=input_seq.device,
+ dtype=torch.long)
+ input_token = self.embedding(input_token) # bsz * emb_dim
+ input_seq[:, t + 1, :] = input_token
+
+ output_seq = self._2d_attention(
+ input_seq, feat, out_enc, valid_ratios=valid_ratios)
+
+ output_char = output_seq[:, i, :] # bsz * num_classes
+ output_char = F.softmax(output_char, -1)
+ topk_value, topk_idx = output_char.topk(self.beam_width, dim=1)
+ topk_value, topk_idx = topk_value.squeeze(0), topk_idx.squeeze(
+ 0)
+
+ for k in range(self.beam_width):
+ kth_score = topk_value[k].item()
+ kth_idx = topk_idx[k].item()
+ next_node = DecodeNode(node.indexes + [kth_idx],
+ node.scores + [kth_score])
+ delta = k * 1e-6
+ next_nodes.append(
+ (-node.eval() - kth_score - delta, next_node))
+ # Use minus since priority queue sort
+ # with ascending order
+
+ while not q.empty():
+ q.get()
+
+ # Put all candidates to queue
+ for next_node in next_nodes:
+ q.put(next_node)
+
+ best_node = q.get()
+ num_classes = self.num_classes - 1 # ignore padding index
+ outputs = torch.zeros(bsz, seq_len, num_classes)
+ for i in range(seq_len):
+ idx = best_node[1].indexes[i + 1]
+ outputs[0, i, idx] = best_node[1].scores[i + 1]
+
+ return outputs
diff --git a/mmocr/models/textrecog/decoders/transformer_decoder.py b/mmocr/models/textrecog/decoders/transformer_decoder.py
new file mode 100644
index 00000000..c8ef2fad
--- /dev/null
+++ b/mmocr/models/textrecog/decoders/transformer_decoder.py
@@ -0,0 +1,99 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mmocr.models.builder import DECODERS
+from mmocr.models.textrecog.layers import (DecoderLayer, PositionalEncoding,
+ get_pad_mask, get_subsequent_mask)
+from .base_decoder import BaseDecoder
+
+
+@DECODERS.register_module()
+class TFDecoder(BaseDecoder):
+ """Transformer Decoder block with self attention mechanism."""
+
+ def __init__(self,
+ n_layers=6,
+ d_embedding=512,
+ n_head=8,
+ d_k=64,
+ d_v=64,
+ d_model=512,
+ d_inner=256,
+ n_position=200,
+ dropout=0.1,
+ num_classes=93,
+ max_seq_len=40,
+ start_idx=1,
+ padding_idx=92,
+ **kwargs):
+ super().__init__()
+
+ self.padding_idx = padding_idx
+ self.start_idx = start_idx
+ self.max_seq_len = max_seq_len
+
+ self.trg_word_emb = nn.Embedding(
+ num_classes, d_embedding, padding_idx=padding_idx)
+
+ self.position_enc = PositionalEncoding(
+ d_embedding, n_position=n_position)
+ self.dropout = nn.Dropout(p=dropout)
+
+ self.layer_stack = nn.ModuleList([
+ DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
+ for _ in range(n_layers)
+ ])
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
+
+ pred_num_class = num_classes - 1 # ignore padding_idx
+ self.classifier = nn.Linear(d_model, pred_num_class)
+
+ def _attention(self, trg_seq, src, src_mask=None):
+ trg_embedding = self.trg_word_emb(trg_seq)
+ trg_pos_encoded = self.position_enc(trg_embedding)
+ tgt = self.dropout(trg_pos_encoded)
+
+ trg_mask = get_pad_mask(
+ trg_seq, pad_idx=self.padding_idx) & get_subsequent_mask(trg_seq)
+ output = tgt
+ for dec_layer in self.layer_stack:
+ output = dec_layer(
+ output,
+ src,
+ slf_attn_mask=trg_mask,
+ dec_enc_attn_mask=src_mask)
+ output = self.layer_norm(output)
+
+ return output
+
+ def forward_train(self, feat, out_enc, targets_dict, img_metas):
+ targets = targets_dict['padded_targets'].to(out_enc.device)
+ attn_output = self._attention(targets, out_enc, src_mask=None)
+ outputs = self.classifier(attn_output)
+ return outputs
+
+ def forward_test(self, feat, out_enc, img_metas):
+ bsz = out_enc.size(0)
+ init_target_seq = torch.full((bsz, self.max_seq_len + 1),
+ self.padding_idx,
+ device=out_enc.device,
+ dtype=torch.long)
+ # bsz * seq_len
+ init_target_seq[:, 0] = self.start_idx
+
+ outputs = []
+ for step in range(0, self.max_seq_len):
+ decoder_output = self._attention(
+ init_target_seq, out_enc, src_mask=None)
+ # bsz * seq_len * 512
+ step_result = F.softmax(
+ self.classifier(decoder_output[:, step, :]), dim=-1)
+ # bsz * num_classes
+ outputs.append(step_result)
+ _, step_max_index = torch.max(step_result, dim=-1)
+ init_target_seq[:, step + 1] = step_max_index
+
+ outputs = torch.stack(outputs, dim=1)
+
+ return outputs
diff --git a/mmocr/models/textrecog/encoders/__init__.py b/mmocr/models/textrecog/encoders/__init__.py
new file mode 100755
index 00000000..e0d9394a
--- /dev/null
+++ b/mmocr/models/textrecog/encoders/__init__.py
@@ -0,0 +1,6 @@
+from .base_encoder import BaseEncoder
+from .channel_reduction_encoder import ChannelReductionEncoder
+from .sar_encoder import SAREncoder
+from .transformer_encoder import TFEncoder
+
+__all__ = ['SAREncoder', 'TFEncoder', 'BaseEncoder', 'ChannelReductionEncoder']
diff --git a/mmocr/models/textrecog/encoders/base_encoder.py b/mmocr/models/textrecog/encoders/base_encoder.py
new file mode 100644
index 00000000..3dadc687
--- /dev/null
+++ b/mmocr/models/textrecog/encoders/base_encoder.py
@@ -0,0 +1,14 @@
+import torch.nn as nn
+
+from mmocr.models.builder import ENCODERS
+
+
+@ENCODERS.register_module()
+class BaseEncoder(nn.Module):
+ """Base Encoder class for text recognition."""
+
+ def init_weights(self):
+ pass
+
+ def forward(self, feat, **kwargs):
+ return feat
diff --git a/mmocr/models/textrecog/encoders/sar_encoder.py b/mmocr/models/textrecog/encoders/sar_encoder.py
new file mode 100644
index 00000000..d0381d6c
--- /dev/null
+++ b/mmocr/models/textrecog/encoders/sar_encoder.py
@@ -0,0 +1,102 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import uniform_init, xavier_init
+
+import mmocr.utils as utils
+from mmocr.models.builder import ENCODERS
+from .base_encoder import BaseEncoder
+
+
+@ENCODERS.register_module()
+class SAREncoder(BaseEncoder):
+ """Implementation of encoder module in `SAR.
+
+ `_
+
+ Args:
+ enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
+ enc_do_rnn (float): Dropout probability of RNN layer in encoder.
+ enc_gru (bool): If True, use GRU, else LSTM in encoder.
+ d_model (int): Dim of channels from backbone.
+ d_enc (int): Dim of encoder RNN layer.
+ mask (bool): If True, mask padding in RNN sequence.
+ """
+
+ def __init__(self,
+ enc_bi_rnn=False,
+ enc_do_rnn=0.0,
+ enc_gru=False,
+ d_model=512,
+ d_enc=512,
+ mask=True,
+ **kwargs):
+ super().__init__()
+ assert isinstance(enc_bi_rnn, bool)
+ assert isinstance(enc_do_rnn, (int, float))
+ assert 0 <= enc_do_rnn < 1.0
+ assert isinstance(enc_gru, bool)
+ assert isinstance(d_model, int)
+ assert isinstance(d_enc, int)
+ assert isinstance(mask, bool)
+
+ self.enc_bi_rnn = enc_bi_rnn
+ self.enc_do_rnn = enc_do_rnn
+ self.mask = mask
+
+ # LSTM Encoder
+ kwargs = dict(
+ input_size=d_model,
+ hidden_size=d_enc,
+ num_layers=2,
+ batch_first=True,
+ dropout=enc_do_rnn,
+ bidirectional=enc_bi_rnn)
+ if enc_gru:
+ self.rnn_encoder = nn.GRU(**kwargs)
+ else:
+ self.rnn_encoder = nn.LSTM(**kwargs)
+
+ # global feature transformation
+ encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
+ self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
+
+ def init_weights(self):
+ # initialize weight and bias
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ uniform_init(m)
+
+ def forward(self, feat, img_metas=None):
+ assert utils.is_type_list(img_metas, dict)
+ assert len(img_metas) == feat.size(0)
+
+ valid_ratios = [
+ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
+ ] if self.mask else None
+
+ h_feat = feat.size(2)
+ feat_v = F.max_pool2d(
+ feat, kernel_size=(h_feat, 1), stride=1, padding=0)
+ feat_v = feat_v.squeeze(2) # bsz * C * W
+ feat_v = feat_v.permute(0, 2, 1).contiguous() # bsz * W * C
+
+ holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
+
+ if valid_ratios is not None:
+ valid_hf = []
+ T = holistic_feat.size(1)
+ for i, valid_ratio in enumerate(valid_ratios):
+ valid_step = min(T, math.ceil(T * valid_ratio)) - 1
+ valid_hf.append(holistic_feat[i, valid_step, :])
+ valid_hf = torch.stack(valid_hf, dim=0)
+ else:
+ valid_hf = holistic_feat[:, -1, :] # bsz * C
+
+ holistic_feat = self.linear(valid_hf) # bsz * C
+
+ return holistic_feat
diff --git a/mmocr/models/textrecog/encoders/transformer_encoder.py b/mmocr/models/textrecog/encoders/transformer_encoder.py
new file mode 100644
index 00000000..d33d17bf
--- /dev/null
+++ b/mmocr/models/textrecog/encoders/transformer_encoder.py
@@ -0,0 +1,16 @@
+from mmocr.models.builder import ENCODERS
+from .base_encoder import BaseEncoder
+
+
+@ENCODERS.register_module()
+class TFEncoder(BaseEncoder):
+ """Encode 2d feature map to 1d sequence."""
+
+ def __init__(self, **kwargs):
+ super().__init__()
+
+ def forward(self, feat, img_metas=None):
+ n, c, _, _ = feat.size()
+ enc_output = feat.view(n, c, -1).transpose(2, 1).contiguous()
+
+ return enc_output
diff --git a/mmocr/models/textrecog/heads/__init__.py b/mmocr/models/textrecog/heads/__init__.py
new file mode 100755
index 00000000..761bb9a9
--- /dev/null
+++ b/mmocr/models/textrecog/heads/__init__.py
@@ -0,0 +1,3 @@
+from .seg_head import SegHead
+
+__all__ = ['SegHead']
diff --git a/mmocr/models/textrecog/heads/seg_head.py b/mmocr/models/textrecog/heads/seg_head.py
new file mode 100644
index 00000000..a0ca59cc
--- /dev/null
+++ b/mmocr/models/textrecog/heads/seg_head.py
@@ -0,0 +1,50 @@
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from torch import nn
+
+from mmdet.models.builder import HEADS
+
+
+@HEADS.register_module()
+class SegHead(nn.Module):
+ """Head for segmentation based text recognition.
+
+ Args:
+ in_channels (int): Number of input channels.
+ num_classes (int): Number of output classes.
+ upsample_param (dict | None): Config dict for interpolation layer.
+ Default: `dict(scale_factor=1.0, mode='nearest')`
+ """
+
+ def __init__(self, in_channels=128, num_classes=37, upsample_param=None):
+ super().__init__()
+ assert isinstance(num_classes, int)
+ assert num_classes > 0
+ assert upsample_param is None or isinstance(upsample_param, dict)
+
+ self.upsample_param = upsample_param
+
+ self.seg_conv = ConvModule(
+ in_channels,
+ in_channels,
+ 3,
+ stride=1,
+ padding=1,
+ norm_cfg=dict(type='BN'))
+
+ # prediction
+ self.pred_conv = nn.Conv2d(
+ in_channels, num_classes, kernel_size=1, stride=1, padding=0)
+
+ def init_weights(self):
+ pass
+
+ def forward(self, out_neck):
+
+ seg_map = self.seg_conv(out_neck[-1])
+ seg_map = self.pred_conv(seg_map)
+
+ if self.upsample_param is not None:
+ seg_map = F.interpolate(seg_map, **self.upsample_param)
+
+ return seg_map
diff --git a/mmocr/models/textrecog/layers/__init__.py b/mmocr/models/textrecog/layers/__init__.py
new file mode 100755
index 00000000..a38a8c5d
--- /dev/null
+++ b/mmocr/models/textrecog/layers/__init__.py
@@ -0,0 +1,15 @@
+from .conv_layer import BasicBlock, Bottleneck
+from .dot_product_attention_layer import DotProductAttentionLayer
+from .lstm_layer import BidirectionalLSTM
+from .position_aware_layer import PositionAwareLayer
+from .robust_scanner_fusion_layer import RobustScannerFusionLayer
+from .transformer_layer import (DecoderLayer, MultiHeadAttention,
+ PositionalEncoding, PositionwiseFeedForward,
+ get_pad_mask, get_subsequent_mask)
+
+__all__ = [
+ 'BidirectionalLSTM', 'MultiHeadAttention', 'PositionalEncoding',
+ 'PositionwiseFeedForward', 'BasicBlock', 'Bottleneck',
+ 'RobustScannerFusionLayer', 'DotProductAttentionLayer',
+ 'PositionAwareLayer', 'DecoderLayer', 'get_pad_mask', 'get_subsequent_mask'
+]
diff --git a/mmocr/models/textrecog/layers/conv_layer.py b/mmocr/models/textrecog/layers/conv_layer.py
new file mode 100644
index 00000000..d0ce32a3
--- /dev/null
+++ b/mmocr/models/textrecog/layers/conv_layer.py
@@ -0,0 +1,93 @@
+import torch.nn as nn
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=False):
+ super().__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ if downsample:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(
+ inplanes, planes * self.expansion, 1, stride, bias=False),
+ nn.BatchNorm2d(planes * self.expansion),
+ )
+ else:
+ self.downsample = nn.Sequential()
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=False):
+ super().__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(
+ planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ if downsample:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(
+ inplanes, planes * self.expansion, 1, stride, bias=False),
+ nn.BatchNorm2d(planes * self.expansion),
+ )
+ else:
+ self.downsample = nn.Sequential()
+
+ def forward(self, x):
+ residual = self.downsample(x)
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
diff --git a/mmocr/models/textrecog/layers/lstm_layer.py b/mmocr/models/textrecog/layers/lstm_layer.py
new file mode 100644
index 00000000..e4017d02
--- /dev/null
+++ b/mmocr/models/textrecog/layers/lstm_layer.py
@@ -0,0 +1,20 @@
+import torch.nn as nn
+
+
+class BidirectionalLSTM(nn.Module):
+
+ def __init__(self, nIn, nHidden, nOut):
+ super().__init__()
+
+ self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
+ self.embedding = nn.Linear(nHidden * 2, nOut)
+
+ def forward(self, input):
+ recurrent, _ = self.rnn(input)
+ T, b, h = recurrent.size()
+ t_rec = recurrent.view(T * b, h)
+
+ output = self.embedding(t_rec) # [T * b, nOut]
+ output = output.view(T, b, -1)
+
+ return output
diff --git a/mmocr/models/textrecog/layers/transformer_layer.py b/mmocr/models/textrecog/layers/transformer_layer.py
new file mode 100644
index 00000000..a9bba4b3
--- /dev/null
+++ b/mmocr/models/textrecog/layers/transformer_layer.py
@@ -0,0 +1,172 @@
+"""This code is from https://github.com/jadore801120/attention-is-all-you-need-
+pytorch."""
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class DecoderLayer(nn.Module):
+ """Compose with three layers:
+ 1. MultiHeadSelfAttn
+ 2. MultiHeadEncoderDecoderAttn
+ 3. PositionwiseFeedForward
+ """
+
+ def __init__(self,
+ d_model=512,
+ d_inner=256,
+ n_head=8,
+ d_k=64,
+ d_v=64,
+ dropout=0.1):
+ super().__init__()
+ self.slf_attn = MultiHeadAttention(
+ n_head, d_model, d_k, d_v, dropout=dropout)
+ self.enc_attn = MultiHeadAttention(
+ n_head, d_model, d_k, d_v, dropout=dropout)
+ self.pos_ffn = PositionwiseFeedForward(
+ d_model, d_inner, dropout=dropout)
+
+ def forward(self,
+ dec_input,
+ enc_output,
+ slf_attn_mask=None,
+ dec_enc_attn_mask=None):
+
+ dec_output = self.slf_attn(
+ dec_input, dec_input, dec_input, mask=slf_attn_mask)
+ dec_output = self.enc_attn(
+ dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
+ dec_output = self.pos_ffn(dec_output)
+
+ return dec_output
+
+
+class ScaledDotProductAttention(nn.Module):
+ """Scaled Dot-Product Attention."""
+
+ def __init__(self, temperature, attn_dropout=0.1):
+ super().__init__()
+ self.temperature = temperature
+ self.dropout = nn.Dropout(attn_dropout)
+
+ def forward(self, q, k, v, mask=None):
+
+ attn = torch.matmul(q / self.temperature, k.transpose(1, 2))
+
+ if mask is not None:
+ attn = attn.masked_fill(mask == 0, float('-inf'))
+
+ attn = self.dropout(F.softmax(attn, dim=-1))
+ output = torch.matmul(attn, v)
+
+ return output, attn
+
+
+class MultiHeadAttention(nn.Module):
+ """Multi-Head Attention module."""
+
+ def __init__(self, n_head=8, d_model=512, d_k=64, d_v=64, dropout=0.1):
+ super().__init__()
+
+ self.n_head = n_head
+ self.d_k = d_k
+ self.d_v = d_v
+
+ self.w_qs_list = nn.ModuleList(
+ [nn.Linear(d_model, d_k, bias=False) for _ in range(n_head)])
+ self.w_ks_list = nn.ModuleList(
+ [nn.Linear(d_model, d_k, bias=False) for _ in range(n_head)])
+ self.w_vs_list = nn.ModuleList(
+ [nn.Linear(d_model, d_v, bias=False) for _ in range(n_head)])
+ self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
+
+ self.attention = ScaledDotProductAttention(temperature=d_k**0.5)
+
+ self.dropout = nn.Dropout(dropout)
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
+
+ def forward(self, q, k, v, mask=None):
+
+ residual = q
+ q = self.layer_norm(q)
+
+ attention_q_list = []
+ for head_index in range(self.n_head):
+ q_each = self.w_qs_list[head_index](q) # bsz * seq_len * d_k
+ k_each = self.w_ks_list[head_index](k) # bsz * seq_len * d_k
+ v_each = self.w_vs_list[head_index](v) # bsz * seq_len * d_v
+ attention_q_each, _ = self.attention(
+ q_each, k_each, v_each, mask=mask)
+ attention_q_list.append(attention_q_each)
+
+ q = torch.cat(attention_q_list, dim=-1)
+ q = self.dropout(self.fc(q))
+ q += residual
+
+ return q
+
+
+class PositionwiseFeedForward(nn.Module):
+ """A two-feed-forward-layer module."""
+
+ def __init__(self, d_in, d_hid, dropout=0.1):
+ super().__init__()
+ self.w_1 = nn.Linear(d_in, d_hid) # position-wise
+ self.w_2 = nn.Linear(d_hid, d_in) # position-wise
+ self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+
+ residual = x
+ x = self.layer_norm(x)
+
+ x = self.w_2(F.relu(self.w_1(x)))
+ x = self.dropout(x)
+ x += residual
+
+ return x
+
+
+class PositionalEncoding(nn.Module):
+
+ def __init__(self, d_hid=512, n_position=200):
+ super().__init__()
+
+ # Not a parameter
+ self.register_buffer(
+ 'position_table',
+ self._get_sinusoid_encoding_table(n_position, d_hid))
+
+ def _get_sinusoid_encoding_table(self, n_position, d_hid):
+ """Sinusoid position encoding table."""
+ denominator = torch.Tensor([
+ 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
+ for hid_j in range(d_hid)
+ ])
+ denominator = denominator.view(1, -1)
+ pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
+ sinusoid_table = pos_tensor * denominator
+ sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
+ sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
+
+ return sinusoid_table.unsqueeze(0)
+
+ def forward(self, x):
+ self.device = x.device
+ return x + self.position_table[:, :x.size(1)].clone().detach()
+
+
+def get_pad_mask(seq, pad_idx):
+ return (seq != pad_idx).unsqueeze(-2)
+
+
+def get_subsequent_mask(seq):
+ """For masking out the subsequent info."""
+ len_s = seq.size(1)
+ subsequent_mask = 1 - torch.triu(
+ torch.ones((len_s, len_s), device=seq.device), diagonal=1)
+ subsequent_mask = subsequent_mask.unsqueeze(0).bool()
+ return subsequent_mask
diff --git a/mmocr/models/textrecog/losses/__init__.py b/mmocr/models/textrecog/losses/__init__.py
new file mode 100755
index 00000000..4b5a24a2
--- /dev/null
+++ b/mmocr/models/textrecog/losses/__init__.py
@@ -0,0 +1,5 @@
+from .ce_loss import CELoss, SARLoss, TFLoss
+from .ctc_loss import CTCLoss
+from .seg_loss import CAFCNLoss, SegLoss
+
+__all__ = ['CELoss', 'SARLoss', 'CTCLoss', 'TFLoss', 'SegLoss', 'CAFCNLoss']
diff --git a/mmocr/models/textrecog/losses/ce_loss.py b/mmocr/models/textrecog/losses/ce_loss.py
new file mode 100644
index 00000000..4fad5ae4
--- /dev/null
+++ b/mmocr/models/textrecog/losses/ce_loss.py
@@ -0,0 +1,94 @@
+import torch.nn as nn
+
+from mmdet.models.builder import LOSSES
+
+
+@LOSSES.register_module()
+class CELoss(nn.Module):
+ """Implementation of loss module for encoder-decoder based text recognition
+ method with CrossEntropy loss.
+
+ Args:
+ ignore_index (int): Specifies a target value that is
+ ignored and does not contribute to the input gradient.
+ reduction (str): Specifies the reduction to apply to the output,
+ should be one of the following: ('none', 'mean', 'sum').
+ """
+
+ def __init__(self, ignore_index=-1, reduction='none'):
+ super().__init__()
+ assert isinstance(ignore_index, int)
+ assert isinstance(reduction, str)
+ assert reduction in ['none', 'mean', 'sum']
+
+ self.loss_ce = nn.CrossEntropyLoss(
+ ignore_index=ignore_index, reduction=reduction)
+
+ def format(self, outputs, targets_dict):
+ targets = targets_dict['padded_targets']
+
+ return outputs.permute(0, 2, 1).contiguous(), targets
+
+ def forward(self, outputs, targets_dict):
+ outputs, targets = self.format(outputs, targets_dict)
+
+ loss_ce = self.loss_ce(outputs, targets.to(outputs.device))
+ losses = dict(loss_ce=loss_ce)
+
+ return losses
+
+
+@LOSSES.register_module()
+class SARLoss(CELoss):
+ """Implementation of loss module in `SAR.
+
+ `_.
+
+ Args:
+ ignore_index (int): Specifies a target value that is
+ ignored and does not contribute to the input gradient.
+ reduction (str): Specifies the reduction to apply to the output,
+ should be one of the following: ('none', 'mean', 'sum').
+ """
+
+ def __init__(self, ignore_index=0, reduction='mean', **kwargs):
+ super().__init__(ignore_index, reduction)
+
+ def format(self, outputs, targets_dict):
+ targets = targets_dict['padded_targets']
+ # targets[0, :], [start_idx, idx1, idx2, ..., end_idx, pad_idx...]
+ # outputs[0, :, 0], [idx1, idx2, ..., end_idx, ...]
+
+ # ignore first index of target in loss calculation
+ targets = targets[:, 1:].contiguous()
+ # ignore last index of outputs to be in same seq_len with targets
+ outputs = outputs[:, :-1, :].permute(0, 2, 1).contiguous()
+
+ return outputs, targets
+
+
+@LOSSES.register_module()
+class TFLoss(CELoss):
+ """Implementation of loss module for transformer."""
+
+ def __init__(self,
+ ignore_index=-1,
+ reduction='none',
+ flatten=True,
+ **kwargs):
+ super().__init__(ignore_index, reduction)
+ assert isinstance(flatten, bool)
+
+ self.flatten = flatten
+
+ def format(self, outputs, targets_dict):
+ outputs = outputs[:, :-1, :].contiguous()
+ targets = targets_dict['padded_targets']
+ targets = targets[:, 1:].contiguous()
+ if self.flatten:
+ outputs = outputs.view(-1, outputs.size(-1))
+ targets = targets.view(-1)
+ else:
+ outputs = outputs.permute(0, 2, 1).contiguous()
+
+ return outputs, targets
diff --git a/mmocr/models/textrecog/losses/ctc_loss.py b/mmocr/models/textrecog/losses/ctc_loss.py
new file mode 100644
index 00000000..231469a2
--- /dev/null
+++ b/mmocr/models/textrecog/losses/ctc_loss.py
@@ -0,0 +1,63 @@
+import torch
+import torch.nn as nn
+
+from mmdet.models.builder import LOSSES
+
+
+@LOSSES.register_module()
+class CTCLoss(nn.Module):
+ """Implementation of loss module for CTC-loss based text recognition.
+
+ Args:
+ flatten (bool): If True, use flattened targets, else padded targets.
+ blank (int): Blank label. Default 0.
+ reduction (str): Specifies the reduction to apply to the output,
+ should be one of the following: ('none', 'mean', 'sum').
+ zero_infinity (bool): Whether to zero infinite losses and
+ the associated gradients. Default: False.
+ Infinite losses mainly occur when the inputs
+ are too short to be aligned to the targets.
+ """
+
+ def __init__(self,
+ flatten=True,
+ blank=0,
+ reduction='mean',
+ zero_infinity=False,
+ **kwargs):
+ super().__init__()
+ assert isinstance(flatten, bool)
+ assert isinstance(blank, int)
+ assert isinstance(reduction, str)
+ assert isinstance(zero_infinity, bool)
+
+ self.flatten = flatten
+ self.blank = blank
+ self.ctc_loss = nn.CTCLoss(
+ blank=blank, reduction=reduction, zero_infinity=zero_infinity)
+
+ def forward(self, outputs, targets_dict):
+
+ outputs = torch.log_softmax(outputs, dim=2)
+ bsz, seq_len = outputs.size(0), outputs.size(1)
+ input_lengths = torch.full(
+ size=(bsz, ), fill_value=seq_len, dtype=torch.long)
+ outputs_for_loss = outputs.permute(1, 0, 2).contiguous() # T * N * C
+
+ if self.flatten:
+ targets = targets_dict['flatten_targets']
+ else:
+ targets = torch.full(
+ size=(bsz, seq_len), fill_value=self.blank, dtype=torch.long)
+ for idx, tensor in enumerate(targets_dict['targets']):
+ valid_len = min(tensor.size(0), seq_len)
+ targets[idx, :valid_len] = tensor[:valid_len]
+
+ target_lengths = targets_dict['target_lengths']
+
+ loss_ctc = self.ctc_loss(outputs_for_loss, targets, input_lengths,
+ target_lengths)
+
+ losses = dict(loss_ctc=loss_ctc)
+
+ return losses
diff --git a/mmocr/models/textrecog/losses/seg_loss.py b/mmocr/models/textrecog/losses/seg_loss.py
new file mode 100644
index 00000000..0da9e61e
--- /dev/null
+++ b/mmocr/models/textrecog/losses/seg_loss.py
@@ -0,0 +1,176 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mmdet.models.builder import LOSSES
+from mmocr.models.common.losses import DiceLoss
+
+
+@LOSSES.register_module()
+class SegLoss(nn.Module):
+ """Implementation of loss module for segmentation based text recognition
+ method.
+
+ Args:
+ seg_downsample_ratio (float): Downsample ratio of
+ segmentation map.
+ seg_with_loss_weight (bool): If True, set weight for
+ segmentation loss.
+ ignore_index (int): Specifies a target value that is ignored
+ and does not contribute to the input gradient.
+ """
+
+ def __init__(self,
+ seg_downsample_ratio=0.5,
+ seg_with_loss_weight=True,
+ ignore_index=255,
+ **kwargs):
+ super().__init__()
+
+ assert isinstance(seg_downsample_ratio, (int, float))
+ assert 0 < seg_downsample_ratio <= 1
+ assert isinstance(ignore_index, int)
+
+ self.seg_downsample_ratio = seg_downsample_ratio
+ self.seg_with_loss_weight = seg_with_loss_weight
+ self.ignore_index = ignore_index
+
+ def seg_loss(self, out_head, gt_kernels):
+ seg_map = out_head # bsz * num_classes * H/2 * W/2
+ seg_target = [
+ item[1].rescale(self.seg_downsample_ratio).to_tensor(
+ torch.long, seg_map.device) for item in gt_kernels
+ ]
+ seg_target = torch.stack(seg_target).squeeze(1)
+
+ loss_weight = None
+ if self.seg_with_loss_weight:
+ N = torch.sum(seg_target != self.ignore_index)
+ N_neg = torch.sum(seg_target == 0)
+ weight_val = 1.0 * N_neg / (N - N_neg)
+ loss_weight = torch.ones(seg_map.size(1), device=seg_map.device)
+ loss_weight[1:] = weight_val
+
+ loss_seg = F.cross_entropy(
+ seg_map,
+ seg_target,
+ weight=loss_weight,
+ ignore_index=self.ignore_index)
+
+ return loss_seg
+
+ def forward(self, out_neck, out_head, gt_kernels):
+
+ losses = {}
+
+ loss_seg = self.seg_loss(out_head, gt_kernels)
+
+ losses['loss_seg'] = loss_seg
+
+ return losses
+
+
+@LOSSES.register_module()
+class CAFCNLoss(SegLoss):
+ """Implementation of loss module in `CA-FCN.
+
+ `_
+
+ Args:
+ alpha (float): Weight ratio of attention loss.
+ attn_s2_downsample_ratio (float): Downsample ratio
+ of attention map from output stage 2.
+ attn_s3_downsample_ratio (float): Downsample ratio
+ of attention map from output stage 3.
+ seg_downsample_ratio (float): Downsample ratio of
+ segmentation map.
+ attn_with_dice_loss (bool): If True, use dice_loss for attention,
+ else BCELoss.
+ with_attn (bool): If True, include attention loss, else
+ segmentation loss only.
+ seg_with_loss_weight (bool): If True, set weight for
+ segmentation loss.
+ ignore_index (int): Specifies a target value that is ignored
+ and does not contribute to the input gradient.
+ """
+
+ def __init__(self,
+ alpha=1.0,
+ attn_s2_downsample_ratio=0.25,
+ attn_s3_downsample_ratio=0.125,
+ seg_downsample_ratio=0.5,
+ attn_with_dice_loss=False,
+ with_attn=True,
+ seg_with_loss_weight=True,
+ ignore_index=255):
+ super().__init__(seg_downsample_ratio, seg_with_loss_weight,
+ ignore_index)
+ assert isinstance(alpha, (int, float))
+ assert isinstance(attn_s2_downsample_ratio, (int, float))
+ assert isinstance(attn_s3_downsample_ratio, (int, float))
+ assert 0 < attn_s2_downsample_ratio <= 1
+ assert 0 < attn_s3_downsample_ratio <= 1
+
+ self.alpha = alpha
+ self.attn_s2_downsample_ratio = attn_s2_downsample_ratio
+ self.attn_s3_downsample_ratio = attn_s3_downsample_ratio
+ self.with_attn = with_attn
+ self.attn_with_dice_loss = attn_with_dice_loss
+
+ # attention loss
+ if with_attn:
+ if attn_with_dice_loss:
+ self.criterion_attn = DiceLoss()
+ else:
+ self.criterion_attn = nn.BCELoss(reduction='none')
+
+ def attn_loss(self, out_neck, gt_kernels):
+ attn_map_s2 = out_neck[0] # bsz * 2 * H/4 * W/4
+
+ mask_s2 = torch.stack([
+ item[2].rescale(self.attn_s2_downsample_ratio).to_tensor(
+ torch.float, attn_map_s2.device) for item in gt_kernels
+ ])
+
+ attn_target_s2 = torch.stack([
+ item[0].rescale(self.attn_s2_downsample_ratio).to_tensor(
+ torch.float, attn_map_s2.device) for item in gt_kernels
+ ])
+
+ mask_s3 = torch.stack([
+ item[2].rescale(self.attn_s3_downsample_ratio).to_tensor(
+ torch.float, attn_map_s2.device) for item in gt_kernels
+ ])
+
+ attn_target_s3 = torch.stack([
+ item[0].rescale(self.attn_s3_downsample_ratio).to_tensor(
+ torch.float, attn_map_s2.device) for item in gt_kernels
+ ])
+
+ targets = [
+ attn_target_s2, attn_target_s3, attn_target_s3, attn_target_s3
+ ]
+
+ masks = [mask_s2, mask_s3, mask_s3, mask_s3]
+
+ loss_attn = 0.
+ for i in range(len(out_neck) - 1):
+ pred = out_neck[i]
+ if self.attn_with_dice_loss:
+ loss_attn += self.criterion_attn(pred, targets[i], masks[i])
+ else:
+ loss_attn += torch.sum(
+ self.criterion_attn(pred, targets[i]) *
+ masks[i]) / torch.sum(masks[i])
+
+ return loss_attn
+
+ def forward(self, out_neck, out_head, gt_kernels):
+
+ losses = super().forward(out_neck, out_head, gt_kernels)
+
+ if self.with_attn:
+ loss_attn = self.attn_loss(out_neck, gt_kernels)
+ losses['loss_attn'] = loss_attn
+
+ return losses
diff --git a/mmocr/models/textrecog/necks/__init__.py b/mmocr/models/textrecog/necks/__init__.py
new file mode 100755
index 00000000..c10a46a5
--- /dev/null
+++ b/mmocr/models/textrecog/necks/__init__.py
@@ -0,0 +1,5 @@
+from .cafcn_neck import CAFCNNeck
+from .fpn_ocr import FPNOCR
+from .fpn_seg import FPNSeg
+
+__all__ = ['CAFCNNeck', 'FPNSeg', 'FPNOCR']
diff --git a/mmocr/models/textrecog/necks/cafcn_neck.py b/mmocr/models/textrecog/necks/cafcn_neck.py
new file mode 100644
index 00000000..b5fca8d2
--- /dev/null
+++ b/mmocr/models/textrecog/necks/cafcn_neck.py
@@ -0,0 +1,223 @@
+import torch
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.ops import DeformConv2dPack
+from torch import nn
+
+from mmdet.models.builder import NECKS
+
+
+class CharAttn(nn.Module):
+ """Implementation of Character attention module in `CA-FCN.
+
+ `_
+ """
+
+ def __init__(self, in_channels=128, out_channels=128, deformable=False):
+ super().__init__()
+ assert isinstance(in_channels, int)
+ assert isinstance(deformable, bool)
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.deformable = deformable
+
+ # attention layers
+ self.attn_layer = nn.Sequential(
+ ConvModule(
+ in_channels,
+ in_channels,
+ 3,
+ stride=1,
+ padding=1,
+ norm_cfg=dict(type='BN')),
+ ConvModule(
+ in_channels,
+ 1,
+ 3,
+ stride=1,
+ padding=1,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='Sigmoid')))
+
+ conv_kwargs = dict(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3),
+ stride=1,
+ padding=1)
+ if self.deformable:
+ self.conv = DeformConv2dPack(**conv_kwargs)
+ else:
+ self.conv = nn.Conv2d(**conv_kwargs)
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, in_feat):
+ # Calculate attn map
+ attn_map = self.attn_layer(in_feat) # N * 1 * H * W
+
+ in_feat = self.relu(self.bn(self.conv(in_feat)))
+
+ out_feat_map = self._upsample_mul(in_feat, 1 + attn_map)
+
+ return out_feat_map, attn_map
+
+ def _upsample_add(self, x, y):
+ return F.interpolate(x, size=y.size()[2:]) + y
+
+ def _upsample_mul(self, x, y):
+ return F.interpolate(x, size=y.size()[2:]) * y
+
+
+class FeatGenerator(nn.Module):
+ """Generate attention-augmented stage feature from backbone stage
+ feature."""
+
+ def __init__(self,
+ in_channels=512,
+ out_channels=128,
+ deformable=True,
+ concat=False,
+ upsample=False,
+ with_attn=True):
+ super().__init__()
+
+ self.concat = concat
+ self.upsample = upsample
+ self.with_attn = with_attn
+
+ if with_attn:
+ self.char_attn = CharAttn(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ deformable=deformable)
+ else:
+ self.char_attn = ConvModule(
+ in_channels,
+ out_channels,
+ 3,
+ stride=1,
+ padding=1,
+ norm_cfg=dict(type='BN'))
+
+ if concat:
+ self.conv_to_concat = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ stride=1,
+ padding=1,
+ norm_cfg=dict(type='BN'))
+
+ kernel_size = (3, 1) if deformable else 3
+ padding = (1, 0) if deformable else 1
+ tmp_in_channels = out_channels * 2 if concat else out_channels
+
+ self.conv_after_concat = ConvModule(
+ tmp_in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ norm_cfg=dict(type='BN'))
+
+ def forward(self, x, y=None, size=None):
+ if self.with_attn:
+ feat_map, attn_map = self.char_attn(x)
+ else:
+ feat_map = self.char_attn(x)
+ attn_map = feat_map
+
+ if self.concat:
+ y = self.conv_to_concat(y)
+ feat_map = torch.cat((y, feat_map), dim=1)
+
+ feat_map = self.conv_after_concat(feat_map)
+
+ if self.upsample:
+ feat_map = F.interpolate(feat_map, size)
+
+ return attn_map, feat_map
+
+
+@NECKS.register_module()
+class CAFCNNeck(nn.Module):
+ """Implementation of neck module in `CA-FCN.
+
+ `_
+
+ Args:
+ in_channels (list[int]): Number of input channels for each scale.
+ out_channels (int): Number of output channels for each scale.
+ deformable (bool): If True, use deformable conv.
+ with_attn (bool): If True, add attention for each output feature map.
+ """
+
+ def __init__(self,
+ in_channels=[128, 256, 512, 512],
+ out_channels=128,
+ deformable=True,
+ with_attn=True):
+ super().__init__()
+
+ self.deformable = deformable
+ self.with_attn = with_attn
+
+ # stage_in5_to_out5
+ self.s5 = FeatGenerator(
+ in_channels=in_channels[-1],
+ out_channels=out_channels,
+ deformable=deformable,
+ concat=False,
+ with_attn=with_attn)
+
+ # stage_in4_to_out4
+ self.s4 = FeatGenerator(
+ in_channels=in_channels[-2],
+ out_channels=out_channels,
+ deformable=deformable,
+ concat=True,
+ with_attn=with_attn)
+
+ # stage_in3_to_out3
+ self.s3 = FeatGenerator(
+ in_channels=in_channels[-3],
+ out_channels=out_channels,
+ deformable=False,
+ concat=True,
+ upsample=True,
+ with_attn=with_attn)
+
+ # stage_in2_to_out2
+ self.s2 = FeatGenerator(
+ in_channels=in_channels[-4],
+ out_channels=out_channels,
+ deformable=False,
+ concat=True,
+ upsample=True,
+ with_attn=with_attn)
+
+ def init_weights(self):
+ pass
+
+ def forward(self, feats):
+ in_stage1 = feats[0]
+ in_stage2, in_stage3 = feats[1], feats[2]
+ in_stage4, in_stage5 = feats[3], feats[4]
+ # out stage 5
+ out_s5_attn_map, out_s5 = self.s5(in_stage5)
+
+ # out stage 4
+ out_s4_attn_map, out_s4 = self.s4(in_stage4, out_s5)
+
+ # out stage 3
+ out_s3_attn_map, out_s3 = self.s3(in_stage3, out_s4,
+ in_stage2.size()[2:])
+
+ # out stage 2
+ out_s2_attn_map, out_s2 = self.s2(in_stage2, out_s3,
+ in_stage1.size()[2:])
+
+ return (out_s2_attn_map, out_s3_attn_map, out_s4_attn_map,
+ out_s5_attn_map, out_s2)
diff --git a/mmocr/models/textrecog/necks/fpn_ocr.py b/mmocr/models/textrecog/necks/fpn_ocr.py
new file mode 100644
index 00000000..c1e9f178
--- /dev/null
+++ b/mmocr/models/textrecog/necks/fpn_ocr.py
@@ -0,0 +1,70 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+
+from mmdet.models.builder import NECKS
+
+
+@NECKS.register_module()
+class FPNOCR(nn.Module):
+ """FPN-like Network for segmentation based text recognition.
+
+ Args:
+ in_channels (list[int]): Number of input channels for each scale.
+ out_channels (int): Number of output channels for each scale.
+ last_stage_only (bool): If True, output last stage only.
+ """
+
+ def __init__(self, in_channels, out_channels, last_stage_only=True):
+ super(FPNOCR, self).__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+
+ self.last_stage_only = last_stage_only
+
+ self.lateral_convs = nn.ModuleList()
+ self.smooth_convs_1x1 = nn.ModuleList()
+ self.smooth_convs_3x3 = nn.ModuleList()
+
+ for i in range(self.num_ins):
+ l_conv = ConvModule(
+ in_channels[i], out_channels, 1, norm_cfg=dict(type='BN'))
+ self.lateral_convs.append(l_conv)
+
+ for i in range(self.num_ins - 1):
+ s_conv_1x1 = ConvModule(
+ out_channels * 2, out_channels, 1, norm_cfg=dict(type='BN'))
+ s_conv_3x3 = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ norm_cfg=dict(type='BN'))
+ self.smooth_convs_1x1.append(s_conv_1x1)
+ self.smooth_convs_3x3.append(s_conv_3x3)
+
+ def init_weights(self):
+ pass
+
+ def _upsample_x2(self, x):
+ return F.interpolate(x, scale_factor=2, mode='bilinear')
+
+ def forward(self, inputs):
+ lateral_features = [
+ l_conv(inputs[i]) for i, l_conv in enumerate(self.lateral_convs)
+ ]
+
+ outs = []
+ for i in range(len(self.smooth_convs_3x3), 0, -1): # 3, 2, 1
+ last_out = lateral_features[-1] if len(outs) == 0 else outs[-1]
+ upsample = self._upsample_x2(last_out)
+ upsample_cat = torch.cat((upsample, lateral_features[i - 1]),
+ dim=1)
+ smooth_1x1 = self.smooth_convs_1x1[i - 1](upsample_cat)
+ smooth_3x3 = self.smooth_convs_3x3[i - 1](smooth_1x1)
+ outs.append(smooth_3x3)
+
+ return tuple(outs[-1:]) if self.last_stage_only else tuple(outs)
diff --git a/mmocr/models/textrecog/necks/fpn_seg.py b/mmocr/models/textrecog/necks/fpn_seg.py
new file mode 100644
index 00000000..997951e4
--- /dev/null
+++ b/mmocr/models/textrecog/necks/fpn_seg.py
@@ -0,0 +1,43 @@
+import torch.nn.functional as F
+from mmcv.runner import auto_fp16
+
+from mmdet.models.builder import NECKS
+from mmdet.models.necks import FPN
+
+
+@NECKS.register_module()
+class FPNSeg(FPN):
+ """Feature Pyramid Network for segmentation based text recognition.
+
+ Args:
+ in_channels (list[int]): Number of input channels for each scale.
+ out_channels (int): Number of output channels for each scale.
+ num_outs (int): Number of output scales.
+ upsample_param (dict | None): Config dict for interpolate layer.
+ Default: `dict(scale_factor=1.0, mode='nearest')`
+ last_stage_only (bool): If True, output last stage of FPN only.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ upsample_param=None,
+ last_stage_only=True):
+ super().__init__(in_channels, out_channels, num_outs)
+ self.upsample_param = upsample_param
+ self.last_stage_only = last_stage_only
+
+ @auto_fp16()
+ def forward(self, inputs):
+ outs = super().forward(inputs)
+
+ outs = list(outs)
+
+ if self.upsample_param is not None:
+ outs[0] = F.interpolate(outs[0], **self.upsample_param)
+
+ if self.last_stage_only:
+ return tuple(outs[0:1])
+
+ return tuple(outs[::-1])
diff --git a/mmocr/models/textrecog/recognizer/__init__.py b/mmocr/models/textrecog/recognizer/__init__.py
new file mode 100644
index 00000000..abc9c132
--- /dev/null
+++ b/mmocr/models/textrecog/recognizer/__init__.py
@@ -0,0 +1,13 @@
+from .base import BaseRecognizer
+from .cafcn import CAFCNNet
+from .crnn import CRNNNet
+from .encode_decode_recognizer import EncodeDecodeRecognizer
+from .robust_scanner import RobustScanner
+from .sar import SARNet
+from .seg_recognizer import SegRecognizer
+from .transformer import TransformerNet
+
+__all__ = [
+ 'BaseRecognizer', 'EncodeDecodeRecognizer', 'CRNNNet', 'SARNet',
+ 'TransformerNet', 'SegRecognizer', 'RobustScanner', 'CAFCNNet'
+]
diff --git a/mmocr/models/textrecog/recognizer/base.py b/mmocr/models/textrecog/recognizer/base.py
new file mode 100644
index 00000000..6472a5a2
--- /dev/null
+++ b/mmocr/models/textrecog/recognizer/base.py
@@ -0,0 +1,236 @@
+import warnings
+from abc import ABCMeta, abstractmethod
+from collections import OrderedDict
+
+import mmcv
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from mmcv.runner import auto_fp16
+from mmcv.utils import print_log
+
+from mmdet.utils import get_root_logger
+from mmocr.core import imshow_text_label
+
+
+class BaseRecognizer(nn.Module, metaclass=ABCMeta):
+ """Base class for text recognition."""
+
+ def __init__(self):
+ super().__init__()
+ self.fp16_enabled = False
+
+ @abstractmethod
+ def extract_feat(self, imgs):
+ """Extract features from images."""
+ pass
+
+ @abstractmethod
+ def forward_train(self, imgs, img_metas, **kwargs):
+ """
+ Args:
+ img (tensor): tensors with shape (N, C, H, W).
+ Typically should be mean centered and std scaled.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details of the values of these keys, see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ kwargs (keyword arguments): Specific to concrete implementation.
+ """
+ pass
+
+ @abstractmethod
+ def simple_test(self, img, img_metas, **kwargs):
+ pass
+
+ @abstractmethod
+ def aug_test(self, imgs, img_metas, **kwargs):
+ """Test function with test time augmentation.
+
+ Args:
+ imgs (list[tensor]): Tensor should have shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (list[list[dict]]): The metadata of images.
+ """
+ pass
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights for detector.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if pretrained is not None:
+ logger = get_root_logger()
+ print_log(f'load model from: {pretrained}', logger=logger)
+
+ def forward_test(self, imgs, img_metas, **kwargs):
+ """
+ Args:
+ imgs (tensor | list[tensor]): Tensor should have shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (list[dict] | list[list[dict]]):
+ The outer list indicates images in a batch.
+ """
+ if isinstance(imgs, list):
+ assert len(imgs) == len(img_metas)
+ assert len(imgs) > 0
+ assert imgs[0].size(0) == 1, 'aug test does not support ' \
+ 'inference with batch size ' \
+ f'{imgs[0].size(0)}'
+ return self.aug_test(imgs, img_metas, **kwargs)
+
+ return self.simple_test(imgs, img_metas, **kwargs)
+
+ @auto_fp16(apply_to=('img', ))
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
+ """Calls either :func:`forward_train` or :func:`forward_test` depending
+ on whether ``return_loss`` is ``True``.
+
+ Note that img and img_meta are single-nested (i.e. tensor and
+ list[dict]).
+ """
+ if return_loss:
+ return self.forward_train(img, img_metas, **kwargs)
+
+ return self.forward_test(img, img_metas, **kwargs)
+
+ def _parse_losses(self, losses):
+ """Parse the raw outputs (losses) of the network.
+
+ Args:
+ losses (dict): Raw outputs of the network, which usually contain
+ losses and other necessary infomation.
+
+ Returns:
+ tuple[tensor, dict]: (loss, log_vars), loss is the loss tensor \
+ which may be a weighted sum of all losses, log_vars contains \
+ all the variables to be sent to the logger.
+ """
+ log_vars = OrderedDict()
+ for loss_name, loss_value in losses.items():
+ if isinstance(loss_value, torch.Tensor):
+ log_vars[loss_name] = loss_value.mean()
+ elif isinstance(loss_value, list):
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+ else:
+ raise TypeError(
+ f'{loss_name} is not a tensor or list of tensors')
+
+ loss = sum(_value for _key, _value in log_vars.items()
+ if 'loss' in _key)
+
+ log_vars['loss'] = loss
+ for loss_name, loss_value in log_vars.items():
+ # reduce loss when distributed training
+ if dist.is_available() and dist.is_initialized():
+ loss_value = loss_value.data.clone()
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
+ log_vars[loss_name] = loss_value.item()
+
+ return loss, log_vars
+
+ def train_step(self, data, optimizer):
+ """The iteration step during training.
+
+ This method defines an iteration step during training, except for the
+ back propagation and optimizer update, which are done by an optimizer
+ hook. Note that in some complicated cases or models (e.g. GAN),
+ the whole process (including the back propagation and optimizer update)
+ is also defined by this method.
+
+ Args:
+ data (dict): The outputs of dataloader.
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
+ runner is passed to ``train_step()``. This argument is unused
+ and reserved.
+
+ Returns:
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
+ ``num_samples``.
+
+ - ``loss`` is a tensor for back propagation, which is a \
+ weighted sum of multiple losses.
+ - ``log_vars`` contains all the variables to be sent to the
+ logger.
+ - ``num_samples`` indicates the batch size used for \
+ averaging the logs (Note: for the \
+ DDP model, num_samples refers to the batch size for each GPU).
+ """
+ losses = self(**data)
+ loss, log_vars = self._parse_losses(losses)
+
+ outputs = dict(
+ loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
+
+ return outputs
+
+ def val_step(self, data, optimizer):
+ """The iteration step during validation.
+
+ This method shares the same signature as :func:`train_step`, but is
+ used during val epochs. Note that the evaluation after training epochs
+ is not implemented by this method, but by an evaluation hook.
+ """
+ losses = self(**data)
+ loss, log_vars = self._parse_losses(losses)
+
+ outputs = dict(
+ loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
+
+ return outputs
+
+ def show_result(self,
+ img,
+ result,
+ gt_label='',
+ win_name='',
+ show=False,
+ wait_time=0,
+ out_file=None,
+ **kwargs):
+ """Draw `result` on `img`.
+
+ Args:
+ img (str or tensor): The image to be displayed.
+ result (dict): The results to draw on `img`.
+ gt_label (str): Ground truth label of img.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ Default: 0.
+ show (bool): Whether to show the image.
+ Default: False.
+ out_file (str or None): The output filename.
+ Default: None.
+
+ Returns:
+ img (tensor): Only if not `show` or `out_file`.
+ """
+ img = mmcv.imread(img)
+ img = img.copy()
+ pred_label = None
+ if 'text' in result.keys():
+ pred_label = result['text']
+
+ # if out_file specified, do not show image in window
+ if out_file is not None:
+ show = False
+ # draw text label
+ if pred_label is not None:
+ img = imshow_text_label(
+ img,
+ pred_label,
+ gt_label,
+ show=show,
+ win_name=win_name,
+ wait_time=wait_time,
+ out_file=out_file)
+
+ if not (show or out_file):
+ warnings.warn('show==False and out_file is not specified, only '
+ 'result image will be returned')
+ return img
+
+ return img
diff --git a/mmocr/models/textrecog/recognizer/cafcn.py b/mmocr/models/textrecog/recognizer/cafcn.py
new file mode 100644
index 00000000..6acade5e
--- /dev/null
+++ b/mmocr/models/textrecog/recognizer/cafcn.py
@@ -0,0 +1,7 @@
+from mmdet.models.builder import DETECTORS
+from .seg_recognizer import SegRecognizer
+
+
+@DETECTORS.register_module()
+class CAFCNNet(SegRecognizer):
+ """Implementation of `CAFCN `_"""
diff --git a/mmocr/models/textrecog/recognizer/crnn.py b/mmocr/models/textrecog/recognizer/crnn.py
new file mode 100644
index 00000000..1ff68f7f
--- /dev/null
+++ b/mmocr/models/textrecog/recognizer/crnn.py
@@ -0,0 +1,18 @@
+import torch
+import torch.nn.functional as F
+
+from mmdet.models.builder import DETECTORS
+from .encode_decode_recognizer import EncodeDecodeRecognizer
+
+
+@DETECTORS.register_module()
+class CRNNNet(EncodeDecodeRecognizer):
+ """CTC-loss based recognizer."""
+
+ def forward_conversion(self, params, img):
+ x = self.extract_feat(img)
+ x = self.encoder(x)
+ outs = self.decoder(x)
+ outs = F.softmax(outs, dim=2)
+ params = torch.pow(params, 1)
+ return outs, params
diff --git a/mmocr/models/textrecog/recognizer/encode_decode_recognizer.py b/mmocr/models/textrecog/recognizer/encode_decode_recognizer.py
new file mode 100644
index 00000000..793f3853
--- /dev/null
+++ b/mmocr/models/textrecog/recognizer/encode_decode_recognizer.py
@@ -0,0 +1,173 @@
+from mmdet.models.builder import DETECTORS, build_backbone, build_loss
+from mmocr.models.builder import (build_convertor, build_decoder,
+ build_encoder, build_preprocessor)
+from .base import BaseRecognizer
+
+
+@DETECTORS.register_module()
+class EncodeDecodeRecognizer(BaseRecognizer):
+ """Base class for encode-decode recognizer."""
+
+ def __init__(self,
+ preprocessor=None,
+ backbone=None,
+ encoder=None,
+ decoder=None,
+ loss=None,
+ label_convertor=None,
+ train_cfg=None,
+ test_cfg=None,
+ max_seq_len=40,
+ pretrained=None):
+ super().__init__()
+
+ # Label convertor (str2tensor, tensor2str)
+ assert label_convertor is not None
+ label_convertor.update(max_seq_len=max_seq_len)
+ self.label_convertor = build_convertor(label_convertor)
+
+ # Preprocessor module, e.g., TPS
+ self.preprocessor = None
+ if preprocessor is not None:
+ self.preprocessor = build_preprocessor(preprocessor)
+
+ # Backbone
+ assert backbone is not None
+ self.backbone = build_backbone(backbone)
+
+ # Encoder module
+ self.encoder = None
+ if encoder is not None:
+ self.encoder = build_encoder(encoder)
+
+ # Decoder module
+ assert decoder is not None
+ decoder.update(num_classes=self.label_convertor.num_classes())
+ decoder.update(start_idx=self.label_convertor.start_idx)
+ decoder.update(padding_idx=self.label_convertor.padding_idx)
+ decoder.update(max_seq_len=max_seq_len)
+ self.decoder = build_decoder(decoder)
+
+ # Loss
+ assert loss is not None
+ loss.update(ignore_index=self.label_convertor.padding_idx)
+ self.loss = build_loss(loss)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.max_seq_len = max_seq_len
+ self.init_weights(pretrained=pretrained)
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights of recognizer."""
+ super().init_weights(pretrained)
+
+ if self.preprocessor is not None:
+ self.preprocessor.init_weights()
+
+ self.backbone.init_weights()
+
+ if self.encoder is not None:
+ self.encoder.init_weights()
+
+ self.decoder.init_weights()
+
+ def extract_feat(self, img):
+ """Directly extract features from the backbone."""
+ if self.preprocessor is not None:
+ img = self.preprocessor(img)
+
+ x = self.backbone(img)
+
+ return x
+
+ def forward_train(self, img, img_metas):
+ """
+ Args:
+ img (tensor): Input images of shape (N, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): A list of image info dict where each dict
+ contains: 'img_shape', 'filename', and may also contain
+ 'ori_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ :class:`mmdet.datasets.pipelines.Collect`.
+
+ Returns:
+ dict[str, tensor]: A dictionary of loss components.
+ """
+ feat = self.extract_feat(img)
+
+ gt_labels = [img_meta['text'] for img_meta in img_metas]
+
+ targets_dict = self.label_convertor.str2tensor(gt_labels)
+
+ out_enc = None
+ if self.encoder is not None:
+ out_enc = self.encoder(feat, img_metas)
+
+ out_dec = self.decoder(
+ feat, out_enc, targets_dict, img_metas, train_mode=True)
+
+ loss_inputs = (
+ out_dec,
+ targets_dict,
+ )
+ losses = self.loss(*loss_inputs)
+
+ return losses
+
+ def simple_test(self, img, img_metas, **kwargs):
+ """Test function with test time augmentation.
+
+ Args:
+ imgs (torch.Tensor): Image input tensor.
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ list[str]: Text label result of each image.
+ """
+ feat = self.extract_feat(img)
+
+ out_enc = None
+ if self.encoder is not None:
+ out_enc = self.encoder(feat, img_metas)
+
+ out_dec = self.decoder(
+ feat, out_enc, None, img_metas, train_mode=False)
+
+ label_indexes, label_scores = \
+ self.label_convertor.tensor2idx(out_dec, img_metas)
+ label_strings = self.label_convertor.idx2str(label_indexes)
+
+ # flatten batch results
+ results = []
+ for string, score in zip(label_strings, label_scores):
+ results.append(dict(text=string, score=score))
+
+ return results
+
+ def merge_aug_results(self, aug_results):
+ out_text, out_score = '', -1
+ for result in aug_results:
+ text = result[0]['text']
+ score = sum(result[0]['score']) / max(1, len(text))
+ if score > out_score:
+ out_text = text
+ out_score = score
+ out_results = [dict(text=out_text, score=out_score)]
+ return out_results
+
+ def aug_test(self, imgs, img_metas, **kwargs):
+ """Test function as well as time augmentation.
+
+ Args:
+ imgs (list[tensor]): Tensor should have shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (list[list[dict]]): The metadata of images.
+ """
+ aug_results = []
+ for img, img_meta in zip(imgs, img_metas):
+ result = self.simple_test(img, img_meta, **kwargs)
+ aug_results.append(result)
+
+ return self.merge_aug_results(aug_results)
diff --git a/mmocr/models/textrecog/recognizer/sar.py b/mmocr/models/textrecog/recognizer/sar.py
new file mode 100644
index 00000000..bce67dca
--- /dev/null
+++ b/mmocr/models/textrecog/recognizer/sar.py
@@ -0,0 +1,7 @@
+from mmdet.models.builder import DETECTORS
+from .encode_decode_recognizer import EncodeDecodeRecognizer
+
+
+@DETECTORS.register_module()
+class SARNet(EncodeDecodeRecognizer):
+ """Implementation of `SAR `_"""
diff --git a/mmocr/models/textrecog/recognizer/seg_recognizer.py b/mmocr/models/textrecog/recognizer/seg_recognizer.py
new file mode 100644
index 00000000..e013e1d7
--- /dev/null
+++ b/mmocr/models/textrecog/recognizer/seg_recognizer.py
@@ -0,0 +1,153 @@
+from mmdet.models.builder import (DETECTORS, build_backbone, build_head,
+ build_loss, build_neck)
+from mmocr.models.builder import build_convertor, build_preprocessor
+from .base import BaseRecognizer
+
+
+@DETECTORS.register_module()
+class SegRecognizer(BaseRecognizer):
+ """Base class for segmentation based recognizer."""
+
+ def __init__(self,
+ preprocessor=None,
+ backbone=None,
+ neck=None,
+ head=None,
+ loss=None,
+ label_convertor=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super().__init__()
+
+ # Label_convertor
+ assert label_convertor is not None
+ self.label_convertor = build_convertor(label_convertor)
+
+ # Preprocessor module, e.g., TPS
+ self.preprocessor = None
+ if preprocessor is not None:
+ self.preprocessor = build_preprocessor(preprocessor)
+
+ # Backbone
+ assert backbone is not None
+ self.backbone = build_backbone(backbone)
+
+ # Neck
+ assert neck is not None
+ self.neck = build_neck(neck)
+
+ # Head
+ assert head is not None
+ head.update(num_classes=self.label_convertor.num_classes())
+ self.head = build_head(head)
+
+ # Loss
+ assert loss is not None
+ self.loss = build_loss(loss)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.init_weights(pretrained=pretrained)
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights of recognizer."""
+ super().init_weights(pretrained)
+
+ if self.preprocessor is not None:
+ self.preprocessor.init_weights()
+
+ self.backbone.init_weights(pretrained=pretrained)
+
+ if self.neck is not None:
+ self.neck.init_weights()
+
+ self.head.init_weights()
+
+ def extract_feat(self, img):
+ """Directly extract features from the backbone."""
+ if self.preprocessor is not None:
+ img = self.preprocessor(img)
+
+ x = self.backbone(img)
+
+ return x
+
+ def forward_train(self, img, img_metas, gt_kernels=None):
+ """
+ Args:
+ img (tensor): Input images of shape (N, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): A list of image info dict where each dict
+ contains: 'img_shape', 'filename', and may also contain
+ 'ori_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ :class:`mmdet.datasets.pipelines.Collect`.
+
+ Returns:
+ dict[str, tensor]: A dictionary of loss components.
+ """
+
+ feats = self.extract_feat(img)
+
+ out_neck = self.neck(feats)
+
+ out_head = self.head(out_neck)
+
+ loss_inputs = (out_neck, out_head, gt_kernels)
+
+ losses = self.loss(*loss_inputs)
+
+ return losses
+
+ def simple_test(self, img, img_metas, **kwargs):
+ """Test function without test time augmentation.
+
+ Args:
+ imgs (torch.Tensor): Image input tensor.
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ list[str]: Text label result of each image.
+ """
+
+ feat = self.extract_feat(img)
+
+ out_neck = self.neck(feat)
+
+ out_head = self.head(out_neck)
+
+ texts, scores = self.label_convertor.tensor2str(out_head, img_metas)
+
+ # flatten batch results
+ results = []
+ for text, score in zip(texts, scores):
+ results.append(dict(text=text, score=score))
+
+ return results
+
+ def merge_aug_results(self, aug_results):
+ out_text, out_score = '', -1
+ for result in aug_results:
+ text = result[0]['text']
+ score = sum(result[0]['score']) / max(1, len(text))
+ if score > out_score:
+ out_text = text
+ out_score = score
+ out_results = [dict(text=out_text, score=out_score)]
+ return out_results
+
+ def aug_test(self, imgs, img_metas, **kwargs):
+ """Test function with test time augmentation.
+
+ Args:
+ imgs (list[tensor]): Tensor should have shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (list[list[dict]]): The metadata of images.
+ """
+ aug_results = []
+ for img, img_meta in zip(imgs, img_metas):
+ result = self.simple_test(img, img_meta, **kwargs)
+ aug_results.append(result)
+
+ return self.merge_aug_results(aug_results)
diff --git a/mmocr/models/textrecog/recognizer/transformer.py b/mmocr/models/textrecog/recognizer/transformer.py
new file mode 100644
index 00000000..86636424
--- /dev/null
+++ b/mmocr/models/textrecog/recognizer/transformer.py
@@ -0,0 +1,7 @@
+from mmdet.models.builder import DETECTORS
+from .encode_decode_recognizer import EncodeDecodeRecognizer
+
+
+@DETECTORS.register_module()
+class TransformerNet(EncodeDecodeRecognizer):
+ """Implementation of Transformer based OCR."""
diff --git a/tests/test_dataset/test_base_dataset.py b/tests/test_dataset/test_base_dataset.py
new file mode 100644
index 00000000..43260d37
--- /dev/null
+++ b/tests/test_dataset/test_base_dataset.py
@@ -0,0 +1,74 @@
+import os.path as osp
+import tempfile
+
+import numpy as np
+import pytest
+
+from mmocr.datasets.base_dataset import BaseDataset
+
+
+def _create_dummy_ann_file(ann_file):
+ ann_info1 = 'sample1.jpg hello'
+ ann_info2 = 'sample2.jpg world'
+
+ with open(ann_file, 'w') as fw:
+ for ann_info in [ann_info1, ann_info2]:
+ fw.write(ann_info + '\n')
+
+
+def _create_dummy_loader():
+ loader = dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(type='LineStrParser', keys=['file_name', 'text']))
+ return loader
+
+
+def test_custom_dataset():
+ tmp_dir = tempfile.TemporaryDirectory()
+ # create dummy data
+ ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
+ _create_dummy_ann_file(ann_file)
+ loader = _create_dummy_loader()
+
+ for mode in [True, False]:
+ dataset = BaseDataset(ann_file, loader, pipeline=[], test_mode=mode)
+
+ # test len
+ assert len(dataset) == len(dataset.data_infos)
+
+ # test set group flag
+ assert np.allclose(dataset.flag, [0, 0])
+
+ # test prepare_train_img
+ expect_results = {
+ 'img_info': {
+ 'file_name': 'sample1.jpg',
+ 'text': 'hello'
+ },
+ 'img_prefix': ''
+ }
+ assert dataset.prepare_train_img(0) == expect_results
+
+ # test prepare_test_img
+ assert dataset.prepare_test_img(0) == expect_results
+
+ # test __getitem__
+ assert dataset[0] == expect_results
+
+ # test get_next_index
+ assert dataset._get_next_index(0) == 1
+
+ # test format_resuls
+ expect_results_copy = {
+ key: value
+ for key, value in expect_results.items()
+ }
+ dataset.format_results(expect_results)
+ assert expect_results_copy == expect_results
+
+ # test evaluate
+ with pytest.raises(NotImplementedError):
+ dataset.evaluate(expect_results)
+
+ tmp_dir.cleanup()
diff --git a/tests/test_dataset/test_crop.py b/tests/test_dataset/test_crop.py
new file mode 100644
index 00000000..99155344
--- /dev/null
+++ b/tests/test_dataset/test_crop.py
@@ -0,0 +1,96 @@
+import math
+
+import numpy as np
+import pytest
+
+from mmocr.datasets.pipelines.crop import (box_jitter, convert_canonical,
+ crop_img, sort_vertex, warp_img)
+
+
+def test_order_vertex():
+ dummy_points_x = [20, 20, 120, 120]
+ dummy_points_y = [20, 40, 40, 20]
+
+ with pytest.raises(AssertionError):
+ sort_vertex([], dummy_points_y)
+ with pytest.raises(AssertionError):
+ sort_vertex(dummy_points_x, [])
+
+ ordered_points_x, ordered_points_y = sort_vertex(dummy_points_x,
+ dummy_points_y)
+
+ expect_points_x = [20, 120, 120, 20]
+ expect_points_y = [20, 20, 40, 40]
+
+ assert np.allclose(ordered_points_x, expect_points_x)
+ assert np.allclose(ordered_points_y, expect_points_y)
+
+
+def test_convert_canonical():
+ dummy_points_x = [120, 120, 20, 20]
+ dummy_points_y = [20, 40, 40, 20]
+
+ with pytest.raises(AssertionError):
+ convert_canonical([], dummy_points_y)
+ with pytest.raises(AssertionError):
+ convert_canonical(dummy_points_x, [])
+
+ ordered_points_x, ordered_points_y = convert_canonical(
+ dummy_points_x, dummy_points_y)
+
+ expect_points_x = [20, 120, 120, 20]
+ expect_points_y = [20, 20, 40, 40]
+
+ assert np.allclose(ordered_points_x, expect_points_x)
+ assert np.allclose(ordered_points_y, expect_points_y)
+
+
+def test_box_jitter():
+ dummy_points_x = [20, 120, 120, 20]
+ dummy_points_y = [20, 20, 40, 40]
+
+ kwargs = dict(jitter_ratio_x=0.0, jitter_ratio_y=0.0)
+
+ with pytest.raises(AssertionError):
+ box_jitter([], dummy_points_y)
+ with pytest.raises(AssertionError):
+ box_jitter(dummy_points_x, [])
+ with pytest.raises(AssertionError):
+ box_jitter(dummy_points_x, dummy_points_y, jitter_ratio_x=1.)
+ with pytest.raises(AssertionError):
+ box_jitter(dummy_points_x, dummy_points_y, jitter_ratio_y=1.)
+
+ box_jitter(dummy_points_x, dummy_points_y, **kwargs)
+
+ assert np.allclose(dummy_points_x, [20, 120, 120, 20])
+ assert np.allclose(dummy_points_y, [20, 20, 40, 40])
+
+
+def test_opencv_crop():
+ dummy_img = np.ones((600, 600, 3), dtype=np.uint8)
+ dummy_box = [20, 20, 120, 20, 120, 40, 20, 40]
+
+ cropped_img = warp_img(dummy_img, dummy_box)
+
+ with pytest.raises(AssertionError):
+ warp_img(dummy_img, [])
+ with pytest.raises(AssertionError):
+ warp_img(dummy_img, [20, 40, 40, 20])
+
+ assert math.isclose(cropped_img.shape[0], 20)
+ assert math.isclose(cropped_img.shape[1], 100)
+
+
+def test_min_rect_crop():
+ dummy_img = np.ones((600, 600, 3), dtype=np.uint8)
+ dummy_box = [20, 20, 120, 20, 120, 40, 20, 40]
+
+ cropped_img = crop_img(dummy_img, dummy_box)
+
+ with pytest.raises(AssertionError):
+ crop_img(dummy_img, [])
+ with pytest.raises(AssertionError):
+ crop_img(dummy_img, [20, 40, 40, 20])
+
+ assert math.isclose(cropped_img.shape[0], 20)
+ assert math.isclose(cropped_img.shape[1], 100)
diff --git a/tests/test_dataset/test_detect_dataset.py b/tests/test_dataset/test_detect_dataset.py
new file mode 100644
index 00000000..83480c45
--- /dev/null
+++ b/tests/test_dataset/test_detect_dataset.py
@@ -0,0 +1,83 @@
+import json
+import os.path as osp
+import tempfile
+
+import numpy as np
+
+from mmocr.datasets.text_det_dataset import TextDetDataset
+
+
+def _create_dummy_ann_file(ann_file):
+ ann_info1 = {
+ 'file_name':
+ 'sample1.jpg',
+ 'height':
+ 640,
+ 'width':
+ 640,
+ 'annotations': [{
+ 'iscrowd': 0,
+ 'category_id': 1,
+ 'bbox': [50, 70, 80, 100],
+ 'segmentation': [[50, 70, 80, 70, 80, 100, 50, 100]]
+ }, {
+ 'iscrowd':
+ 1,
+ 'category_id':
+ 1,
+ 'bbox': [120, 140, 200, 200],
+ 'segmentation': [[120, 140, 200, 140, 200, 200, 120, 200]]
+ }]
+ }
+
+ with open(ann_file, 'w') as fw:
+ fw.write(json.dumps(ann_info1) + '\n')
+
+
+def _create_dummy_loader():
+ loader = dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineJsonParser',
+ keys=['file_name', 'height', 'width', 'annotations']))
+ return loader
+
+
+def test_detect_dataset():
+ tmp_dir = tempfile.TemporaryDirectory()
+ # create dummy data
+ ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
+ _create_dummy_ann_file(ann_file)
+
+ # test initialization
+ loader = _create_dummy_loader()
+ dataset = TextDetDataset(ann_file, loader, pipeline=[])
+
+ # test _parse_ann_info
+ img_ann_info = dataset.data_infos[0]
+ ann = dataset._parse_anno_info(img_ann_info['annotations'])
+ print(ann['bboxes'])
+ assert np.allclose(ann['bboxes'], [[50., 70., 80., 100.]])
+ assert np.allclose(ann['labels'], [1])
+ assert np.allclose(ann['bboxes_ignore'], [[120, 140, 200, 200]])
+ assert np.allclose(ann['masks'], [[[50, 70, 80, 70, 80, 100, 50, 100]]])
+ assert np.allclose(ann['masks_ignore'],
+ [[[120, 140, 200, 140, 200, 200, 120, 200]]])
+
+ tmp_dir.cleanup()
+
+ # test prepare_train_img
+ pipeline_results = dataset.prepare_train_img(0)
+ assert np.allclose(pipeline_results['bbox_fields'], [])
+ assert np.allclose(pipeline_results['mask_fields'], [])
+ assert np.allclose(pipeline_results['seg_fields'], [])
+ expect_img_info = {'filename': 'sample1.jpg', 'height': 640, 'width': 640}
+ assert pipeline_results['img_info'] == expect_img_info
+
+ # test evluation
+ metrics = 'hmean-iou'
+ results = [{'boundary_result': [[50, 70, 80, 70, 80, 100, 50, 100, 1]]}]
+ eval_res = dataset.evaluate(results, metrics)
+
+ assert eval_res['hmean-iou:hmean'] == 1
diff --git a/tests/test_dataset/test_loader.py b/tests/test_dataset/test_loader.py
new file mode 100644
index 00000000..2d3de5c1
--- /dev/null
+++ b/tests/test_dataset/test_loader.py
@@ -0,0 +1,71 @@
+import json
+import os.path as osp
+import tempfile
+
+import pytest
+from tools.data.utils.txt2lmdb import converter
+
+from mmocr.datasets.utils.loader import HardDiskLoader, LmdbLoader, Loader
+
+
+def _create_dummy_line_str_file(ann_file):
+ ann_info1 = 'sample1.jpg hello'
+ ann_info2 = 'sample2.jpg world'
+
+ with open(ann_file, 'w') as fw:
+ for ann_info in [ann_info1, ann_info2]:
+ fw.write(ann_info + '\n')
+
+
+def _create_dummy_line_json_file(ann_file):
+ ann_info1 = {'filename': 'sample1.jpg', 'text': 'hello'}
+ ann_info2 = {'filename': 'sample2.jpg', 'text': 'world'}
+
+ with open(ann_file, 'w') as fw:
+ for ann_info in [ann_info1, ann_info2]:
+ fw.write(json.dumps(ann_info) + '\n')
+
+
+def test_loader():
+ tmp_dir = tempfile.TemporaryDirectory()
+ # create dummy data
+ ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
+ _create_dummy_line_str_file(ann_file)
+
+ parser = dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')
+
+ with pytest.raises(AssertionError):
+ Loader(ann_file, parser, repeat=0)
+ with pytest.raises(AssertionError):
+ Loader(ann_file, [], repeat=1)
+ with pytest.raises(AssertionError):
+ Loader('sample.txt', parser, repeat=1)
+ with pytest.raises(NotImplementedError):
+ loader = Loader(ann_file, parser, repeat=1)
+ print(loader)
+
+ # test text loader and line str parser
+ text_loader = HardDiskLoader(ann_file, parser, repeat=1)
+ assert len(text_loader) == 2
+ assert text_loader.ori_data_infos[0] == 'sample1.jpg hello'
+ assert text_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'}
+
+ # test text loader and linedict parser
+ _create_dummy_line_json_file(ann_file)
+ json_parser = dict(type='LineJsonParser', keys=['filename', 'text'])
+ text_loader = HardDiskLoader(ann_file, json_parser, repeat=1)
+ assert text_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'}
+
+ # test lmdb loader and line str parser
+ _create_dummy_line_str_file(ann_file)
+ lmdb_file = osp.join(tmp_dir.name, 'fake_data.lmdb')
+ converter(ann_file, lmdb_file)
+
+ lmdb_loader = LmdbLoader(lmdb_file, parser, repeat=1)
+ assert lmdb_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'}
+
+ tmp_dir.cleanup()
diff --git a/tests/test_dataset/test_ocr_dataset.py b/tests/test_dataset/test_ocr_dataset.py
new file mode 100644
index 00000000..1787db88
--- /dev/null
+++ b/tests/test_dataset/test_ocr_dataset.py
@@ -0,0 +1,51 @@
+import math
+import os.path as osp
+import tempfile
+
+from mmocr.datasets.ocr_dataset import OCRDataset
+
+
+def _create_dummy_ann_file(ann_file):
+ ann_info1 = 'sample1.jpg hello'
+ ann_info2 = 'sample2.jpg world'
+
+ with open(ann_file, 'w') as fw:
+ for ann_info in [ann_info1, ann_info2]:
+ fw.write(ann_info + '\n')
+
+
+def _create_dummy_loader():
+ loader = dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(type='LineStrParser', keys=['file_name', 'text']))
+ return loader
+
+
+def test_detect_dataset():
+ tmp_dir = tempfile.TemporaryDirectory()
+ # create dummy data
+ ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
+ _create_dummy_ann_file(ann_file)
+
+ # test initialization
+ loader = _create_dummy_loader()
+ dataset = OCRDataset(ann_file, loader, pipeline=[])
+
+ tmp_dir.cleanup()
+
+ # test pre_pipeline
+ img_info = dataset.data_infos[0]
+ results = dict(img_info=img_info)
+ dataset.pre_pipeline(results)
+ assert results['img_prefix'] == dataset.img_prefix
+ assert results['text'] == img_info['text']
+
+ # test evluation
+ metric = 'acc'
+ results = [{'text': 'hello'}, {'text': 'worl'}]
+ eval_res = dataset.evaluate(results, metric)
+
+ assert math.isclose(eval_res['word_acc'], 0.5, abs_tol=1e-4)
+ assert math.isclose(eval_res['char_precision'], 1.0, abs_tol=1e-4)
+ assert math.isclose(eval_res['char_recall'], 0.9, abs_tol=1e-4)
diff --git a/tests/test_dataset/test_ocr_seg_dataset.py b/tests/test_dataset/test_ocr_seg_dataset.py
new file mode 100644
index 00000000..0ecfcfdf
--- /dev/null
+++ b/tests/test_dataset/test_ocr_seg_dataset.py
@@ -0,0 +1,127 @@
+import json
+import math
+import os.path as osp
+import tempfile
+
+import pytest
+
+from mmocr.datasets.ocr_seg_dataset import OCRSegDataset
+
+
+def _create_dummy_ann_file(ann_file):
+ ann_info1 = {
+ 'file_name':
+ 'sample1.png',
+ 'annotations': [{
+ 'char_text':
+ 'F',
+ 'char_box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0]
+ }, {
+ 'char_text':
+ 'r',
+ 'char_box': [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0]
+ }, {
+ 'char_text':
+ 'o',
+ 'char_box': [33.0, 2.0, 43.0, 2.0, 36.0, 12.0, 25.0, 12.0]
+ }, {
+ 'char_text':
+ 'm',
+ 'char_box': [46.0, 2.0, 61.0, 2.0, 53.0, 12.0, 39.0, 12.0]
+ }, {
+ 'char_text':
+ ':',
+ 'char_box': [61.0, 2.0, 69.0, 2.0, 63.0, 12.0, 55.0, 12.0]
+ }],
+ 'text':
+ 'From:'
+ }
+ ann_info2 = {
+ 'file_name':
+ 'sample2.png',
+ 'annotations': [{
+ 'char_text': 'o',
+ 'char_box': [0.0, 5.0, 7.0, 5.0, 9.0, 15.0, 2.0, 15.0]
+ }, {
+ 'char_text':
+ 'u',
+ 'char_box': [7.0, 4.0, 14.0, 4.0, 18.0, 18.0, 11.0, 18.0]
+ }, {
+ 'char_text':
+ 't',
+ 'char_box': [13.0, 1.0, 19.0, 2.0, 24.0, 18.0, 17.0, 18.0]
+ }],
+ 'text':
+ 'out'
+ }
+
+ with open(ann_file, 'w') as fw:
+ for ann_info in [ann_info1, ann_info2]:
+ fw.write(json.dumps(ann_info) + '\n')
+
+ return ann_info1, ann_info2
+
+
+def _create_dummy_loader():
+ loader = dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineJsonParser', keys=['file_name', 'text', 'annotations']))
+ return loader
+
+
+def test_ocr_seg_dataset():
+ tmp_dir = tempfile.TemporaryDirectory()
+ # create dummy data
+ ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
+ ann_info1, ann_info2 = _create_dummy_ann_file(ann_file)
+
+ # test initialization
+ loader = _create_dummy_loader()
+ dataset = OCRSegDataset(ann_file, loader, pipeline=[])
+
+ tmp_dir.cleanup()
+
+ # test pre_pipeline
+ img_info = dataset.data_infos[0]
+ results = dict(img_info=img_info)
+ dataset.pre_pipeline(results)
+ assert results['img_prefix'] == dataset.img_prefix
+
+ # test _parse_anno_info
+ annos = ann_info1['annotations']
+ with pytest.raises(AssertionError):
+ dataset._parse_anno_info(annos[0])
+ annos2 = ann_info2['annotations']
+ with pytest.raises(AssertionError):
+ dataset._parse_anno_info([{'char_text': 'i'}])
+ with pytest.raises(AssertionError):
+ dataset._parse_anno_info([{'char_box': [1, 2, 3, 4, 5, 6, 7, 8]}])
+ annos2[0]['char_box'] = [1, 2, 3]
+ with pytest.raises(AssertionError):
+ dataset._parse_anno_info(annos2)
+
+ return_anno = dataset._parse_anno_info(annos)
+ assert return_anno['chars'] == ['F', 'r', 'o', 'm', ':']
+ assert len(return_anno['char_rects']) == 5
+
+ # test prepare_train_img
+ expect_results = {
+ 'img_info': {
+ 'filename': 'sample1.png'
+ },
+ 'img_prefix': '',
+ 'ann_info': return_anno
+ }
+ data = dataset.prepare_train_img(0)
+ assert data == expect_results
+
+ # test evluation
+ metric = 'acc'
+ results = [{'text': 'From:'}, {'text': 'ou'}]
+ eval_res = dataset.evaluate(results, metric)
+
+ assert math.isclose(eval_res['word_acc'], 0.5, abs_tol=1e-4)
+ assert math.isclose(eval_res['char_precision'], 1.0, abs_tol=1e-4)
+ assert math.isclose(eval_res['char_recall'], 0.857, abs_tol=1e-4)
diff --git a/tests/test_dataset/test_ocr_seg_target.py b/tests/test_dataset/test_ocr_seg_target.py
new file mode 100644
index 00000000..45b85352
--- /dev/null
+++ b/tests/test_dataset/test_ocr_seg_target.py
@@ -0,0 +1,93 @@
+import os.path as osp
+import tempfile
+
+import numpy as np
+import pytest
+
+from mmocr.datasets.pipelines.ocr_seg_targets import OCRSegTargets
+
+
+def _create_dummy_dict_file(dict_file):
+ chars = list('0123456789')
+ with open(dict_file, 'w') as fw:
+ for char in chars:
+ fw.write(char + '\n')
+
+
+def test_ocr_segm_targets():
+ tmp_dir = tempfile.TemporaryDirectory()
+ # create dummy dict file
+ dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
+ _create_dummy_dict_file(dict_file)
+ # dummy label convertor
+ label_convertor = dict(
+ type='SegConvertor',
+ dict_file=dict_file,
+ with_unknown=True,
+ lower=True)
+ # test init
+ with pytest.raises(AssertionError):
+ OCRSegTargets(None, 0.5, 0.5)
+ with pytest.raises(AssertionError):
+ OCRSegTargets(label_convertor, '1by2', 0.5)
+ with pytest.raises(AssertionError):
+ OCRSegTargets(label_convertor, 0.5, 2)
+
+ ocr_seg_tgt = OCRSegTargets(label_convertor, 0.5, 0.5)
+ # test generate kernels
+ img_size = (8, 8)
+ pad_size = (8, 10)
+ char_boxes = [[2, 2, 6, 6]]
+ char_idxs = [2]
+
+ with pytest.raises(AssertionError):
+ ocr_seg_tgt.generate_kernels(8, pad_size, char_boxes, char_idxs, 0.5,
+ True)
+ with pytest.raises(AssertionError):
+ ocr_seg_tgt.generate_kernels(img_size, pad_size, [2, 2, 6, 6],
+ char_idxs, 0.5, True)
+ with pytest.raises(AssertionError):
+ ocr_seg_tgt.generate_kernels(img_size, pad_size, char_boxes, 2, 0.5,
+ True)
+
+ attn_tgt = ocr_seg_tgt.generate_kernels(
+ img_size, pad_size, char_boxes, char_idxs, 0.5, binary=True)
+ expect_attn_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
+ [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
+ [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
+ [0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
+ [0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
+ [0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
+ [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
+ [0, 0, 0, 0, 0, 0, 0, 0, 255, 255]]
+ assert np.allclose(attn_tgt, np.array(expect_attn_tgt, dtype=np.int32))
+
+ segm_tgt = ocr_seg_tgt.generate_kernels(
+ img_size, pad_size, char_boxes, char_idxs, 0.5, binary=False)
+ expect_segm_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
+ [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
+ [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
+ [0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
+ [0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
+ [0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
+ [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
+ [0, 0, 0, 0, 0, 0, 0, 0, 255, 255]]
+ assert np.allclose(segm_tgt, np.array(expect_segm_tgt, dtype=np.int32))
+
+ # test __call__
+ results = {}
+ results['img_shape'] = (4, 4, 3)
+ results['resize_shape'] = (8, 8, 3)
+ results['pad_shape'] = (8, 10)
+ results['ann_info'] = {}
+ results['ann_info']['char_rects'] = [[1, 1, 3, 3]]
+ results['ann_info']['chars'] = ['1']
+
+ results = ocr_seg_tgt(results)
+ assert results['mask_fields'] == ['gt_kernels']
+ assert np.allclose(results['gt_kernels'].masks[0],
+ np.array(expect_attn_tgt, dtype=np.int32))
+ assert np.allclose(results['gt_kernels'].masks[1],
+ np.array(expect_segm_tgt, dtype=np.int32))
+
+ tmp_dir.cleanup()
diff --git a/tests/test_dataset/test_ocr_transforms.py b/tests/test_dataset/test_ocr_transforms.py
new file mode 100644
index 00000000..a568b908
--- /dev/null
+++ b/tests/test_dataset/test_ocr_transforms.py
@@ -0,0 +1,94 @@
+import math
+import unittest.mock as mock
+
+import numpy as np
+import torch
+import torchvision.transforms.functional as TF
+
+import mmocr.datasets.pipelines.ocr_transforms as transforms
+
+
+def test_resize_ocr():
+ input_img = np.ones((64, 256, 3), dtype=np.uint8)
+
+ rci = transforms.ResizeOCR(
+ 32, min_width=32, max_width=160, keep_aspect_ratio=True)
+ results = {'img_shape': input_img.shape, 'img': input_img}
+
+ # test call
+ results = rci(results)
+ assert np.allclose([32, 160, 3], results['pad_shape'])
+ assert np.allclose([32, 160, 3], results['img'].shape)
+ assert 'valid_ratio' in results
+ assert math.isclose(results['valid_ratio'], 0.8)
+ assert math.isclose(np.sum(results['img'][:, 129:, :]), 0)
+
+ rci = transforms.ResizeOCR(
+ 32, min_width=32, max_width=160, keep_aspect_ratio=False)
+ results = {'img_shape': input_img.shape, 'img': input_img}
+ results = rci(results)
+ assert math.isclose(results['valid_ratio'], 1)
+
+
+def test_to_tensor():
+ input_img = np.ones((64, 256, 3), dtype=np.uint8)
+
+ expect_output = TF.to_tensor(input_img)
+ rci = transforms.ToTensorOCR()
+
+ results = {'img': input_img}
+ results = rci(results)
+
+ assert np.allclose(results['img'].numpy(), expect_output.numpy())
+
+
+def test_normalize():
+ inputs = torch.zeros(3, 10, 10)
+
+ expect_output = torch.ones_like(inputs) * (-1)
+ rci = transforms.NormalizeOCR(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ results = {'img': inputs}
+ results = rci(results)
+
+ assert np.allclose(results['img'].numpy(), expect_output.numpy())
+
+
+@mock.patch('%s.transforms.np.random.random' % __name__)
+def test_online_crop(mock_random):
+ kwargs = dict(
+ box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'],
+ jitter_prob=0.5,
+ max_jitter_ratio_x=0.05,
+ max_jitter_ratio_y=0.02)
+
+ mock_random.side_effect = [0.1, 1, 1, 1]
+
+ src_img = np.ones((100, 100, 3), dtype=np.uint8)
+ results = {
+ 'img': src_img,
+ 'img_info': {
+ 'x1': '20',
+ 'y1': '20',
+ 'x2': '40',
+ 'y2': '20',
+ 'x3': '40',
+ 'y3': '40',
+ 'x4': '20',
+ 'y4': '40'
+ }
+ }
+
+ rci = transforms.OnlineCropOCR(**kwargs)
+
+ results = rci(results)
+
+ assert np.allclose(results['img_shape'], [20, 20, 3])
+
+ # test not crop
+ mock_random.side_effect = [0.1, 1, 1, 1]
+ results['img_info'] = {}
+ results['img'] = src_img
+
+ results = rci(results)
+ assert np.allclose(results['img'].shape, [100, 100, 3])
diff --git a/tests/test_dataset/test_parser.py b/tests/test_dataset/test_parser.py
new file mode 100644
index 00000000..881ac92e
--- /dev/null
+++ b/tests/test_dataset/test_parser.py
@@ -0,0 +1,59 @@
+import json
+
+import pytest
+
+from mmocr.datasets.utils.parser import LineJsonParser, LineStrParser
+
+
+def test_line_str_parser():
+ data_ret = ['sample1.jpg hello', 'sample2.jpg world']
+ keys = ['filename', 'text']
+ keys_idx = [0, 1]
+ separator = ' '
+
+ # test init
+ with pytest.raises(AssertionError):
+ parser = LineStrParser('filename', keys_idx, separator)
+ with pytest.raises(AssertionError):
+ parser = LineStrParser(keys, keys_idx, [' '])
+ with pytest.raises(AssertionError):
+ parser = LineStrParser(keys, [0], separator)
+
+ # test get_item
+ parser = LineStrParser(keys, keys_idx, separator)
+ assert parser.get_item(data_ret, 0) == \
+ {'filename': 'sample1.jpg', 'text': 'hello'}
+
+ with pytest.raises(Exception):
+ parser = LineStrParser(['filename', 'text', 'ignore'], [0, 1, 2],
+ separator)
+ parser.get_item(data_ret, 0)
+
+
+def test_line_dict_parser():
+ data_ret = [
+ json.dumps({
+ 'filename': 'sample1.jpg',
+ 'text': 'hello'
+ }),
+ json.dumps({
+ 'filename': 'sample2.jpg',
+ 'text': 'world'
+ })
+ ]
+ keys = ['filename', 'text']
+
+ # test init
+ with pytest.raises(AssertionError):
+ parser = LineJsonParser('filename')
+ with pytest.raises(AssertionError):
+ parser = LineJsonParser([])
+
+ # test get_item
+ parser = LineJsonParser(keys)
+ assert parser.get_item(data_ret, 0) == \
+ {'filename': 'sample1.jpg', 'text': 'hello'}
+
+ with pytest.raises(Exception):
+ parser = LineJsonParser(['img_name', 'text'])
+ parser.get_item(data_ret, 0)
diff --git a/tests/test_dataset/test_test_time_aug.py b/tests/test_dataset/test_test_time_aug.py
new file mode 100644
index 00000000..22bf80c6
--- /dev/null
+++ b/tests/test_dataset/test_test_time_aug.py
@@ -0,0 +1,33 @@
+import numpy as np
+import pytest
+
+from mmocr.datasets.pipelines.test_time_aug import MultiRotateAugOCR
+
+
+def test_resize_ocr():
+ input_img1 = np.ones((64, 256, 3), dtype=np.uint8)
+ input_img2 = np.ones((64, 32, 3), dtype=np.uint8)
+
+ rci = MultiRotateAugOCR(transforms=[], rotate_degrees=[0, 90, 270])
+
+ # test invalid arguments
+ with pytest.raises(AssertionError):
+ MultiRotateAugOCR(transforms=[], rotate_degrees=[45])
+ with pytest.raises(AssertionError):
+ MultiRotateAugOCR(transforms=[], rotate_degrees=[20.5])
+
+ # test call with input_img1
+ results = {'img_shape': input_img1.shape, 'img': input_img1}
+ results = rci(results)
+ assert np.allclose([64, 256, 3], results['img_shape'])
+ assert len(results['img']) == 1
+ assert len(results['img_shape']) == 1
+ assert np.allclose([64, 256, 3], results['img_shape'][0])
+
+ # test call with input_img2
+ results = {'img_shape': input_img2.shape, 'img': input_img2}
+ results = rci(results)
+ assert np.allclose([64, 32, 3], results['img_shape'])
+ assert len(results['img']) == 3
+ assert len(results['img_shape']) == 3
+ assert np.allclose([64, 32, 3], results['img_shape'][0])
diff --git a/tests/test_models/test_label_convertor/test_attn_label_convertor.py b/tests/test_models/test_label_convertor/test_attn_label_convertor.py
new file mode 100644
index 00000000..00eaeacc
--- /dev/null
+++ b/tests/test_models/test_label_convertor/test_attn_label_convertor.py
@@ -0,0 +1,77 @@
+import os.path as osp
+import tempfile
+
+import numpy as np
+import pytest
+import torch
+
+from mmocr.models.textrecog.convertors import AttnConvertor
+
+
+def _create_dummy_dict_file(dict_file):
+ characters = list('helowrd')
+ with open(dict_file, 'w') as fw:
+ for char in characters:
+ fw.write(char + '\n')
+
+
+def test_attn_label_convertor():
+ tmp_dir = tempfile.TemporaryDirectory()
+ # create dummy data
+ dict_file = osp.join(tmp_dir.name, 'fake_dict.txt')
+ _create_dummy_dict_file(dict_file)
+
+ # test invalid arguments
+ with pytest.raises(AssertionError):
+ AttnConvertor(5)
+ with pytest.raises(AssertionError):
+ AttnConvertor('DICT90', dict_file, '1')
+ with pytest.raises(AssertionError):
+ AttnConvertor('DICT90', dict_file, True, '1')
+
+ label_convertor = AttnConvertor(dict_file=dict_file, max_seq_len=10)
+ # test init and parse_dict
+ assert label_convertor.num_classes() == 10
+ assert len(label_convertor.idx2char) == 10
+ assert label_convertor.idx2char[0] == 'h'
+ assert label_convertor.idx2char[1] == 'e'
+ assert label_convertor.idx2char[-3] == ''
+ assert label_convertor.char2idx['h'] == 0
+ assert label_convertor.unknown_idx == 7
+
+ # test encode str to tensor
+ strings = ['hell']
+ targets_dict = label_convertor.str2tensor(strings)
+ assert torch.allclose(targets_dict['targets'][0],
+ torch.LongTensor([0, 1, 2, 2]))
+ assert torch.allclose(targets_dict['padded_targets'][0],
+ torch.LongTensor([8, 0, 1, 2, 2, 8, 9, 9, 9, 9]))
+
+ # test decode output to index
+ dummy_output = torch.Tensor([[[100, 2, 3, 4, 5, 6, 7, 8, 9],
+ [1, 100, 3, 4, 5, 6, 7, 8, 9],
+ [1, 2, 100, 4, 5, 6, 7, 8, 9],
+ [1, 2, 100, 4, 5, 6, 7, 8, 9],
+ [1, 2, 3, 4, 5, 6, 7, 8, 100],
+ [1, 2, 3, 4, 5, 6, 7, 100, 9],
+ [1, 2, 3, 4, 5, 6, 7, 100, 9],
+ [1, 2, 3, 4, 5, 6, 7, 100, 9],
+ [1, 2, 3, 4, 5, 6, 7, 100, 9],
+ [1, 2, 3, 4, 5, 6, 7, 100, 9]]])
+ indexes, scores = label_convertor.tensor2idx(dummy_output)
+ assert np.allclose(indexes, [[0, 1, 2, 2]])
+
+ # test encode_str_label_to_index
+ with pytest.raises(AssertionError):
+ label_convertor.str2idx('hell')
+ tmp_indexes = label_convertor.str2idx(strings)
+ assert np.allclose(tmp_indexes, [[0, 1, 2, 2]])
+
+ # test decode_index to str_label
+ input_indexes = [[0, 1, 2, 2]]
+ with pytest.raises(AssertionError):
+ label_convertor.idx2str('hell')
+ output_strings = label_convertor.idx2str(input_indexes)
+ assert output_strings[0] == 'hell'
+
+ tmp_dir.cleanup()
diff --git a/tests/test_models/test_label_convertor/test_ctc_label_convertor.py b/tests/test_models/test_label_convertor/test_ctc_label_convertor.py
new file mode 100644
index 00000000..07c9cbf0
--- /dev/null
+++ b/tests/test_models/test_label_convertor/test_ctc_label_convertor.py
@@ -0,0 +1,79 @@
+import os.path as osp
+import tempfile
+
+import numpy as np
+import pytest
+import torch
+
+from mmocr.models.textrecog.convertors import BaseConvertor, CTCConvertor
+
+
+def _create_dummy_dict_file(dict_file):
+ chars = list('helowrd')
+ with open(dict_file, 'w') as fw:
+ for char in chars:
+ fw.write(char + '\n')
+
+
+def test_ctc_label_convertor():
+ tmp_dir = tempfile.TemporaryDirectory()
+ # create dummy data
+ dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
+ _create_dummy_dict_file(dict_file)
+
+ # test invalid arguments
+ with pytest.raises(AssertionError):
+ CTCConvertor(5)
+
+ label_convertor = CTCConvertor(dict_file=dict_file, with_unknown=False)
+ # test init and parse_chars
+ assert label_convertor.num_classes() == 8
+ assert len(label_convertor.idx2char) == 8
+ assert label_convertor.idx2char[0] == ''
+ assert label_convertor.char2idx['h'] == 1
+ assert label_convertor.unknown_idx is None
+
+ # test encode str to tensor
+ strings = ['hell']
+ expect_tensor = torch.IntTensor([1, 2, 3, 3])
+ targets_dict = label_convertor.str2tensor(strings)
+ assert torch.allclose(targets_dict['targets'][0], expect_tensor)
+ assert torch.allclose(targets_dict['flatten_targets'], expect_tensor)
+ assert torch.allclose(targets_dict['target_lengths'], torch.IntTensor([4]))
+
+ # test decode output to index
+ dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8],
+ [100, 2, 3, 4, 5, 6, 7, 8],
+ [1, 2, 100, 4, 5, 6, 7, 8],
+ [1, 2, 100, 4, 5, 6, 7, 8],
+ [100, 2, 3, 4, 5, 6, 7, 8],
+ [1, 2, 3, 100, 5, 6, 7, 8],
+ [100, 2, 3, 4, 5, 6, 7, 8],
+ [1, 2, 3, 100, 5, 6, 7, 8]]])
+ indexes, scores = label_convertor.tensor2idx(
+ dummy_output, img_metas=[{
+ 'valid_ratio': 1.0
+ }])
+ assert np.allclose(indexes, [[1, 2, 3, 3]])
+
+ # test encode_str_label_to_index
+ with pytest.raises(AssertionError):
+ label_convertor.str2idx('hell')
+ tmp_indexes = label_convertor.str2idx(strings)
+ assert np.allclose(tmp_indexes, [[1, 2, 3, 3]])
+
+ # test deocde_index_to_str_label
+ input_indexes = [[1, 2, 3, 3]]
+ with pytest.raises(AssertionError):
+ label_convertor.idx2str('hell')
+ output_strings = label_convertor.idx2str(input_indexes)
+ assert output_strings[0] == 'hell'
+
+ tmp_dir.cleanup()
+
+
+def test_base_label_convertor():
+ with pytest.raises(NotImplementedError):
+ label_convertor = BaseConvertor()
+ label_convertor.str2tensor(None)
+ label_convertor.tensor2idx(None)
diff --git a/tests/test_models/test_ocr_backbone.py b/tests/test_models/test_ocr_backbone.py
new file mode 100644
index 00000000..f49a334d
--- /dev/null
+++ b/tests/test_models/test_ocr_backbone.py
@@ -0,0 +1,36 @@
+import pytest
+import torch
+
+from mmocr.models.textrecog.backbones import ResNet31OCR, VeryDeepVgg
+
+
+def test_resnet31_ocr_backbone():
+ """Test resnet backbone."""
+ with pytest.raises(AssertionError):
+ ResNet31OCR(2.5)
+
+ with pytest.raises(AssertionError):
+ ResNet31OCR(3, layers=5)
+
+ with pytest.raises(AssertionError):
+ ResNet31OCR(3, channels=5)
+
+ # Test ResNet18 forward
+ model = ResNet31OCR()
+ model.init_weights()
+ model.train()
+
+ imgs = torch.randn(1, 3, 32, 160)
+ feat = model(imgs)
+ assert feat.shape == torch.Size([1, 512, 4, 40])
+
+
+def test_vgg_deep_vgg_ocr_backbone():
+
+ model = VeryDeepVgg()
+ model.init_weights()
+ model.train()
+
+ imgs = torch.randn(1, 3, 32, 160)
+ feats = model(imgs)
+ assert feats.shape == torch.Size([1, 512, 1, 41])
diff --git a/tests/test_models/test_ocr_decoder.py b/tests/test_models/test_ocr_decoder.py
new file mode 100644
index 00000000..9a063e26
--- /dev/null
+++ b/tests/test_models/test_ocr_decoder.py
@@ -0,0 +1,112 @@
+import math
+
+import pytest
+import torch
+
+from mmocr.models.textrecog.decoders import (BaseDecoder, ParallelSARDecoder,
+ ParallelSARDecoderWithBS,
+ SequentialSARDecoder, TFDecoder)
+from mmocr.models.textrecog.decoders.sar_decoder_with_bs import DecodeNode
+
+
+def _create_dummy_input():
+ feat = torch.rand(1, 512, 4, 40)
+ out_enc = torch.rand(1, 512)
+ tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])}
+ img_metas = [{'valid_ratio': 1.0}]
+
+ return feat, out_enc, tgt_dict, img_metas
+
+
+def test_base_decoder():
+ decoder = BaseDecoder()
+ with pytest.raises(NotImplementedError):
+ decoder.forward_train(None, None, None, None)
+ with pytest.raises(NotImplementedError):
+ decoder.forward_test(None, None, None)
+
+
+def test_parallel_sar_decoder():
+ # test parallel sar decoder
+ decoder = ParallelSARDecoder(num_classes=37, padding_idx=36, max_seq_len=5)
+ decoder.init_weights()
+ decoder.train()
+
+ feat, out_enc, tgt_dict, img_metas = _create_dummy_input()
+ with pytest.raises(AssertionError):
+ decoder(feat, out_enc, tgt_dict, [], True)
+ with pytest.raises(AssertionError):
+ decoder(feat, out_enc, tgt_dict, img_metas * 2, True)
+
+ out_train = decoder(feat, out_enc, tgt_dict, img_metas, True)
+ assert out_train.shape == torch.Size([1, 5, 36])
+
+ out_test = decoder(feat, out_enc, tgt_dict, img_metas, False)
+ assert out_test.shape == torch.Size([1, 5, 36])
+
+
+def test_sequential_sar_decoder():
+ # test parallel sar decoder
+ decoder = SequentialSARDecoder(
+ num_classes=37, padding_idx=36, max_seq_len=5)
+ decoder.init_weights()
+ decoder.train()
+
+ feat, out_enc, tgt_dict, img_metas = _create_dummy_input()
+ with pytest.raises(AssertionError):
+ decoder(feat, out_enc, tgt_dict, [])
+ with pytest.raises(AssertionError):
+ decoder(feat, out_enc, tgt_dict, img_metas * 2)
+
+ out_train = decoder(feat, out_enc, tgt_dict, img_metas, True)
+ assert out_train.shape == torch.Size([1, 5, 36])
+
+ out_test = decoder(feat, out_enc, tgt_dict, img_metas, False)
+ assert out_test.shape == torch.Size([1, 5, 36])
+
+
+def test_parallel_sar_decoder_with_beam_search():
+ with pytest.raises(AssertionError):
+ ParallelSARDecoderWithBS(beam_width='beam')
+ with pytest.raises(AssertionError):
+ ParallelSARDecoderWithBS(beam_width=0)
+
+ feat, out_enc, tgt_dict, img_metas = _create_dummy_input()
+ decoder = ParallelSARDecoderWithBS(
+ beam_width=1, num_classes=37, padding_idx=36, max_seq_len=5)
+ decoder.init_weights()
+ decoder.train()
+ with pytest.raises(AssertionError):
+ decoder(feat, out_enc, tgt_dict, [])
+ with pytest.raises(AssertionError):
+ decoder(feat, out_enc, tgt_dict, img_metas * 2)
+
+ out_test = decoder(feat, out_enc, tgt_dict, img_metas, train_mode=False)
+ assert out_test.shape == torch.Size([1, 5, 36])
+
+ # test decodenode
+ with pytest.raises(AssertionError):
+ DecodeNode(1, 1)
+ with pytest.raises(AssertionError):
+ DecodeNode([1, 2], ['4', '3'])
+ with pytest.raises(AssertionError):
+ DecodeNode([1, 2], [0.5])
+ decode_node = DecodeNode([1, 2], [0.7, 0.8])
+ assert math.isclose(decode_node.eval(), 1.5)
+
+
+def test_transformer_decoder():
+ decoder = TFDecoder(num_classes=37, padding_idx=36, max_seq_len=5)
+ decoder.init_weights()
+ decoder.train()
+
+ out_enc = torch.rand(1, 128, 512)
+ tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])}
+ img_metas = [{'valid_ratio': 1.0}]
+ tgt_dict['padded_targets'] = tgt_dict['padded_targets']
+
+ out_train = decoder(None, out_enc, tgt_dict, img_metas, True)
+ assert out_train.shape == torch.Size([1, 5, 36])
+
+ out_test = decoder(None, out_enc, tgt_dict, img_metas, False)
+ assert out_test.shape == torch.Size([1, 5, 36])
diff --git a/tests/test_models/test_ocr_encoder.py b/tests/test_models/test_ocr_encoder.py
new file mode 100644
index 00000000..c5a26667
--- /dev/null
+++ b/tests/test_models/test_ocr_encoder.py
@@ -0,0 +1,53 @@
+import pytest
+import torch
+
+from mmocr.models.textrecog.encoders import BaseEncoder, SAREncoder, TFEncoder
+
+
+def test_sar_encoder():
+ with pytest.raises(AssertionError):
+ SAREncoder(enc_bi_rnn='bi')
+ with pytest.raises(AssertionError):
+ SAREncoder(enc_do_rnn=2)
+ with pytest.raises(AssertionError):
+ SAREncoder(enc_gru='gru')
+ with pytest.raises(AssertionError):
+ SAREncoder(d_model=512.5)
+ with pytest.raises(AssertionError):
+ SAREncoder(d_enc=200.5)
+ with pytest.raises(AssertionError):
+ SAREncoder(mask='mask')
+
+ encoder = SAREncoder()
+ encoder.init_weights()
+ encoder.train()
+
+ feat = torch.randn(1, 512, 4, 40)
+ with pytest.raises(AssertionError):
+ encoder(feat)
+ img_metas = [{'valid_ratio': 1.0}]
+ with pytest.raises(AssertionError):
+ encoder(feat, img_metas * 2)
+ out_enc = encoder(feat, img_metas)
+
+ assert out_enc.shape == torch.Size([1, 512])
+
+
+def test_transformer_encoder():
+ tf_encoder = TFEncoder()
+ tf_encoder.init_weights()
+ tf_encoder.train()
+
+ feat = torch.randn(1, 512, 4, 40)
+ out_enc = tf_encoder(feat)
+ assert out_enc.shape == torch.Size([1, 160, 512])
+
+
+def test_base_encoder():
+ encoder = BaseEncoder()
+ encoder.init_weights()
+ encoder.train()
+
+ feat = torch.randn(1, 256, 4, 40)
+ out_enc = encoder(feat)
+ assert out_enc.shape == torch.Size([1, 256, 4, 40])
diff --git a/tests/test_models/test_ocr_head.py b/tests/test_models/test_ocr_head.py
new file mode 100644
index 00000000..52761405
--- /dev/null
+++ b/tests/test_models/test_ocr_head.py
@@ -0,0 +1,16 @@
+import pytest
+import torch
+
+from mmocr.models.textrecog import SegHead
+
+
+def test_cafcn_head():
+ with pytest.raises(AssertionError):
+ SegHead(num_classes='100')
+ with pytest.raises(AssertionError):
+ SegHead(num_classes=-1)
+
+ cafcn_head = SegHead(num_classes=37)
+ out_neck = (torch.rand(1, 128, 32, 32), )
+ out_head = cafcn_head(out_neck)
+ assert out_head.shape == torch.Size([1, 37, 32, 32])
diff --git a/tests/test_models/test_ocr_layer.py b/tests/test_models/test_ocr_layer.py
new file mode 100644
index 00000000..da96273d
--- /dev/null
+++ b/tests/test_models/test_ocr_layer.py
@@ -0,0 +1,55 @@
+import torch
+
+from mmocr.models.textrecog.layers import (BasicBlock, Bottleneck,
+ DecoderLayer, PositionalEncoding,
+ get_pad_mask, get_subsequent_mask)
+from mmocr.models.textrecog.layers.conv_layer import conv3x3
+
+
+def test_conv_layer():
+ conv3by3 = conv3x3(3, 6)
+ assert conv3by3.in_channels == 3
+ assert conv3by3.out_channels == 6
+ assert conv3by3.kernel_size == (3, 3)
+
+ x = torch.rand(1, 64, 224, 224)
+ # test basic block
+ basic_block = BasicBlock(64, 64)
+ assert basic_block.expansion == 1
+
+ out = basic_block(x)
+
+ assert out.shape == torch.Size([1, 64, 224, 224])
+
+ # test bottle neck
+ bottle_neck = Bottleneck(64, 64, downsample=True)
+ assert bottle_neck.expansion == 4
+
+ out = bottle_neck(x)
+
+ assert out.shape == torch.Size([1, 256, 224, 224])
+
+
+def test_transformer_layer():
+ # test decoder_layer
+ decoder_layer = DecoderLayer()
+ in_dec = torch.rand(1, 30, 512)
+ out_enc = torch.rand(1, 128, 512)
+ out_dec = decoder_layer(in_dec, out_enc)
+ assert out_dec.shape == torch.Size([1, 30, 512])
+
+ # test positional_encoding
+ pos_encoder = PositionalEncoding()
+ x = torch.rand(1, 30, 512)
+ out = pos_encoder(x)
+ assert out.size() == x.size()
+
+ # test get pad mask
+ seq = torch.rand(1, 30)
+ pad_idx = 0
+ out = get_pad_mask(seq, pad_idx)
+ assert out.shape == torch.Size([1, 1, 30])
+
+ # test get_subsequent_mask
+ out_mask = get_subsequent_mask(seq)
+ assert out_mask.shape == torch.Size([1, 30, 30])
diff --git a/tests/test_models/test_ocr_loss.py b/tests/test_models/test_ocr_loss.py
new file mode 100644
index 00000000..fa118f5e
--- /dev/null
+++ b/tests/test_models/test_ocr_loss.py
@@ -0,0 +1,123 @@
+import numpy as np
+import pytest
+import torch
+
+from mmdet.core import BitmapMasks
+from mmocr.models.common.losses import DiceLoss
+from mmocr.models.textrecog.losses import (CAFCNLoss, CELoss, CTCLoss, SARLoss,
+ TFLoss)
+
+
+def test_ctc_loss():
+ # test CTCLoss
+ ctc_loss = CTCLoss()
+ outputs = torch.zeros(2, 40, 37)
+ targets_dict = {
+ 'flatten_targets': torch.IntTensor([1, 2, 3, 4, 5]),
+ 'target_lengths': torch.LongTensor([2, 3])
+ }
+
+ losses = ctc_loss(outputs, targets_dict)
+ assert isinstance(losses, dict)
+ assert 'loss_ctc' in losses
+ assert torch.allclose(losses['loss_ctc'],
+ torch.tensor(losses['loss_ctc'].item()).float())
+
+
+def test_ce_loss():
+ with pytest.raises(AssertionError):
+ CELoss(ignore_index='ignore')
+ with pytest.raises(AssertionError):
+ CELoss(reduction=1)
+ with pytest.raises(AssertionError):
+ CELoss(reduction='avg')
+
+ ce_loss = CELoss(ignore_index=0)
+ outputs = torch.rand(1, 10, 37)
+ targets_dict = {
+ 'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
+ }
+ losses = ce_loss(outputs, targets_dict)
+ assert isinstance(losses, dict)
+ assert 'loss_ce' in losses
+ print(losses['loss_ce'].size())
+ assert losses['loss_ce'].size(1) == 10
+
+
+def test_sar_loss():
+ outputs = torch.rand(1, 10, 37)
+ targets_dict = {
+ 'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
+ }
+ sar_loss = SARLoss()
+ new_output, new_target = sar_loss.format(outputs, targets_dict)
+ assert new_output.shape == torch.Size([1, 37, 9])
+ assert new_target.shape == torch.Size([1, 9])
+
+
+def test_tf_loss():
+ with pytest.raises(AssertionError):
+ TFLoss(flatten=1.0)
+
+ outputs = torch.rand(1, 10, 37)
+ targets_dict = {
+ 'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
+ }
+ tf_loss = TFLoss(flatten=False)
+ new_output, new_target = tf_loss.format(outputs, targets_dict)
+ assert new_output.shape == torch.Size([1, 37, 9])
+ assert new_target.shape == torch.Size([1, 9])
+
+
+def test_cafcn_loss():
+ with pytest.raises(AssertionError):
+ CAFCNLoss(alpha='1')
+ with pytest.raises(AssertionError):
+ CAFCNLoss(attn_s2_downsample_ratio='2')
+ with pytest.raises(AssertionError):
+ CAFCNLoss(attn_s3_downsample_ratio='1.5')
+ with pytest.raises(AssertionError):
+ CAFCNLoss(seg_downsample_ratio='1.5')
+ with pytest.raises(AssertionError):
+ CAFCNLoss(attn_s2_downsample_ratio=2)
+ with pytest.raises(AssertionError):
+ CAFCNLoss(attn_s3_downsample_ratio=1.5)
+ with pytest.raises(AssertionError):
+ CAFCNLoss(seg_downsample_ratio=1.5)
+
+ bsz = 1
+ H = W = 64
+ out_neck = (torch.ones(bsz, 1, H // 4, W // 4) * 0.5,
+ torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
+ torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
+ torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
+ torch.ones(bsz, 1, H // 2, W // 2) * 0.5)
+ out_head = torch.rand(bsz, 37, H // 2, W // 2)
+
+ attn_tgt = np.zeros((H, W), dtype=np.float32)
+ segm_tgt = np.zeros((H, W), dtype=np.float32)
+ mask = np.ones((H, W), dtype=np.float32)
+ gt_kernels = BitmapMasks([attn_tgt, segm_tgt, mask], H, W)
+
+ cafcn_loss = CAFCNLoss()
+ losses = cafcn_loss(out_neck, out_head, [gt_kernels])
+ assert isinstance(losses, dict)
+ assert 'loss_seg' in losses
+ assert torch.allclose(losses['loss_seg'],
+ torch.tensor(losses['loss_seg'].item()).float())
+
+
+def test_dice_loss():
+ with pytest.raises(AssertionError):
+ DiceLoss(eps='1')
+
+ dice_loss = DiceLoss()
+ pred = torch.rand(1, 1, 32, 32)
+ gt = torch.rand(1, 1, 32, 32)
+
+ loss = dice_loss(pred, gt, None)
+ assert isinstance(loss, torch.Tensor)
+
+ mask = torch.rand(1, 1, 1, 1)
+ loss = dice_loss(pred, gt, mask)
+ assert isinstance(loss, torch.Tensor)
diff --git a/tests/test_models/test_ocr_neck.py b/tests/test_models/test_ocr_neck.py
new file mode 100644
index 00000000..8af3e971
--- /dev/null
+++ b/tests/test_models/test_ocr_neck.py
@@ -0,0 +1,48 @@
+import pytest
+import torch
+
+from mmocr.models.textrecog.necks.cafcn_neck import (CAFCNNeck, CharAttn,
+ FeatGenerator)
+
+
+def test_char_attn():
+ with pytest.raises(AssertionError):
+ CharAttn(in_channels=5.0)
+ with pytest.raises(AssertionError):
+ CharAttn(deformable='deformabel')
+
+ in_feat = torch.rand(1, 128, 32, 32)
+ char_attn = CharAttn()
+ out_feat_map, attn_map = char_attn(in_feat)
+ assert attn_map.shape == torch.Size([1, 1, 32, 32])
+ assert out_feat_map.shape == torch.Size([1, 128, 32, 32])
+
+
+@pytest.mark.skip(reason='TODO: re-enable after CI support pytorch>1.4')
+def test_feat_generator():
+ in_feat = torch.rand(1, 128, 32, 32)
+ feat_generator = FeatGenerator(in_channels=128, out_channels=128)
+
+ attn_map, feat_map = feat_generator(in_feat)
+ assert attn_map.shape == torch.Size([1, 1, 32, 32])
+ assert feat_map.shape == torch.Size([1, 128, 32, 32])
+
+
+@pytest.mark.skip(reason='TODO: re-enable after CI support pytorch>1.4')
+def test_cafcn_neck():
+ in_s1 = torch.rand(1, 64, 64, 64)
+ in_s2 = torch.rand(1, 128, 32, 32)
+ in_s3 = torch.rand(1, 256, 16, 16)
+ in_s4 = torch.rand(1, 512, 16, 16)
+ in_s5 = torch.rand(1, 512, 16, 16)
+
+ cafcn_neck = CAFCNNeck()
+ cafcn_neck.init_weights()
+ cafcn_neck.train()
+
+ out_neck = cafcn_neck((in_s1, in_s2, in_s3, in_s4, in_s5))
+ assert out_neck[0].shape == torch.Size([1, 1, 32, 32])
+ assert out_neck[1].shape == torch.Size([1, 1, 16, 16])
+ assert out_neck[2].shape == torch.Size([1, 1, 16, 16])
+ assert out_neck[3].shape == torch.Size([1, 1, 16, 16])
+ assert out_neck[4].shape == torch.Size([1, 128, 64, 64])
diff --git a/tests/test_models/test_recognizer.py b/tests/test_models/test_recognizer.py
new file mode 100644
index 00000000..54f66ca6
--- /dev/null
+++ b/tests/test_models/test_recognizer.py
@@ -0,0 +1,156 @@
+import os.path as osp
+import tempfile
+
+import numpy as np
+import pytest
+import torch
+
+from mmdet.core import BitmapMasks
+from mmocr.models.textrecog.recognizer import (EncodeDecodeRecognizer,
+ SegRecognizer)
+
+
+def _create_dummy_dict_file(dict_file):
+ chars = list('helowrd')
+ with open(dict_file, 'w') as fw:
+ for char in chars:
+ fw.write(char + '\n')
+
+
+def test_base_recognizer():
+ tmp_dir = tempfile.TemporaryDirectory()
+ # create dummy data
+ dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
+ _create_dummy_dict_file(dict_file)
+
+ label_convertor = dict(
+ type='CTCConvertor', dict_file=dict_file, with_unknown=False)
+
+ preprocessor = None
+ backbone = dict(type='VeryDeepVgg', leakyRelu=False)
+ encoder = None
+ decoder = dict(type='CRNNDecoder', in_channels=512, rnn_flag=True)
+ loss = dict(type='CTCLoss')
+
+ with pytest.raises(AssertionError):
+ EncodeDecodeRecognizer(backbone=None)
+ with pytest.raises(AssertionError):
+ EncodeDecodeRecognizer(decoder=None)
+ with pytest.raises(AssertionError):
+ EncodeDecodeRecognizer(loss=None)
+ with pytest.raises(AssertionError):
+ EncodeDecodeRecognizer(label_convertor=None)
+
+ recognizer = EncodeDecodeRecognizer(
+ preprocessor=preprocessor,
+ backbone=backbone,
+ encoder=encoder,
+ decoder=decoder,
+ loss=loss,
+ label_convertor=label_convertor)
+
+ recognizer.init_weights()
+ recognizer.train()
+
+ imgs = torch.rand(1, 3, 32, 160)
+
+ # test extract feat
+ feat = recognizer.extract_feat(imgs)
+ assert feat.shape == torch.Size([1, 512, 1, 41])
+
+ # test forward train
+ img_metas = [{'text': 'hello', 'valid_ratio': 1.0}]
+ losses = recognizer.forward_train(imgs, img_metas)
+ assert isinstance(losses, dict)
+ assert 'loss_ctc' in losses
+
+ # test simple test
+ results = recognizer.simple_test(imgs, img_metas)
+ assert isinstance(results, list)
+ assert isinstance(results[0], dict)
+ assert 'text' in results[0]
+ assert 'score' in results[0]
+
+ # test aug_test
+ aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas])
+ assert isinstance(aug_results, list)
+ assert isinstance(aug_results[0], dict)
+ assert 'text' in aug_results[0]
+ assert 'score' in aug_results[0]
+
+ tmp_dir.cleanup()
+
+
+@pytest.mark.skip(reason='TODO: re-enable after CI support pytorch>1.4')
+def test_seg_recognizer():
+ tmp_dir = tempfile.TemporaryDirectory()
+ # create dummy data
+ dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
+ _create_dummy_dict_file(dict_file)
+
+ label_convertor = dict(
+ type='SegConvertor', dict_file=dict_file, with_unknown=False)
+
+ preprocessor = None
+ backbone = dict(type='ResNet31OCR')
+ neck = dict(type='FPNOCR')
+ head = dict(type='SegHead')
+ loss = dict(type='SegLoss')
+
+ with pytest.raises(AssertionError):
+ SegRecognizer(backbone=None)
+ with pytest.raises(AssertionError):
+ SegRecognizer(neck=None)
+ with pytest.raises(AssertionError):
+ SegRecognizer(head=None)
+ with pytest.raises(AssertionError):
+ SegRecognizer(loss=None)
+ with pytest.raises(AssertionError):
+ SegRecognizer(label_convertor=None)
+
+ recognizer = SegRecognizer(
+ preprocessor=preprocessor,
+ backbone=backbone,
+ neck=neck,
+ head=head,
+ loss=loss,
+ label_convertor=label_convertor)
+
+ recognizer.init_weights()
+ recognizer.train()
+
+ imgs = torch.rand(1, 3, 64, 256)
+
+ # test extract feat
+ feats = recognizer.extract_feat(imgs)
+ assert len(feats) == 5
+ assert feats[0].shape == torch.Size([1, 64, 32, 128])
+ assert feats[1].shape == torch.Size([1, 128, 16, 64])
+ assert feats[2].shape == torch.Size([1, 256, 8, 32])
+ assert feats[3].shape == torch.Size([1, 512, 8, 32])
+ assert feats[4].shape == torch.Size([1, 512, 8, 32])
+
+ attn_tgt = np.zeros((64, 256), dtype=np.float32)
+ segm_tgt = np.zeros((64, 256), dtype=np.float32)
+ gt_kernels = BitmapMasks([attn_tgt, segm_tgt], 64, 256)
+
+ # test forward train
+ img_metas = [{'text': 'hello', 'valid_ratio': 1.0}]
+ losses = recognizer.forward_train(imgs, img_metas, gt_kernels=[gt_kernels])
+ assert isinstance(losses, dict)
+
+ # test simple test
+ results = recognizer.simple_test(imgs, img_metas)
+ assert isinstance(results, list)
+ assert isinstance(results[0], dict)
+ assert 'text' in results[0]
+ assert 'score' in results[0]
+
+ # test aug_test
+ aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas])
+ assert isinstance(aug_results, list)
+ assert isinstance(aug_results[0], dict)
+ assert 'text' in aug_results[0]
+ assert 'score' in aug_results[0]
+
+ tmp_dir.cleanup()
diff --git a/tests/test_utils/test_text/test_text_utils.py b/tests/test_utils/test_text/test_text_utils.py
new file mode 100644
index 00000000..a1d4a50f
--- /dev/null
+++ b/tests/test_utils/test_text/test_text_utils.py
@@ -0,0 +1,66 @@
+"""Test text label visualize."""
+import os.path as osp
+import random
+import tempfile
+from unittest import mock
+
+import numpy as np
+import pytest
+
+import mmocr.core.visualize as visualize_utils
+
+
+def test_tile_image():
+ dummp_imgs, heights, widths = [], [], []
+ for _ in range(3):
+ h = random.randint(100, 300)
+ w = random.randint(100, 300)
+ heights.append(h)
+ widths.append(w)
+ # dummy_img = Image.new('RGB', (w, h), Image.ANTIALIAS)
+ dummy_img = np.ones((h, w, 3), dtype=np.uint8)
+ dummp_imgs.append(dummy_img)
+ joint_img = visualize_utils.tile_image(dummp_imgs)
+ assert joint_img.shape[0] == sum(heights)
+ assert joint_img.shape[1] == max(widths)
+
+ # test invalid arguments
+ with pytest.raises(AssertionError):
+ visualize_utils.tile_image(dummp_imgs[0])
+ with pytest.raises(AssertionError):
+ visualize_utils.tile_image([])
+
+
+@mock.patch('%s.visualize_utils.mmcv.imread' % __name__)
+@mock.patch('%s.visualize_utils.mmcv.imshow' % __name__)
+@mock.patch('%s.visualize_utils.mmcv.imwrite' % __name__)
+def test_show_text_label(mock_imwrite, mock_imshow, mock_imread):
+ img = np.ones((32, 160), dtype=np.uint8)
+ pred_label = 'hello'
+ gt_label = 'world'
+
+ tmp_dir = tempfile.TemporaryDirectory()
+ out_file = osp.join(tmp_dir.name, 'tmp.jpg')
+
+ # test invalid arguments
+ with pytest.raises(AssertionError):
+ visualize_utils.imshow_text_label(5, pred_label, gt_label)
+ with pytest.raises(AssertionError):
+ visualize_utils.imshow_text_label(img, pred_label, 4)
+ with pytest.raises(AssertionError):
+ visualize_utils.imshow_text_label(img, 3, gt_label)
+ with pytest.raises(AssertionError):
+ visualize_utils.imshow_text_label(
+ img, pred_label, gt_label, show=True, wait_time=0.1)
+
+ mock_imread.side_effect = [img, img]
+ visualize_utils.imshow_text_label(
+ img, pred_label, gt_label, out_file=out_file)
+ visualize_utils.imshow_text_label(
+ img, pred_label, gt_label, out_file=None, show=True)
+
+ # test showing img
+ mock_imshow.assert_called_once()
+ mock_imwrite.assert_called_once()
+
+ tmp_dir.cleanup()
diff --git a/tools/data/textrecog/seg_synthtext_converter.py b/tools/data/textrecog/seg_synthtext_converter.py
new file mode 100644
index 00000000..64ed9701
--- /dev/null
+++ b/tools/data/textrecog/seg_synthtext_converter.py
@@ -0,0 +1,93 @@
+import argparse
+import codecs
+import json
+import os.path as osp
+
+import cv2
+
+
+def read_json(fpath):
+ with codecs.open(fpath, 'r', 'utf-8') as f:
+ obj = json.load(f)
+ return obj
+
+
+def parse_old_label(img_prefix, in_path):
+ imgid2imgname = {}
+ imgid2anno = {}
+ idx = 0
+ with open(in_path, 'r') as fr:
+ for line in fr:
+ line = line.strip().split()
+ img_full_path = osp.join(img_prefix, line[0])
+ if not osp.exists(img_full_path):
+ continue
+ img = cv2.imread(img_full_path)
+ h, w = img.shape[:2]
+ img_info = {}
+ img_info['file_name'] = line[0]
+ img_info['height'] = h
+ img_info['width'] = w
+ imgid2imgname[idx] = img_info
+ imgid2anno[idx] = []
+ for i in range(len(line[1:]) // 8):
+ seg = [int(x) for x in line[(1 + i * 8):(1 + (i + 1) * 8)]]
+ points_x = seg[0:2:8]
+ points_y = seg[1:2:9]
+ box = [
+ min(points_x),
+ min(points_y),
+ max(points_x),
+ max(points_y)
+ ]
+ new_anno = {}
+ new_anno['iscrowd'] = 0
+ new_anno['category_id'] = 1
+ new_anno['bbox'] = box
+ new_anno['segmentation'] = [seg]
+ imgid2anno[idx].append(new_anno)
+ idx += 1
+
+ return imgid2imgname, imgid2anno
+
+
+def gen_line_dict_file(out_path, imgid2imgname, imgid2anno):
+ # import pdb; pdb.set_trace()
+ with codecs.open(out_path, 'w', 'utf-8') as fw:
+ for key, value in imgid2imgname.items():
+ if key in imgid2anno:
+ anno = imgid2anno[key]
+ line_dict = {}
+ line_dict['file_name'] = value['file_name']
+ line_dict['height'] = value['height']
+ line_dict['width'] = value['width']
+ line_dict['annotations'] = anno
+ line_dict_str = json.dumps(line_dict)
+ fw.write(line_dict_str + '\n')
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--img-prefix',
+ help='image prefix, to generate full image path with "image_name"')
+ parser.add_argument(
+ '--in-path',
+ help='mapping file of image_name and ann_file,'
+ ' "image_name ann_file" in each line')
+ parser.add_argument(
+ '--out-path', help='output txt path with line-json format')
+
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+ imgid2imgname, imgid2anno = parse_old_label(args.img_prefix, args.in_path)
+ gen_line_dict_file(args.out_path, imgid2imgname, imgid2anno)
+ print('finish')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/data/utils/txt2lmdb.py b/tools/data/utils/txt2lmdb.py
new file mode 100644
index 00000000..e7554ccf
--- /dev/null
+++ b/tools/data/utils/txt2lmdb.py
@@ -0,0 +1,74 @@
+import argparse
+import shutil
+import sys
+import time
+from pathlib import Path
+
+import lmdb
+
+
+def converter(imglist, output, batch_size=1000, coding='utf-8'):
+ # read imglist
+ with open(imglist) as f:
+ lines = f.readlines()
+
+ # create lmdb database
+ if Path(output).is_dir():
+ while True:
+ print('%s already exist, delete or not? [Y/n]' % output)
+ Yn = input().strip()
+ if Yn in ['Y', 'y']:
+ shutil.rmtree(output)
+ break
+ elif Yn in ['N', 'n']:
+ return
+ print('create database %s' % output)
+ Path(output).mkdir(parents=True, exist_ok=False)
+ env = lmdb.open(output, map_size=1099511627776)
+
+ # build lmdb
+ beg_time = time.strftime('%H:%M:%S')
+ for beg_index in range(0, len(lines), batch_size):
+ end_index = min(beg_index + batch_size, len(lines))
+ sys.stdout.write('\r[%s-%s], processing [%d-%d] / %d' %
+ (beg_time, time.strftime('%H:%M:%S'), beg_index,
+ end_index, len(lines)))
+ sys.stdout.flush()
+ batch = [(str(index).encode(coding), lines[index].encode(coding))
+ for index in range(beg_index, end_index)]
+ with env.begin(write=True) as txn:
+ cursor = txn.cursor()
+ cursor.putmulti(batch, dupdata=False, overwrite=True)
+ sys.stdout.write('\n')
+ with env.begin(write=True) as txn:
+ key = 'total_number'.encode(coding)
+ value = str(len(lines)).encode(coding)
+ txn.put(key, value)
+ print('done', flush=True)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--imglist', '-i', required=True, help='input imglist path')
+ parser.add_argument(
+ '--output', '-o', required=True, help='output lmdb path')
+ parser.add_argument(
+ '--batch_size',
+ '-b',
+ type=int,
+ default=10000,
+ help='processing batch size, default 10000')
+ parser.add_argument(
+ '--coding',
+ '-c',
+ default='utf8',
+ help='bytes coding scheme, default utf8')
+ opt = parser.parse_args()
+
+ converter(
+ opt.imglist, opt.output, batch_size=opt.batch_size, coding=opt.coding)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/ocr_test_imgs.py b/tools/ocr_test_imgs.py
new file mode 100644
index 00000000..a2d642b1
--- /dev/null
+++ b/tools/ocr_test_imgs.py
@@ -0,0 +1,132 @@
+import os.path as osp
+import shutil
+import time
+from argparse import ArgumentParser
+
+import mmcv
+import torch
+from mmcv.utils import ProgressBar
+
+from mmdet.apis import init_detector
+from mmdet.utils import get_root_logger
+from mmocr.apis import model_inference
+from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
+from mmocr.datasets import build_dataset # noqa: F401
+from mmocr.models import build_detector # noqa: F401
+
+
+def save_results(img_paths, pred_labels, gt_labels, res_dir):
+ """Save predicted results to txt file.
+
+ Args:
+ img_paths (list[str])
+ pred_labels (list[str])
+ gt_labels (list[str])
+ res_dir (str)
+ """
+ assert len(img_paths) == len(pred_labels) == len(gt_labels)
+ res_file = osp.join(res_dir, 'results.txt')
+ correct_file = osp.join(res_dir, 'correct.txt')
+ wrong_file = osp.join(res_dir, 'wrong.txt')
+ with open(res_file, 'w') as fw, \
+ open(correct_file, 'w') as fw_correct, \
+ open(wrong_file, 'w') as fw_wrong:
+ for img_path, pred_label, gt_label in zip(img_paths, pred_labels,
+ gt_labels):
+ fw.write(img_path + ' ' + pred_label + ' ' + gt_label + '\n')
+ if pred_label == gt_label:
+ fw_correct.write(img_path + ' ' + pred_label + ' ' + gt_label +
+ '\n')
+ else:
+ fw_wrong.write(img_path + ' ' + pred_label + ' ' + gt_label +
+ '\n')
+
+
+def main():
+ parser = ArgumentParser()
+ parser.add_argument('--img_root_path', type=str, help='Image root path')
+ parser.add_argument('--img_list', type=str, help='Image path list file')
+ parser.add_argument('--config', type=str, help='Config file')
+ parser.add_argument('--checkpoint', type=str, help='Checkpoint file')
+ parser.add_argument(
+ '--out_dir', type=str, default='./results', help='Dir to save results')
+ parser.add_argument(
+ '--show', action='store_true', help='show image or save')
+ args = parser.parse_args()
+
+ # init the logger before other steps
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+ log_file = osp.join(args.out_dir, f'{timestamp}.log')
+ logger = get_root_logger(log_file=log_file, log_level='INFO')
+
+ # build the model from a config file and a checkpoint file
+ device = 'cuda:' + str(torch.cuda.current_device())
+ model = init_detector(args.config, args.checkpoint, device=device)
+ if hasattr(model, 'module'):
+ model = model.module
+ if model.cfg.data.test['type'] == 'ConcatDataset':
+ model.cfg.data.test.pipeline = \
+ model.cfg.data.test['datasets'][0].pipeline
+
+ # Start Inference
+ out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
+ mmcv.mkdir_or_exist(out_vis_dir)
+ correct_vis_dir = osp.join(args.out_dir, 'correct')
+ mmcv.mkdir_or_exist(correct_vis_dir)
+ wrong_vis_dir = osp.join(args.out_dir, 'wrong')
+ mmcv.mkdir_or_exist(wrong_vis_dir)
+ img_paths, pred_labels, gt_labels = [], [], []
+ total_img_num = sum([1 for _ in open(args.img_list)])
+ progressbar = ProgressBar(task_num=total_img_num)
+ num_gt_label = 0
+ with open(args.img_list, 'r') as fr:
+ for line in fr:
+ progressbar.update()
+ item_list = line.strip().split()
+ img_file = item_list[0]
+ gt_label = ''
+ if len(item_list) >= 2:
+ gt_label = item_list[1]
+ num_gt_label += 1
+ img_path = osp.join(args.img_root_path, img_file)
+ if not osp.exists(img_path):
+ raise FileNotFoundError(img_path)
+ # Test a single image
+ result = model_inference(model, img_path)
+ pred_label = result['text']
+
+ out_img_name = '_'.join(img_file.split('/'))
+ out_file = osp.join(out_vis_dir, out_img_name)
+ kwargs_dict = {
+ 'gt_label': gt_label,
+ 'show': args.show,
+ 'out_file': '' if args.show else out_file
+ }
+ model.show_result(img_path, result, **kwargs_dict)
+ if gt_label != '':
+ if gt_label == pred_label:
+ dst_file = osp.join(correct_vis_dir, out_img_name)
+ else:
+ dst_file = osp.join(wrong_vis_dir, out_img_name)
+ shutil.copy(out_file, dst_file)
+ img_paths.append(img_path)
+ gt_labels.append(gt_label)
+ pred_labels.append(pred_label)
+
+ # Save results
+ save_results(img_paths, pred_labels, gt_labels, args.out_dir)
+
+ if num_gt_label == len(pred_labels):
+ # eval
+ eval_results = eval_ocr_metric(pred_labels, gt_labels)
+ logger.info('\n' + '-' * 100)
+ info = 'eval on testset with img_root_path ' + \
+ f'{args.img_root_path} and img_list {args.img_list}\n'
+ logger.info(info)
+ logger.info(eval_results)
+
+ print(f'\nInference done, and results saved in {args.out_dir}\n')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/ocr_test_imgs.sh b/tools/ocr_test_imgs.sh
new file mode 100644
index 00000000..69d719a1
--- /dev/null
+++ b/tools/ocr_test_imgs.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+DATE=`date +%Y-%m-%d`
+TIME=`date +"%H-%M-%S"`
+
+if [ $# -lt 5 ]
+then
+ echo "Usage: bash $0 CONFIG CHECKPOINT IMG_PREFIX IMG_LIST RESULTS_DIR"
+ exit
+fi
+
+CONFIG_FILE=$1
+CHECKPOINT=$2
+IMG_ROOT_PATH=$3
+IMG_LIST=$4
+OUT_DIR=$5_${DATE}_${TIME}
+
+mkdir ${OUT_DIR} -p &&
+
+python tools/ocr_test_imgs.py \
+ --img_root_path ${IMG_ROOT_PATH} \
+ --img_list ${IMG_LIST} \
+ --config ${CONFIG_FILE} \
+ --checkpoint ${CHECKPOINT} \
+ --out_dir ${OUT_DIR}