diff --git a/configs/_base_/recog_datasets/seg_toy_dataset.py b/configs/_base_/recog_datasets/seg_toy_dataset.py new file mode 100644 index 00000000..d6c49dbf --- /dev/null +++ b/configs/_base_/recog_datasets/seg_toy_dataset.py @@ -0,0 +1,96 @@ +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + +gt_label_convertor = dict( + type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomPaddingOCR', + max_ratio=[0.15, 0.2, 0.15, 0.2], + box_type='char_quads'), + dict(type='OpencvToPil'), + dict( + type='RandomRotateImageBox', + min_angle=-17, + max_angle=17, + box_type='char_quads'), + dict(type='PilToOpencv'), + dict( + type='ResizeOCR', + height=64, + min_width=64, + max_width=512, + keep_aspect_ratio=True), + dict( + type='OCRSegTargets', + label_convertor=gt_label_convertor, + box_type='char_quads'), + dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), + dict(type='ToTensorOCR'), + dict(type='FancyPCA'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='CustomFormatBundle', + keys=['gt_kernels'], + visualize=dict(flag=False, boundary_key=None), + call_super=False), + dict( + type='Collect', + keys=['img', 'gt_kernels'], + meta_keys=['filename', 'ori_shape', 'img_shape']) +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeOCR', + height=64, + min_width=64, + max_width=None, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict(type='CustomFormatBundle', call_super=False), + dict( + type='Collect', + keys=['img'], + meta_keys=['filename', 'ori_shape', 'img_shape']) +] + +prefix = 'tests/data/ocr_char_ann_toy_dataset/' +train = dict( + type='OCRSegDataset', + img_prefix=prefix + 'imgs', + ann_file=prefix + 'instances_train.txt', + loader=dict( + type='HardDiskLoader', + repeat=100, + parser=dict( + type='LineJsonParser', keys=['file_name', 'annotations', 'text'])), + pipeline=train_pipeline, + test_mode=True) + +test = dict( + type='OCRDataset', + img_prefix=prefix + 'imgs', + ann_file=prefix + 'instances_test.txt', + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=test_pipeline, + test_mode=True) + +data = dict( + samples_per_gpu=8, + workers_per_gpu=1, + train=dict(type='ConcatDataset', datasets=[train]), + val=dict(type='ConcatDataset', datasets=[test]), + test=dict(type='ConcatDataset', datasets=[test])) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/_base_/recog_datasets/toy_dataset.py b/configs/_base_/recog_datasets/toy_dataset.py new file mode 100755 index 00000000..83848863 --- /dev/null +++ b/configs/_base_/recog_datasets/toy_dataset.py @@ -0,0 +1,99 @@ +img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeOCR', + height=32, + min_width=32, + max_width=160, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio' + ]), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiRotateAugOCR', + rotate_degrees=[0, 90, 270], + transforms=[ + dict( + type='ResizeOCR', + height=32, + min_width=32, + max_width=160, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'valid_ratio' + ]), + ]) +] + +dataset_type = 'OCRDataset' +img_prefix = 'tests/data/ocr_toy_dataset/imgs' +train_anno_file1 = 'tests/data/ocr_toy_dataset/label.txt' +train1 = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=train_anno_file1, + loader=dict( + type='HardDiskLoader', + repeat=100, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb' +train2 = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=train_anno_file2, + loader=dict( + type='LmdbLoader', + repeat=100, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb' +test = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=test_anno_file1, + loader=dict( + type='LmdbLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=test_pipeline, + test_mode=True) + +data = dict( + samples_per_gpu=16, + workers_per_gpu=2, + train=dict(type='ConcatDataset', datasets=[train1, train2]), + val=dict(type='ConcatDataset', datasets=[test]), + test=dict(type='ConcatDataset', datasets=[test])) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/_base_/recog_models/crnn.py b/configs/_base_/recog_models/crnn.py new file mode 100644 index 00000000..6b98c3d9 --- /dev/null +++ b/configs/_base_/recog_models/crnn.py @@ -0,0 +1,11 @@ +label_convertor = dict( + type='CTCConvertor', dict_type='DICT90', with_unknown=False) + +model = dict( + type='CRNNNet', + preprocessor=None, + backbone=dict(type='VeryDeepVgg', leakyRelu=False), + encoder=None, + decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True), + loss=dict(type='CTCLoss', flatten=False), + label_convertor=label_convertor) diff --git a/configs/_base_/recog_models/robust_scanner.py b/configs/_base_/recog_models/robust_scanner.py new file mode 100644 index 00000000..4cc2fa10 --- /dev/null +++ b/configs/_base_/recog_models/robust_scanner.py @@ -0,0 +1,24 @@ +label_convertor = dict( + type='AttnConvertor', dict_type='DICT90', with_unknown=True) + +hybrid_decoder = dict(type='SequenceAttentionDecoder') + +position_decoder = dict(type='PositionAttentionDecoder') + +model = dict( + type='RobustScanner', + backbone=dict(type='ResNet31OCR'), + encoder=dict( + type='ChannelReductionEncoder', + in_channels=512, + out_channels=128, + ), + decoder=dict( + type='RobustScannerDecoder', + dim_input=512, + dim_model=128, + hybrid_decoder=hybrid_decoder, + position_decoder=position_decoder), + loss=dict(type='SARLoss'), + label_convertor=label_convertor, + max_seq_len=30) diff --git a/configs/_base_/recog_models/sar.py b/configs/_base_/recog_models/sar.py new file mode 100755 index 00000000..8438d9b9 --- /dev/null +++ b/configs/_base_/recog_models/sar.py @@ -0,0 +1,24 @@ +label_convertor = dict( + type='AttnConvertor', dict_type='DICT90', with_unknown=True) + +model = dict( + type='SARNet', + backbone=dict(type='ResNet31OCR'), + encoder=dict( + type='SAREncoder', + enc_bi_rnn=False, + enc_do_rnn=0.1, + enc_gru=False, + ), + decoder=dict( + type='ParallelSARDecoder', + enc_bi_rnn=False, + dec_bi_rnn=False, + dec_do_rnn=0, + dec_gru=False, + pred_dropout=0.1, + d_k=512, + pred_concat=True), + loss=dict(type='SARLoss'), + label_convertor=label_convertor, + max_seq_len=30) diff --git a/configs/_base_/recog_models/transformer.py b/configs/_base_/recog_models/transformer.py new file mode 100644 index 00000000..476643fa --- /dev/null +++ b/configs/_base_/recog_models/transformer.py @@ -0,0 +1,11 @@ +label_convertor = dict( + type='AttnConvertor', dict_type='DICT90', with_unknown=False) + +model = dict( + type='TransformerNet', + backbone=dict(type='ResNet31OCR'), + encoder=dict(type='TFEncoder'), + decoder=dict(type='TFDecoder'), + loss=dict(type='TFLoss'), + label_convertor=label_convertor, + max_seq_len=40) diff --git a/configs/textrecog/sar/README.md b/configs/textrecog/sar/README.md new file mode 100644 index 00000000..517c6985 --- /dev/null +++ b/configs/textrecog/sar/README.md @@ -0,0 +1,65 @@ +# Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition + +## Introduction + +[ALGORITHM] + +``` +@inproceedings{li2019show, + title={Show, attend and read: A simple and strong baseline for irregular text recognition}, + author={Li, Hui and Wang, Peng and Shen, Chunhua and Zhang, Guyu}, + booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, + volume={33}, + number={01}, + pages={8610--8617}, + year={2019} +} +``` + +## Dataset + +### Train Dataset + +| trainset | instance_num | repeat_num | source | +| :--------: | :----------: | :--------: | :----------------------: | +| icdar_2011 | 3567 | 20 | real | +| icdar_2013 | 848 | 20 | real | +| icdar2015 | 4468 | 20 | real | +| coco_text | 42142 | 20 | real | +| IIIT5K | 2000 | 20 | real | +| SynthText | 2400000 | 1 | synth | +| SynthAdd | 1216889 | 1 | synth, 1.6m in [[1]](#1) | +| Syn90k | 2400000 | 1 | synth | + +### Test Dataset + +| testset | instance_num | type | +| :-----: | :----------: | :-------------------------: | +| IIIT5K | 3000 | regular | +| SVT | 647 | regular | +| IC13 | 1015 | regular | +| IC15 | 2077 | irregular | +| SVTP | 645 | irregular, 639 in [[1]](#1) | +| CT80 | 288 | irregular | + +## Results and Models +| Methods|Backbone|Decoder||Regular Text||||Irregular Text||download| +| :-------------: | :-----: | :-----: | :-----: | :------: | :-----: | :----: | :-----: | :-----: | :-----: |:-----: | +||||IIIT5K|SVT|IC13||IC15|SVTP|CT80| +|[SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py)|R31-1/8-1/4|ParallelSARDecoder|95.0|89.6|93.7||79.0|82.2|88.9|[model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth) | [config](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_academic.py) | [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210327_154129.log.json) | +|[SAR](configs/textrecog/sar/sar_r31_sequential_decoder_academic.py)|R31-1/8-1/4|SequentialSARDecoder|95.2|88.7|92.4||78.2|81.9|89.6|[model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_sequential_decoder_academic-d06c9a8e.pth) | [config](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_sequential_decoder_academic.py) | [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210330_105728.log.json)| + +**Notes:** +- `R31-1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width. +- We did not use beam search during decoding. +- We implemented two kinds of decoder. Namely, `ParallelSARDecoder` and `SequentialSARDecoder`. + - `ParallelSARDecoder`: Parallel decoding during training with `LSTM` layer. It would be faster. + - `SequentialSARDecoder`: Sequential Decoding during training with `LSTMCell`. It would be easier to understand. +- For train dataset. + - We did not construct distinct data groups (20 groups in [[1]](#1)) to train the model group-by-group since it would render model training too complicated. + - Instead, we randomly selected `2.4m` patches from `Syn90k`, `2.4m` from `SynthText` and `1.2m` from `SynthAdd`, and grouped all data together. See [config](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_academic.py) for details. +- We used 48 GPUs with `total_batch_size = 64 * 48` in the experiment above to speedup training, while keeping the `initial lr = 1e-3` unchanged. + +## References + +[1] Li, Hui and Wang, Peng and Shen, Chunhua and Zhang, Guyu. Show, attend and read: A simple and strong baseline for irregular text recognition. In AAAI 2019. diff --git a/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py b/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py new file mode 100644 index 00000000..4e405227 --- /dev/null +++ b/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py @@ -0,0 +1,219 @@ +_base_ = ['../../_base_/default_runtime.py'] + +label_convertor = dict( + type='AttnConvertor', dict_type='DICT90', with_unknown=True) + +model = dict( + type='SARNet', + backbone=dict(type='ResNet31OCR'), + encoder=dict( + type='SAREncoder', + enc_bi_rnn=False, + enc_do_rnn=0.1, + enc_gru=False, + ), + decoder=dict( + type='ParallelSARDecoder', + enc_bi_rnn=False, + dec_bi_rnn=False, + dec_do_rnn=0, + dec_gru=False, + pred_dropout=0.1, + d_k=512, + pred_concat=True), + loss=dict(type='SARLoss'), + label_convertor=label_convertor, + max_seq_len=30) + +# optimizer +optimizer = dict(type='Adam', lr=1e-3) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[3, 4]) +total_epochs = 5 + +img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeOCR', + height=48, + min_width=48, + max_width=160, + keep_aspect_ratio=True, + width_downsample_ratio=0.25), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio' + ]), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiRotateAugOCR', + rotate_degrees=[0, 90, 270], + transforms=[ + dict( + type='ResizeOCR', + height=48, + min_width=48, + max_width=160, + keep_aspect_ratio=True, + width_downsample_ratio=0.25), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'valid_ratio' + ]), + ]) +] + +dataset_type = 'OCRDataset' + +train_prefix = 'data/mixture/' + +train_img_prefix1 = train_prefix + 'icdar_2011' +train_img_prefix2 = train_prefix + 'icdar_2013' +train_img_prefix3 = train_prefix + 'icdar_2015' +train_img_prefix4 = train_prefix + 'coco_text' +train_img_prefix5 = train_prefix + 'III5K' +train_img_prefix6 = train_prefix + 'SynthText_Add' +train_img_prefix7 = train_prefix + 'SynthText' +train_img_prefix8 = train_prefix + 'Syn90k' + +train_ann_file1 = train_prefix + 'icdar_2011/train_label.txt', +train_ann_file2 = train_prefix + 'icdar_2013/train_label.txt', +train_ann_file3 = train_prefix + 'icdar_2015/train_label.txt', +train_ann_file4 = train_prefix + 'coco_text/train_label.txt', +train_ann_file5 = train_prefix + 'III5K/train_label.txt', +train_ann_file6 = train_prefix + 'SynthText_Add/label.txt', +train_ann_file7 = train_prefix + 'SynthText/shuffle_labels.txt', +train_ann_file8 = train_prefix + 'Syn90k/shuffle_labels.txt' + +train1 = dict( + type=dataset_type, + img_prefix=train_img_prefix1, + ann_file=train_ann_file1, + loader=dict( + type='HardDiskLoader', + repeat=20, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +train2 = {key: value for key, value in train1.items()} +train2['img_prefix'] = train_img_prefix2 +train2['ann_file'] = train_ann_file2 + +train3 = {key: value for key, value in train1.items()} +train3['img_prefix'] = train_img_prefix3 +train3['ann_file'] = train_ann_file3 + +train4 = {key: value for key, value in train1.items()} +train4['img_prefix'] = train_img_prefix4 +train4['ann_file'] = train_ann_file4 + +train5 = {key: value for key, value in train1.items()} +train5['img_prefix'] = train_img_prefix5 +train5['ann_file'] = train_ann_file5 + +train6 = dict( + type=dataset_type, + img_prefix=train_img_prefix6, + ann_file=train_ann_file6, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +train7 = {key: value for key, value in train6.items()} +train7['img_prefix'] = train_img_prefix7 +train7['ann_file'] = train_ann_file7 + +train8 = {key: value for key, value in train6.items()} +train8['img_prefix'] = train_img_prefix8 +train8['ann_file'] = train_ann_file8 + +test_prefix = 'data/mixture/' +test_img_prefix1 = test_prefix + 'IIIT5K/' +test_img_prefix2 = test_prefix + 'svt/' +test_img_prefix3 = test_prefix + 'icdar_2013/' +test_img_prefix4 = test_prefix + 'icdar_2015/' +test_img_prefix5 = test_prefix + 'svtp/' +test_img_prefix6 = test_prefix + 'ct80/' + +test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt' +test_ann_file2 = test_prefix + 'svt/test_label.txt' +test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt' +test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt' +test_ann_file5 = test_prefix + 'svtp/test_label.txt' +test_ann_file6 = test_prefix + 'ct80/test_label.txt' + +test1 = dict( + type=dataset_type, + img_prefix=test_img_prefix1, + ann_file=test_ann_file1, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=test_pipeline, + test_mode=True) + +test2 = {key: value for key, value in test1.items()} +test2['img_prefix'] = test_img_prefix2 +test2['ann_file'] = test_ann_file2 + +test3 = {key: value for key, value in test1.items()} +test3['img_prefix'] = test_img_prefix3 +test3['ann_file'] = test_ann_file3 + +test4 = {key: value for key, value in test1.items()} +test4['img_prefix'] = test_img_prefix4 +test4['ann_file'] = test_ann_file4 + +test5 = {key: value for key, value in test1.items()} +test5['img_prefix'] = test_img_prefix5 +test5['ann_file'] = test_ann_file5 + +test6 = {key: value for key, value in test1.items()} +test6['img_prefix'] = test_img_prefix6 +test6['ann_file'] = test_ann_file6 + +data = dict( + samples_per_gpu=64, + workers_per_gpu=2, + train=dict( + type='ConcatDataset', + datasets=[ + train1, train2, train3, train4, train5, train6, train7, train8 + ]), + val=dict( + type='ConcatDataset', + datasets=[test1, test2, test3, test4, test5, test6]), + test=dict( + type='ConcatDataset', + datasets=[test1, test2, test3, test4, test5, test6])) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py b/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py new file mode 100755 index 00000000..0c8b53e2 --- /dev/null +++ b/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py @@ -0,0 +1,110 @@ +_base_ = [ + '../../_base_/default_runtime.py', '../../_base_/recog_models/sar.py' +] + +# optimizer +optimizer = dict(type='Adam', lr=1e-3) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[3, 4]) +total_epochs = 5 + +img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeOCR', + height=48, + min_width=48, + max_width=160, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio' + ]), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiRotateAugOCR', + rotate_degrees=[0, 90, 270], + transforms=[ + dict( + type='ResizeOCR', + height=48, + min_width=48, + max_width=160, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'valid_ratio' + ]), + ]) +] + +dataset_type = 'OCRDataset' +img_prefix = 'tests/data/ocr_toy_dataset/imgs' +train_anno_file1 = 'tests/data/ocr_toy_dataset/label.txt' +train1 = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=train_anno_file1, + loader=dict( + type='HardDiskLoader', + repeat=100, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb' +train2 = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=train_anno_file2, + loader=dict( + type='LmdbLoader', + repeat=100, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb' +test = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=test_anno_file1, + loader=dict( + type='LmdbLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=test_pipeline, + test_mode=True) + +data = dict( + samples_per_gpu=16, + workers_per_gpu=2, + train=dict(type='ConcatDataset', datasets=[train1, train2]), + val=dict(type='ConcatDataset', datasets=[test]), + test=dict(type='ConcatDataset', datasets=[test])) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py b/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py new file mode 100644 index 00000000..6fa00dd7 --- /dev/null +++ b/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py @@ -0,0 +1,219 @@ +_base_ = ['../../_base_/default_runtime.py'] + +label_convertor = dict( + type='AttnConvertor', dict_type='DICT90', with_unknown=True) + +model = dict( + type='SARNet', + backbone=dict(type='ResNet31OCR'), + encoder=dict( + type='SAREncoder', + enc_bi_rnn=False, + enc_do_rnn=0.1, + enc_gru=False, + ), + decoder=dict( + type='SequentialSARDecoder', + enc_bi_rnn=False, + dec_bi_rnn=False, + dec_do_rnn=0, + dec_gru=False, + pred_dropout=0.1, + d_k=512, + pred_concat=True), + loss=dict(type='SARLoss'), + label_convertor=label_convertor, + max_seq_len=30) + +# optimizer +optimizer = dict(type='Adam', lr=1e-3) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[3, 4]) +total_epochs = 5 + +img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeOCR', + height=48, + min_width=48, + max_width=160, + keep_aspect_ratio=True, + width_downsample_ratio=0.25), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio' + ]), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiRotateAugOCR', + rotate_degrees=[0, 90, 270], + transforms=[ + dict( + type='ResizeOCR', + height=48, + min_width=48, + max_width=160, + keep_aspect_ratio=True, + width_downsample_ratio=0.25), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'valid_ratio' + ]), + ]) +] + +dataset_type = 'OCRDataset' + +train_prefix = 'data/mixture/' + +train_img_prefix1 = train_prefix + 'icdar_2011' +train_img_prefix2 = train_prefix + 'icdar_2013' +train_img_prefix3 = train_prefix + 'icdar_2015' +train_img_prefix4 = train_prefix + 'coco_text' +train_img_prefix5 = train_prefix + 'III5K' +train_img_prefix6 = train_prefix + 'SynthText_Add' +train_img_prefix7 = train_prefix + 'SynthText' +train_img_prefix8 = train_prefix + 'Syn90k' + +train_ann_file1 = train_prefix + 'icdar_2011/train_label.txt', +train_ann_file2 = train_prefix + 'icdar_2013/train_label.txt', +train_ann_file3 = train_prefix + 'icdar_2015/train_label.txt', +train_ann_file4 = train_prefix + 'coco_text/train_label.txt', +train_ann_file5 = train_prefix + 'III5K/train_label.txt', +train_ann_file6 = train_prefix + 'SynthText_Add/label.txt', +train_ann_file7 = train_prefix + 'SynthText/shuffle_labels.txt', +train_ann_file8 = train_prefix + 'Syn90k/shuffle_labels.txt' + +train1 = dict( + type=dataset_type, + img_prefix=train_img_prefix1, + ann_file=train_ann_file1, + loader=dict( + type='HardDiskLoader', + repeat=20, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +train2 = {key: value for key, value in train1.items()} +train2['img_prefix'] = train_img_prefix2 +train2['ann_file'] = train_ann_file2 + +train3 = {key: value for key, value in train1.items()} +train3['img_prefix'] = train_img_prefix3 +train3['ann_file'] = train_ann_file3 + +train4 = {key: value for key, value in train1.items()} +train4['img_prefix'] = train_img_prefix4 +train4['ann_file'] = train_ann_file4 + +train5 = {key: value for key, value in train1.items()} +train5['img_prefix'] = train_img_prefix5 +train5['ann_file'] = train_ann_file5 + +train6 = dict( + type=dataset_type, + img_prefix=train_img_prefix6, + ann_file=train_ann_file6, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +train7 = {key: value for key, value in train6.items()} +train7['img_prefix'] = train_img_prefix7 +train7['ann_file'] = train_ann_file7 + +train8 = {key: value for key, value in train6.items()} +train8['img_prefix'] = train_img_prefix8 +train8['ann_file'] = train_ann_file8 + +test_prefix = 'data/mixture/' +test_img_prefix1 = test_prefix + 'IIIT5K/' +test_img_prefix2 = test_prefix + 'svt/' +test_img_prefix3 = test_prefix + 'icdar_2013/' +test_img_prefix4 = test_prefix + 'icdar_2015/' +test_img_prefix5 = test_prefix + 'svtp/' +test_img_prefix6 = test_prefix + 'ct80/' + +test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt' +test_ann_file2 = test_prefix + 'svt/test_label.txt' +test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt' +test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt' +test_ann_file5 = test_prefix + 'svtp/test_label.txt' +test_ann_file6 = test_prefix + 'ct80/test_label.txt' + +test1 = dict( + type=dataset_type, + img_prefix=test_img_prefix1, + ann_file=test_ann_file1, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=test_pipeline, + test_mode=True) + +test2 = {key: value for key, value in test1.items()} +test2['img_prefix'] = test_img_prefix2 +test2['ann_file'] = test_ann_file2 + +test3 = {key: value for key, value in test1.items()} +test3['img_prefix'] = test_img_prefix3 +test3['ann_file'] = test_ann_file3 + +test4 = {key: value for key, value in test1.items()} +test4['img_prefix'] = test_img_prefix4 +test4['ann_file'] = test_ann_file4 + +test5 = {key: value for key, value in test1.items()} +test5['img_prefix'] = test_img_prefix5 +test5['ann_file'] = test_ann_file5 + +test6 = {key: value for key, value in test1.items()} +test6['img_prefix'] = test_img_prefix6 +test6['ann_file'] = test_ann_file6 + +data = dict( + samples_per_gpu=64, + workers_per_gpu=2, + train=dict( + type='ConcatDataset', + datasets=[ + train1, train2, train3, train4, train5, train6, train7, train8 + ]), + val=dict( + type='ConcatDataset', + datasets=[test1, test2, test3, test4, test5, test6]), + test=dict( + type='ConcatDataset', + datasets=[test1, test2, test3, test4, test5, test6])) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/textrecog/seg/README.md b/configs/textrecog/seg/README.md new file mode 100644 index 00000000..49ca2014 --- /dev/null +++ b/configs/textrecog/seg/README.md @@ -0,0 +1,36 @@ +# Baseline of segmentation based text recognition method. + +## Introduction + +A Baseline Method for Segmentation based Text Recognition. + +[ALGORITHM] + +## Dataset + +### Train Dataset + +| trainset | instance_num | repeat_num | source | +| :-------: | :----------: | :--------: | :----: | +| SynthText | 7266686 | 1 | synth | + +### Test Dataset + +| testset | instance_num | type | +| :-----: | :----------: | :-------: | +| IIIT5K | 3000 | regular | +| SVT | 647 | regular | +| IC13 | 1015 | regular | +| CT80 | 288 | irregular | + +## Results and Models +|Backbone|Neck|Head|||Regular Text|||Irregular Text|base_lr|batch_size/gpu|gpus|download +| :-------------: | :-----: | :-----: | :------: | :-----: | :----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | +|||||IIIT5K|SVT|IC13||CT80| +|R31-1/16|FPNOCR|1x||90.9|81.8|90.7||80.9|1e-4|16|4|[model](https://download.openmmlab.com/mmocr/textrecog/seg/seg_r31_1by16_fpnocr_academic-0c50e163.pth) | [config](https://download.openmmlab.com/mmocr/textrecog/seg/seg_r31_1by16_fpnocr_academic.py) | [log](https://download.openmmlab.com/mmocr/textrecog/seg/20210325_112835.log.json) | + + +**Notes:** + +- `R31-1/16` means the size (both height and width ) of feature from backbone is 1/16 of input image. +- `1x` means the size (both height and width) of feature from head is the same with input image. diff --git a/configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py b/configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py new file mode 100644 index 00000000..8a568f8e --- /dev/null +++ b/configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py @@ -0,0 +1,160 @@ +_base_ = ['../../_base_/default_runtime.py'] + +# optimizer +optimizer = dict(type='Adam', lr=1e-4) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[3, 4]) +total_epochs = 5 + +label_convertor = dict( + type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True) + +model = dict( + type='SegRecognizer', + backbone=dict( + type='ResNet31OCR', + layers=[1, 2, 5, 3], + channels=[32, 64, 128, 256, 512, 512], + out_indices=[0, 1, 2, 3], + stage4_pool_cfg=dict(kernel_size=2, stride=2), + last_stage_pool=True), + neck=dict( + type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256), + head=dict( + type='SegHead', + in_channels=256, + upsample_param=dict(scale_factor=2.0, mode='nearest')), + loss=dict( + type='SegLoss', seg_downsample_ratio=1.0, seg_with_loss_weight=True), + label_convertor=label_convertor) + +find_unused_parameters = True + +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + +gt_label_convertor = dict( + type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomPaddingOCR', + max_ratio=[0.15, 0.2, 0.15, 0.2], + box_type='char_quads'), + dict(type='OpencvToPil'), + dict( + type='RandomRotateImageBox', + min_angle=-17, + max_angle=17, + box_type='char_quads'), + dict(type='PilToOpencv'), + dict( + type='ResizeOCR', + height=64, + min_width=64, + max_width=512, + keep_aspect_ratio=True), + dict( + type='OCRSegTargets', + label_convertor=gt_label_convertor, + box_type='char_quads'), + dict(type='RandomRotateTextDet', rotate_ratio=0.5, max_angle=15), + dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), + dict(type='ToTensorOCR'), + dict(type='FancyPCA'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='CustomFormatBundle', + keys=['gt_kernels'], + visualize=dict(flag=False, boundary_key=None), + call_super=False), + dict( + type='Collect', + keys=['img', 'gt_kernels'], + meta_keys=['filename', 'ori_shape', 'img_shape']) +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeOCR', + height=64, + min_width=64, + max_width=None, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict(type='CustomFormatBundle', call_super=False), + dict( + type='Collect', + keys=['img'], + meta_keys=['filename', 'ori_shape', 'img_shape']) +] + +train_img_root = 'data/mixture/' + +train_img_prefix = train_img_root + 'SynthText' + +train_ann_file = train_img_root + 'SynthText/instances_train.txt' + +train = dict( + type='OCRSegDataset', + img_prefix=train_img_prefix, + ann_file=train_ann_file, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineJsonParser', keys=['file_name', 'annotations', 'text'])), + pipeline=train_pipeline, + test_mode=False) + +dataset_type = 'OCRDataset' +test_prefix = 'data/mixture/' + +test_img_prefix1 = test_prefix + 'IIIT5K/' +test_img_prefix2 = test_prefix + 'svt/' +test_img_prefix3 = test_prefix + 'icdar_2013/' +test_img_prefix4 = test_prefix + 'ct80/' + +test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt' +test_ann_file2 = test_prefix + 'svt/test_label.txt' +test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt' +test_ann_file4 = test_prefix + 'ct80/test_label.txt' + +test1 = dict( + type=dataset_type, + img_prefix=test_img_prefix1, + ann_file=test_ann_file1, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=test_pipeline, + test_mode=True) + +test2 = {key: value for key, value in test1.items()} +test2['img_prefix'] = test_img_prefix2 +test2['ann_file'] = test_ann_file2 + +test3 = {key: value for key, value in test1.items()} +test3['img_prefix'] = test_img_prefix3 +test3['ann_file'] = test_ann_file3 + +test4 = {key: value for key, value in test1.items()} +test4['img_prefix'] = test_img_prefix4 +test4['ann_file'] = test_ann_file4 + +data = dict( + samples_per_gpu=16, + workers_per_gpu=2, + train=dict(type='ConcatDataset', datasets=[train]), + val=dict(type='ConcatDataset', datasets=[test1, test2, test3, test4]), + test=dict(type='ConcatDataset', datasets=[test1, test2, test3, test4])) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py b/configs/textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py new file mode 100644 index 00000000..63b3d08c --- /dev/null +++ b/configs/textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py @@ -0,0 +1,35 @@ +_base_ = [ + '../../_base_/default_runtime.py', + '../../_base_/recog_datasets/seg_toy_dataset.py' +] + +# optimizer +optimizer = dict(type='Adam', lr=1e-4) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[3, 4]) +total_epochs = 5 + +label_convertor = dict( + type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True) + +model = dict( + type='SegRecognizer', + backbone=dict( + type='ResNet31OCR', + layers=[1, 2, 5, 3], + channels=[32, 64, 128, 256, 512, 512], + out_indices=[0, 1, 2, 3], + stage4_pool_cfg=dict(kernel_size=2, stride=2), + last_stage_pool=True), + neck=dict( + type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256), + head=dict( + type='SegHead', + in_channels=256, + upsample_param=dict(scale_factor=2.0, mode='nearest')), + loss=dict( + type='SegLoss', seg_downsample_ratio=1.0, seg_with_loss_weight=False), + label_convertor=label_convertor) + +find_unused_parameters = True diff --git a/configs/textrecog/transformer/README.md b/configs/textrecog/transformer/README.md new file mode 100644 index 00000000..46a132ad --- /dev/null +++ b/configs/textrecog/transformer/README.md @@ -0,0 +1,30 @@ +## Introduction + +### Train Dataset + +| trainset | instance_num | repeat_num | note | +| :--------: | :----------: | :--------: | :---: | +| icdar_2011 | 3567 | 20 | real | +| icdar_2013 | 848 | 20 | real | +| icdar2015 | 4468 | 20 | real | +| coco_text | 42142 | 20 | real | +| IIIT5K | 2000 | 20 | real | +| SynthText | 2400000 | 1 | synth | + +### Test Dataset + +| testset | instance_num | note | +| :-----: | :----------: | :-------------------------: | +| IIIT5K | 3000 | regular | +| SVT | 647 | regular | +| IC13 | 1015 | regular | +| IC15 | 2077 | irregular | +| SVTP | 645 | irregular, 639 in [[1]](#1) | +| CT80 | 288 | irregular | + +## Results and models + +| methods | | Regular Text | | | | Irregular Text | | download | +| :---------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :------: | +| | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 | +| Transformer | 93.3 | 85.8 | 91.3 | | 73.2 | 76.6 | 87.8 | | diff --git a/configs/textrecog/transformer/transformer_r31_toy_dataset.py b/configs/textrecog/transformer/transformer_r31_toy_dataset.py new file mode 100755 index 00000000..9fc4a3ab --- /dev/null +++ b/configs/textrecog/transformer/transformer_r31_toy_dataset.py @@ -0,0 +1,12 @@ +_base_ = [ + '../../_base_/default_runtime.py', + '../../_base_/recog_models/transformer.py', + '../../_base_/recog_datasets/toy_dataset.py' +] + +# optimizer +optimizer = dict(type='Adadelta', lr=1) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[3, 4]) +total_epochs = 5 diff --git a/mmocr/__init__.py b/mmocr/__init__.py new file mode 100644 index 00000000..1c4f7e8f --- /dev/null +++ b/mmocr/__init__.py @@ -0,0 +1,3 @@ +from .version import __version__, short_version + +__all__ = ['__version__', 'short_version'] diff --git a/mmocr/core/evaluation/ocr_metric.py b/mmocr/core/evaluation/ocr_metric.py new file mode 100644 index 00000000..5c5124f0 --- /dev/null +++ b/mmocr/core/evaluation/ocr_metric.py @@ -0,0 +1,133 @@ +import re +from difflib import SequenceMatcher + +import Levenshtein + + +def cal_true_positive_char(pred, gt): + """Calculate correct character number in prediction. + + Args: + pred (str): Prediction text. + gt (str): Ground truth text. + + Returns: + true_positive_char_num (int): The true positive number. + """ + + all_opt = SequenceMatcher(None, pred, gt) + true_positive_char_num = 0 + for opt, _, _, s2, e2 in all_opt.get_opcodes(): + if opt == 'equal': + true_positive_char_num += (e2 - s2) + else: + pass + return true_positive_char_num + + +def count_matches(pred_texts, gt_texts): + """Count the various match number for metric calculation. + + Args: + pred_texts (list[str]): Predicted text string. + gt_texts (list[str]): Ground truth text string. + + Returns: + match_res: (dict[str: int]): Match number used for + metric calculation. + """ + match_res = { + 'gt_char_num': 0, + 'pred_char_num': 0, + 'true_positive_char_num': 0, + 'gt_word_num': 0, + 'match_word_num': 0, + 'match_word_ignore_case': 0, + 'match_word_ignore_case_symbol': 0 + } + comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]') + norm_ed_sum = 0.0 + for pred_text, gt_text in zip(pred_texts, gt_texts): + if gt_text == pred_text: + match_res['match_word_num'] += 1 + gt_text_lower = gt_text.lower() + pred_text_lower = pred_text.lower() + if gt_text_lower == pred_text_lower: + match_res['match_word_ignore_case'] += 1 + gt_text_lower_ignore = comp.sub('', gt_text_lower) + pred_text_lower_ignore = comp.sub('', pred_text_lower) + if gt_text_lower_ignore == pred_text_lower_ignore: + match_res['match_word_ignore_case_symbol'] += 1 + match_res['gt_word_num'] += 1 + + # normalized edit distance + edit_dist = Levenshtein.distance(pred_text_lower_ignore, + gt_text_lower_ignore) + norm_ed = float(edit_dist) / max(1, len(gt_text_lower_ignore), + len(pred_text_lower_ignore)) + norm_ed_sum += norm_ed + + # number to calculate char level recall & precision + match_res['gt_char_num'] += len(gt_text_lower_ignore) + match_res['pred_char_num'] += len(pred_text_lower_ignore) + true_positive_char_num = cal_true_positive_char( + pred_text_lower_ignore, gt_text_lower_ignore) + match_res['true_positive_char_num'] += true_positive_char_num + + normalized_edit_distance = norm_ed_sum / max(1, len(gt_texts)) + match_res['ned'] = normalized_edit_distance + + return match_res + + +def eval_ocr_metric(pred_texts, gt_texts): + """Evaluate the text recognition performance with metric: word accuracy and + 1-N.E.D. See https://rrc.cvc.uab.es/?ch=14&com=tasks for details. + + Args: + pred_texts (list[str]): Text strings of prediction. + gt_texts (list[str]): Text strings of ground truth. + + Returns: + eval_res (dict[str: float]): Metric dict for text recognition, include: + - word_acc: Accuracy in word level. + - word_acc_ignore_case: Accuracy in word level, ignore letter case. + - word_acc_ignore_case_symbol: Accuracy in word level, ignore + letter case and symbol. (default metric for + academic evaluation) + - char_recall: Recall in character level, ignore + letter case and symbol. + - char_precision: Precision in character level, ignore + letter case and symbol. + - 1-N.E.D: 1 - normalized_edit_distance. + """ + assert isinstance(pred_texts, list) + assert isinstance(gt_texts, list) + assert len(pred_texts) == len(gt_texts) + + match_res = count_matches(pred_texts, gt_texts) + eps = 1e-8 + char_recall = 1.0 * match_res['true_positive_char_num'] / ( + eps + match_res['gt_char_num']) + char_precision = 1.0 * match_res['true_positive_char_num'] / ( + eps + match_res['pred_char_num']) + word_acc = 1.0 * match_res['match_word_num'] / ( + eps + match_res['gt_word_num']) + word_acc_ignore_case = 1.0 * match_res['match_word_ignore_case'] / ( + eps + match_res['gt_word_num']) + word_acc_ignore_case_symbol = 1.0 * match_res[ + 'match_word_ignore_case_symbol'] / ( + eps + match_res['gt_word_num']) + + eval_res = {} + eval_res['word_acc'] = word_acc + eval_res['word_acc_ignore_case'] = word_acc_ignore_case + eval_res['word_acc_ignore_case_symbol'] = word_acc_ignore_case_symbol + eval_res['char_recall'] = char_recall + eval_res['char_precision'] = char_precision + eval_res['1-N.E.D'] = 1.0 - match_res['ned'] + + for key, value in eval_res.items(): + eval_res[key] = float('{:.4f}'.format(value)) + + return eval_res diff --git a/mmocr/datasets/__init__.py b/mmocr/datasets/__init__.py new file mode 100644 index 00000000..dd219226 --- /dev/null +++ b/mmocr/datasets/__init__.py @@ -0,0 +1,15 @@ +from mmdet.datasets.builder import DATASETS, build_dataloader, build_dataset +from .base_dataset import BaseDataset +from .icdar_dataset import IcdarDataset +from .kie_dataset import KIEDataset +from .ocr_dataset import OCRDataset +from .ocr_seg_dataset import OCRSegDataset +from .pipelines import CustomFormatBundle, DBNetTargets, DRRGTargets +from .text_det_dataset import TextDetDataset +from .utils import * # noqa: F401,F403 + +__all__ = [ + 'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset', + 'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle', + 'DBNetTargets', 'OCRSegDataset', 'DRRGTargets', 'KIEDataset' +] diff --git a/mmocr/datasets/base_dataset.py b/mmocr/datasets/base_dataset.py new file mode 100644 index 00000000..c91497d9 --- /dev/null +++ b/mmocr/datasets/base_dataset.py @@ -0,0 +1,166 @@ +import numpy as np +from mmcv.utils import print_log +from torch.utils.data import Dataset + +from mmdet.datasets.builder import DATASETS +from mmdet.datasets.pipelines import Compose +from mmocr.datasets.builder import build_loader + + +@DATASETS.register_module() +class BaseDataset(Dataset): + """Custom dataset for text detection, text recognition, and their + downstream tasks. + + 1. The text detection annotation format is as follows: + The `annotations` field is optional for testing + (this is one line of anno_file, with line-json-str + converted to dict for visualizing only). + + { + "file_name": "sample.jpg", + "height": 1080, + "width": 960, + "annotations": + [ + { + "iscrowd": 0, + "category_id": 1, + "bbox": [357.0, 667.0, 804.0, 100.0], + "segmentation": [[361, 667, 710, 670, + 72, 767, 357, 763]] + } + ] + } + + 2. The two text recognition annotation formats are as follows: + The `x1,y1,x2,y2,x3,y3,x4,y4` field is used for online crop + augmentation during training. + + format1: sample.jpg hello + format2: sample.jpg 20 20 100 20 100 40 20 40 hello + + Args: + ann_file (str): Annotation file path. + pipeline (list[dict]): Processing pipeline. + loader (dict): Dictionary to construct loader + to load annotation infos. + img_prefix (str, optional): Image prefix to generate full + image path. + test_mode (bool, optional): If set True, try...except will + be turned off in __getitem__. + """ + + def __init__(self, + ann_file, + loader, + pipeline, + img_prefix='', + test_mode=False): + super().__init__() + self.test_mode = test_mode + self.img_prefix = img_prefix + self.ann_file = ann_file + # load annotations + loader.update(ann_file=ann_file) + self.data_infos = build_loader(loader) + # processing pipeline + self.pipeline = Compose(pipeline) + # set group flag and class, no meaning + # for text detect and recognize + self._set_group_flag() + self.CLASSES = 0 + + def __len__(self): + return len(self.data_infos) + + def _set_group_flag(self): + """Set flag.""" + self.flag = np.zeros(len(self), dtype=np.uint8) + + def pre_pipeline(self, results): + """Prepare results dict for pipeline.""" + results['img_prefix'] = self.img_prefix + + def prepare_train_img(self, index): + """Get training data and annotations from pipeline. + + Args: + index (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys \ + introduced by pipeline. + """ + img_info = self.data_infos[index] + results = dict(img_info=img_info) + self.pre_pipeline(results) + return self.pipeline(results) + + def prepare_test_img(self, img_info): + """Get testing data from pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Testing data after pipeline with new keys introduced by \ + pipeline. + """ + return self.prepare_train_img(img_info) + + def _log_error_index(self, index): + """Logging data info of bad index.""" + try: + data_info = self.data_infos[index] + img_prefix = self.img_prefix + print_log(f'Warning: skip broken file {data_info} ' + f'with img_prefix {img_prefix}') + except Exception as e: + print_log(f'load index {index} with error {e}') + + def _get_next_index(self, index): + """Get next index from dataset.""" + self._log_error_index(index) + index = (index + 1) % len(self) + return index + + def __getitem__(self, index): + """Get training/test data from pipeline. + + Args: + index (int): Index of data. + + Returns: + dict: Training/test data. + """ + if self.test_mode: + return self.prepare_test_img(index) + + while True: + try: + data = self.prepare_train_img(index) + if data is None: + raise Exception('prepared train data empty') + break + except Exception as e: + print_log(f'prepare index {index} with error {e}') + index = self._get_next_index(index) + return data + + def format_results(self, results, **kwargs): + """Placeholder to format result to dataset-specific output.""" + pass + + def evaluate(self, results, metric=None, logger=None, **kwargs): + """Evaluate the dataset. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + Returns: + dict[str: float] + """ + raise NotImplementedError diff --git a/mmocr/datasets/builder.py b/mmocr/datasets/builder.py new file mode 100644 index 00000000..e7bcf423 --- /dev/null +++ b/mmocr/datasets/builder.py @@ -0,0 +1,14 @@ +from mmcv.utils import Registry, build_from_cfg + +LOADERS = Registry('loader') +PARSERS = Registry('parser') + + +def build_loader(cfg): + """Build anno file loader.""" + return build_from_cfg(cfg, LOADERS) + + +def build_parser(cfg): + """Build anno file parser.""" + return build_from_cfg(cfg, PARSERS) diff --git a/mmocr/datasets/ocr_dataset.py b/mmocr/datasets/ocr_dataset.py new file mode 100644 index 00000000..4ec6d962 --- /dev/null +++ b/mmocr/datasets/ocr_dataset.py @@ -0,0 +1,34 @@ +from mmdet.datasets.builder import DATASETS +from mmocr.core.evaluation.ocr_metric import eval_ocr_metric +from mmocr.datasets.base_dataset import BaseDataset + + +@DATASETS.register_module() +class OCRDataset(BaseDataset): + + def pre_pipeline(self, results): + results['img_prefix'] = self.img_prefix + results['text'] = results['img_info']['text'] + + def evaluate(self, results, metric='acc', logger=None, **kwargs): + """Evaluate the dataset. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + Returns: + dict[str: float] + """ + gt_texts = [] + pred_texts = [] + for i in range(len(self)): + item_info = self.data_infos[i] + text = item_info['text'] + gt_texts.append(text) + pred_texts.append(results[i]['text']) + + eval_results = eval_ocr_metric(pred_texts, gt_texts) + + return eval_results diff --git a/mmocr/datasets/ocr_seg_dataset.py b/mmocr/datasets/ocr_seg_dataset.py new file mode 100644 index 00000000..0b149af1 --- /dev/null +++ b/mmocr/datasets/ocr_seg_dataset.py @@ -0,0 +1,90 @@ +import mmocr.utils as utils +from mmdet.datasets.builder import DATASETS +from mmocr.datasets.ocr_dataset import OCRDataset + + +@DATASETS.register_module() +class OCRSegDataset(OCRDataset): + + def pre_pipeline(self, results): + results['img_prefix'] = self.img_prefix + + def _parse_anno_info(self, annotations): + """Parse char boxes annotations. + Args: + annotations (list[dict]): Annotations of one image, where + each dict is for one character. + + Returns: + dict: A dict containing the following keys: + + - chars (list[str]): List of character strings. + - char_rects (list[list[float]]): List of char box, with each + in style of rectangle: [x_min, y_min, x_max, y_max]. + - char_quads (list[list[float]]): List of char box, with each + in style of quadrangle: [x1, y1, x2, y2, x3, y3, x4, y4]. + """ + + assert utils.is_type_list(annotations, dict) + assert 'char_box' in annotations[0] + assert 'char_text' in annotations[0] + assert len(annotations[0]['char_box']) == 4 or \ + len(annotations[0]['char_box']) == 8 + + chars, char_rects, char_quads = [], [], [] + for ann in annotations: + char_box = ann['char_box'] + if len(char_box) == 4: + char_box_type = ann.get('char_box_type', 'xyxy') + if char_box_type == 'xyxy': + char_rects.append(char_box) + char_quads.append([ + char_box[0], char_box[1], char_box[2], char_box[1], + char_box[2], char_box[3], char_box[0], char_box[3] + ]) + elif char_box_type == 'xywh': + x1, y1, w, h = char_box + x2 = x1 + w + y2 = y1 + h + char_rects.append([x1, y1, x2, y2]) + char_quads.append([x1, y1, x2, y1, x2, y2, x1, y2]) + else: + raise ValueError(f'invalid char_box_type {char_box_type}') + elif len(char_box) == 8: + x_list, y_list = [], [] + for i in range(4): + x_list.append(char_box[2 * i]) + y_list.append(char_box[2 * i + 1]) + x_max, x_min = max(x_list), min(x_list) + y_max, y_min = max(y_list), min(y_list) + char_rects.append([x_min, y_min, x_max, y_max]) + char_quads.append(char_box) + else: + raise Exception( + f'invalid num in char box: {len(char_box)} not in (4, 8)') + chars.append(ann['char_text']) + + ann = dict(chars=chars, char_rects=char_rects, char_quads=char_quads) + + return ann + + def prepare_train_img(self, index): + """Get training data and annotations from pipeline. + + Args: + index (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys \ + introduced by pipeline. + """ + img_ann_info = self.data_infos[index] + img_info = { + 'filename': img_ann_info['file_name'], + } + ann_info = self._parse_anno_info(img_ann_info['annotations']) + results = dict(img_info=img_info, ann_info=ann_info) + + self.pre_pipeline(results) + + return self.pipeline(results) diff --git a/mmocr/datasets/pipelines/crop.py b/mmocr/datasets/pipelines/crop.py new file mode 100644 index 00000000..c2e57bdd --- /dev/null +++ b/mmocr/datasets/pipelines/crop.py @@ -0,0 +1,185 @@ +import cv2 +import numpy as np +from shapely.geometry import LineString, Point, Polygon + +import mmocr.utils as utils + + +def sort_vertex(points_x, points_y): + """Sort box vertices in clockwise order from left-top first. + + Args: + points_x (list[float]): x of four vertices. + points_y (list[float]): y of four vertices. + Returns: + sorted_points_x (list[float]): x of sorted four vertices. + sorted_points_y (list[float]): y of sorted four vertices. + """ + assert utils.is_type_list(points_x, float) or utils.is_type_list( + points_x, int) + assert utils.is_type_list(points_y, float) or utils.is_type_list( + points_y, int) + assert len(points_x) == 4 + assert len(points_y) == 4 + + x = np.array(points_x) + y = np.array(points_y) + center_x = np.sum(x) * 0.25 + center_y = np.sum(y) * 0.25 + + x_arr = np.array(x - center_x) + y_arr = np.array(y - center_y) + + angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi + sort_idx = np.argsort(angle) + + sorted_points_x, sorted_points_y = [], [] + for i in range(4): + sorted_points_x.append(points_x[sort_idx[i]]) + sorted_points_y.append(points_y[sort_idx[i]]) + + return convert_canonical(sorted_points_x, sorted_points_y) + + +def convert_canonical(points_x, points_y): + """Make left-top be first. + + Args: + points_x (list[float]): x of four vertices. + points_y (list[float]): y of four vertices. + Returns: + sorted_points_x (list[float]): x of sorted four vertices. + sorted_points_y (list[float]): y of sorted four vertices. + """ + assert utils.is_type_list(points_x, float) or utils.is_type_list( + points_x, int) + assert utils.is_type_list(points_y, float) or utils.is_type_list( + points_y, int) + assert len(points_x) == 4 + assert len(points_y) == 4 + + points = [Point(points_x[i], points_y[i]) for i in range(4)] + + polygon = Polygon([(p.x, p.y) for p in points]) + min_x, min_y, _, _ = polygon.bounds + points_to_lefttop = [ + LineString([points[i], Point(min_x, min_y)]) for i in range(4) + ] + distances = np.array([line.length for line in points_to_lefttop]) + sort_dist_idx = np.argsort(distances) + lefttop_idx = sort_dist_idx[0] + + if lefttop_idx == 0: + point_orders = [0, 1, 2, 3] + elif lefttop_idx == 1: + point_orders = [1, 2, 3, 0] + elif lefttop_idx == 2: + point_orders = [2, 3, 0, 1] + else: + point_orders = [3, 0, 1, 2] + + sorted_points_x = [points_x[i] for i in point_orders] + sorted_points_y = [points_y[j] for j in point_orders] + + return sorted_points_x, sorted_points_y + + +def box_jitter(points_x, points_y, jitter_ratio_x=0.5, jitter_ratio_y=0.1): + """Jitter on the coordinates of bounding box. + + Args: + points_x (list[float | int]): List of y for four vertices. + points_y (list[float | int]): List of x for four vertices. + jitter_ratio_x (float): Horizontal jitter ratio relative to the height. + jitter_ratio_y (float): Vertical jitter ratio relative to the height. + """ + assert len(points_x) == 4 + assert len(points_y) == 4 + assert isinstance(jitter_ratio_x, float) + assert isinstance(jitter_ratio_y, float) + assert 0 <= jitter_ratio_x < 1 + assert 0 <= jitter_ratio_y < 1 + + points = [Point(points_x[i], points_y[i]) for i in range(4)] + line_list = [ + LineString([points[i], points[i + 1 if i < 3 else 0]]) + for i in range(4) + ] + + tmp_h = max(line_list[1].length, line_list[3].length) + + for i in range(4): + jitter_pixel_x = (np.random.rand() - 0.5) * 2 * jitter_ratio_x * tmp_h + jitter_pixel_y = (np.random.rand() - 0.5) * 2 * jitter_ratio_y * tmp_h + points_x[i] += jitter_pixel_x + points_y[i] += jitter_pixel_y + + +def warp_img(src_img, + box, + jitter_flag=False, + jitter_ratio_x=0.5, + jitter_ratio_y=0.1): + """Crop box area from image using opencv warpPerspective w/o box jitter. + + Args: + src_img (np.array): Image before cropping. + box (list[float | int]): Coordinates of quadrangle. + """ + assert utils.is_type_list(box, float) or utils.is_type_list(box, int) + assert len(box) == 8 + + h, w = src_img.shape[:2] + points_x = [min(max(x, 0), w) for x in box[0:8:2]] + points_y = [min(max(y, 0), h) for y in box[1:9:2]] + + points_x, points_y = sort_vertex(points_x, points_y) + + if jitter_flag: + box_jitter( + points_x, + points_y, + jitter_ratio_x=jitter_ratio_x, + jitter_ratio_y=jitter_ratio_y) + + points = [Point(points_x[i], points_y[i]) for i in range(4)] + edges = [ + LineString([points[i], points[i + 1 if i < 3 else 0]]) + for i in range(4) + ] + + pts1 = np.float32([[points[i].x, points[i].y] for i in range(4)]) + box_width = max(edges[0].length, edges[2].length) + box_height = max(edges[1].length, edges[3].length) + + pts2 = np.float32([[0, 0], [box_width, 0], [box_width, box_height], + [0, box_height]]) + M = cv2.getPerspectiveTransform(pts1, pts2) + dst_img = cv2.warpPerspective(src_img, M, + (int(box_width), int(box_height))) + + return dst_img + + +def crop_img(src_img, box): + """Crop box area to rectangle. + + Args: + src_img (np.array): Image before crop. + box (list[float | int]): Points of quadrangle. + """ + assert utils.is_type_list(box, float) or utils.is_type_list(box, int) + assert len(box) == 8 + + h, w = src_img.shape[:2] + points_x = [min(max(x, 0), w) for x in box[0:8:2]] + points_y = [min(max(y, 0), h) for y in box[1:9:2]] + + left = int(min(points_x)) + top = int(min(points_y)) + right = int(max(points_x)) + bottom = int(max(points_y)) + + dst_img = src_img[top:bottom, left:right] + + return dst_img diff --git a/mmocr/datasets/pipelines/ocr_seg_targets.py b/mmocr/datasets/pipelines/ocr_seg_targets.py new file mode 100644 index 00000000..cbd9b869 --- /dev/null +++ b/mmocr/datasets/pipelines/ocr_seg_targets.py @@ -0,0 +1,201 @@ +import cv2 +import numpy as np + +import mmocr.utils.check_argument as check_argument +from mmdet.core import BitmapMasks +from mmdet.datasets.builder import PIPELINES +from mmocr.models.builder import build_convertor + + +@PIPELINES.register_module() +class OCRSegTargets: + """Generate gt shrinked kernels for segmentation based OCR framework. + + Args: + label_convertor (dict): Dictionary to construct label_convertor + to convert char to index. + attn_shrink_ratio (float): The area shrinked ratio + between attention kernels and gt text masks. + seg_shrink_ratio (float): The area shrinked ratio + between segmentation kernels and gt text masks. + box_type (str): Character box type, should be either + 'char_rects' or 'char_quads', with 'char_rects' + for rectangle with ``xyxy`` style and 'char_quads' + for quadrangle with ``x1y1x2y2x3y3x4y4`` style. + """ + + def __init__(self, + label_convertor=None, + attn_shrink_ratio=0.5, + seg_shrink_ratio=0.25, + box_type='char_rects', + pad_val=255): + + assert isinstance(attn_shrink_ratio, float) + assert isinstance(seg_shrink_ratio, float) + assert 0. < attn_shrink_ratio < 1.0 + assert 0. < seg_shrink_ratio < 1.0 + assert label_convertor is not None + assert box_type in ('char_rects', 'char_quads') + + self.attn_shrink_ratio = attn_shrink_ratio + self.seg_shrink_ratio = seg_shrink_ratio + self.label_convertor = build_convertor(label_convertor) + self.box_type = box_type + self.pad_val = pad_val + + def shrink_char_quad(self, char_quad, shrink_ratio): + """Shrink char box in style of quadrangle. + + Args: + char_quad (list[float]): Char box with format + [x1, y1, x2, y2, x3, y3, x4, y4]. + shrink_ratio (float): The area shrinked ratio + between gt kernels and gt text masks. + """ + points = [[char_quad[0], char_quad[1]], [char_quad[2], char_quad[3]], + [char_quad[4], char_quad[5]], [char_quad[6], char_quad[7]]] + shrink_points = [] + for p_idx, point in enumerate(points): + p1 = points[(p_idx + 3) % 4] + p2 = points[(p_idx + 1) % 4] + + dist1 = self.l2_dist_two_points(p1, point) + dist2 = self.l2_dist_two_points(p2, point) + min_dist = min(dist1, dist2) + + v1 = [p1[0] - point[0], p1[1] - point[1]] + v2 = [p2[0] - point[0], p2[1] - point[1]] + + temp_dist1 = (shrink_ratio * min_dist / + dist1) if min_dist != 0 else 0. + temp_dist2 = (shrink_ratio * min_dist / + dist2) if min_dist != 0 else 0. + + v1 = [temp * temp_dist1 for temp in v1] + v2 = [temp * temp_dist2 for temp in v2] + + shrink_point = [ + round(point[0] + v1[0] + v2[0]), + round(point[1] + v1[1] + v2[1]) + ] + shrink_points.append(shrink_point) + + poly = np.array(shrink_points) + + return poly + + def shrink_char_rect(self, char_rect, shrink_ratio): + """Shrink char box in style of rectangle. + + Args: + char_rect (list[float]): Char box with format + [x_min, y_min, x_max, y_max]. + shrink_ratio (float): The area shrinked ratio + between gt kernels and gt text masks. + """ + x_min, y_min, x_max, y_max = char_rect + w = x_max - x_min + h = y_max - y_min + x_min_s = round((x_min + x_max - w * shrink_ratio) / 2) + y_min_s = round((y_min + y_max - h * shrink_ratio) / 2) + x_max_s = round((x_min + x_max + w * shrink_ratio) / 2) + y_max_s = round((y_min + y_max + h * shrink_ratio) / 2) + poly = np.array([[x_min_s, y_min_s], [x_max_s, y_min_s], + [x_max_s, y_max_s], [x_min_s, y_max_s]]) + + return poly + + def generate_kernels(self, + resize_shape, + pad_shape, + char_boxes, + char_inds, + shrink_ratio=0.5, + binary=True): + """Generate char instance kernels for one shrink ratio. + + Args: + resize_shape (tuple(int, int)): Image size (height, width) + after resizing. + pad_shape (tuple(int, int)): Image size (height, width) + after padding. + char_boxes (list[list[float]]): The list of char polygons. + char_inds (list[int]): List of char indexes. + shrink_ratio (float): The shrink ratio of kernel. + binary (bool): If True, return binary ndarray + containing 0 & 1 only. + Returns: + char_kernel (ndarray): The text kernel mask of (height, width). + """ + assert isinstance(resize_shape, tuple) + assert isinstance(pad_shape, tuple) + assert check_argument.is_2dlist(char_boxes) + assert check_argument.is_type_list(char_inds, int) + assert isinstance(shrink_ratio, float) + assert isinstance(binary, bool) + + char_kernel = np.zeros(pad_shape, dtype=np.int32) + char_kernel[:resize_shape[0], resize_shape[1]:] = self.pad_val + + for i, char_box in enumerate(char_boxes): + if self.box_type == 'char_rects': + poly = self.shrink_char_rect(char_box, shrink_ratio) + elif self.box_type == 'char_quads': + poly = self.shrink_char_quad(char_box, shrink_ratio) + + fill_value = 1 if binary else char_inds[i] + cv2.fillConvexPoly(char_kernel, poly.astype(np.int32), + (fill_value)) + + return char_kernel + + def l2_dist_two_points(self, p1, p2): + return ((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5 + + def __call__(self, results): + img_shape = results['img_shape'] + resize_shape = results['resize_shape'] + + h_scale = 1.0 * resize_shape[0] / img_shape[0] + w_scale = 1.0 * resize_shape[1] / img_shape[1] + + char_boxes, char_inds = [], [] + char_num = len(results['ann_info'][self.box_type]) + for i in range(char_num): + char_box = results['ann_info'][self.box_type][i] + num_points = 2 if self.box_type == 'char_rects' else 4 + for j in range(num_points): + char_box[j * 2] = round(char_box[j * 2] * w_scale) + char_box[j * 2 + 1] = round(char_box[j * 2 + 1] * h_scale) + char_boxes.append(char_box) + char = results['ann_info']['chars'][i] + char_ind = self.label_convertor.str2idx([char])[0][0] + char_inds.append(char_ind) + + resize_shape = tuple(results['resize_shape'][:2]) + pad_shape = tuple(results['pad_shape'][:2]) + binary_target = self.generate_kernels( + resize_shape, + pad_shape, + char_boxes, + char_inds, + shrink_ratio=self.attn_shrink_ratio, + binary=True) + + seg_target = self.generate_kernels( + resize_shape, + pad_shape, + char_boxes, + char_inds, + shrink_ratio=self.seg_shrink_ratio, + binary=False) + + mask = np.ones(pad_shape, dtype=np.int32) + mask[:resize_shape[0], resize_shape[1]:] = 0 + + results['gt_kernels'] = BitmapMasks([binary_target, seg_target, mask], + pad_shape[0], pad_shape[1]) + results['mask_fields'] = ['gt_kernels'] + + return results diff --git a/mmocr/datasets/pipelines/ocr_transforms.py b/mmocr/datasets/pipelines/ocr_transforms.py new file mode 100644 index 00000000..afddd40e --- /dev/null +++ b/mmocr/datasets/pipelines/ocr_transforms.py @@ -0,0 +1,447 @@ +import math + +import cv2 +import mmcv +import numpy as np +import torch +import torchvision.transforms.functional as TF +from mmcv.runner.dist_utils import get_dist_info +from PIL import Image +from shapely.geometry import Polygon +from shapely.geometry import box as shapely_box + +import mmocr.utils as utils +from mmdet.datasets.builder import PIPELINES +from mmocr.datasets.pipelines.crop import warp_img + + +@PIPELINES.register_module() +class ResizeOCR: + """Image resizing and padding for OCR. + + Args: + height (int | tuple(int)): Image height after resizing. + min_width (none | int | tuple(int)): Image minimum width + after resizing. + max_width (none | int | tuple(int)): Image maximum width + after resizing. + keep_aspect_ratio (bool): Keep image aspect ratio if True + during resizing, Otherwise resize to the size height * + max_width. + img_pad_value (int): Scalar to fill padding area. + width_downsample_ratio (float): Downsample ratio in horizontal + direction from input image to output feature. + """ + + def __init__(self, + height, + min_width=None, + max_width=None, + keep_aspect_ratio=True, + img_pad_value=0, + width_downsample_ratio=1.0 / 16): + assert isinstance(height, (int, tuple)) + assert utils.is_none_or_type(min_width, (int, tuple)) + assert utils.is_none_or_type(max_width, (int, tuple)) + if not keep_aspect_ratio: + assert max_width is not None, \ + '"max_width" must assigned ' + \ + 'if "keep_aspect_ratio" is False' + assert isinstance(img_pad_value, int) + if isinstance(height, tuple): + assert isinstance(min_width, tuple) + assert isinstance(max_width, tuple) + assert len(height) == len(min_width) == len(max_width) + + self.height = height + self.min_width = min_width + self.max_width = max_width + self.keep_aspect_ratio = keep_aspect_ratio + self.img_pad_value = img_pad_value + self.width_downsample_ratio = width_downsample_ratio + + def __call__(self, results): + rank, _ = get_dist_info() + if isinstance(self.height, int): + dst_height = self.height + dst_min_width = self.min_width + dst_max_width = self.max_width + else: + """Multi-scale resize used in distributed training. + + Choose one (height, width) pair for one rank id. + """ + idx = rank % len(self.height) + dst_height = self.height[idx] + dst_min_width = self.min_width[idx] + dst_max_width = self.max_width[idx] + + img_shape = results['img_shape'] + ori_height, ori_width = img_shape[:2] + valid_ratio = 1.0 + resize_shape = list(img_shape) + pad_shape = list(img_shape) + + if self.keep_aspect_ratio: + new_width = math.ceil(float(dst_height) / ori_height * ori_width) + width_divisor = int(1 / self.width_downsample_ratio) + # make sure new_width is an integral multiple of width_divisor. + if new_width % width_divisor != 0: + new_width = round(new_width / width_divisor) * width_divisor + if dst_min_width is not None: + new_width = max(dst_min_width, new_width) + if dst_max_width is not None: + valid_ratio = min(1.0, 1.0 * new_width / dst_max_width) + resize_width = min(dst_max_width, new_width) + img_resize = cv2.resize(results['img'], + (resize_width, dst_height)) + resize_shape = img_resize.shape + pad_shape = img_resize.shape + if new_width < dst_max_width: + img_resize = mmcv.impad( + img_resize, + shape=(dst_height, dst_max_width), + pad_val=self.img_pad_value) + pad_shape = img_resize.shape + else: + img_resize = cv2.resize(results['img'], + (new_width, dst_height)) + resize_shape = img_resize.shape + pad_shape = img_resize.shape + else: + img_resize = cv2.resize(results['img'], + (dst_max_width, dst_height)) + resize_shape = img_resize.shape + pad_shape = img_resize.shape + + results['img'] = img_resize + results['resize_shape'] = resize_shape + results['pad_shape'] = pad_shape + results['valid_ratio'] = valid_ratio + + return results + + +@PIPELINES.register_module() +class ToTensorOCR: + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.""" + + def __init__(self): + pass + + def __call__(self, results): + results['img'] = TF.to_tensor(results['img'].copy()) + + return results + + +@PIPELINES.register_module() +class NormalizeOCR: + """Normalize a tensor image with mean and standard deviation.""" + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, results): + results['img'] = TF.normalize(results['img'], self.mean, self.std) + + return results + + +@PIPELINES.register_module() +class OnlineCropOCR: + """Crop text areas from whole image with bounding box jitter. If no bbox is + given, return directly. + + Args: + box_keys (list[str]): Keys in results which correspond to RoI bbox. + jitter_prob (float): The probability of box jitter. + max_jitter_ratio_x (float): Maximum horizontal jitter ratio + relative to height. + max_jitter_ratio_y (float): Maximum vertical jitter ratio + relative to height. + """ + + def __init__(self, + box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'], + jitter_prob=0.5, + max_jitter_ratio_x=0.05, + max_jitter_ratio_y=0.02): + assert utils.is_type_list(box_keys, str) + assert 0 <= jitter_prob <= 1 + assert 0 <= max_jitter_ratio_x <= 1 + assert 0 <= max_jitter_ratio_y <= 1 + + self.box_keys = box_keys + self.jitter_prob = jitter_prob + self.max_jitter_ratio_x = max_jitter_ratio_x + self.max_jitter_ratio_y = max_jitter_ratio_y + + def __call__(self, results): + + if 'img_info' not in results: + return results + + crop_flag = True + box = [] + for key in self.box_keys: + if key not in results['img_info']: + crop_flag = False + break + + box.append(float(results['img_info'][key])) + + if not crop_flag: + return results + + jitter_flag = np.random.random() > self.jitter_prob + + kwargs = dict( + jitter_flag=jitter_flag, + jitter_ratio_x=self.max_jitter_ratio_x, + jitter_ratio_y=self.max_jitter_ratio_y) + crop_img = warp_img(results['img'], box, **kwargs) + + results['img'] = crop_img + results['img_shape'] = crop_img.shape + + return results + + +@PIPELINES.register_module() +class FancyPCA: + """Implementation of PCA based image augmentation, proposed in the paper + ``Imagenet Classification With Deep Convolutional Neural Networks``. + + It alters the intensities of RGB values along the principal components of + ImageNet dataset. + """ + + def __init__(self, eig_vec=None, eig_val=None): + if eig_vec is None: + eig_vec = torch.Tensor([ + [-0.5675, +0.7192, +0.4009], + [-0.5808, -0.0045, -0.8140], + [-0.5836, -0.6948, +0.4203], + ]).t() + if eig_val is None: + eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]]) + self.eig_val = eig_val # 1*3 + self.eig_vec = eig_vec # 3*3 + + def pca(self, tensor): + assert tensor.size(0) == 3 + alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1 + reconst = torch.mm(self.eig_val * alpha, self.eig_vec) + tensor = tensor + reconst.view(3, 1, 1) + + return tensor + + def __call__(self, results): + img = results['img'] + tensor = self.pca(img) + results['img'] = tensor + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class RandomPaddingOCR: + """Pad the given image on all sides, as well as modify the coordinates of + character bounding box in image. + + Args: + max_ratio (list[int]): [left, top, right, bottom]. + box_type (None|str): Character box type. If not none, + should be either 'char_rects' or 'char_quads', with + 'char_rects' for rectangle with ``xyxy`` style and + 'char_quads' for quadrangle with ``x1y1x2y2x3y3x4y4`` style. + """ + + def __init__(self, max_ratio=None, box_type=None): + if max_ratio is None: + max_ratio = [0.1, 0.2, 0.1, 0.2] + else: + assert utils.is_type_list(max_ratio, float) + assert len(max_ratio) == 4 + assert box_type is None or box_type in ('char_rects', 'char_quads') + + self.max_ratio = max_ratio + self.box_type = box_type + + def __call__(self, results): + + img_shape = results['img_shape'] + ori_height, ori_width = img_shape[:2] + + random_padding_left = round( + np.random.uniform(0, self.max_ratio[0]) * ori_width) + random_padding_top = round( + np.random.uniform(0, self.max_ratio[1]) * ori_height) + random_padding_right = round( + np.random.uniform(0, self.max_ratio[2]) * ori_width) + random_padding_bottom = round( + np.random.uniform(0, self.max_ratio[3]) * ori_height) + + img = np.copy(results['img']) + img = cv2.copyMakeBorder(img, random_padding_top, + random_padding_bottom, random_padding_left, + random_padding_right, cv2.BORDER_REPLICATE) + results['img'] = img + results['img_shape'] = img.shape + + if self.box_type is not None: + num_points = 2 if self.box_type == 'char_rects' else 4 + char_num = len(results['ann_info'][self.box_type]) + for i in range(char_num): + for j in range(num_points): + results['ann_info'][self.box_type][i][ + j * 2] += random_padding_left + results['ann_info'][self.box_type][i][ + j * 2 + 1] += random_padding_top + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class RandomRotateImageBox: + """Rotate augmentation for segmentation based text recognition. + + Args: + min_angle (int): Minimum rotation angle for image and box. + max_angle (int): Maximum rotation angle for image and box. + box_type (str): Character box type, should be either + 'char_rects' or 'char_quads', with 'char_rects' + for rectangle with ``xyxy`` style and 'char_quads' + for quadrangle with ``x1y1x2y2x3y3x4y4`` style. + """ + + def __init__(self, min_angle=-10, max_angle=10, box_type='char_quads'): + assert box_type in ('char_rects', 'char_quads') + + self.min_angle = min_angle + self.max_angle = max_angle + self.box_type = box_type + + def __call__(self, results): + in_img = results['img'] + in_chars = results['ann_info']['chars'] + in_boxes = results['ann_info'][self.box_type] + + img_width, img_height = in_img.size + rotate_center = [img_width / 2., img_height / 2.] + + tan_temp_max_angle = rotate_center[1] / rotate_center[0] + temp_max_angle = np.arctan(tan_temp_max_angle) * 180. / np.pi + + random_angle = np.random.uniform( + max(self.min_angle, -temp_max_angle), + min(self.max_angle, temp_max_angle)) + random_angle_radian = random_angle * np.pi / 180. + + img_box = shapely_box(0, 0, img_width, img_height) + + out_img = TF.rotate( + in_img, + random_angle, + resample=False, + expand=False, + center=rotate_center) + + out_boxes, out_chars = self.rotate_bbox(in_boxes, in_chars, + random_angle_radian, + rotate_center, img_box) + + results['img'] = out_img + results['ann_info']['chars'] = out_chars + results['ann_info'][self.box_type] = out_boxes + + return results + + @staticmethod + def rotate_bbox(boxes, chars, angle, center, img_box): + out_boxes = [] + out_chars = [] + for idx, bbox in enumerate(boxes): + temp_bbox = [] + for i in range(len(bbox) // 2): + point = [bbox[2 * i], bbox[2 * i + 1]] + temp_bbox.append( + RandomRotateImageBox.rotate_point(point, angle, center)) + poly_temp_bbox = Polygon(temp_bbox).buffer(0) + if poly_temp_bbox.is_valid: + if img_box.intersects(poly_temp_bbox) and ( + not img_box.touches(poly_temp_bbox)): + temp_bbox_area = poly_temp_bbox.area + + intersect_area = img_box.intersection(poly_temp_bbox).area + intersect_ratio = intersect_area / temp_bbox_area + + if intersect_ratio >= 0.7: + out_box = [] + for p in temp_bbox: + out_box.extend(p) + out_boxes.append(out_box) + out_chars.append(chars[idx]) + + return out_boxes, out_chars + + @staticmethod + def rotate_point(point, angle, center): + cos_theta = math.cos(-angle) + sin_theta = math.sin(-angle) + c_x = center[0] + c_y = center[1] + new_x = (point[0] - c_x) * cos_theta - (point[1] - + c_y) * sin_theta + c_x + new_y = (point[0] - c_x) * sin_theta + (point[1] - + c_y) * cos_theta + c_y + + return [new_x, new_y] + + +@PIPELINES.register_module() +class OpencvToPil: + """Convert ``numpy.ndarray`` (bgr) to ``PIL Image`` (rgb).""" + + def __init__(self, **kwargs): + pass + + def __call__(self, results): + img = results['img'][..., ::-1] + img = Image.fromarray(img) + results['img'] = img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class PilToOpencv: + """Convert ``PIL Image`` (rgb) to ``numpy.ndarray`` (bgr).""" + + def __init__(self, **kwargs): + pass + + def __call__(self, results): + img = np.asarray(results['img']) + img = img[..., ::-1] + results['img'] = img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str diff --git a/mmocr/datasets/pipelines/test_time_aug.py b/mmocr/datasets/pipelines/test_time_aug.py new file mode 100644 index 00000000..5c8c1a60 --- /dev/null +++ b/mmocr/datasets/pipelines/test_time_aug.py @@ -0,0 +1,108 @@ +import mmcv +import numpy as np + +from mmdet.datasets.builder import PIPELINES +from mmdet.datasets.pipelines.compose import Compose + + +@PIPELINES.register_module() +class MultiRotateAugOCR: + """Test-time augmentation with multiple rotations in the case that + img_height > img_width. + + An example configuration is as follows: + + .. code-block:: + + rotate_degrees=[0, 90, 270], + transforms=[ + dict( + type='ResizeOCR', + height=32, + min_width=32, + max_width=160, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'valid_ratio' + ]), + ] + + After MultiRotateAugOCR with above configuration, the results are wrapped + into lists of the same length as follows: + + .. code-block:: + + dict( + img=[...], + img_shape=[...] + ... + ) + + Args: + transforms (list[dict]): Transformation applied for each augmentation. + rotate_degrees (list[int] | None): Degrees of anti-clockwise rotation. + force_rotate (bool): If True, rotate image by 'rotate_degrees' + while ignore image aspect ratio. + """ + + def __init__(self, transforms, rotate_degrees=None, force_rotate=False): + self.transforms = Compose(transforms) + self.force_rotate = force_rotate + if rotate_degrees is not None: + self.rotate_degrees = rotate_degrees if isinstance( + rotate_degrees, list) else [rotate_degrees] + assert mmcv.is_list_of(self.rotate_degrees, int) + for degree in self.rotate_degrees: + assert 0 <= degree < 360 + assert degree % 90 == 0 + if 0 not in self.rotate_degrees: + self.rotate_degrees.append(0) + else: + self.rotate_degrees = [0] + + def __call__(self, results): + """Call function to apply test time augment transformation to results. + + Args: + results (dict): Result dict contains the data to be transformed. + + Returns: + dict[str: list]: The augmented data, where each value is wrapped + into a list. + """ + img_shape = results['img_shape'] + ori_height, ori_width = img_shape[:2] + if not self.force_rotate and ori_height <= ori_width: + rotate_degrees = [0] + else: + rotate_degrees = self.rotate_degrees + aug_data = [] + for degree in set(rotate_degrees): + _results = results.copy() + if degree == 0: + pass + elif degree == 90: + _results['img'] = np.rot90(_results['img'], 1) + elif degree == 180: + _results['img'] = np.rot90(_results['img'], 2) + elif degree == 270: + _results['img'] = np.rot90(_results['img'], 3) + data = self.transforms(_results) + aug_data.append(data) + # list of dict to dict of list + aug_data_dict = {key: [] for key in aug_data[0]} + for data in aug_data: + for key, val in data.items(): + aug_data_dict[key].append(val) + return aug_data_dict + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(transforms={self.transforms}, ' + repr_str += f'rotate_degrees={self.rotate_degrees})' + return repr_str diff --git a/mmocr/datasets/text_det_dataset.py b/mmocr/datasets/text_det_dataset.py new file mode 100644 index 00000000..fb5d3fa4 --- /dev/null +++ b/mmocr/datasets/text_det_dataset.py @@ -0,0 +1,121 @@ +import numpy as np + +from mmdet.datasets.builder import DATASETS +from mmocr.core.evaluation.hmean import eval_hmean +from mmocr.datasets.base_dataset import BaseDataset + + +@DATASETS.register_module() +class TextDetDataset(BaseDataset): + + def _parse_anno_info(self, annotations): + """Parse bbox and mask annotation. + Args: + annotations (dict): Annotations of one image. + + Returns: + dict: A dict containing the following keys: bboxes, bboxes_ignore, + labels, masks, masks_ignore. "masks" and + "masks_ignore" are represented by polygon boundary + point sequences. + """ + gt_bboxes, gt_bboxes_ignore = [], [] + gt_masks, gt_masks_ignore = [], [] + gt_labels = [] + for ann in annotations: + if ann.get('iscrowd', False): + gt_bboxes_ignore.append(ann['bbox']) + gt_masks_ignore.append(ann.get('segmentation', None)) + else: + gt_bboxes.append(ann['bbox']) + gt_labels.append(ann['category_id']) + gt_masks.append(ann.get('segmentation', None)) + if gt_bboxes: + gt_bboxes = np.array(gt_bboxes, dtype=np.float32) + gt_labels = np.array(gt_labels, dtype=np.int64) + else: + gt_bboxes = np.zeros((0, 4), dtype=np.float32) + gt_labels = np.array([], dtype=np.int64) + + if gt_bboxes_ignore: + gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) + else: + gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) + + ann = dict( + bboxes=gt_bboxes, + labels=gt_labels, + bboxes_ignore=gt_bboxes_ignore, + masks_ignore=gt_masks_ignore, + masks=gt_masks) + + return ann + + def prepare_train_img(self, index): + """Get training data and annotations from pipeline. + + Args: + index (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys \ + introduced by pipeline. + """ + img_ann_info = self.data_infos[index] + img_info = { + 'filename': img_ann_info['file_name'], + 'height': img_ann_info['height'], + 'width': img_ann_info['width'] + } + ann_info = self._parse_anno_info(img_ann_info['annotations']) + results = dict(img_info=img_info, ann_info=ann_info) + results['bbox_fields'] = [] + results['mask_fields'] = [] + results['seg_fields'] = [] + self.pre_pipeline(results) + + return self.pipeline(results) + + def evaluate(self, + results, + metric='hmean-iou', + score_thr=0.3, + rank_list=None, + logger=None, + **kwargs): + """Evaluate the dataset. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + score_thr (float): Score threshold for prediction map. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + rank_list (str): json file used to save eval result + of each image after ranking. + Returns: + dict[str: float] + """ + metrics = metric if isinstance(metric, list) else [metric] + allowed_metrics = ['hmean-iou', 'hmean-ic13'] + metrics = set(metrics) & set(allowed_metrics) + + img_infos = [] + ann_infos = [] + for i in range(len(self)): + img_ann_info = self.data_infos[i] + img_info = {'filename': img_ann_info['file_name']} + ann_info = self._parse_anno_info(img_ann_info['annotations']) + img_infos.append(img_info) + ann_infos.append(ann_info) + + eval_results = eval_hmean( + results, + img_infos, + ann_infos, + metrics=metrics, + score_thr=score_thr, + logger=logger, + rank_list=rank_list) + + return eval_results diff --git a/mmocr/datasets/utils/__init__.py b/mmocr/datasets/utils/__init__.py new file mode 100644 index 00000000..f014de7e --- /dev/null +++ b/mmocr/datasets/utils/__init__.py @@ -0,0 +1,4 @@ +from .loader import HardDiskLoader, LmdbLoader +from .parser import LineJsonParser, LineStrParser + +__all__ = ['HardDiskLoader', 'LmdbLoader', 'LineStrParser', 'LineJsonParser'] diff --git a/mmocr/datasets/utils/loader.py b/mmocr/datasets/utils/loader.py new file mode 100644 index 00000000..55a4d075 --- /dev/null +++ b/mmocr/datasets/utils/loader.py @@ -0,0 +1,108 @@ +import os.path as osp + +from mmocr.datasets.builder import LOADERS, build_parser + + +@LOADERS.register_module() +class Loader: + """Load annotation from annotation file, and parse instance information to + dict format with parser. + + Args: + ann_file (str): Annotation file path. + parser (dict): Dictionary to construct parser + to parse original annotation infos. + repeat (int): Repeated times of annotations. + """ + + def __init__(self, ann_file, parser, repeat=1): + assert isinstance(ann_file, str) + assert isinstance(repeat, int) + assert isinstance(parser, dict) + assert repeat > 0 + assert osp.exists(ann_file), f'{ann_file} is not exist' + + self.ori_data_infos = self._load(ann_file) + self.parser = build_parser(parser) + self.repeat = repeat + + def __len__(self): + return len(self.ori_data_infos) * self.repeat + + def _load(self, ann_file): + """Load annotation file.""" + raise NotImplementedError + + def __getitem__(self, index): + """Retrieve anno info of one instance with dict format.""" + return self.parser.get_item(self.ori_data_infos, index) + + +@LOADERS.register_module() +class HardDiskLoader(Loader): + """Load annotation file from hard disk to RAM. + + Args: + ann_file (str): Annotation file path. + """ + + def _load(self, ann_file): + data_ret = [] + with open(ann_file, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + data_ret.append(line) + + return data_ret + + +@LOADERS.register_module() +class LmdbLoader(Loader): + """Load annotation file with lmdb storage backend.""" + + def _load(self, ann_file): + lmdb_anno_obj = LmdbAnnFileBackend(ann_file) + + return lmdb_anno_obj + + +class LmdbAnnFileBackend: + """Lmdb storage backend for annotation file. + + Args: + lmdb_path (str): Lmdb file path. + """ + + def __init__(self, lmdb_path, coding='utf8'): + self.lmdb_path = lmdb_path + self.coding = coding + env = self._get_env() + with env.begin(write=False) as txn: + self.total_number = int( + txn.get('total_number'.encode(self.coding)).decode( + self.coding)) + + def __getitem__(self, index): + """Retrieval one line from lmdb file by index.""" + # only attach env to self when __getitem__ is called + # because env object cannot be pickle + if not hasattr(self, 'env'): + self.env = self._get_env() + + with self.env.begin(write=False) as txn: + line = txn.get(str(index).encode(self.coding)).decode(self.coding) + return line + + def __len__(self): + return self.total_number + + def _get_env(self): + import lmdb + return lmdb.open( + self.lmdb_path, + max_readers=1, + readonly=True, + lock=False, + readahead=False, + meminit=False, + ) diff --git a/mmocr/datasets/utils/parser.py b/mmocr/datasets/utils/parser.py new file mode 100644 index 00000000..a895e217 --- /dev/null +++ b/mmocr/datasets/utils/parser.py @@ -0,0 +1,69 @@ +import json + +from mmocr.datasets.builder import PARSERS + + +@PARSERS.register_module() +class LineStrParser: + """Parse string of one line in annotation file to dict format. + + Args: + keys (list[str]): Keys in result dict. + keys_idx (list[int]): Value index in sub-string list + for each key above. + separator (str): Separator to separate string to list of sub-string. + """ + + def __init__(self, + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' '): + assert isinstance(keys, list) + assert isinstance(keys_idx, list) + assert isinstance(separator, str) + assert len(keys) > 0 + assert len(keys) == len(keys_idx) + self.keys = keys + self.keys_idx = keys_idx + self.separator = separator + + def get_item(self, data_ret, index): + map_index = index % len(data_ret) + line_str = data_ret[map_index] + for split_key in self.separator: + if split_key != ' ': + line_str = line_str.replace(split_key, ' ') + line_str = line_str.split() + if len(line_str) <= max(self.keys_idx): + raise Exception( + f'key index: {max(self.keys_idx)} out of range: {line_str}') + + line_info = {} + for i, key in enumerate(self.keys): + line_info[key] = line_str[self.keys_idx[i]] + return line_info + + +@PARSERS.register_module() +class LineJsonParser: + """Parse json-string of one line in annotation file to dict format. + + Args: + keys (list[str]): Keys in both json-string and result dict. + """ + + def __init__(self, keys=[], **kwargs): + assert isinstance(keys, list) + assert len(keys) > 0 + self.keys = keys + + def get_item(self, data_ret, index): + map_index = index % len(data_ret) + line_json_obj = json.loads(data_ret[map_index]) + line_info = {} + for key in self.keys: + if key not in line_json_obj: + raise Exception(f'key {key} not in line json {line_json_obj}') + line_info[key] = line_json_obj[key] + + return line_info diff --git a/mmocr/models/textrecog/__init__.py b/mmocr/models/textrecog/__init__.py new file mode 100644 index 00000000..76bfa419 --- /dev/null +++ b/mmocr/models/textrecog/__init__.py @@ -0,0 +1,8 @@ +from .backbones import * # noqa: F401,F403 +from .convertors import * # noqa: F401,F403 +from .decoders import * # noqa: F401,F403 +from .encoders import * # noqa: F401,F403 +from .heads import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 +from .recognizer import * # noqa: F401,F403 diff --git a/mmocr/models/textrecog/backbones/__init__.py b/mmocr/models/textrecog/backbones/__init__.py new file mode 100644 index 00000000..5cfb1c30 --- /dev/null +++ b/mmocr/models/textrecog/backbones/__init__.py @@ -0,0 +1,4 @@ +from .resnet31_ocr import ResNet31OCR +from .very_deep_vgg import VeryDeepVgg + +__all__ = ['ResNet31OCR', 'VeryDeepVgg'] diff --git a/mmocr/models/textrecog/backbones/resnet31_ocr.py b/mmocr/models/textrecog/backbones/resnet31_ocr.py new file mode 100644 index 00000000..e2787556 --- /dev/null +++ b/mmocr/models/textrecog/backbones/resnet31_ocr.py @@ -0,0 +1,149 @@ +import torch.nn as nn +from mmcv.cnn import kaiming_init, uniform_init + +import mmocr.utils as utils +from mmdet.models.builder import BACKBONES +from mmocr.models.textrecog.layers import BasicBlock + + +@BACKBONES.register_module() +class ResNet31OCR(nn.Module): + """Implement ResNet backbone for text recognition, modified from + `ResNet `_ + Args: + base_channels (int): Number of channels of input image tensor. + layers (list[int]): List of BasicBlock number for each stage. + channels (list[int]): List of out_channels of Conv2d layer. + out_indices (None | Sequence[int]): Indicdes of output stages. + stage4_pool_cfg (dict): Dictionary to construct and configure + pooling layer in stage 4. + last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage. + """ + + def __init__(self, + base_channels=3, + layers=[1, 2, 5, 3], + channels=[64, 128, 256, 256, 512, 512, 512], + out_indices=None, + stage4_pool_cfg=dict(kernel_size=(2, 1), stride=(2, 1)), + last_stage_pool=False): + super().__init__() + assert isinstance(base_channels, int) + assert utils.is_type_list(layers, int) + assert utils.is_type_list(channels, int) + assert out_indices is None or (isinstance(out_indices, list) + or isinstance(out_indices, tuple)) + assert isinstance(last_stage_pool, bool) + + self.out_indices = out_indices + self.last_stage_pool = last_stage_pool + + # conv 1 (Conv, Conv) + self.conv1_1 = nn.Conv2d( + base_channels, channels[0], kernel_size=3, stride=1, padding=1) + self.bn1_1 = nn.BatchNorm2d(channels[0]) + self.relu1_1 = nn.ReLU(inplace=True) + + self.conv1_2 = nn.Conv2d( + channels[0], channels[1], kernel_size=3, stride=1, padding=1) + self.bn1_2 = nn.BatchNorm2d(channels[1]) + self.relu1_2 = nn.ReLU(inplace=True) + + # conv 2 (Max-pooling, Residual block, Conv) + self.pool2 = nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self.block2 = self._make_layer(channels[1], channels[2], layers[0]) + self.conv2 = nn.Conv2d( + channels[2], channels[2], kernel_size=3, stride=1, padding=1) + self.bn2 = nn.BatchNorm2d(channels[2]) + self.relu2 = nn.ReLU(inplace=True) + + # conv 3 (Max-pooling, Residual block, Conv) + self.pool3 = nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self.block3 = self._make_layer(channels[2], channels[3], layers[1]) + self.conv3 = nn.Conv2d( + channels[3], channels[3], kernel_size=3, stride=1, padding=1) + self.bn3 = nn.BatchNorm2d(channels[3]) + self.relu3 = nn.ReLU(inplace=True) + + # conv 4 (Max-pooling, Residual block, Conv) + self.pool4 = nn.MaxPool2d(padding=0, ceil_mode=True, **stage4_pool_cfg) + self.block4 = self._make_layer(channels[3], channels[4], layers[2]) + self.conv4 = nn.Conv2d( + channels[4], channels[4], kernel_size=3, stride=1, padding=1) + self.bn4 = nn.BatchNorm2d(channels[4]) + self.relu4 = nn.ReLU(inplace=True) + + # conv 5 ((Max-pooling), Residual block, Conv) + self.pool5 = None + if self.last_stage_pool: + self.pool5 = nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True) # 1/16 + self.block5 = self._make_layer(channels[4], channels[5], layers[3]) + self.conv5 = nn.Conv2d( + channels[5], channels[5], kernel_size=3, stride=1, padding=1) + self.bn5 = nn.BatchNorm2d(channels[5]) + self.relu5 = nn.ReLU(inplace=True) + + def init_weights(self, pretrained=None): + # initialize weight and bias + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, nn.BatchNorm2d): + uniform_init(m) + + def _make_layer(self, input_channels, output_channels, blocks): + layers = [] + for _ in range(blocks): + downsample = None + if input_channels != output_channels: + downsample = nn.Sequential( + nn.Conv2d( + input_channels, + output_channels, + kernel_size=1, + stride=1, + bias=False), + nn.BatchNorm2d(output_channels), + ) + layers.append( + BasicBlock( + input_channels, output_channels, downsample=downsample)) + input_channels = output_channels + + return nn.Sequential(*layers) + + def forward(self, x): + + x = self.conv1_1(x) + x = self.bn1_1(x) + x = self.relu1_1(x) + + x = self.conv1_2(x) + x = self.bn1_2(x) + x = self.relu1_2(x) + + outs = [] + for i in range(4): + layer_index = i + 2 + pool_layer = getattr(self, f'pool{layer_index}') + block_layer = getattr(self, f'block{layer_index}') + conv_layer = getattr(self, f'conv{layer_index}') + bn_layer = getattr(self, f'bn{layer_index}') + relu_layer = getattr(self, f'relu{layer_index}') + + if pool_layer is not None: + x = pool_layer(x) + x = block_layer(x) + x = conv_layer(x) + x = bn_layer(x) + x = relu_layer(x) + + outs.append(x) + + if self.out_indices is not None: + return tuple([outs[i] for i in self.out_indices]) + + return x diff --git a/mmocr/models/textrecog/convertors/__init__.py b/mmocr/models/textrecog/convertors/__init__.py new file mode 100644 index 00000000..60fc6300 --- /dev/null +++ b/mmocr/models/textrecog/convertors/__init__.py @@ -0,0 +1,6 @@ +from .attn import AttnConvertor +from .base import BaseConvertor +from .ctc import CTCConvertor +from .seg import SegConvertor + +__all__ = ['BaseConvertor', 'CTCConvertor', 'AttnConvertor', 'SegConvertor'] diff --git a/mmocr/models/textrecog/convertors/attn.py b/mmocr/models/textrecog/convertors/attn.py new file mode 100644 index 00000000..a80282e8 --- /dev/null +++ b/mmocr/models/textrecog/convertors/attn.py @@ -0,0 +1,140 @@ +import torch + +import mmocr.utils as utils +from mmocr.models.builder import CONVERTORS +from .base import BaseConvertor + + +@CONVERTORS.register_module() +class AttnConvertor(BaseConvertor): + """Convert between text, index and tensor for encoder-decoder based + pipeline. + + Args: + dict_type (str): Type of dict, should be one of {'DICT36', 'DICT90'}. + dict_file (None|str): Character dict file path. If not none, + higher priority than dict_type. + dict_list (None|list[str]): Character list. If not none, higher + priority than dict_type, but lower than dict_file. + with_unknown (bool): If True, add `UKN` token to class. + max_seq_len (int): Maximum sequence length of label. + lower (bool): If True, convert original string to lower case. + start_end_same (bool): Whether use the same index for + start and end token or not. Default: True. + """ + + def __init__(self, + dict_type='DICT90', + dict_file=None, + dict_list=None, + with_unknown=True, + max_seq_len=40, + lower=False, + start_end_same=True, + **kwargs): + super().__init__(dict_type, dict_file, dict_list) + assert isinstance(with_unknown, bool) + assert isinstance(max_seq_len, int) + assert isinstance(lower, bool) + + self.with_unknown = with_unknown + self.max_seq_len = max_seq_len + self.lower = lower + self.start_end_same = start_end_same + + self.update_dict() + + def update_dict(self): + start_end_token = '' + unknown_token = '' + padding_token = '' + + # unknown + self.unknown_idx = None + if self.with_unknown: + self.idx2char.append(unknown_token) + self.unknown_idx = len(self.idx2char) - 1 + + # BOS/EOS + self.idx2char.append(start_end_token) + self.start_idx = len(self.idx2char) - 1 + if not self.start_end_same: + self.idx2char.append(start_end_token) + self.end_idx = len(self.idx2char) - 1 + + # padding + self.idx2char.append(padding_token) + self.padding_idx = len(self.idx2char) - 1 + + # update char2idx + self.char2idx = {} + for idx, char in enumerate(self.idx2char): + self.char2idx[char] = idx + + def str2tensor(self, strings): + """ + Convert text-string into tensor. + Args: + strings (list[str]): ['hello', 'world'] + Returns: + dict (str: Tensor | list[tensor]): + tensors (list[Tensor]): [torch.Tensor([1,2,3,3,4]), + torch.Tensor([5,4,6,3,7])] + padded_targets (Tensor(bsz * max_seq_len)) + """ + assert utils.is_type_list(strings, str) + + tensors, padded_targets = [], [] + indexes = self.str2idx(strings) + for index in indexes: + tensor = torch.LongTensor(index) + tensors.append(tensor) + # target tensor for loss + src_target = torch.LongTensor(tensor.size(0) + 2).fill_(0) + src_target[-1] = self.end_idx + src_target[0] = self.start_idx + src_target[1:-1] = tensor + padded_target = (torch.ones(self.max_seq_len) * + self.padding_idx).long() + char_num = src_target.size(0) + if char_num > self.max_seq_len: + padded_target = src_target[:self.max_seq_len] + else: + padded_target[:char_num] = src_target + padded_targets.append(padded_target) + padded_targets = torch.stack(padded_targets, 0).long() + + return {'targets': tensors, 'padded_targets': padded_targets} + + def tensor2idx(self, outputs, img_metas=None): + """ + Convert output tensor to text-index + Args: + outputs (tensor): model outputs with size: N * T * C + img_metas (list[dict]): Each dict contains one image info. + Returns: + indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]] + scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94], + [0.9,0.9,0.98,0.97,0.96]] + """ + batch_size = outputs.size(0) + ignore_indexes = [self.padding_idx] + indexes, scores = [], [] + for idx in range(batch_size): + seq = outputs[idx, :, :] + max_value, max_idx = torch.max(seq, -1) + str_index, str_score = [], [] + output_index = max_idx.cpu().detach().numpy().tolist() + output_score = max_value.cpu().detach().numpy().tolist() + for char_index, char_score in zip(output_index, output_score): + if char_index in ignore_indexes: + continue + if char_index == self.end_idx: + break + str_index.append(char_index) + str_score.append(char_score) + + indexes.append(str_index) + scores.append(str_score) + + return indexes, scores diff --git a/mmocr/models/textrecog/convertors/base.py b/mmocr/models/textrecog/convertors/base.py new file mode 100644 index 00000000..9002d4fe --- /dev/null +++ b/mmocr/models/textrecog/convertors/base.py @@ -0,0 +1,115 @@ +from mmocr.models.builder import CONVERTORS + + +@CONVERTORS.register_module() +class BaseConvertor: + """Convert between text, index and tensor for text recognize pipeline. + + Args: + dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'. + dict_file (None|str): Character dict file path. If not none, + the dict_file is of higher priority than dict_type. + dict_list (None|list[str]): Character list. If not none, the list + is of higher priority than dict_type, but lower than dict_file. + """ + start_idx = end_idx = padding_idx = 0 + unknown_idx = None + lower = False + + DICT36 = tuple('0123456789abcdefghijklmnopqrstuvwxyz') + DICT90 = tuple('0123456789abcdefghijklmnopqrstuvwxyz' + 'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()' + '*+,-./:;<=>?@[\\]_`~') + + def __init__(self, dict_type='DICT90', dict_file=None, dict_list=None): + assert dict_type in ('DICT36', 'DICT90') + assert dict_file is None or isinstance(dict_file, str) + assert dict_list is None or isinstance(dict_list, list) + self.idx2char = [] + if dict_file is not None: + with open(dict_file, encoding='utf-8') as fr: + for line in fr: + line = line.strip() + if line != '': + self.idx2char.append(line) + elif dict_list is not None: + self.idx2char = dict_list + else: + if dict_type == 'DICT36': + self.idx2char = list(self.DICT36) + else: + self.idx2char = list(self.DICT90) + + self.char2idx = {} + for idx, char in enumerate(self.idx2char): + self.char2idx[char] = idx + + def num_classes(self): + """Number of output classes.""" + return len(self.idx2char) + + def str2idx(self, strings): + """Convert strings to indexes. + + Args: + strings (list[str]): ['hello', 'world']. + Returns: + indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]]. + """ + assert isinstance(strings, list) + + indexes = [] + for string in strings: + if self.lower: + string = string.lower() + index = [] + for char in string: + char_idx = self.char2idx.get(char, self.unknown_idx) + if char_idx is None: + raise Exception(f'Chararcter: {char} not in dict,' + f' please check gt_label and use' + f' custom dict file,' + f' or set "with_unknown=True"') + index.append(char_idx) + indexes.append(index) + + return indexes + + def str2tensor(self, strings): + """Convert text-string to input tensor. + + Args: + strings (list[str]): ['hello', 'world']. + Returns: + tensors (list[torch.Tensor]): [torch.Tensor([1,2,3,3,4]), + torch.Tensor([5,4,6,3,7])]. + """ + raise NotImplementedError + + def idx2str(self, indexes): + """Convert indexes to text strings. + + Args: + indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]]. + Returns: + strings (list[str]): ['hello', 'world']. + """ + assert isinstance(indexes, list) + + strings = [] + for index in indexes: + string = [self.idx2char[i] for i in index] + strings.append(''.join(string)) + + return strings + + def tensor2idx(self, output): + """Convert model output tensor to character indexes and scores. + Args: + output (tensor): The model outputs with size: N * T * C + Returns: + indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]]. + scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94], + [0.9,0.9,0.98,0.97,0.96]]. + """ + raise NotImplementedError diff --git a/mmocr/models/textrecog/convertors/ctc.py b/mmocr/models/textrecog/convertors/ctc.py new file mode 100644 index 00000000..c14fc23f --- /dev/null +++ b/mmocr/models/textrecog/convertors/ctc.py @@ -0,0 +1,144 @@ +import math + +import torch +import torch.nn.functional as F + +import mmocr.utils as utils +from mmocr.models.builder import CONVERTORS +from .base import BaseConvertor + + +@CONVERTORS.register_module() +class CTCConvertor(BaseConvertor): + """Convert between text, index and tensor for CTC loss-based pipeline. + + Args: + dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'. + dict_file (None|str): Character dict file path. If not none, the file + is of higher priority than dict_type. + dict_list (None|list[str]): Character list. If not none, the list + is of higher priority than dict_type, but lower than dict_file. + with_unknown (bool): If True, add `UKN` token to class. + lower (bool): If True, convert original string to lower case. + """ + + def __init__(self, + dict_type='DICT90', + dict_file=None, + dict_list=None, + with_unknown=True, + lower=False, + **kwargs): + super().__init__(dict_type, dict_file, dict_list) + assert isinstance(with_unknown, bool) + assert isinstance(lower, bool) + + self.with_unknown = with_unknown + self.lower = lower + self.update_dict() + + def update_dict(self): + # CTC-blank + blank_token = '' + self.blank_idx = 0 + self.idx2char.insert(0, blank_token) + + # unknown + self.unknown_idx = None + if self.with_unknown: + self.idx2char.append('') + self.unknown_idx = len(self.idx2char) - 1 + + # update char2idx + self.char2idx = {} + for idx, char in enumerate(self.idx2char): + self.char2idx[char] = idx + + def str2tensor(self, strings): + """Convert text-string to ctc-loss input tensor. + + Args: + strings (list[str]): ['hello', 'world']. + Returns: + dict (str: tensor | list[tensor]): + tensors (list[tensor]): [torch.Tensor([1,2,3,3,4]), + torch.Tensor([5,4,6,3,7])]. + flatten_targets (tensor): torch.Tensor([1,2,3,3,4,5,4,6,3,7]). + target_lengths (tensor): torch.IntTensot([5,5]). + """ + assert utils.is_type_list(strings, str) + + tensors = [] + indexes = self.str2idx(strings) + for index in indexes: + tensor = torch.IntTensor(index) + tensors.append(tensor) + target_lengths = torch.IntTensor([len(t) for t in tensors]) + flatten_target = torch.cat(tensors) + + return { + 'targets': tensors, + 'flatten_targets': flatten_target, + 'target_lengths': target_lengths + } + + def tensor2idx(self, output, img_metas, topk=1, return_topk=False): + """Convert model output tensor to index-list. + Args: + output (tensor): The model outputs with size: N * T * C. + img_metas (list[dict]): Each dict contains one image info. + topk (int): The highest k classes to be returned. + return_topk (bool): Whether to return topk or just top1. + Returns: + indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]]. + scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94], + [0.9,0.9,0.98,0.97,0.96]] + ( + indexes_topk (list[list[list[int]->len=topk]]): + scores_topk (list[list[list[float]->len=topk]]) + ). + """ + assert utils.is_type_list(img_metas, dict) + assert len(img_metas) == output.size(0) + assert isinstance(topk, int) + assert topk >= 1 + + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] + + batch_size = output.size(0) + output = F.softmax(output, dim=2) + output = output.cpu().detach() + batch_topk_value, batch_topk_idx = output.topk(topk, dim=2) + batch_max_idx = batch_topk_idx[:, :, 0] + scores_topk, indexes_topk = [], [] + scores, indexes = [], [] + feat_len = output.size(1) + for b in range(batch_size): + valid_ratio = valid_ratios[b] + decode_len = min(feat_len, math.ceil(feat_len * valid_ratio)) + pred = batch_max_idx[b, :] + select_idx = [] + prev_idx = self.blank_idx + for t in range(decode_len): + tmp_value = pred[t].item() + if tmp_value not in (prev_idx, self.blank_idx): + select_idx.append(t) + prev_idx = tmp_value + select_idx = torch.LongTensor(select_idx) + topk_value = torch.index_select(batch_topk_value[b, :, :], 0, + select_idx) # valid_seqlen * topk + topk_idx = torch.index_select(batch_topk_idx[b, :, :], 0, + select_idx) + topk_idx_list, topk_value_list = topk_idx.numpy().tolist( + ), topk_value.numpy().tolist() + indexes_topk.append(topk_idx_list) + scores_topk.append(topk_value_list) + indexes.append([x[0] for x in topk_idx_list]) + scores.append([x[0] for x in topk_value_list]) + + if return_topk: + return indexes_topk, scores_topk + + return indexes, scores diff --git a/mmocr/models/textrecog/convertors/seg.py b/mmocr/models/textrecog/convertors/seg.py new file mode 100644 index 00000000..0a626dac --- /dev/null +++ b/mmocr/models/textrecog/convertors/seg.py @@ -0,0 +1,123 @@ +import cv2 +import numpy as np +import torch + +import mmocr.utils as utils +from mmocr.models.builder import CONVERTORS +from .base import BaseConvertor + + +@CONVERTORS.register_module() +class SegConvertor(BaseConvertor): + """Convert between text, index and tensor for segmentation based pipeline. + + Args: + dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'. + dict_file (None|str): Character dict file path. If not none, the + file is of higher priority than dict_type. + dict_list (None|list[str]): Character list. If not none, the list + is of higher priority than dict_type, but lower than dict_file. + with_unknown (bool): If True, add `UKN` token to class. + lower (bool): If True, convert original string to lower case. + """ + + def __init__(self, + dict_type='DICT36', + dict_file=None, + dict_list=None, + with_unknown=True, + lower=False, + **kwargs): + super().__init__(dict_type, dict_file, dict_list) + assert isinstance(with_unknown, bool) + assert isinstance(lower, bool) + + self.with_unknown = with_unknown + self.lower = lower + self.update_dict() + + def update_dict(self): + # background + self.idx2char.insert(0, '') + + # unknown + self.unknown_idx = None + if self.with_unknown: + self.idx2char.append('') + self.unknown_idx = len(self.idx2char) - 1 + + # update char2idx + self.char2idx = {} + for idx, char in enumerate(self.idx2char): + self.char2idx[char] = idx + + def tensor2str(self, output, img_metas=None): + """Convert model output tensor to string labels. + Args: + output (tensor): Model outputs with size: N * C * H * W + img_metas (list[dict]): Each dict contains one image info. + Returns: + texts (list[str]): Decoded text labels. + scores (list[list[float]]): Decoded chars scores. + """ + assert utils.is_type_list(img_metas, dict) + assert len(img_metas) == output.size(0) + + texts, scores = [], [] + for b in range(output.size(0)): + seg_pred = output[b].detach() + seg_res = torch.argmax( + seg_pred, dim=0).cpu().numpy().astype(np.int32) + + seg_thr = np.where(seg_res == 0, 0, 255).astype(np.uint8) + _, labels, stats, centroids = cv2.connectedComponentsWithStats( + seg_thr) + + component_num = stats.shape[0] + + all_res = [] + for i in range(component_num): + temp_loc = (labels == i) + temp_value = seg_res[temp_loc] + temp_center = centroids[i] + + temp_max_num = 0 + temp_max_cls = -1 + temp_total_num = 0 + for c in range(len(self.idx2char)): + c_num = np.sum(temp_value == c) + temp_total_num += c_num + if c_num > temp_max_num: + temp_max_num = c_num + temp_max_cls = c + + if temp_max_cls == 0: + continue + temp_max_score = 1.0 * temp_max_num / temp_total_num + all_res.append( + [temp_max_cls, temp_center, temp_max_num, temp_max_score]) + + all_res = sorted(all_res, key=lambda s: s[1][0]) + chars, char_scores = [], [] + for res in all_res: + temp_area = res[2] + if temp_area < 20: + continue + temp_char_index = res[0] + if temp_char_index >= len(self.idx2char): + temp_char = '' + elif temp_char_index <= 0: + temp_char = '' + elif temp_char_index == self.unknown_idx: + temp_char = '' + else: + temp_char = self.idx2char[temp_char_index] + chars.append(temp_char) + char_scores.append(res[3]) + + text = ''.join(chars) + + texts.append(text) + scores.append(char_scores) + + return texts, scores diff --git a/mmocr/models/textrecog/decoders/__init__.py b/mmocr/models/textrecog/decoders/__init__.py new file mode 100755 index 00000000..8b374733 --- /dev/null +++ b/mmocr/models/textrecog/decoders/__init__.py @@ -0,0 +1,15 @@ +from .base_decoder import BaseDecoder +from .crnn_decoder import CRNNDecoder +from .position_attention_decoder import PositionAttentionDecoder +from .robust_scanner_decoder import RobustScannerDecoder +from .sar_decoder import ParallelSARDecoder, SequentialSARDecoder +from .sar_decoder_with_bs import ParallelSARDecoderWithBS +from .sequence_attention_decoder import SequenceAttentionDecoder +from .transformer_decoder import TFDecoder + +__all__ = [ + 'CRNNDecoder', 'ParallelSARDecoder', 'SequentialSARDecoder', + 'ParallelSARDecoderWithBS', 'TFDecoder', 'BaseDecoder', + 'SequenceAttentionDecoder', 'PositionAttentionDecoder', + 'RobustScannerDecoder' +] diff --git a/mmocr/models/textrecog/decoders/base_decoder.py b/mmocr/models/textrecog/decoders/base_decoder.py new file mode 100644 index 00000000..543cd528 --- /dev/null +++ b/mmocr/models/textrecog/decoders/base_decoder.py @@ -0,0 +1,32 @@ +import torch.nn as nn + +from mmocr.models.builder import DECODERS + + +@DECODERS.register_module() +class BaseDecoder(nn.Module): + """Base decoder class for text recognition.""" + + def __init__(self, **kwargs): + super().__init__() + + def init_weights(self): + pass + + def forward_train(self, feat, out_enc, targets_dict, img_metas): + raise NotImplementedError + + def forward_test(self, feat, out_enc, img_metas): + raise NotImplementedError + + def forward(self, + feat, + out_enc, + targets_dict=None, + img_metas=None, + train_mode=True): + self.train_mode = train_mode + if train_mode: + return self.forward_train(feat, out_enc, targets_dict, img_metas) + + return self.forward_test(feat, out_enc, img_metas) diff --git a/mmocr/models/textrecog/decoders/crnn_decoder.py b/mmocr/models/textrecog/decoders/crnn_decoder.py new file mode 100644 index 00000000..1ce5226a --- /dev/null +++ b/mmocr/models/textrecog/decoders/crnn_decoder.py @@ -0,0 +1,49 @@ +import torch.nn as nn +from mmcv.cnn import xavier_init + +from mmocr.models.builder import DECODERS +from mmocr.models.textrecog.layers import BidirectionalLSTM +from .base_decoder import BaseDecoder + + +@DECODERS.register_module() +class CRNNDecoder(BaseDecoder): + + def __init__(self, + in_channels=None, + num_classes=None, + rnn_flag=False, + **kwargs): + super().__init__() + self.num_classes = num_classes + self.rnn_flag = rnn_flag + + if rnn_flag: + self.decoder = nn.Sequential( + BidirectionalLSTM(in_channels, 256, 256), + BidirectionalLSTM(256, 256, num_classes)) + else: + self.decoder = nn.Conv2d( + in_channels, num_classes, kernel_size=1, stride=1) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m) + + def forward_train(self, feat, out_enc, targets_dict, img_metas): + assert feat.size(2) == 1, 'feature height must be 1' + if self.rnn_flag: + x = feat.squeeze(2) # [N, C, W] + x = x.permute(2, 0, 1) # [W, N, C] + x = self.decoder(x) # [W, N, C] + outputs = x.permute(1, 0, 2).contiguous() + else: + x = self.decoder(feat) + x = x.permute(0, 3, 1, 2).contiguous() + n, w, c, h = x.size() + outputs = x.view(n, w, c * h) + return outputs + + def forward_test(self, feat, out_enc, img_metas): + return self.forward_train(feat, out_enc, None, img_metas) diff --git a/mmocr/models/textrecog/decoders/sar_decoder.py b/mmocr/models/textrecog/decoders/sar_decoder.py new file mode 100755 index 00000000..110fa8e2 --- /dev/null +++ b/mmocr/models/textrecog/decoders/sar_decoder.py @@ -0,0 +1,407 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import mmocr.utils as utils +from mmocr.models.builder import DECODERS +from .base_decoder import BaseDecoder + + +@DECODERS.register_module() +class ParallelSARDecoder(BaseDecoder): + """Implementation Parallel Decoder module in `SAR. + + `_ + + Args: + number_classes (int): Output class number. + channels (list[int]): Network layer channels. + enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. + dec_bi_rnn (bool): If True, use bidirectional RNN in decoder. + dec_do_rnn (float): Dropout of RNN layer in decoder. + dec_gru (bool): If True, use GRU, else LSTM in decoder. + d_model (int): Dim of channels from backbone. + d_enc (int): Dim of encoder RNN layer. + d_k (int): Dim of channels of attention module. + pred_dropout (float): Dropout probability of prediction layer. + max_seq_len (int): Maximum sequence length for decoding. + mask (bool): If True, mask padding in feature map. + start_idx (int): Index of start token. + padding_idx (int): Index of padding token. + pred_concat (bool): If True, concat glimpse feature from + attention with holistic feature and hidden state. + """ + + def __init__(self, + num_classes=37, + enc_bi_rnn=False, + dec_bi_rnn=False, + dec_do_rnn=0.0, + dec_gru=False, + d_model=512, + d_enc=512, + d_k=64, + pred_dropout=0.0, + max_seq_len=40, + mask=True, + start_idx=0, + padding_idx=92, + pred_concat=False, + **kwargs): + super().__init__() + + self.num_classes = num_classes + self.enc_bi_rnn = enc_bi_rnn + self.d_k = d_k + self.start_idx = start_idx + self.max_seq_len = max_seq_len + self.mask = mask + self.pred_concat = pred_concat + + encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) + decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1) + # 2D attention layer + self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k) + self.conv3x3_1 = nn.Conv2d( + d_model, d_k, kernel_size=3, stride=1, padding=1) + self.conv1x1_2 = nn.Linear(d_k, 1) + + # Decoder RNN layer + kwargs = dict( + input_size=encoder_rnn_out_size, + hidden_size=encoder_rnn_out_size, + num_layers=2, + batch_first=True, + dropout=dec_do_rnn, + bidirectional=dec_bi_rnn) + if dec_gru: + self.rnn_decoder = nn.GRU(**kwargs) + else: + self.rnn_decoder = nn.LSTM(**kwargs) + + # Decoder input embedding + self.embedding = nn.Embedding( + self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx) + + # Prediction layer + self.pred_dropout = nn.Dropout(pred_dropout) + pred_num_classes = num_classes - 1 # ignore padding_idx in prediction + if pred_concat: + fc_in_channel = decoder_rnn_out_size + d_model + d_enc + else: + fc_in_channel = d_model + self.prediction = nn.Linear(fc_in_channel, pred_num_classes) + + def _2d_attention(self, + decoder_input, + feat, + holistic_feat, + valid_ratios=None): + y = self.rnn_decoder(decoder_input)[0] + # y: bsz * (seq_len + 1) * hidden_size + + attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size + bsz, seq_len, attn_size = attn_query.size() + attn_query = attn_query.view(bsz, seq_len, attn_size, 1, 1) + + attn_key = self.conv3x3_1(feat) + # bsz * attn_size * h * w + attn_key = attn_key.unsqueeze(1) + # bsz * 1 * attn_size * h * w + + attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1)) + # bsz * (seq_len + 1) * attn_size * h * w + attn_weight = attn_weight.permute(0, 1, 3, 4, 2).contiguous() + # bsz * (seq_len + 1) * h * w * attn_size + attn_weight = self.conv1x1_2(attn_weight) + # bsz * (seq_len + 1) * h * w * 1 + bsz, T, h, w, c = attn_weight.size() + assert c == 1 + + if valid_ratios is not None: + # cal mask of attention weight + attn_mask = torch.zeros_like(attn_weight) + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(w, math.ceil(w * valid_ratio)) + attn_mask[i, :, :, valid_width:, :] = 1 + attn_weight = attn_weight.masked_fill(attn_mask.bool(), + float('-inf')) + + attn_weight = attn_weight.view(bsz, T, -1) + attn_weight = F.softmax(attn_weight, dim=-1) + attn_weight = attn_weight.view(bsz, T, h, w, + c).permute(0, 1, 4, 2, 3).contiguous() + + attn_feat = torch.sum( + torch.mul(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False) + # bsz * (seq_len + 1) * C + + # linear transformation + if self.pred_concat: + hf_c = holistic_feat.size(-1) + holistic_feat = holistic_feat.expand(bsz, seq_len, hf_c) + y = self.prediction(torch.cat((y, attn_feat, holistic_feat), 2)) + else: + y = self.prediction(attn_feat) + # bsz * (seq_len + 1) * num_classes + if self.train_mode: + y = self.pred_dropout(y) + + return y + + def forward_train(self, feat, out_enc, targets_dict, img_metas): + assert utils.is_type_list(img_metas, dict) + assert len(img_metas) == feat.size(0) + + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] if self.mask else None + + targets = targets_dict['padded_targets'].to(feat.device) + tgt_embedding = self.embedding(targets) + # bsz * seq_len * emb_dim + out_enc = out_enc.unsqueeze(1) + # bsz * 1 * emb_dim + in_dec = torch.cat((out_enc, tgt_embedding), dim=1) + # bsz * (seq_len + 1) * C + out_dec = self._2d_attention( + in_dec, feat, out_enc, valid_ratios=valid_ratios) + # bsz * (seq_len + 1) * num_classes + + return out_dec[:, 1:, :] # bsz * seq_len * num_classes + + def forward_test(self, feat, out_enc, img_metas): + assert utils.is_type_list(img_metas, dict) + assert len(img_metas) == feat.size(0) + + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] if self.mask else None + + seq_len = self.max_seq_len + + bsz = feat.size(0) + start_token = torch.full((bsz, ), + self.start_idx, + device=feat.device, + dtype=torch.long) + # bsz + start_token = self.embedding(start_token) + # bsz * emb_dim + start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1) + # bsz * seq_len * emb_dim + out_enc = out_enc.unsqueeze(1) + # bsz * 1 * emb_dim + decoder_input = torch.cat((out_enc, start_token), dim=1) + # bsz * (seq_len + 1) * emb_dim + + outputs = [] + for i in range(1, seq_len + 1): + decoder_output = self._2d_attention( + decoder_input, feat, out_enc, valid_ratios=valid_ratios) + char_output = decoder_output[:, i, :] # bsz * num_classes + char_output = F.softmax(char_output, -1) + outputs.append(char_output) + _, max_idx = torch.max(char_output, dim=1, keepdim=False) + char_embedding = self.embedding(max_idx) # bsz * emb_dim + if i < seq_len: + decoder_input[:, i + 1, :] = char_embedding + + outputs = torch.stack(outputs, 1) # bsz * seq_len * num_classes + + return outputs + + +@DECODERS.register_module() +class SequentialSARDecoder(BaseDecoder): + """Implementation Sequential Decoder module in `SAR. + + `_. + + Args: + number_classes (int): Number of output class. + enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. + dec_bi_rnn (bool): If True, use bidirectional RNN in decoder. + dec_do_rnn (float): Dropout of RNN layer in decoder. + dec_gru (bool): If True, use GRU, else LSTM in decoder. + d_k (int): Dim of conv layers in attention module. + d_model (int): Dim of channels from backbone. + d_enc (int): Dim of encoder RNN layer. + pred_dropout (float): Dropout probability of prediction layer. + max_seq_len (int): Maximum sequence length during decoding. + mask (bool): If True, mask padding in feature map. + start_idx (int): Index of start token. + padding_idx (int): Index of padding token. + pred_concat (bool): If True, concat glimpse feature from + attention with holistic feature and hidden state. + """ + + def __init__(self, + num_classes=37, + enc_bi_rnn=False, + dec_bi_rnn=False, + dec_gru=False, + d_k=64, + d_model=512, + d_enc=512, + pred_dropout=0.0, + mask=True, + max_seq_len=40, + start_idx=0, + padding_idx=92, + pred_concat=False, + **kwargs): + super().__init__() + + self.num_classes = num_classes + self.enc_bi_rnn = enc_bi_rnn + self.d_k = d_k + self.start_idx = start_idx + self.dec_gru = dec_gru + self.max_seq_len = max_seq_len + self.mask = mask + self.pred_concat = pred_concat + + encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) + decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1) + # 2D attention layer + self.conv1x1_1 = nn.Conv2d( + decoder_rnn_out_size, d_k, kernel_size=1, stride=1) + self.conv3x3_1 = nn.Conv2d( + d_model, d_k, kernel_size=3, stride=1, padding=1) + self.conv1x1_2 = nn.Conv2d(d_k, 1, kernel_size=1, stride=1) + + # Decoder rnn layer + if dec_gru: + self.rnn_decoder_layer1 = nn.GRUCell(encoder_rnn_out_size, + encoder_rnn_out_size) + self.rnn_decoder_layer2 = nn.GRUCell(encoder_rnn_out_size, + encoder_rnn_out_size) + else: + self.rnn_decoder_layer1 = nn.LSTMCell(encoder_rnn_out_size, + encoder_rnn_out_size) + self.rnn_decoder_layer2 = nn.LSTMCell(encoder_rnn_out_size, + encoder_rnn_out_size) + + # Decoder input embedding + self.embedding = nn.Embedding( + self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx) + + # Prediction layer + self.pred_dropout = nn.Dropout(pred_dropout) + pred_num_class = num_classes - 1 # ignore padding index + if pred_concat: + fc_in_channel = decoder_rnn_out_size + d_model + d_enc + else: + fc_in_channel = d_model + self.prediction = nn.Linear(fc_in_channel, pred_num_class) + + def _2d_attention(self, + y_prev, + feat, + holistic_feat, + hx1, + cx1, + hx2, + cx2, + valid_ratios=None): + _, _, h_feat, w_feat = feat.size() + if self.dec_gru: + hx1 = cx1 = self.rnn_decoder_layer1(y_prev, hx1) + hx2 = cx2 = self.rnn_decoder_layer2(hx1, hx2) + else: + hx1, cx1 = self.rnn_decoder_layer1(y_prev, (hx1, cx1)) + hx2, cx2 = self.rnn_decoder_layer2(hx1, (hx2, cx2)) + + tile_hx2 = hx2.view(hx2.size(0), hx2.size(1), 1, 1) + attn_query = self.conv1x1_1(tile_hx2) # bsz * attn_size * 1 * 1 + attn_query = attn_query.expand(-1, -1, h_feat, w_feat) + attn_key = self.conv3x3_1(feat) + attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1)) + attn_weight = self.conv1x1_2(attn_weight) + bsz, c, h, w = attn_weight.size() + assert c == 1 + + if valid_ratios is not None: + # cal mask of attention weight + attn_mask = torch.zeros_like(attn_weight) + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(w, math.ceil(w * valid_ratio)) + attn_mask[i, :, :, valid_width:] = 1 + attn_weight = attn_weight.masked_fill(attn_mask.bool(), + float('-inf')) + + attn_weight = F.softmax(attn_weight.view(bsz, -1), dim=-1) + attn_weight = attn_weight.view(bsz, c, h, w) + + attn_feat = torch.sum( + torch.mul(feat, attn_weight), (2, 3), keepdim=False) # n * c + + # linear transformation + if self.pred_concat: + y = self.prediction(torch.cat((hx2, attn_feat, holistic_feat), 1)) + else: + y = self.prediction(attn_feat) + + return y, hx1, hx1, hx2, hx2 + + def forward_train(self, feat, out_enc, targets_dict, img_metas=None): + assert utils.is_type_list(img_metas, dict) + assert len(img_metas) == feat.size(0) + + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] if self.mask else None + + if self.train_mode: + targets = targets_dict['padded_targets'].to(feat.device) + tgt_embedding = self.embedding(targets) + + outputs = [] + start_token = torch.full((feat.size(0), ), + self.start_idx, + device=feat.device, + dtype=torch.long) + start_token = self.embedding(start_token) + for i in range(-1, self.max_seq_len): + if i == -1: + if self.dec_gru: + hx1 = cx1 = self.rnn_decoder_layer1(out_enc) + hx2 = cx2 = self.rnn_decoder_layer2(hx1) + else: + hx1, cx1 = self.rnn_decoder_layer1(out_enc) + hx2, cx2 = self.rnn_decoder_layer2(hx1) + if not self.train_mode: + y_prev = start_token + else: + if self.train_mode: + y_prev = tgt_embedding[:, i, :] + y, hx1, cx1, hx2, cx2 = self._2d_attention( + y_prev, + feat, + out_enc, + hx1, + cx1, + hx2, + cx2, + valid_ratios=valid_ratios) + if self.train_mode: + y = self.pred_dropout(y) + else: + y = F.softmax(y, -1) + _, max_idx = torch.max(y, dim=1, keepdim=False) + char_embedding = self.embedding(max_idx) + y_prev = char_embedding + outputs.append(y) + + outputs = torch.stack(outputs, 1) + + return outputs + + def forward_test(self, feat, out_enc, img_metas): + assert utils.is_type_list(img_metas, dict) + assert len(img_metas) == feat.size(0) + + return self.forward_train(feat, out_enc, None, img_metas) diff --git a/mmocr/models/textrecog/decoders/sar_decoder_with_bs.py b/mmocr/models/textrecog/decoders/sar_decoder_with_bs.py new file mode 100755 index 00000000..98094dd4 --- /dev/null +++ b/mmocr/models/textrecog/decoders/sar_decoder_with_bs.py @@ -0,0 +1,148 @@ +from queue import PriorityQueue + +import torch +import torch.nn.functional as F + +import mmocr.utils as utils +from mmocr.models.builder import DECODERS +from . import ParallelSARDecoder + + +class DecodeNode: + """Node class to save decoded char indices and scores. + + Args: + indexes (list[int]): Char indices that decoded yes. + scores (list[float]): Char scores that decoded yes. + """ + + def __init__(self, indexes=[1], scores=[0.9]): + assert utils.is_type_list(indexes, int) + assert utils.is_type_list(scores, float) + assert utils.equal_len(indexes, scores) + + self.indexes = indexes + self.scores = scores + + def eval(self): + """Calculate accumulated score.""" + accu_score = sum(self.scores) + return accu_score + + +@DECODERS.register_module() +class ParallelSARDecoderWithBS(ParallelSARDecoder): + """Parallel Decoder module with beam-search in SAR. + + Args: + beam_width (int): Width for beam search. + """ + + def __init__(self, + beam_width=5, + num_classes=37, + enc_bi_rnn=False, + dec_bi_rnn=False, + dec_do_rnn=0, + dec_gru=False, + d_model=512, + d_enc=512, + d_k=64, + pred_dropout=0.0, + max_seq_len=40, + mask=True, + start_idx=0, + padding_idx=0, + pred_concat=False, + **kwargs): + super().__init__(num_classes, enc_bi_rnn, dec_bi_rnn, dec_do_rnn, + dec_gru, d_model, d_enc, d_k, pred_dropout, + max_seq_len, mask, start_idx, padding_idx, + pred_concat) + assert isinstance(beam_width, int) + assert beam_width > 0 + + self.beam_width = beam_width + + def forward_test(self, feat, out_enc, img_metas): + assert utils.is_type_list(img_metas, dict) + assert len(img_metas) == feat.size(0) + + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] if self.mask else None + + seq_len = self.max_seq_len + bsz = feat.size(0) + assert bsz == 1, 'batch size must be 1 for beam search.' + + start_token = torch.full((bsz, ), + self.start_idx, + device=feat.device, + dtype=torch.long) + # bsz + start_token = self.embedding(start_token) + # bsz * emb_dim + start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1) + # bsz * seq_len * emb_dim + out_enc = out_enc.unsqueeze(1) + # bsz * 1 * emb_dim + decoder_input = torch.cat((out_enc, start_token), dim=1) + # bsz * (seq_len + 1) * emb_dim + + # Initialize beam-search queue + q = PriorityQueue() + init_node = DecodeNode([self.start_idx], [0.0]) + q.put((-init_node.eval(), init_node)) + + for i in range(1, seq_len + 1): + next_nodes = [] + beam_width = self.beam_width if i > 1 else 1 + for _ in range(beam_width): + _, node = q.get() + + input_seq = torch.clone(decoder_input) # bsz * T * emb_dim + # fill previous input tokens (step 1...i) in input_seq + for t, index in enumerate(node.indexes): + input_token = torch.full((bsz, ), + index, + device=input_seq.device, + dtype=torch.long) + input_token = self.embedding(input_token) # bsz * emb_dim + input_seq[:, t + 1, :] = input_token + + output_seq = self._2d_attention( + input_seq, feat, out_enc, valid_ratios=valid_ratios) + + output_char = output_seq[:, i, :] # bsz * num_classes + output_char = F.softmax(output_char, -1) + topk_value, topk_idx = output_char.topk(self.beam_width, dim=1) + topk_value, topk_idx = topk_value.squeeze(0), topk_idx.squeeze( + 0) + + for k in range(self.beam_width): + kth_score = topk_value[k].item() + kth_idx = topk_idx[k].item() + next_node = DecodeNode(node.indexes + [kth_idx], + node.scores + [kth_score]) + delta = k * 1e-6 + next_nodes.append( + (-node.eval() - kth_score - delta, next_node)) + # Use minus since priority queue sort + # with ascending order + + while not q.empty(): + q.get() + + # Put all candidates to queue + for next_node in next_nodes: + q.put(next_node) + + best_node = q.get() + num_classes = self.num_classes - 1 # ignore padding index + outputs = torch.zeros(bsz, seq_len, num_classes) + for i in range(seq_len): + idx = best_node[1].indexes[i + 1] + outputs[0, i, idx] = best_node[1].scores[i + 1] + + return outputs diff --git a/mmocr/models/textrecog/decoders/transformer_decoder.py b/mmocr/models/textrecog/decoders/transformer_decoder.py new file mode 100644 index 00000000..c8ef2fad --- /dev/null +++ b/mmocr/models/textrecog/decoders/transformer_decoder.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmocr.models.builder import DECODERS +from mmocr.models.textrecog.layers import (DecoderLayer, PositionalEncoding, + get_pad_mask, get_subsequent_mask) +from .base_decoder import BaseDecoder + + +@DECODERS.register_module() +class TFDecoder(BaseDecoder): + """Transformer Decoder block with self attention mechanism.""" + + def __init__(self, + n_layers=6, + d_embedding=512, + n_head=8, + d_k=64, + d_v=64, + d_model=512, + d_inner=256, + n_position=200, + dropout=0.1, + num_classes=93, + max_seq_len=40, + start_idx=1, + padding_idx=92, + **kwargs): + super().__init__() + + self.padding_idx = padding_idx + self.start_idx = start_idx + self.max_seq_len = max_seq_len + + self.trg_word_emb = nn.Embedding( + num_classes, d_embedding, padding_idx=padding_idx) + + self.position_enc = PositionalEncoding( + d_embedding, n_position=n_position) + self.dropout = nn.Dropout(p=dropout) + + self.layer_stack = nn.ModuleList([ + DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + for _ in range(n_layers) + ]) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + pred_num_class = num_classes - 1 # ignore padding_idx + self.classifier = nn.Linear(d_model, pred_num_class) + + def _attention(self, trg_seq, src, src_mask=None): + trg_embedding = self.trg_word_emb(trg_seq) + trg_pos_encoded = self.position_enc(trg_embedding) + tgt = self.dropout(trg_pos_encoded) + + trg_mask = get_pad_mask( + trg_seq, pad_idx=self.padding_idx) & get_subsequent_mask(trg_seq) + output = tgt + for dec_layer in self.layer_stack: + output = dec_layer( + output, + src, + slf_attn_mask=trg_mask, + dec_enc_attn_mask=src_mask) + output = self.layer_norm(output) + + return output + + def forward_train(self, feat, out_enc, targets_dict, img_metas): + targets = targets_dict['padded_targets'].to(out_enc.device) + attn_output = self._attention(targets, out_enc, src_mask=None) + outputs = self.classifier(attn_output) + return outputs + + def forward_test(self, feat, out_enc, img_metas): + bsz = out_enc.size(0) + init_target_seq = torch.full((bsz, self.max_seq_len + 1), + self.padding_idx, + device=out_enc.device, + dtype=torch.long) + # bsz * seq_len + init_target_seq[:, 0] = self.start_idx + + outputs = [] + for step in range(0, self.max_seq_len): + decoder_output = self._attention( + init_target_seq, out_enc, src_mask=None) + # bsz * seq_len * 512 + step_result = F.softmax( + self.classifier(decoder_output[:, step, :]), dim=-1) + # bsz * num_classes + outputs.append(step_result) + _, step_max_index = torch.max(step_result, dim=-1) + init_target_seq[:, step + 1] = step_max_index + + outputs = torch.stack(outputs, dim=1) + + return outputs diff --git a/mmocr/models/textrecog/encoders/__init__.py b/mmocr/models/textrecog/encoders/__init__.py new file mode 100755 index 00000000..e0d9394a --- /dev/null +++ b/mmocr/models/textrecog/encoders/__init__.py @@ -0,0 +1,6 @@ +from .base_encoder import BaseEncoder +from .channel_reduction_encoder import ChannelReductionEncoder +from .sar_encoder import SAREncoder +from .transformer_encoder import TFEncoder + +__all__ = ['SAREncoder', 'TFEncoder', 'BaseEncoder', 'ChannelReductionEncoder'] diff --git a/mmocr/models/textrecog/encoders/base_encoder.py b/mmocr/models/textrecog/encoders/base_encoder.py new file mode 100644 index 00000000..3dadc687 --- /dev/null +++ b/mmocr/models/textrecog/encoders/base_encoder.py @@ -0,0 +1,14 @@ +import torch.nn as nn + +from mmocr.models.builder import ENCODERS + + +@ENCODERS.register_module() +class BaseEncoder(nn.Module): + """Base Encoder class for text recognition.""" + + def init_weights(self): + pass + + def forward(self, feat, **kwargs): + return feat diff --git a/mmocr/models/textrecog/encoders/sar_encoder.py b/mmocr/models/textrecog/encoders/sar_encoder.py new file mode 100644 index 00000000..d0381d6c --- /dev/null +++ b/mmocr/models/textrecog/encoders/sar_encoder.py @@ -0,0 +1,102 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import uniform_init, xavier_init + +import mmocr.utils as utils +from mmocr.models.builder import ENCODERS +from .base_encoder import BaseEncoder + + +@ENCODERS.register_module() +class SAREncoder(BaseEncoder): + """Implementation of encoder module in `SAR. + + `_ + + Args: + enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. + enc_do_rnn (float): Dropout probability of RNN layer in encoder. + enc_gru (bool): If True, use GRU, else LSTM in encoder. + d_model (int): Dim of channels from backbone. + d_enc (int): Dim of encoder RNN layer. + mask (bool): If True, mask padding in RNN sequence. + """ + + def __init__(self, + enc_bi_rnn=False, + enc_do_rnn=0.0, + enc_gru=False, + d_model=512, + d_enc=512, + mask=True, + **kwargs): + super().__init__() + assert isinstance(enc_bi_rnn, bool) + assert isinstance(enc_do_rnn, (int, float)) + assert 0 <= enc_do_rnn < 1.0 + assert isinstance(enc_gru, bool) + assert isinstance(d_model, int) + assert isinstance(d_enc, int) + assert isinstance(mask, bool) + + self.enc_bi_rnn = enc_bi_rnn + self.enc_do_rnn = enc_do_rnn + self.mask = mask + + # LSTM Encoder + kwargs = dict( + input_size=d_model, + hidden_size=d_enc, + num_layers=2, + batch_first=True, + dropout=enc_do_rnn, + bidirectional=enc_bi_rnn) + if enc_gru: + self.rnn_encoder = nn.GRU(**kwargs) + else: + self.rnn_encoder = nn.LSTM(**kwargs) + + # global feature transformation + encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) + self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size) + + def init_weights(self): + # initialize weight and bias + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m) + elif isinstance(m, nn.BatchNorm2d): + uniform_init(m) + + def forward(self, feat, img_metas=None): + assert utils.is_type_list(img_metas, dict) + assert len(img_metas) == feat.size(0) + + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] if self.mask else None + + h_feat = feat.size(2) + feat_v = F.max_pool2d( + feat, kernel_size=(h_feat, 1), stride=1, padding=0) + feat_v = feat_v.squeeze(2) # bsz * C * W + feat_v = feat_v.permute(0, 2, 1).contiguous() # bsz * W * C + + holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C + + if valid_ratios is not None: + valid_hf = [] + T = holistic_feat.size(1) + for i, valid_ratio in enumerate(valid_ratios): + valid_step = min(T, math.ceil(T * valid_ratio)) - 1 + valid_hf.append(holistic_feat[i, valid_step, :]) + valid_hf = torch.stack(valid_hf, dim=0) + else: + valid_hf = holistic_feat[:, -1, :] # bsz * C + + holistic_feat = self.linear(valid_hf) # bsz * C + + return holistic_feat diff --git a/mmocr/models/textrecog/encoders/transformer_encoder.py b/mmocr/models/textrecog/encoders/transformer_encoder.py new file mode 100644 index 00000000..d33d17bf --- /dev/null +++ b/mmocr/models/textrecog/encoders/transformer_encoder.py @@ -0,0 +1,16 @@ +from mmocr.models.builder import ENCODERS +from .base_encoder import BaseEncoder + + +@ENCODERS.register_module() +class TFEncoder(BaseEncoder): + """Encode 2d feature map to 1d sequence.""" + + def __init__(self, **kwargs): + super().__init__() + + def forward(self, feat, img_metas=None): + n, c, _, _ = feat.size() + enc_output = feat.view(n, c, -1).transpose(2, 1).contiguous() + + return enc_output diff --git a/mmocr/models/textrecog/heads/__init__.py b/mmocr/models/textrecog/heads/__init__.py new file mode 100755 index 00000000..761bb9a9 --- /dev/null +++ b/mmocr/models/textrecog/heads/__init__.py @@ -0,0 +1,3 @@ +from .seg_head import SegHead + +__all__ = ['SegHead'] diff --git a/mmocr/models/textrecog/heads/seg_head.py b/mmocr/models/textrecog/heads/seg_head.py new file mode 100644 index 00000000..a0ca59cc --- /dev/null +++ b/mmocr/models/textrecog/heads/seg_head.py @@ -0,0 +1,50 @@ +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from torch import nn + +from mmdet.models.builder import HEADS + + +@HEADS.register_module() +class SegHead(nn.Module): + """Head for segmentation based text recognition. + + Args: + in_channels (int): Number of input channels. + num_classes (int): Number of output classes. + upsample_param (dict | None): Config dict for interpolation layer. + Default: `dict(scale_factor=1.0, mode='nearest')` + """ + + def __init__(self, in_channels=128, num_classes=37, upsample_param=None): + super().__init__() + assert isinstance(num_classes, int) + assert num_classes > 0 + assert upsample_param is None or isinstance(upsample_param, dict) + + self.upsample_param = upsample_param + + self.seg_conv = ConvModule( + in_channels, + in_channels, + 3, + stride=1, + padding=1, + norm_cfg=dict(type='BN')) + + # prediction + self.pred_conv = nn.Conv2d( + in_channels, num_classes, kernel_size=1, stride=1, padding=0) + + def init_weights(self): + pass + + def forward(self, out_neck): + + seg_map = self.seg_conv(out_neck[-1]) + seg_map = self.pred_conv(seg_map) + + if self.upsample_param is not None: + seg_map = F.interpolate(seg_map, **self.upsample_param) + + return seg_map diff --git a/mmocr/models/textrecog/layers/__init__.py b/mmocr/models/textrecog/layers/__init__.py new file mode 100755 index 00000000..a38a8c5d --- /dev/null +++ b/mmocr/models/textrecog/layers/__init__.py @@ -0,0 +1,15 @@ +from .conv_layer import BasicBlock, Bottleneck +from .dot_product_attention_layer import DotProductAttentionLayer +from .lstm_layer import BidirectionalLSTM +from .position_aware_layer import PositionAwareLayer +from .robust_scanner_fusion_layer import RobustScannerFusionLayer +from .transformer_layer import (DecoderLayer, MultiHeadAttention, + PositionalEncoding, PositionwiseFeedForward, + get_pad_mask, get_subsequent_mask) + +__all__ = [ + 'BidirectionalLSTM', 'MultiHeadAttention', 'PositionalEncoding', + 'PositionwiseFeedForward', 'BasicBlock', 'Bottleneck', + 'RobustScannerFusionLayer', 'DotProductAttentionLayer', + 'PositionAwareLayer', 'DecoderLayer', 'get_pad_mask', 'get_subsequent_mask' +] diff --git a/mmocr/models/textrecog/layers/conv_layer.py b/mmocr/models/textrecog/layers/conv_layer.py new file mode 100644 index 00000000..d0ce32a3 --- /dev/null +++ b/mmocr/models/textrecog/layers/conv_layer.py @@ -0,0 +1,93 @@ +import torch.nn as nn + + +def conv3x3(in_planes, out_planes, stride=1): + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=False): + super().__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + if downsample: + self.downsample = nn.Sequential( + nn.Conv2d( + inplanes, planes * self.expansion, 1, stride, bias=False), + nn.BatchNorm2d(planes * self.expansion), + ) + else: + self.downsample = nn.Sequential() + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=False): + super().__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + if downsample: + self.downsample = nn.Sequential( + nn.Conv2d( + inplanes, planes * self.expansion, 1, stride, bias=False), + nn.BatchNorm2d(planes * self.expansion), + ) + else: + self.downsample = nn.Sequential() + + def forward(self, x): + residual = self.downsample(x) + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += residual + out = self.relu(out) + + return out diff --git a/mmocr/models/textrecog/layers/lstm_layer.py b/mmocr/models/textrecog/layers/lstm_layer.py new file mode 100644 index 00000000..e4017d02 --- /dev/null +++ b/mmocr/models/textrecog/layers/lstm_layer.py @@ -0,0 +1,20 @@ +import torch.nn as nn + + +class BidirectionalLSTM(nn.Module): + + def __init__(self, nIn, nHidden, nOut): + super().__init__() + + self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) + self.embedding = nn.Linear(nHidden * 2, nOut) + + def forward(self, input): + recurrent, _ = self.rnn(input) + T, b, h = recurrent.size() + t_rec = recurrent.view(T * b, h) + + output = self.embedding(t_rec) # [T * b, nOut] + output = output.view(T, b, -1) + + return output diff --git a/mmocr/models/textrecog/layers/transformer_layer.py b/mmocr/models/textrecog/layers/transformer_layer.py new file mode 100644 index 00000000..a9bba4b3 --- /dev/null +++ b/mmocr/models/textrecog/layers/transformer_layer.py @@ -0,0 +1,172 @@ +"""This code is from https://github.com/jadore801120/attention-is-all-you-need- +pytorch.""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DecoderLayer(nn.Module): + """Compose with three layers: + 1. MultiHeadSelfAttn + 2. MultiHeadEncoderDecoderAttn + 3. PositionwiseFeedForward + """ + + def __init__(self, + d_model=512, + d_inner=256, + n_head=8, + d_k=64, + d_v=64, + dropout=0.1): + super().__init__() + self.slf_attn = MultiHeadAttention( + n_head, d_model, d_k, d_v, dropout=dropout) + self.enc_attn = MultiHeadAttention( + n_head, d_model, d_k, d_v, dropout=dropout) + self.pos_ffn = PositionwiseFeedForward( + d_model, d_inner, dropout=dropout) + + def forward(self, + dec_input, + enc_output, + slf_attn_mask=None, + dec_enc_attn_mask=None): + + dec_output = self.slf_attn( + dec_input, dec_input, dec_input, mask=slf_attn_mask) + dec_output = self.enc_attn( + dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) + dec_output = self.pos_ffn(dec_output) + + return dec_output + + +class ScaledDotProductAttention(nn.Module): + """Scaled Dot-Product Attention.""" + + def __init__(self, temperature, attn_dropout=0.1): + super().__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + + def forward(self, q, k, v, mask=None): + + attn = torch.matmul(q / self.temperature, k.transpose(1, 2)) + + if mask is not None: + attn = attn.masked_fill(mask == 0, float('-inf')) + + attn = self.dropout(F.softmax(attn, dim=-1)) + output = torch.matmul(attn, v) + + return output, attn + + +class MultiHeadAttention(nn.Module): + """Multi-Head Attention module.""" + + def __init__(self, n_head=8, d_model=512, d_k=64, d_v=64, dropout=0.1): + super().__init__() + + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + self.w_qs_list = nn.ModuleList( + [nn.Linear(d_model, d_k, bias=False) for _ in range(n_head)]) + self.w_ks_list = nn.ModuleList( + [nn.Linear(d_model, d_k, bias=False) for _ in range(n_head)]) + self.w_vs_list = nn.ModuleList( + [nn.Linear(d_model, d_v, bias=False) for _ in range(n_head)]) + self.fc = nn.Linear(n_head * d_v, d_model, bias=False) + + self.attention = ScaledDotProductAttention(temperature=d_k**0.5) + + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + def forward(self, q, k, v, mask=None): + + residual = q + q = self.layer_norm(q) + + attention_q_list = [] + for head_index in range(self.n_head): + q_each = self.w_qs_list[head_index](q) # bsz * seq_len * d_k + k_each = self.w_ks_list[head_index](k) # bsz * seq_len * d_k + v_each = self.w_vs_list[head_index](v) # bsz * seq_len * d_v + attention_q_each, _ = self.attention( + q_each, k_each, v_each, mask=mask) + attention_q_list.append(attention_q_each) + + q = torch.cat(attention_q_list, dim=-1) + q = self.dropout(self.fc(q)) + q += residual + + return q + + +class PositionwiseFeedForward(nn.Module): + """A two-feed-forward-layer module.""" + + def __init__(self, d_in, d_hid, dropout=0.1): + super().__init__() + self.w_1 = nn.Linear(d_in, d_hid) # position-wise + self.w_2 = nn.Linear(d_hid, d_in) # position-wise + self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + + residual = x + x = self.layer_norm(x) + + x = self.w_2(F.relu(self.w_1(x))) + x = self.dropout(x) + x += residual + + return x + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_hid=512, n_position=200): + super().__init__() + + # Not a parameter + self.register_buffer( + 'position_table', + self._get_sinusoid_encoding_table(n_position, d_hid)) + + def _get_sinusoid_encoding_table(self, n_position, d_hid): + """Sinusoid position encoding table.""" + denominator = torch.Tensor([ + 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ]) + denominator = denominator.view(1, -1) + pos_tensor = torch.arange(n_position).unsqueeze(-1).float() + sinusoid_table = pos_tensor * denominator + sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) + sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) + + return sinusoid_table.unsqueeze(0) + + def forward(self, x): + self.device = x.device + return x + self.position_table[:, :x.size(1)].clone().detach() + + +def get_pad_mask(seq, pad_idx): + return (seq != pad_idx).unsqueeze(-2) + + +def get_subsequent_mask(seq): + """For masking out the subsequent info.""" + len_s = seq.size(1) + subsequent_mask = 1 - torch.triu( + torch.ones((len_s, len_s), device=seq.device), diagonal=1) + subsequent_mask = subsequent_mask.unsqueeze(0).bool() + return subsequent_mask diff --git a/mmocr/models/textrecog/losses/__init__.py b/mmocr/models/textrecog/losses/__init__.py new file mode 100755 index 00000000..4b5a24a2 --- /dev/null +++ b/mmocr/models/textrecog/losses/__init__.py @@ -0,0 +1,5 @@ +from .ce_loss import CELoss, SARLoss, TFLoss +from .ctc_loss import CTCLoss +from .seg_loss import CAFCNLoss, SegLoss + +__all__ = ['CELoss', 'SARLoss', 'CTCLoss', 'TFLoss', 'SegLoss', 'CAFCNLoss'] diff --git a/mmocr/models/textrecog/losses/ce_loss.py b/mmocr/models/textrecog/losses/ce_loss.py new file mode 100644 index 00000000..4fad5ae4 --- /dev/null +++ b/mmocr/models/textrecog/losses/ce_loss.py @@ -0,0 +1,94 @@ +import torch.nn as nn + +from mmdet.models.builder import LOSSES + + +@LOSSES.register_module() +class CELoss(nn.Module): + """Implementation of loss module for encoder-decoder based text recognition + method with CrossEntropy loss. + + Args: + ignore_index (int): Specifies a target value that is + ignored and does not contribute to the input gradient. + reduction (str): Specifies the reduction to apply to the output, + should be one of the following: ('none', 'mean', 'sum'). + """ + + def __init__(self, ignore_index=-1, reduction='none'): + super().__init__() + assert isinstance(ignore_index, int) + assert isinstance(reduction, str) + assert reduction in ['none', 'mean', 'sum'] + + self.loss_ce = nn.CrossEntropyLoss( + ignore_index=ignore_index, reduction=reduction) + + def format(self, outputs, targets_dict): + targets = targets_dict['padded_targets'] + + return outputs.permute(0, 2, 1).contiguous(), targets + + def forward(self, outputs, targets_dict): + outputs, targets = self.format(outputs, targets_dict) + + loss_ce = self.loss_ce(outputs, targets.to(outputs.device)) + losses = dict(loss_ce=loss_ce) + + return losses + + +@LOSSES.register_module() +class SARLoss(CELoss): + """Implementation of loss module in `SAR. + + `_. + + Args: + ignore_index (int): Specifies a target value that is + ignored and does not contribute to the input gradient. + reduction (str): Specifies the reduction to apply to the output, + should be one of the following: ('none', 'mean', 'sum'). + """ + + def __init__(self, ignore_index=0, reduction='mean', **kwargs): + super().__init__(ignore_index, reduction) + + def format(self, outputs, targets_dict): + targets = targets_dict['padded_targets'] + # targets[0, :], [start_idx, idx1, idx2, ..., end_idx, pad_idx...] + # outputs[0, :, 0], [idx1, idx2, ..., end_idx, ...] + + # ignore first index of target in loss calculation + targets = targets[:, 1:].contiguous() + # ignore last index of outputs to be in same seq_len with targets + outputs = outputs[:, :-1, :].permute(0, 2, 1).contiguous() + + return outputs, targets + + +@LOSSES.register_module() +class TFLoss(CELoss): + """Implementation of loss module for transformer.""" + + def __init__(self, + ignore_index=-1, + reduction='none', + flatten=True, + **kwargs): + super().__init__(ignore_index, reduction) + assert isinstance(flatten, bool) + + self.flatten = flatten + + def format(self, outputs, targets_dict): + outputs = outputs[:, :-1, :].contiguous() + targets = targets_dict['padded_targets'] + targets = targets[:, 1:].contiguous() + if self.flatten: + outputs = outputs.view(-1, outputs.size(-1)) + targets = targets.view(-1) + else: + outputs = outputs.permute(0, 2, 1).contiguous() + + return outputs, targets diff --git a/mmocr/models/textrecog/losses/ctc_loss.py b/mmocr/models/textrecog/losses/ctc_loss.py new file mode 100644 index 00000000..231469a2 --- /dev/null +++ b/mmocr/models/textrecog/losses/ctc_loss.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn + +from mmdet.models.builder import LOSSES + + +@LOSSES.register_module() +class CTCLoss(nn.Module): + """Implementation of loss module for CTC-loss based text recognition. + + Args: + flatten (bool): If True, use flattened targets, else padded targets. + blank (int): Blank label. Default 0. + reduction (str): Specifies the reduction to apply to the output, + should be one of the following: ('none', 'mean', 'sum'). + zero_infinity (bool): Whether to zero infinite losses and + the associated gradients. Default: False. + Infinite losses mainly occur when the inputs + are too short to be aligned to the targets. + """ + + def __init__(self, + flatten=True, + blank=0, + reduction='mean', + zero_infinity=False, + **kwargs): + super().__init__() + assert isinstance(flatten, bool) + assert isinstance(blank, int) + assert isinstance(reduction, str) + assert isinstance(zero_infinity, bool) + + self.flatten = flatten + self.blank = blank + self.ctc_loss = nn.CTCLoss( + blank=blank, reduction=reduction, zero_infinity=zero_infinity) + + def forward(self, outputs, targets_dict): + + outputs = torch.log_softmax(outputs, dim=2) + bsz, seq_len = outputs.size(0), outputs.size(1) + input_lengths = torch.full( + size=(bsz, ), fill_value=seq_len, dtype=torch.long) + outputs_for_loss = outputs.permute(1, 0, 2).contiguous() # T * N * C + + if self.flatten: + targets = targets_dict['flatten_targets'] + else: + targets = torch.full( + size=(bsz, seq_len), fill_value=self.blank, dtype=torch.long) + for idx, tensor in enumerate(targets_dict['targets']): + valid_len = min(tensor.size(0), seq_len) + targets[idx, :valid_len] = tensor[:valid_len] + + target_lengths = targets_dict['target_lengths'] + + loss_ctc = self.ctc_loss(outputs_for_loss, targets, input_lengths, + target_lengths) + + losses = dict(loss_ctc=loss_ctc) + + return losses diff --git a/mmocr/models/textrecog/losses/seg_loss.py b/mmocr/models/textrecog/losses/seg_loss.py new file mode 100644 index 00000000..0da9e61e --- /dev/null +++ b/mmocr/models/textrecog/losses/seg_loss.py @@ -0,0 +1,176 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmdet.models.builder import LOSSES +from mmocr.models.common.losses import DiceLoss + + +@LOSSES.register_module() +class SegLoss(nn.Module): + """Implementation of loss module for segmentation based text recognition + method. + + Args: + seg_downsample_ratio (float): Downsample ratio of + segmentation map. + seg_with_loss_weight (bool): If True, set weight for + segmentation loss. + ignore_index (int): Specifies a target value that is ignored + and does not contribute to the input gradient. + """ + + def __init__(self, + seg_downsample_ratio=0.5, + seg_with_loss_weight=True, + ignore_index=255, + **kwargs): + super().__init__() + + assert isinstance(seg_downsample_ratio, (int, float)) + assert 0 < seg_downsample_ratio <= 1 + assert isinstance(ignore_index, int) + + self.seg_downsample_ratio = seg_downsample_ratio + self.seg_with_loss_weight = seg_with_loss_weight + self.ignore_index = ignore_index + + def seg_loss(self, out_head, gt_kernels): + seg_map = out_head # bsz * num_classes * H/2 * W/2 + seg_target = [ + item[1].rescale(self.seg_downsample_ratio).to_tensor( + torch.long, seg_map.device) for item in gt_kernels + ] + seg_target = torch.stack(seg_target).squeeze(1) + + loss_weight = None + if self.seg_with_loss_weight: + N = torch.sum(seg_target != self.ignore_index) + N_neg = torch.sum(seg_target == 0) + weight_val = 1.0 * N_neg / (N - N_neg) + loss_weight = torch.ones(seg_map.size(1), device=seg_map.device) + loss_weight[1:] = weight_val + + loss_seg = F.cross_entropy( + seg_map, + seg_target, + weight=loss_weight, + ignore_index=self.ignore_index) + + return loss_seg + + def forward(self, out_neck, out_head, gt_kernels): + + losses = {} + + loss_seg = self.seg_loss(out_head, gt_kernels) + + losses['loss_seg'] = loss_seg + + return losses + + +@LOSSES.register_module() +class CAFCNLoss(SegLoss): + """Implementation of loss module in `CA-FCN. + + `_ + + Args: + alpha (float): Weight ratio of attention loss. + attn_s2_downsample_ratio (float): Downsample ratio + of attention map from output stage 2. + attn_s3_downsample_ratio (float): Downsample ratio + of attention map from output stage 3. + seg_downsample_ratio (float): Downsample ratio of + segmentation map. + attn_with_dice_loss (bool): If True, use dice_loss for attention, + else BCELoss. + with_attn (bool): If True, include attention loss, else + segmentation loss only. + seg_with_loss_weight (bool): If True, set weight for + segmentation loss. + ignore_index (int): Specifies a target value that is ignored + and does not contribute to the input gradient. + """ + + def __init__(self, + alpha=1.0, + attn_s2_downsample_ratio=0.25, + attn_s3_downsample_ratio=0.125, + seg_downsample_ratio=0.5, + attn_with_dice_loss=False, + with_attn=True, + seg_with_loss_weight=True, + ignore_index=255): + super().__init__(seg_downsample_ratio, seg_with_loss_weight, + ignore_index) + assert isinstance(alpha, (int, float)) + assert isinstance(attn_s2_downsample_ratio, (int, float)) + assert isinstance(attn_s3_downsample_ratio, (int, float)) + assert 0 < attn_s2_downsample_ratio <= 1 + assert 0 < attn_s3_downsample_ratio <= 1 + + self.alpha = alpha + self.attn_s2_downsample_ratio = attn_s2_downsample_ratio + self.attn_s3_downsample_ratio = attn_s3_downsample_ratio + self.with_attn = with_attn + self.attn_with_dice_loss = attn_with_dice_loss + + # attention loss + if with_attn: + if attn_with_dice_loss: + self.criterion_attn = DiceLoss() + else: + self.criterion_attn = nn.BCELoss(reduction='none') + + def attn_loss(self, out_neck, gt_kernels): + attn_map_s2 = out_neck[0] # bsz * 2 * H/4 * W/4 + + mask_s2 = torch.stack([ + item[2].rescale(self.attn_s2_downsample_ratio).to_tensor( + torch.float, attn_map_s2.device) for item in gt_kernels + ]) + + attn_target_s2 = torch.stack([ + item[0].rescale(self.attn_s2_downsample_ratio).to_tensor( + torch.float, attn_map_s2.device) for item in gt_kernels + ]) + + mask_s3 = torch.stack([ + item[2].rescale(self.attn_s3_downsample_ratio).to_tensor( + torch.float, attn_map_s2.device) for item in gt_kernels + ]) + + attn_target_s3 = torch.stack([ + item[0].rescale(self.attn_s3_downsample_ratio).to_tensor( + torch.float, attn_map_s2.device) for item in gt_kernels + ]) + + targets = [ + attn_target_s2, attn_target_s3, attn_target_s3, attn_target_s3 + ] + + masks = [mask_s2, mask_s3, mask_s3, mask_s3] + + loss_attn = 0. + for i in range(len(out_neck) - 1): + pred = out_neck[i] + if self.attn_with_dice_loss: + loss_attn += self.criterion_attn(pred, targets[i], masks[i]) + else: + loss_attn += torch.sum( + self.criterion_attn(pred, targets[i]) * + masks[i]) / torch.sum(masks[i]) + + return loss_attn + + def forward(self, out_neck, out_head, gt_kernels): + + losses = super().forward(out_neck, out_head, gt_kernels) + + if self.with_attn: + loss_attn = self.attn_loss(out_neck, gt_kernels) + losses['loss_attn'] = loss_attn + + return losses diff --git a/mmocr/models/textrecog/necks/__init__.py b/mmocr/models/textrecog/necks/__init__.py new file mode 100755 index 00000000..c10a46a5 --- /dev/null +++ b/mmocr/models/textrecog/necks/__init__.py @@ -0,0 +1,5 @@ +from .cafcn_neck import CAFCNNeck +from .fpn_ocr import FPNOCR +from .fpn_seg import FPNSeg + +__all__ = ['CAFCNNeck', 'FPNSeg', 'FPNOCR'] diff --git a/mmocr/models/textrecog/necks/cafcn_neck.py b/mmocr/models/textrecog/necks/cafcn_neck.py new file mode 100644 index 00000000..b5fca8d2 --- /dev/null +++ b/mmocr/models/textrecog/necks/cafcn_neck.py @@ -0,0 +1,223 @@ +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmcv.ops import DeformConv2dPack +from torch import nn + +from mmdet.models.builder import NECKS + + +class CharAttn(nn.Module): + """Implementation of Character attention module in `CA-FCN. + + `_ + """ + + def __init__(self, in_channels=128, out_channels=128, deformable=False): + super().__init__() + assert isinstance(in_channels, int) + assert isinstance(deformable, bool) + + self.in_channels = in_channels + self.out_channels = out_channels + self.deformable = deformable + + # attention layers + self.attn_layer = nn.Sequential( + ConvModule( + in_channels, + in_channels, + 3, + stride=1, + padding=1, + norm_cfg=dict(type='BN')), + ConvModule( + in_channels, + 1, + 3, + stride=1, + padding=1, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='Sigmoid'))) + + conv_kwargs = dict( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=1, + padding=1) + if self.deformable: + self.conv = DeformConv2dPack(**conv_kwargs) + else: + self.conv = nn.Conv2d(**conv_kwargs) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, in_feat): + # Calculate attn map + attn_map = self.attn_layer(in_feat) # N * 1 * H * W + + in_feat = self.relu(self.bn(self.conv(in_feat))) + + out_feat_map = self._upsample_mul(in_feat, 1 + attn_map) + + return out_feat_map, attn_map + + def _upsample_add(self, x, y): + return F.interpolate(x, size=y.size()[2:]) + y + + def _upsample_mul(self, x, y): + return F.interpolate(x, size=y.size()[2:]) * y + + +class FeatGenerator(nn.Module): + """Generate attention-augmented stage feature from backbone stage + feature.""" + + def __init__(self, + in_channels=512, + out_channels=128, + deformable=True, + concat=False, + upsample=False, + with_attn=True): + super().__init__() + + self.concat = concat + self.upsample = upsample + self.with_attn = with_attn + + if with_attn: + self.char_attn = CharAttn( + in_channels=in_channels, + out_channels=out_channels, + deformable=deformable) + else: + self.char_attn = ConvModule( + in_channels, + out_channels, + 3, + stride=1, + padding=1, + norm_cfg=dict(type='BN')) + + if concat: + self.conv_to_concat = ConvModule( + out_channels, + out_channels, + 3, + stride=1, + padding=1, + norm_cfg=dict(type='BN')) + + kernel_size = (3, 1) if deformable else 3 + padding = (1, 0) if deformable else 1 + tmp_in_channels = out_channels * 2 if concat else out_channels + + self.conv_after_concat = ConvModule( + tmp_in_channels, + out_channels, + kernel_size, + stride=1, + padding=padding, + norm_cfg=dict(type='BN')) + + def forward(self, x, y=None, size=None): + if self.with_attn: + feat_map, attn_map = self.char_attn(x) + else: + feat_map = self.char_attn(x) + attn_map = feat_map + + if self.concat: + y = self.conv_to_concat(y) + feat_map = torch.cat((y, feat_map), dim=1) + + feat_map = self.conv_after_concat(feat_map) + + if self.upsample: + feat_map = F.interpolate(feat_map, size) + + return attn_map, feat_map + + +@NECKS.register_module() +class CAFCNNeck(nn.Module): + """Implementation of neck module in `CA-FCN. + + `_ + + Args: + in_channels (list[int]): Number of input channels for each scale. + out_channels (int): Number of output channels for each scale. + deformable (bool): If True, use deformable conv. + with_attn (bool): If True, add attention for each output feature map. + """ + + def __init__(self, + in_channels=[128, 256, 512, 512], + out_channels=128, + deformable=True, + with_attn=True): + super().__init__() + + self.deformable = deformable + self.with_attn = with_attn + + # stage_in5_to_out5 + self.s5 = FeatGenerator( + in_channels=in_channels[-1], + out_channels=out_channels, + deformable=deformable, + concat=False, + with_attn=with_attn) + + # stage_in4_to_out4 + self.s4 = FeatGenerator( + in_channels=in_channels[-2], + out_channels=out_channels, + deformable=deformable, + concat=True, + with_attn=with_attn) + + # stage_in3_to_out3 + self.s3 = FeatGenerator( + in_channels=in_channels[-3], + out_channels=out_channels, + deformable=False, + concat=True, + upsample=True, + with_attn=with_attn) + + # stage_in2_to_out2 + self.s2 = FeatGenerator( + in_channels=in_channels[-4], + out_channels=out_channels, + deformable=False, + concat=True, + upsample=True, + with_attn=with_attn) + + def init_weights(self): + pass + + def forward(self, feats): + in_stage1 = feats[0] + in_stage2, in_stage3 = feats[1], feats[2] + in_stage4, in_stage5 = feats[3], feats[4] + # out stage 5 + out_s5_attn_map, out_s5 = self.s5(in_stage5) + + # out stage 4 + out_s4_attn_map, out_s4 = self.s4(in_stage4, out_s5) + + # out stage 3 + out_s3_attn_map, out_s3 = self.s3(in_stage3, out_s4, + in_stage2.size()[2:]) + + # out stage 2 + out_s2_attn_map, out_s2 = self.s2(in_stage2, out_s3, + in_stage1.size()[2:]) + + return (out_s2_attn_map, out_s3_attn_map, out_s4_attn_map, + out_s5_attn_map, out_s2) diff --git a/mmocr/models/textrecog/necks/fpn_ocr.py b/mmocr/models/textrecog/necks/fpn_ocr.py new file mode 100644 index 00000000..c1e9f178 --- /dev/null +++ b/mmocr/models/textrecog/necks/fpn_ocr.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmdet.models.builder import NECKS + + +@NECKS.register_module() +class FPNOCR(nn.Module): + """FPN-like Network for segmentation based text recognition. + + Args: + in_channels (list[int]): Number of input channels for each scale. + out_channels (int): Number of output channels for each scale. + last_stage_only (bool): If True, output last stage only. + """ + + def __init__(self, in_channels, out_channels, last_stage_only=True): + super(FPNOCR, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + + self.last_stage_only = last_stage_only + + self.lateral_convs = nn.ModuleList() + self.smooth_convs_1x1 = nn.ModuleList() + self.smooth_convs_3x3 = nn.ModuleList() + + for i in range(self.num_ins): + l_conv = ConvModule( + in_channels[i], out_channels, 1, norm_cfg=dict(type='BN')) + self.lateral_convs.append(l_conv) + + for i in range(self.num_ins - 1): + s_conv_1x1 = ConvModule( + out_channels * 2, out_channels, 1, norm_cfg=dict(type='BN')) + s_conv_3x3 = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + norm_cfg=dict(type='BN')) + self.smooth_convs_1x1.append(s_conv_1x1) + self.smooth_convs_3x3.append(s_conv_3x3) + + def init_weights(self): + pass + + def _upsample_x2(self, x): + return F.interpolate(x, scale_factor=2, mode='bilinear') + + def forward(self, inputs): + lateral_features = [ + l_conv(inputs[i]) for i, l_conv in enumerate(self.lateral_convs) + ] + + outs = [] + for i in range(len(self.smooth_convs_3x3), 0, -1): # 3, 2, 1 + last_out = lateral_features[-1] if len(outs) == 0 else outs[-1] + upsample = self._upsample_x2(last_out) + upsample_cat = torch.cat((upsample, lateral_features[i - 1]), + dim=1) + smooth_1x1 = self.smooth_convs_1x1[i - 1](upsample_cat) + smooth_3x3 = self.smooth_convs_3x3[i - 1](smooth_1x1) + outs.append(smooth_3x3) + + return tuple(outs[-1:]) if self.last_stage_only else tuple(outs) diff --git a/mmocr/models/textrecog/necks/fpn_seg.py b/mmocr/models/textrecog/necks/fpn_seg.py new file mode 100644 index 00000000..997951e4 --- /dev/null +++ b/mmocr/models/textrecog/necks/fpn_seg.py @@ -0,0 +1,43 @@ +import torch.nn.functional as F +from mmcv.runner import auto_fp16 + +from mmdet.models.builder import NECKS +from mmdet.models.necks import FPN + + +@NECKS.register_module() +class FPNSeg(FPN): + """Feature Pyramid Network for segmentation based text recognition. + + Args: + in_channels (list[int]): Number of input channels for each scale. + out_channels (int): Number of output channels for each scale. + num_outs (int): Number of output scales. + upsample_param (dict | None): Config dict for interpolate layer. + Default: `dict(scale_factor=1.0, mode='nearest')` + last_stage_only (bool): If True, output last stage of FPN only. + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + upsample_param=None, + last_stage_only=True): + super().__init__(in_channels, out_channels, num_outs) + self.upsample_param = upsample_param + self.last_stage_only = last_stage_only + + @auto_fp16() + def forward(self, inputs): + outs = super().forward(inputs) + + outs = list(outs) + + if self.upsample_param is not None: + outs[0] = F.interpolate(outs[0], **self.upsample_param) + + if self.last_stage_only: + return tuple(outs[0:1]) + + return tuple(outs[::-1]) diff --git a/mmocr/models/textrecog/recognizer/__init__.py b/mmocr/models/textrecog/recognizer/__init__.py new file mode 100644 index 00000000..abc9c132 --- /dev/null +++ b/mmocr/models/textrecog/recognizer/__init__.py @@ -0,0 +1,13 @@ +from .base import BaseRecognizer +from .cafcn import CAFCNNet +from .crnn import CRNNNet +from .encode_decode_recognizer import EncodeDecodeRecognizer +from .robust_scanner import RobustScanner +from .sar import SARNet +from .seg_recognizer import SegRecognizer +from .transformer import TransformerNet + +__all__ = [ + 'BaseRecognizer', 'EncodeDecodeRecognizer', 'CRNNNet', 'SARNet', + 'TransformerNet', 'SegRecognizer', 'RobustScanner', 'CAFCNNet' +] diff --git a/mmocr/models/textrecog/recognizer/base.py b/mmocr/models/textrecog/recognizer/base.py new file mode 100644 index 00000000..6472a5a2 --- /dev/null +++ b/mmocr/models/textrecog/recognizer/base.py @@ -0,0 +1,236 @@ +import warnings +from abc import ABCMeta, abstractmethod +from collections import OrderedDict + +import mmcv +import torch +import torch.distributed as dist +import torch.nn as nn +from mmcv.runner import auto_fp16 +from mmcv.utils import print_log + +from mmdet.utils import get_root_logger +from mmocr.core import imshow_text_label + + +class BaseRecognizer(nn.Module, metaclass=ABCMeta): + """Base class for text recognition.""" + + def __init__(self): + super().__init__() + self.fp16_enabled = False + + @abstractmethod + def extract_feat(self, imgs): + """Extract features from images.""" + pass + + @abstractmethod + def forward_train(self, imgs, img_metas, **kwargs): + """ + Args: + img (tensor): tensors with shape (N, C, H, W). + Typically should be mean centered and std scaled. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details of the values of these keys, see + :class:`mmdet.datasets.pipelines.Collect`. + kwargs (keyword arguments): Specific to concrete implementation. + """ + pass + + @abstractmethod + def simple_test(self, img, img_metas, **kwargs): + pass + + @abstractmethod + def aug_test(self, imgs, img_metas, **kwargs): + """Test function with test time augmentation. + + Args: + imgs (list[tensor]): Tensor should have shape NxCxHxW, + which contains all images in the batch. + img_metas (list[list[dict]]): The metadata of images. + """ + pass + + def init_weights(self, pretrained=None): + """Initialize the weights for detector. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if pretrained is not None: + logger = get_root_logger() + print_log(f'load model from: {pretrained}', logger=logger) + + def forward_test(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (tensor | list[tensor]): Tensor should have shape NxCxHxW, + which contains all images in the batch. + img_metas (list[dict] | list[list[dict]]): + The outer list indicates images in a batch. + """ + if isinstance(imgs, list): + assert len(imgs) == len(img_metas) + assert len(imgs) > 0 + assert imgs[0].size(0) == 1, 'aug test does not support ' \ + 'inference with batch size ' \ + f'{imgs[0].size(0)}' + return self.aug_test(imgs, img_metas, **kwargs) + + return self.simple_test(imgs, img_metas, **kwargs) + + @auto_fp16(apply_to=('img', )) + def forward(self, img, img_metas, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note that img and img_meta are single-nested (i.e. tensor and + list[dict]). + """ + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + + return self.forward_test(img, img_metas, **kwargs) + + def _parse_losses(self, losses): + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw outputs of the network, which usually contain + losses and other necessary infomation. + + Returns: + tuple[tensor, dict]: (loss, log_vars), loss is the loss tensor \ + which may be a weighted sum of all losses, log_vars contains \ + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError( + f'{loss_name} is not a tensor or list of tensors') + + loss = sum(_value for _key, _value in log_vars.items() + if 'loss' in _key) + + log_vars['loss'] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars + + def train_step(self, data, optimizer): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer update, which are done by an optimizer + hook. Note that in some complicated cases or models (e.g. GAN), + the whole process (including the back propagation and optimizer update) + is also defined by this method. + + Args: + data (dict): The outputs of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \ + ``num_samples``. + + - ``loss`` is a tensor for back propagation, which is a \ + weighted sum of multiple losses. + - ``log_vars`` contains all the variables to be sent to the + logger. + - ``num_samples`` indicates the batch size used for \ + averaging the logs (Note: for the \ + DDP model, num_samples refers to the batch size for each GPU). + """ + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(data['img_metas'])) + + return outputs + + def val_step(self, data, optimizer): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but is + used during val epochs. Note that the evaluation after training epochs + is not implemented by this method, but by an evaluation hook. + """ + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(data['img_metas'])) + + return outputs + + def show_result(self, + img, + result, + gt_label='', + win_name='', + show=False, + wait_time=0, + out_file=None, + **kwargs): + """Draw `result` on `img`. + + Args: + img (str or tensor): The image to be displayed. + result (dict): The results to draw on `img`. + gt_label (str): Ground truth label of img. + win_name (str): The window name. + wait_time (int): Value of waitKey param. + Default: 0. + show (bool): Whether to show the image. + Default: False. + out_file (str or None): The output filename. + Default: None. + + Returns: + img (tensor): Only if not `show` or `out_file`. + """ + img = mmcv.imread(img) + img = img.copy() + pred_label = None + if 'text' in result.keys(): + pred_label = result['text'] + + # if out_file specified, do not show image in window + if out_file is not None: + show = False + # draw text label + if pred_label is not None: + img = imshow_text_label( + img, + pred_label, + gt_label, + show=show, + win_name=win_name, + wait_time=wait_time, + out_file=out_file) + + if not (show or out_file): + warnings.warn('show==False and out_file is not specified, only ' + 'result image will be returned') + return img + + return img diff --git a/mmocr/models/textrecog/recognizer/cafcn.py b/mmocr/models/textrecog/recognizer/cafcn.py new file mode 100644 index 00000000..6acade5e --- /dev/null +++ b/mmocr/models/textrecog/recognizer/cafcn.py @@ -0,0 +1,7 @@ +from mmdet.models.builder import DETECTORS +from .seg_recognizer import SegRecognizer + + +@DETECTORS.register_module() +class CAFCNNet(SegRecognizer): + """Implementation of `CAFCN `_""" diff --git a/mmocr/models/textrecog/recognizer/crnn.py b/mmocr/models/textrecog/recognizer/crnn.py new file mode 100644 index 00000000..1ff68f7f --- /dev/null +++ b/mmocr/models/textrecog/recognizer/crnn.py @@ -0,0 +1,18 @@ +import torch +import torch.nn.functional as F + +from mmdet.models.builder import DETECTORS +from .encode_decode_recognizer import EncodeDecodeRecognizer + + +@DETECTORS.register_module() +class CRNNNet(EncodeDecodeRecognizer): + """CTC-loss based recognizer.""" + + def forward_conversion(self, params, img): + x = self.extract_feat(img) + x = self.encoder(x) + outs = self.decoder(x) + outs = F.softmax(outs, dim=2) + params = torch.pow(params, 1) + return outs, params diff --git a/mmocr/models/textrecog/recognizer/encode_decode_recognizer.py b/mmocr/models/textrecog/recognizer/encode_decode_recognizer.py new file mode 100644 index 00000000..793f3853 --- /dev/null +++ b/mmocr/models/textrecog/recognizer/encode_decode_recognizer.py @@ -0,0 +1,173 @@ +from mmdet.models.builder import DETECTORS, build_backbone, build_loss +from mmocr.models.builder import (build_convertor, build_decoder, + build_encoder, build_preprocessor) +from .base import BaseRecognizer + + +@DETECTORS.register_module() +class EncodeDecodeRecognizer(BaseRecognizer): + """Base class for encode-decode recognizer.""" + + def __init__(self, + preprocessor=None, + backbone=None, + encoder=None, + decoder=None, + loss=None, + label_convertor=None, + train_cfg=None, + test_cfg=None, + max_seq_len=40, + pretrained=None): + super().__init__() + + # Label convertor (str2tensor, tensor2str) + assert label_convertor is not None + label_convertor.update(max_seq_len=max_seq_len) + self.label_convertor = build_convertor(label_convertor) + + # Preprocessor module, e.g., TPS + self.preprocessor = None + if preprocessor is not None: + self.preprocessor = build_preprocessor(preprocessor) + + # Backbone + assert backbone is not None + self.backbone = build_backbone(backbone) + + # Encoder module + self.encoder = None + if encoder is not None: + self.encoder = build_encoder(encoder) + + # Decoder module + assert decoder is not None + decoder.update(num_classes=self.label_convertor.num_classes()) + decoder.update(start_idx=self.label_convertor.start_idx) + decoder.update(padding_idx=self.label_convertor.padding_idx) + decoder.update(max_seq_len=max_seq_len) + self.decoder = build_decoder(decoder) + + # Loss + assert loss is not None + loss.update(ignore_index=self.label_convertor.padding_idx) + self.loss = build_loss(loss) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.max_seq_len = max_seq_len + self.init_weights(pretrained=pretrained) + + def init_weights(self, pretrained=None): + """Initialize the weights of recognizer.""" + super().init_weights(pretrained) + + if self.preprocessor is not None: + self.preprocessor.init_weights() + + self.backbone.init_weights() + + if self.encoder is not None: + self.encoder.init_weights() + + self.decoder.init_weights() + + def extract_feat(self, img): + """Directly extract features from the backbone.""" + if self.preprocessor is not None: + img = self.preprocessor(img) + + x = self.backbone(img) + + return x + + def forward_train(self, img, img_metas): + """ + Args: + img (tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A list of image info dict where each dict + contains: 'img_shape', 'filename', and may also contain + 'ori_shape', and 'img_norm_cfg'. + For details on the values of these keys see + :class:`mmdet.datasets.pipelines.Collect`. + + Returns: + dict[str, tensor]: A dictionary of loss components. + """ + feat = self.extract_feat(img) + + gt_labels = [img_meta['text'] for img_meta in img_metas] + + targets_dict = self.label_convertor.str2tensor(gt_labels) + + out_enc = None + if self.encoder is not None: + out_enc = self.encoder(feat, img_metas) + + out_dec = self.decoder( + feat, out_enc, targets_dict, img_metas, train_mode=True) + + loss_inputs = ( + out_dec, + targets_dict, + ) + losses = self.loss(*loss_inputs) + + return losses + + def simple_test(self, img, img_metas, **kwargs): + """Test function with test time augmentation. + + Args: + imgs (torch.Tensor): Image input tensor. + img_metas (list[dict]): List of image information. + + Returns: + list[str]: Text label result of each image. + """ + feat = self.extract_feat(img) + + out_enc = None + if self.encoder is not None: + out_enc = self.encoder(feat, img_metas) + + out_dec = self.decoder( + feat, out_enc, None, img_metas, train_mode=False) + + label_indexes, label_scores = \ + self.label_convertor.tensor2idx(out_dec, img_metas) + label_strings = self.label_convertor.idx2str(label_indexes) + + # flatten batch results + results = [] + for string, score in zip(label_strings, label_scores): + results.append(dict(text=string, score=score)) + + return results + + def merge_aug_results(self, aug_results): + out_text, out_score = '', -1 + for result in aug_results: + text = result[0]['text'] + score = sum(result[0]['score']) / max(1, len(text)) + if score > out_score: + out_text = text + out_score = score + out_results = [dict(text=out_text, score=out_score)] + return out_results + + def aug_test(self, imgs, img_metas, **kwargs): + """Test function as well as time augmentation. + + Args: + imgs (list[tensor]): Tensor should have shape NxCxHxW, + which contains all images in the batch. + img_metas (list[list[dict]]): The metadata of images. + """ + aug_results = [] + for img, img_meta in zip(imgs, img_metas): + result = self.simple_test(img, img_meta, **kwargs) + aug_results.append(result) + + return self.merge_aug_results(aug_results) diff --git a/mmocr/models/textrecog/recognizer/sar.py b/mmocr/models/textrecog/recognizer/sar.py new file mode 100644 index 00000000..bce67dca --- /dev/null +++ b/mmocr/models/textrecog/recognizer/sar.py @@ -0,0 +1,7 @@ +from mmdet.models.builder import DETECTORS +from .encode_decode_recognizer import EncodeDecodeRecognizer + + +@DETECTORS.register_module() +class SARNet(EncodeDecodeRecognizer): + """Implementation of `SAR `_""" diff --git a/mmocr/models/textrecog/recognizer/seg_recognizer.py b/mmocr/models/textrecog/recognizer/seg_recognizer.py new file mode 100644 index 00000000..e013e1d7 --- /dev/null +++ b/mmocr/models/textrecog/recognizer/seg_recognizer.py @@ -0,0 +1,153 @@ +from mmdet.models.builder import (DETECTORS, build_backbone, build_head, + build_loss, build_neck) +from mmocr.models.builder import build_convertor, build_preprocessor +from .base import BaseRecognizer + + +@DETECTORS.register_module() +class SegRecognizer(BaseRecognizer): + """Base class for segmentation based recognizer.""" + + def __init__(self, + preprocessor=None, + backbone=None, + neck=None, + head=None, + loss=None, + label_convertor=None, + train_cfg=None, + test_cfg=None, + pretrained=None): + super().__init__() + + # Label_convertor + assert label_convertor is not None + self.label_convertor = build_convertor(label_convertor) + + # Preprocessor module, e.g., TPS + self.preprocessor = None + if preprocessor is not None: + self.preprocessor = build_preprocessor(preprocessor) + + # Backbone + assert backbone is not None + self.backbone = build_backbone(backbone) + + # Neck + assert neck is not None + self.neck = build_neck(neck) + + # Head + assert head is not None + head.update(num_classes=self.label_convertor.num_classes()) + self.head = build_head(head) + + # Loss + assert loss is not None + self.loss = build_loss(loss) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.init_weights(pretrained=pretrained) + + def init_weights(self, pretrained=None): + """Initialize the weights of recognizer.""" + super().init_weights(pretrained) + + if self.preprocessor is not None: + self.preprocessor.init_weights() + + self.backbone.init_weights(pretrained=pretrained) + + if self.neck is not None: + self.neck.init_weights() + + self.head.init_weights() + + def extract_feat(self, img): + """Directly extract features from the backbone.""" + if self.preprocessor is not None: + img = self.preprocessor(img) + + x = self.backbone(img) + + return x + + def forward_train(self, img, img_metas, gt_kernels=None): + """ + Args: + img (tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A list of image info dict where each dict + contains: 'img_shape', 'filename', and may also contain + 'ori_shape', and 'img_norm_cfg'. + For details on the values of these keys see + :class:`mmdet.datasets.pipelines.Collect`. + + Returns: + dict[str, tensor]: A dictionary of loss components. + """ + + feats = self.extract_feat(img) + + out_neck = self.neck(feats) + + out_head = self.head(out_neck) + + loss_inputs = (out_neck, out_head, gt_kernels) + + losses = self.loss(*loss_inputs) + + return losses + + def simple_test(self, img, img_metas, **kwargs): + """Test function without test time augmentation. + + Args: + imgs (torch.Tensor): Image input tensor. + img_metas (list[dict]): List of image information. + + Returns: + list[str]: Text label result of each image. + """ + + feat = self.extract_feat(img) + + out_neck = self.neck(feat) + + out_head = self.head(out_neck) + + texts, scores = self.label_convertor.tensor2str(out_head, img_metas) + + # flatten batch results + results = [] + for text, score in zip(texts, scores): + results.append(dict(text=text, score=score)) + + return results + + def merge_aug_results(self, aug_results): + out_text, out_score = '', -1 + for result in aug_results: + text = result[0]['text'] + score = sum(result[0]['score']) / max(1, len(text)) + if score > out_score: + out_text = text + out_score = score + out_results = [dict(text=out_text, score=out_score)] + return out_results + + def aug_test(self, imgs, img_metas, **kwargs): + """Test function with test time augmentation. + + Args: + imgs (list[tensor]): Tensor should have shape NxCxHxW, + which contains all images in the batch. + img_metas (list[list[dict]]): The metadata of images. + """ + aug_results = [] + for img, img_meta in zip(imgs, img_metas): + result = self.simple_test(img, img_meta, **kwargs) + aug_results.append(result) + + return self.merge_aug_results(aug_results) diff --git a/mmocr/models/textrecog/recognizer/transformer.py b/mmocr/models/textrecog/recognizer/transformer.py new file mode 100644 index 00000000..86636424 --- /dev/null +++ b/mmocr/models/textrecog/recognizer/transformer.py @@ -0,0 +1,7 @@ +from mmdet.models.builder import DETECTORS +from .encode_decode_recognizer import EncodeDecodeRecognizer + + +@DETECTORS.register_module() +class TransformerNet(EncodeDecodeRecognizer): + """Implementation of Transformer based OCR.""" diff --git a/tests/test_dataset/test_base_dataset.py b/tests/test_dataset/test_base_dataset.py new file mode 100644 index 00000000..43260d37 --- /dev/null +++ b/tests/test_dataset/test_base_dataset.py @@ -0,0 +1,74 @@ +import os.path as osp +import tempfile + +import numpy as np +import pytest + +from mmocr.datasets.base_dataset import BaseDataset + + +def _create_dummy_ann_file(ann_file): + ann_info1 = 'sample1.jpg hello' + ann_info2 = 'sample2.jpg world' + + with open(ann_file, 'w') as fw: + for ann_info in [ann_info1, ann_info2]: + fw.write(ann_info + '\n') + + +def _create_dummy_loader(): + loader = dict( + type='HardDiskLoader', + repeat=1, + parser=dict(type='LineStrParser', keys=['file_name', 'text'])) + return loader + + +def test_custom_dataset(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + ann_file = osp.join(tmp_dir.name, 'fake_data.txt') + _create_dummy_ann_file(ann_file) + loader = _create_dummy_loader() + + for mode in [True, False]: + dataset = BaseDataset(ann_file, loader, pipeline=[], test_mode=mode) + + # test len + assert len(dataset) == len(dataset.data_infos) + + # test set group flag + assert np.allclose(dataset.flag, [0, 0]) + + # test prepare_train_img + expect_results = { + 'img_info': { + 'file_name': 'sample1.jpg', + 'text': 'hello' + }, + 'img_prefix': '' + } + assert dataset.prepare_train_img(0) == expect_results + + # test prepare_test_img + assert dataset.prepare_test_img(0) == expect_results + + # test __getitem__ + assert dataset[0] == expect_results + + # test get_next_index + assert dataset._get_next_index(0) == 1 + + # test format_resuls + expect_results_copy = { + key: value + for key, value in expect_results.items() + } + dataset.format_results(expect_results) + assert expect_results_copy == expect_results + + # test evaluate + with pytest.raises(NotImplementedError): + dataset.evaluate(expect_results) + + tmp_dir.cleanup() diff --git a/tests/test_dataset/test_crop.py b/tests/test_dataset/test_crop.py new file mode 100644 index 00000000..99155344 --- /dev/null +++ b/tests/test_dataset/test_crop.py @@ -0,0 +1,96 @@ +import math + +import numpy as np +import pytest + +from mmocr.datasets.pipelines.crop import (box_jitter, convert_canonical, + crop_img, sort_vertex, warp_img) + + +def test_order_vertex(): + dummy_points_x = [20, 20, 120, 120] + dummy_points_y = [20, 40, 40, 20] + + with pytest.raises(AssertionError): + sort_vertex([], dummy_points_y) + with pytest.raises(AssertionError): + sort_vertex(dummy_points_x, []) + + ordered_points_x, ordered_points_y = sort_vertex(dummy_points_x, + dummy_points_y) + + expect_points_x = [20, 120, 120, 20] + expect_points_y = [20, 20, 40, 40] + + assert np.allclose(ordered_points_x, expect_points_x) + assert np.allclose(ordered_points_y, expect_points_y) + + +def test_convert_canonical(): + dummy_points_x = [120, 120, 20, 20] + dummy_points_y = [20, 40, 40, 20] + + with pytest.raises(AssertionError): + convert_canonical([], dummy_points_y) + with pytest.raises(AssertionError): + convert_canonical(dummy_points_x, []) + + ordered_points_x, ordered_points_y = convert_canonical( + dummy_points_x, dummy_points_y) + + expect_points_x = [20, 120, 120, 20] + expect_points_y = [20, 20, 40, 40] + + assert np.allclose(ordered_points_x, expect_points_x) + assert np.allclose(ordered_points_y, expect_points_y) + + +def test_box_jitter(): + dummy_points_x = [20, 120, 120, 20] + dummy_points_y = [20, 20, 40, 40] + + kwargs = dict(jitter_ratio_x=0.0, jitter_ratio_y=0.0) + + with pytest.raises(AssertionError): + box_jitter([], dummy_points_y) + with pytest.raises(AssertionError): + box_jitter(dummy_points_x, []) + with pytest.raises(AssertionError): + box_jitter(dummy_points_x, dummy_points_y, jitter_ratio_x=1.) + with pytest.raises(AssertionError): + box_jitter(dummy_points_x, dummy_points_y, jitter_ratio_y=1.) + + box_jitter(dummy_points_x, dummy_points_y, **kwargs) + + assert np.allclose(dummy_points_x, [20, 120, 120, 20]) + assert np.allclose(dummy_points_y, [20, 20, 40, 40]) + + +def test_opencv_crop(): + dummy_img = np.ones((600, 600, 3), dtype=np.uint8) + dummy_box = [20, 20, 120, 20, 120, 40, 20, 40] + + cropped_img = warp_img(dummy_img, dummy_box) + + with pytest.raises(AssertionError): + warp_img(dummy_img, []) + with pytest.raises(AssertionError): + warp_img(dummy_img, [20, 40, 40, 20]) + + assert math.isclose(cropped_img.shape[0], 20) + assert math.isclose(cropped_img.shape[1], 100) + + +def test_min_rect_crop(): + dummy_img = np.ones((600, 600, 3), dtype=np.uint8) + dummy_box = [20, 20, 120, 20, 120, 40, 20, 40] + + cropped_img = crop_img(dummy_img, dummy_box) + + with pytest.raises(AssertionError): + crop_img(dummy_img, []) + with pytest.raises(AssertionError): + crop_img(dummy_img, [20, 40, 40, 20]) + + assert math.isclose(cropped_img.shape[0], 20) + assert math.isclose(cropped_img.shape[1], 100) diff --git a/tests/test_dataset/test_detect_dataset.py b/tests/test_dataset/test_detect_dataset.py new file mode 100644 index 00000000..83480c45 --- /dev/null +++ b/tests/test_dataset/test_detect_dataset.py @@ -0,0 +1,83 @@ +import json +import os.path as osp +import tempfile + +import numpy as np + +from mmocr.datasets.text_det_dataset import TextDetDataset + + +def _create_dummy_ann_file(ann_file): + ann_info1 = { + 'file_name': + 'sample1.jpg', + 'height': + 640, + 'width': + 640, + 'annotations': [{ + 'iscrowd': 0, + 'category_id': 1, + 'bbox': [50, 70, 80, 100], + 'segmentation': [[50, 70, 80, 70, 80, 100, 50, 100]] + }, { + 'iscrowd': + 1, + 'category_id': + 1, + 'bbox': [120, 140, 200, 200], + 'segmentation': [[120, 140, 200, 140, 200, 200, 120, 200]] + }] + } + + with open(ann_file, 'w') as fw: + fw.write(json.dumps(ann_info1) + '\n') + + +def _create_dummy_loader(): + loader = dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineJsonParser', + keys=['file_name', 'height', 'width', 'annotations'])) + return loader + + +def test_detect_dataset(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + ann_file = osp.join(tmp_dir.name, 'fake_data.txt') + _create_dummy_ann_file(ann_file) + + # test initialization + loader = _create_dummy_loader() + dataset = TextDetDataset(ann_file, loader, pipeline=[]) + + # test _parse_ann_info + img_ann_info = dataset.data_infos[0] + ann = dataset._parse_anno_info(img_ann_info['annotations']) + print(ann['bboxes']) + assert np.allclose(ann['bboxes'], [[50., 70., 80., 100.]]) + assert np.allclose(ann['labels'], [1]) + assert np.allclose(ann['bboxes_ignore'], [[120, 140, 200, 200]]) + assert np.allclose(ann['masks'], [[[50, 70, 80, 70, 80, 100, 50, 100]]]) + assert np.allclose(ann['masks_ignore'], + [[[120, 140, 200, 140, 200, 200, 120, 200]]]) + + tmp_dir.cleanup() + + # test prepare_train_img + pipeline_results = dataset.prepare_train_img(0) + assert np.allclose(pipeline_results['bbox_fields'], []) + assert np.allclose(pipeline_results['mask_fields'], []) + assert np.allclose(pipeline_results['seg_fields'], []) + expect_img_info = {'filename': 'sample1.jpg', 'height': 640, 'width': 640} + assert pipeline_results['img_info'] == expect_img_info + + # test evluation + metrics = 'hmean-iou' + results = [{'boundary_result': [[50, 70, 80, 70, 80, 100, 50, 100, 1]]}] + eval_res = dataset.evaluate(results, metrics) + + assert eval_res['hmean-iou:hmean'] == 1 diff --git a/tests/test_dataset/test_loader.py b/tests/test_dataset/test_loader.py new file mode 100644 index 00000000..2d3de5c1 --- /dev/null +++ b/tests/test_dataset/test_loader.py @@ -0,0 +1,71 @@ +import json +import os.path as osp +import tempfile + +import pytest +from tools.data.utils.txt2lmdb import converter + +from mmocr.datasets.utils.loader import HardDiskLoader, LmdbLoader, Loader + + +def _create_dummy_line_str_file(ann_file): + ann_info1 = 'sample1.jpg hello' + ann_info2 = 'sample2.jpg world' + + with open(ann_file, 'w') as fw: + for ann_info in [ann_info1, ann_info2]: + fw.write(ann_info + '\n') + + +def _create_dummy_line_json_file(ann_file): + ann_info1 = {'filename': 'sample1.jpg', 'text': 'hello'} + ann_info2 = {'filename': 'sample2.jpg', 'text': 'world'} + + with open(ann_file, 'w') as fw: + for ann_info in [ann_info1, ann_info2]: + fw.write(json.dumps(ann_info) + '\n') + + +def test_loader(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + ann_file = osp.join(tmp_dir.name, 'fake_data.txt') + _create_dummy_line_str_file(ann_file) + + parser = dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ') + + with pytest.raises(AssertionError): + Loader(ann_file, parser, repeat=0) + with pytest.raises(AssertionError): + Loader(ann_file, [], repeat=1) + with pytest.raises(AssertionError): + Loader('sample.txt', parser, repeat=1) + with pytest.raises(NotImplementedError): + loader = Loader(ann_file, parser, repeat=1) + print(loader) + + # test text loader and line str parser + text_loader = HardDiskLoader(ann_file, parser, repeat=1) + assert len(text_loader) == 2 + assert text_loader.ori_data_infos[0] == 'sample1.jpg hello' + assert text_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'} + + # test text loader and linedict parser + _create_dummy_line_json_file(ann_file) + json_parser = dict(type='LineJsonParser', keys=['filename', 'text']) + text_loader = HardDiskLoader(ann_file, json_parser, repeat=1) + assert text_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'} + + # test lmdb loader and line str parser + _create_dummy_line_str_file(ann_file) + lmdb_file = osp.join(tmp_dir.name, 'fake_data.lmdb') + converter(ann_file, lmdb_file) + + lmdb_loader = LmdbLoader(lmdb_file, parser, repeat=1) + assert lmdb_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'} + + tmp_dir.cleanup() diff --git a/tests/test_dataset/test_ocr_dataset.py b/tests/test_dataset/test_ocr_dataset.py new file mode 100644 index 00000000..1787db88 --- /dev/null +++ b/tests/test_dataset/test_ocr_dataset.py @@ -0,0 +1,51 @@ +import math +import os.path as osp +import tempfile + +from mmocr.datasets.ocr_dataset import OCRDataset + + +def _create_dummy_ann_file(ann_file): + ann_info1 = 'sample1.jpg hello' + ann_info2 = 'sample2.jpg world' + + with open(ann_file, 'w') as fw: + for ann_info in [ann_info1, ann_info2]: + fw.write(ann_info + '\n') + + +def _create_dummy_loader(): + loader = dict( + type='HardDiskLoader', + repeat=1, + parser=dict(type='LineStrParser', keys=['file_name', 'text'])) + return loader + + +def test_detect_dataset(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + ann_file = osp.join(tmp_dir.name, 'fake_data.txt') + _create_dummy_ann_file(ann_file) + + # test initialization + loader = _create_dummy_loader() + dataset = OCRDataset(ann_file, loader, pipeline=[]) + + tmp_dir.cleanup() + + # test pre_pipeline + img_info = dataset.data_infos[0] + results = dict(img_info=img_info) + dataset.pre_pipeline(results) + assert results['img_prefix'] == dataset.img_prefix + assert results['text'] == img_info['text'] + + # test evluation + metric = 'acc' + results = [{'text': 'hello'}, {'text': 'worl'}] + eval_res = dataset.evaluate(results, metric) + + assert math.isclose(eval_res['word_acc'], 0.5, abs_tol=1e-4) + assert math.isclose(eval_res['char_precision'], 1.0, abs_tol=1e-4) + assert math.isclose(eval_res['char_recall'], 0.9, abs_tol=1e-4) diff --git a/tests/test_dataset/test_ocr_seg_dataset.py b/tests/test_dataset/test_ocr_seg_dataset.py new file mode 100644 index 00000000..0ecfcfdf --- /dev/null +++ b/tests/test_dataset/test_ocr_seg_dataset.py @@ -0,0 +1,127 @@ +import json +import math +import os.path as osp +import tempfile + +import pytest + +from mmocr.datasets.ocr_seg_dataset import OCRSegDataset + + +def _create_dummy_ann_file(ann_file): + ann_info1 = { + 'file_name': + 'sample1.png', + 'annotations': [{ + 'char_text': + 'F', + 'char_box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0] + }, { + 'char_text': + 'r', + 'char_box': [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0] + }, { + 'char_text': + 'o', + 'char_box': [33.0, 2.0, 43.0, 2.0, 36.0, 12.0, 25.0, 12.0] + }, { + 'char_text': + 'm', + 'char_box': [46.0, 2.0, 61.0, 2.0, 53.0, 12.0, 39.0, 12.0] + }, { + 'char_text': + ':', + 'char_box': [61.0, 2.0, 69.0, 2.0, 63.0, 12.0, 55.0, 12.0] + }], + 'text': + 'From:' + } + ann_info2 = { + 'file_name': + 'sample2.png', + 'annotations': [{ + 'char_text': 'o', + 'char_box': [0.0, 5.0, 7.0, 5.0, 9.0, 15.0, 2.0, 15.0] + }, { + 'char_text': + 'u', + 'char_box': [7.0, 4.0, 14.0, 4.0, 18.0, 18.0, 11.0, 18.0] + }, { + 'char_text': + 't', + 'char_box': [13.0, 1.0, 19.0, 2.0, 24.0, 18.0, 17.0, 18.0] + }], + 'text': + 'out' + } + + with open(ann_file, 'w') as fw: + for ann_info in [ann_info1, ann_info2]: + fw.write(json.dumps(ann_info) + '\n') + + return ann_info1, ann_info2 + + +def _create_dummy_loader(): + loader = dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineJsonParser', keys=['file_name', 'text', 'annotations'])) + return loader + + +def test_ocr_seg_dataset(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + ann_file = osp.join(tmp_dir.name, 'fake_data.txt') + ann_info1, ann_info2 = _create_dummy_ann_file(ann_file) + + # test initialization + loader = _create_dummy_loader() + dataset = OCRSegDataset(ann_file, loader, pipeline=[]) + + tmp_dir.cleanup() + + # test pre_pipeline + img_info = dataset.data_infos[0] + results = dict(img_info=img_info) + dataset.pre_pipeline(results) + assert results['img_prefix'] == dataset.img_prefix + + # test _parse_anno_info + annos = ann_info1['annotations'] + with pytest.raises(AssertionError): + dataset._parse_anno_info(annos[0]) + annos2 = ann_info2['annotations'] + with pytest.raises(AssertionError): + dataset._parse_anno_info([{'char_text': 'i'}]) + with pytest.raises(AssertionError): + dataset._parse_anno_info([{'char_box': [1, 2, 3, 4, 5, 6, 7, 8]}]) + annos2[0]['char_box'] = [1, 2, 3] + with pytest.raises(AssertionError): + dataset._parse_anno_info(annos2) + + return_anno = dataset._parse_anno_info(annos) + assert return_anno['chars'] == ['F', 'r', 'o', 'm', ':'] + assert len(return_anno['char_rects']) == 5 + + # test prepare_train_img + expect_results = { + 'img_info': { + 'filename': 'sample1.png' + }, + 'img_prefix': '', + 'ann_info': return_anno + } + data = dataset.prepare_train_img(0) + assert data == expect_results + + # test evluation + metric = 'acc' + results = [{'text': 'From:'}, {'text': 'ou'}] + eval_res = dataset.evaluate(results, metric) + + assert math.isclose(eval_res['word_acc'], 0.5, abs_tol=1e-4) + assert math.isclose(eval_res['char_precision'], 1.0, abs_tol=1e-4) + assert math.isclose(eval_res['char_recall'], 0.857, abs_tol=1e-4) diff --git a/tests/test_dataset/test_ocr_seg_target.py b/tests/test_dataset/test_ocr_seg_target.py new file mode 100644 index 00000000..45b85352 --- /dev/null +++ b/tests/test_dataset/test_ocr_seg_target.py @@ -0,0 +1,93 @@ +import os.path as osp +import tempfile + +import numpy as np +import pytest + +from mmocr.datasets.pipelines.ocr_seg_targets import OCRSegTargets + + +def _create_dummy_dict_file(dict_file): + chars = list('0123456789') + with open(dict_file, 'w') as fw: + for char in chars: + fw.write(char + '\n') + + +def test_ocr_segm_targets(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy dict file + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + _create_dummy_dict_file(dict_file) + # dummy label convertor + label_convertor = dict( + type='SegConvertor', + dict_file=dict_file, + with_unknown=True, + lower=True) + # test init + with pytest.raises(AssertionError): + OCRSegTargets(None, 0.5, 0.5) + with pytest.raises(AssertionError): + OCRSegTargets(label_convertor, '1by2', 0.5) + with pytest.raises(AssertionError): + OCRSegTargets(label_convertor, 0.5, 2) + + ocr_seg_tgt = OCRSegTargets(label_convertor, 0.5, 0.5) + # test generate kernels + img_size = (8, 8) + pad_size = (8, 10) + char_boxes = [[2, 2, 6, 6]] + char_idxs = [2] + + with pytest.raises(AssertionError): + ocr_seg_tgt.generate_kernels(8, pad_size, char_boxes, char_idxs, 0.5, + True) + with pytest.raises(AssertionError): + ocr_seg_tgt.generate_kernels(img_size, pad_size, [2, 2, 6, 6], + char_idxs, 0.5, True) + with pytest.raises(AssertionError): + ocr_seg_tgt.generate_kernels(img_size, pad_size, char_boxes, 2, 0.5, + True) + + attn_tgt = ocr_seg_tgt.generate_kernels( + img_size, pad_size, char_boxes, char_idxs, 0.5, binary=True) + expect_attn_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255], + [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], + [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], + [0, 0, 0, 1, 1, 1, 0, 0, 255, 255], + [0, 0, 0, 1, 1, 1, 0, 0, 255, 255], + [0, 0, 0, 1, 1, 1, 0, 0, 255, 255], + [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], + [0, 0, 0, 0, 0, 0, 0, 0, 255, 255]] + assert np.allclose(attn_tgt, np.array(expect_attn_tgt, dtype=np.int32)) + + segm_tgt = ocr_seg_tgt.generate_kernels( + img_size, pad_size, char_boxes, char_idxs, 0.5, binary=False) + expect_segm_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255], + [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], + [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], + [0, 0, 0, 2, 2, 2, 0, 0, 255, 255], + [0, 0, 0, 2, 2, 2, 0, 0, 255, 255], + [0, 0, 0, 2, 2, 2, 0, 0, 255, 255], + [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], + [0, 0, 0, 0, 0, 0, 0, 0, 255, 255]] + assert np.allclose(segm_tgt, np.array(expect_segm_tgt, dtype=np.int32)) + + # test __call__ + results = {} + results['img_shape'] = (4, 4, 3) + results['resize_shape'] = (8, 8, 3) + results['pad_shape'] = (8, 10) + results['ann_info'] = {} + results['ann_info']['char_rects'] = [[1, 1, 3, 3]] + results['ann_info']['chars'] = ['1'] + + results = ocr_seg_tgt(results) + assert results['mask_fields'] == ['gt_kernels'] + assert np.allclose(results['gt_kernels'].masks[0], + np.array(expect_attn_tgt, dtype=np.int32)) + assert np.allclose(results['gt_kernels'].masks[1], + np.array(expect_segm_tgt, dtype=np.int32)) + + tmp_dir.cleanup() diff --git a/tests/test_dataset/test_ocr_transforms.py b/tests/test_dataset/test_ocr_transforms.py new file mode 100644 index 00000000..a568b908 --- /dev/null +++ b/tests/test_dataset/test_ocr_transforms.py @@ -0,0 +1,94 @@ +import math +import unittest.mock as mock + +import numpy as np +import torch +import torchvision.transforms.functional as TF + +import mmocr.datasets.pipelines.ocr_transforms as transforms + + +def test_resize_ocr(): + input_img = np.ones((64, 256, 3), dtype=np.uint8) + + rci = transforms.ResizeOCR( + 32, min_width=32, max_width=160, keep_aspect_ratio=True) + results = {'img_shape': input_img.shape, 'img': input_img} + + # test call + results = rci(results) + assert np.allclose([32, 160, 3], results['pad_shape']) + assert np.allclose([32, 160, 3], results['img'].shape) + assert 'valid_ratio' in results + assert math.isclose(results['valid_ratio'], 0.8) + assert math.isclose(np.sum(results['img'][:, 129:, :]), 0) + + rci = transforms.ResizeOCR( + 32, min_width=32, max_width=160, keep_aspect_ratio=False) + results = {'img_shape': input_img.shape, 'img': input_img} + results = rci(results) + assert math.isclose(results['valid_ratio'], 1) + + +def test_to_tensor(): + input_img = np.ones((64, 256, 3), dtype=np.uint8) + + expect_output = TF.to_tensor(input_img) + rci = transforms.ToTensorOCR() + + results = {'img': input_img} + results = rci(results) + + assert np.allclose(results['img'].numpy(), expect_output.numpy()) + + +def test_normalize(): + inputs = torch.zeros(3, 10, 10) + + expect_output = torch.ones_like(inputs) * (-1) + rci = transforms.NormalizeOCR(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + results = {'img': inputs} + results = rci(results) + + assert np.allclose(results['img'].numpy(), expect_output.numpy()) + + +@mock.patch('%s.transforms.np.random.random' % __name__) +def test_online_crop(mock_random): + kwargs = dict( + box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'], + jitter_prob=0.5, + max_jitter_ratio_x=0.05, + max_jitter_ratio_y=0.02) + + mock_random.side_effect = [0.1, 1, 1, 1] + + src_img = np.ones((100, 100, 3), dtype=np.uint8) + results = { + 'img': src_img, + 'img_info': { + 'x1': '20', + 'y1': '20', + 'x2': '40', + 'y2': '20', + 'x3': '40', + 'y3': '40', + 'x4': '20', + 'y4': '40' + } + } + + rci = transforms.OnlineCropOCR(**kwargs) + + results = rci(results) + + assert np.allclose(results['img_shape'], [20, 20, 3]) + + # test not crop + mock_random.side_effect = [0.1, 1, 1, 1] + results['img_info'] = {} + results['img'] = src_img + + results = rci(results) + assert np.allclose(results['img'].shape, [100, 100, 3]) diff --git a/tests/test_dataset/test_parser.py b/tests/test_dataset/test_parser.py new file mode 100644 index 00000000..881ac92e --- /dev/null +++ b/tests/test_dataset/test_parser.py @@ -0,0 +1,59 @@ +import json + +import pytest + +from mmocr.datasets.utils.parser import LineJsonParser, LineStrParser + + +def test_line_str_parser(): + data_ret = ['sample1.jpg hello', 'sample2.jpg world'] + keys = ['filename', 'text'] + keys_idx = [0, 1] + separator = ' ' + + # test init + with pytest.raises(AssertionError): + parser = LineStrParser('filename', keys_idx, separator) + with pytest.raises(AssertionError): + parser = LineStrParser(keys, keys_idx, [' ']) + with pytest.raises(AssertionError): + parser = LineStrParser(keys, [0], separator) + + # test get_item + parser = LineStrParser(keys, keys_idx, separator) + assert parser.get_item(data_ret, 0) == \ + {'filename': 'sample1.jpg', 'text': 'hello'} + + with pytest.raises(Exception): + parser = LineStrParser(['filename', 'text', 'ignore'], [0, 1, 2], + separator) + parser.get_item(data_ret, 0) + + +def test_line_dict_parser(): + data_ret = [ + json.dumps({ + 'filename': 'sample1.jpg', + 'text': 'hello' + }), + json.dumps({ + 'filename': 'sample2.jpg', + 'text': 'world' + }) + ] + keys = ['filename', 'text'] + + # test init + with pytest.raises(AssertionError): + parser = LineJsonParser('filename') + with pytest.raises(AssertionError): + parser = LineJsonParser([]) + + # test get_item + parser = LineJsonParser(keys) + assert parser.get_item(data_ret, 0) == \ + {'filename': 'sample1.jpg', 'text': 'hello'} + + with pytest.raises(Exception): + parser = LineJsonParser(['img_name', 'text']) + parser.get_item(data_ret, 0) diff --git a/tests/test_dataset/test_test_time_aug.py b/tests/test_dataset/test_test_time_aug.py new file mode 100644 index 00000000..22bf80c6 --- /dev/null +++ b/tests/test_dataset/test_test_time_aug.py @@ -0,0 +1,33 @@ +import numpy as np +import pytest + +from mmocr.datasets.pipelines.test_time_aug import MultiRotateAugOCR + + +def test_resize_ocr(): + input_img1 = np.ones((64, 256, 3), dtype=np.uint8) + input_img2 = np.ones((64, 32, 3), dtype=np.uint8) + + rci = MultiRotateAugOCR(transforms=[], rotate_degrees=[0, 90, 270]) + + # test invalid arguments + with pytest.raises(AssertionError): + MultiRotateAugOCR(transforms=[], rotate_degrees=[45]) + with pytest.raises(AssertionError): + MultiRotateAugOCR(transforms=[], rotate_degrees=[20.5]) + + # test call with input_img1 + results = {'img_shape': input_img1.shape, 'img': input_img1} + results = rci(results) + assert np.allclose([64, 256, 3], results['img_shape']) + assert len(results['img']) == 1 + assert len(results['img_shape']) == 1 + assert np.allclose([64, 256, 3], results['img_shape'][0]) + + # test call with input_img2 + results = {'img_shape': input_img2.shape, 'img': input_img2} + results = rci(results) + assert np.allclose([64, 32, 3], results['img_shape']) + assert len(results['img']) == 3 + assert len(results['img_shape']) == 3 + assert np.allclose([64, 32, 3], results['img_shape'][0]) diff --git a/tests/test_models/test_label_convertor/test_attn_label_convertor.py b/tests/test_models/test_label_convertor/test_attn_label_convertor.py new file mode 100644 index 00000000..00eaeacc --- /dev/null +++ b/tests/test_models/test_label_convertor/test_attn_label_convertor.py @@ -0,0 +1,77 @@ +import os.path as osp +import tempfile + +import numpy as np +import pytest +import torch + +from mmocr.models.textrecog.convertors import AttnConvertor + + +def _create_dummy_dict_file(dict_file): + characters = list('helowrd') + with open(dict_file, 'w') as fw: + for char in characters: + fw.write(char + '\n') + + +def test_attn_label_convertor(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + dict_file = osp.join(tmp_dir.name, 'fake_dict.txt') + _create_dummy_dict_file(dict_file) + + # test invalid arguments + with pytest.raises(AssertionError): + AttnConvertor(5) + with pytest.raises(AssertionError): + AttnConvertor('DICT90', dict_file, '1') + with pytest.raises(AssertionError): + AttnConvertor('DICT90', dict_file, True, '1') + + label_convertor = AttnConvertor(dict_file=dict_file, max_seq_len=10) + # test init and parse_dict + assert label_convertor.num_classes() == 10 + assert len(label_convertor.idx2char) == 10 + assert label_convertor.idx2char[0] == 'h' + assert label_convertor.idx2char[1] == 'e' + assert label_convertor.idx2char[-3] == '' + assert label_convertor.char2idx['h'] == 0 + assert label_convertor.unknown_idx == 7 + + # test encode str to tensor + strings = ['hell'] + targets_dict = label_convertor.str2tensor(strings) + assert torch.allclose(targets_dict['targets'][0], + torch.LongTensor([0, 1, 2, 2])) + assert torch.allclose(targets_dict['padded_targets'][0], + torch.LongTensor([8, 0, 1, 2, 2, 8, 9, 9, 9, 9])) + + # test decode output to index + dummy_output = torch.Tensor([[[100, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 100, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 100, 4, 5, 6, 7, 8, 9], + [1, 2, 100, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 100], + [1, 2, 3, 4, 5, 6, 7, 100, 9], + [1, 2, 3, 4, 5, 6, 7, 100, 9], + [1, 2, 3, 4, 5, 6, 7, 100, 9], + [1, 2, 3, 4, 5, 6, 7, 100, 9], + [1, 2, 3, 4, 5, 6, 7, 100, 9]]]) + indexes, scores = label_convertor.tensor2idx(dummy_output) + assert np.allclose(indexes, [[0, 1, 2, 2]]) + + # test encode_str_label_to_index + with pytest.raises(AssertionError): + label_convertor.str2idx('hell') + tmp_indexes = label_convertor.str2idx(strings) + assert np.allclose(tmp_indexes, [[0, 1, 2, 2]]) + + # test decode_index to str_label + input_indexes = [[0, 1, 2, 2]] + with pytest.raises(AssertionError): + label_convertor.idx2str('hell') + output_strings = label_convertor.idx2str(input_indexes) + assert output_strings[0] == 'hell' + + tmp_dir.cleanup() diff --git a/tests/test_models/test_label_convertor/test_ctc_label_convertor.py b/tests/test_models/test_label_convertor/test_ctc_label_convertor.py new file mode 100644 index 00000000..07c9cbf0 --- /dev/null +++ b/tests/test_models/test_label_convertor/test_ctc_label_convertor.py @@ -0,0 +1,79 @@ +import os.path as osp +import tempfile + +import numpy as np +import pytest +import torch + +from mmocr.models.textrecog.convertors import BaseConvertor, CTCConvertor + + +def _create_dummy_dict_file(dict_file): + chars = list('helowrd') + with open(dict_file, 'w') as fw: + for char in chars: + fw.write(char + '\n') + + +def test_ctc_label_convertor(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + _create_dummy_dict_file(dict_file) + + # test invalid arguments + with pytest.raises(AssertionError): + CTCConvertor(5) + + label_convertor = CTCConvertor(dict_file=dict_file, with_unknown=False) + # test init and parse_chars + assert label_convertor.num_classes() == 8 + assert len(label_convertor.idx2char) == 8 + assert label_convertor.idx2char[0] == '' + assert label_convertor.char2idx['h'] == 1 + assert label_convertor.unknown_idx is None + + # test encode str to tensor + strings = ['hell'] + expect_tensor = torch.IntTensor([1, 2, 3, 3]) + targets_dict = label_convertor.str2tensor(strings) + assert torch.allclose(targets_dict['targets'][0], expect_tensor) + assert torch.allclose(targets_dict['flatten_targets'], expect_tensor) + assert torch.allclose(targets_dict['target_lengths'], torch.IntTensor([4])) + + # test decode output to index + dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8], + [100, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 100, 4, 5, 6, 7, 8], + [1, 2, 100, 4, 5, 6, 7, 8], + [100, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 3, 100, 5, 6, 7, 8], + [100, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 3, 100, 5, 6, 7, 8]]]) + indexes, scores = label_convertor.tensor2idx( + dummy_output, img_metas=[{ + 'valid_ratio': 1.0 + }]) + assert np.allclose(indexes, [[1, 2, 3, 3]]) + + # test encode_str_label_to_index + with pytest.raises(AssertionError): + label_convertor.str2idx('hell') + tmp_indexes = label_convertor.str2idx(strings) + assert np.allclose(tmp_indexes, [[1, 2, 3, 3]]) + + # test deocde_index_to_str_label + input_indexes = [[1, 2, 3, 3]] + with pytest.raises(AssertionError): + label_convertor.idx2str('hell') + output_strings = label_convertor.idx2str(input_indexes) + assert output_strings[0] == 'hell' + + tmp_dir.cleanup() + + +def test_base_label_convertor(): + with pytest.raises(NotImplementedError): + label_convertor = BaseConvertor() + label_convertor.str2tensor(None) + label_convertor.tensor2idx(None) diff --git a/tests/test_models/test_ocr_backbone.py b/tests/test_models/test_ocr_backbone.py new file mode 100644 index 00000000..f49a334d --- /dev/null +++ b/tests/test_models/test_ocr_backbone.py @@ -0,0 +1,36 @@ +import pytest +import torch + +from mmocr.models.textrecog.backbones import ResNet31OCR, VeryDeepVgg + + +def test_resnet31_ocr_backbone(): + """Test resnet backbone.""" + with pytest.raises(AssertionError): + ResNet31OCR(2.5) + + with pytest.raises(AssertionError): + ResNet31OCR(3, layers=5) + + with pytest.raises(AssertionError): + ResNet31OCR(3, channels=5) + + # Test ResNet18 forward + model = ResNet31OCR() + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 32, 160) + feat = model(imgs) + assert feat.shape == torch.Size([1, 512, 4, 40]) + + +def test_vgg_deep_vgg_ocr_backbone(): + + model = VeryDeepVgg() + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 32, 160) + feats = model(imgs) + assert feats.shape == torch.Size([1, 512, 1, 41]) diff --git a/tests/test_models/test_ocr_decoder.py b/tests/test_models/test_ocr_decoder.py new file mode 100644 index 00000000..9a063e26 --- /dev/null +++ b/tests/test_models/test_ocr_decoder.py @@ -0,0 +1,112 @@ +import math + +import pytest +import torch + +from mmocr.models.textrecog.decoders import (BaseDecoder, ParallelSARDecoder, + ParallelSARDecoderWithBS, + SequentialSARDecoder, TFDecoder) +from mmocr.models.textrecog.decoders.sar_decoder_with_bs import DecodeNode + + +def _create_dummy_input(): + feat = torch.rand(1, 512, 4, 40) + out_enc = torch.rand(1, 512) + tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])} + img_metas = [{'valid_ratio': 1.0}] + + return feat, out_enc, tgt_dict, img_metas + + +def test_base_decoder(): + decoder = BaseDecoder() + with pytest.raises(NotImplementedError): + decoder.forward_train(None, None, None, None) + with pytest.raises(NotImplementedError): + decoder.forward_test(None, None, None) + + +def test_parallel_sar_decoder(): + # test parallel sar decoder + decoder = ParallelSARDecoder(num_classes=37, padding_idx=36, max_seq_len=5) + decoder.init_weights() + decoder.train() + + feat, out_enc, tgt_dict, img_metas = _create_dummy_input() + with pytest.raises(AssertionError): + decoder(feat, out_enc, tgt_dict, [], True) + with pytest.raises(AssertionError): + decoder(feat, out_enc, tgt_dict, img_metas * 2, True) + + out_train = decoder(feat, out_enc, tgt_dict, img_metas, True) + assert out_train.shape == torch.Size([1, 5, 36]) + + out_test = decoder(feat, out_enc, tgt_dict, img_metas, False) + assert out_test.shape == torch.Size([1, 5, 36]) + + +def test_sequential_sar_decoder(): + # test parallel sar decoder + decoder = SequentialSARDecoder( + num_classes=37, padding_idx=36, max_seq_len=5) + decoder.init_weights() + decoder.train() + + feat, out_enc, tgt_dict, img_metas = _create_dummy_input() + with pytest.raises(AssertionError): + decoder(feat, out_enc, tgt_dict, []) + with pytest.raises(AssertionError): + decoder(feat, out_enc, tgt_dict, img_metas * 2) + + out_train = decoder(feat, out_enc, tgt_dict, img_metas, True) + assert out_train.shape == torch.Size([1, 5, 36]) + + out_test = decoder(feat, out_enc, tgt_dict, img_metas, False) + assert out_test.shape == torch.Size([1, 5, 36]) + + +def test_parallel_sar_decoder_with_beam_search(): + with pytest.raises(AssertionError): + ParallelSARDecoderWithBS(beam_width='beam') + with pytest.raises(AssertionError): + ParallelSARDecoderWithBS(beam_width=0) + + feat, out_enc, tgt_dict, img_metas = _create_dummy_input() + decoder = ParallelSARDecoderWithBS( + beam_width=1, num_classes=37, padding_idx=36, max_seq_len=5) + decoder.init_weights() + decoder.train() + with pytest.raises(AssertionError): + decoder(feat, out_enc, tgt_dict, []) + with pytest.raises(AssertionError): + decoder(feat, out_enc, tgt_dict, img_metas * 2) + + out_test = decoder(feat, out_enc, tgt_dict, img_metas, train_mode=False) + assert out_test.shape == torch.Size([1, 5, 36]) + + # test decodenode + with pytest.raises(AssertionError): + DecodeNode(1, 1) + with pytest.raises(AssertionError): + DecodeNode([1, 2], ['4', '3']) + with pytest.raises(AssertionError): + DecodeNode([1, 2], [0.5]) + decode_node = DecodeNode([1, 2], [0.7, 0.8]) + assert math.isclose(decode_node.eval(), 1.5) + + +def test_transformer_decoder(): + decoder = TFDecoder(num_classes=37, padding_idx=36, max_seq_len=5) + decoder.init_weights() + decoder.train() + + out_enc = torch.rand(1, 128, 512) + tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])} + img_metas = [{'valid_ratio': 1.0}] + tgt_dict['padded_targets'] = tgt_dict['padded_targets'] + + out_train = decoder(None, out_enc, tgt_dict, img_metas, True) + assert out_train.shape == torch.Size([1, 5, 36]) + + out_test = decoder(None, out_enc, tgt_dict, img_metas, False) + assert out_test.shape == torch.Size([1, 5, 36]) diff --git a/tests/test_models/test_ocr_encoder.py b/tests/test_models/test_ocr_encoder.py new file mode 100644 index 00000000..c5a26667 --- /dev/null +++ b/tests/test_models/test_ocr_encoder.py @@ -0,0 +1,53 @@ +import pytest +import torch + +from mmocr.models.textrecog.encoders import BaseEncoder, SAREncoder, TFEncoder + + +def test_sar_encoder(): + with pytest.raises(AssertionError): + SAREncoder(enc_bi_rnn='bi') + with pytest.raises(AssertionError): + SAREncoder(enc_do_rnn=2) + with pytest.raises(AssertionError): + SAREncoder(enc_gru='gru') + with pytest.raises(AssertionError): + SAREncoder(d_model=512.5) + with pytest.raises(AssertionError): + SAREncoder(d_enc=200.5) + with pytest.raises(AssertionError): + SAREncoder(mask='mask') + + encoder = SAREncoder() + encoder.init_weights() + encoder.train() + + feat = torch.randn(1, 512, 4, 40) + with pytest.raises(AssertionError): + encoder(feat) + img_metas = [{'valid_ratio': 1.0}] + with pytest.raises(AssertionError): + encoder(feat, img_metas * 2) + out_enc = encoder(feat, img_metas) + + assert out_enc.shape == torch.Size([1, 512]) + + +def test_transformer_encoder(): + tf_encoder = TFEncoder() + tf_encoder.init_weights() + tf_encoder.train() + + feat = torch.randn(1, 512, 4, 40) + out_enc = tf_encoder(feat) + assert out_enc.shape == torch.Size([1, 160, 512]) + + +def test_base_encoder(): + encoder = BaseEncoder() + encoder.init_weights() + encoder.train() + + feat = torch.randn(1, 256, 4, 40) + out_enc = encoder(feat) + assert out_enc.shape == torch.Size([1, 256, 4, 40]) diff --git a/tests/test_models/test_ocr_head.py b/tests/test_models/test_ocr_head.py new file mode 100644 index 00000000..52761405 --- /dev/null +++ b/tests/test_models/test_ocr_head.py @@ -0,0 +1,16 @@ +import pytest +import torch + +from mmocr.models.textrecog import SegHead + + +def test_cafcn_head(): + with pytest.raises(AssertionError): + SegHead(num_classes='100') + with pytest.raises(AssertionError): + SegHead(num_classes=-1) + + cafcn_head = SegHead(num_classes=37) + out_neck = (torch.rand(1, 128, 32, 32), ) + out_head = cafcn_head(out_neck) + assert out_head.shape == torch.Size([1, 37, 32, 32]) diff --git a/tests/test_models/test_ocr_layer.py b/tests/test_models/test_ocr_layer.py new file mode 100644 index 00000000..da96273d --- /dev/null +++ b/tests/test_models/test_ocr_layer.py @@ -0,0 +1,55 @@ +import torch + +from mmocr.models.textrecog.layers import (BasicBlock, Bottleneck, + DecoderLayer, PositionalEncoding, + get_pad_mask, get_subsequent_mask) +from mmocr.models.textrecog.layers.conv_layer import conv3x3 + + +def test_conv_layer(): + conv3by3 = conv3x3(3, 6) + assert conv3by3.in_channels == 3 + assert conv3by3.out_channels == 6 + assert conv3by3.kernel_size == (3, 3) + + x = torch.rand(1, 64, 224, 224) + # test basic block + basic_block = BasicBlock(64, 64) + assert basic_block.expansion == 1 + + out = basic_block(x) + + assert out.shape == torch.Size([1, 64, 224, 224]) + + # test bottle neck + bottle_neck = Bottleneck(64, 64, downsample=True) + assert bottle_neck.expansion == 4 + + out = bottle_neck(x) + + assert out.shape == torch.Size([1, 256, 224, 224]) + + +def test_transformer_layer(): + # test decoder_layer + decoder_layer = DecoderLayer() + in_dec = torch.rand(1, 30, 512) + out_enc = torch.rand(1, 128, 512) + out_dec = decoder_layer(in_dec, out_enc) + assert out_dec.shape == torch.Size([1, 30, 512]) + + # test positional_encoding + pos_encoder = PositionalEncoding() + x = torch.rand(1, 30, 512) + out = pos_encoder(x) + assert out.size() == x.size() + + # test get pad mask + seq = torch.rand(1, 30) + pad_idx = 0 + out = get_pad_mask(seq, pad_idx) + assert out.shape == torch.Size([1, 1, 30]) + + # test get_subsequent_mask + out_mask = get_subsequent_mask(seq) + assert out_mask.shape == torch.Size([1, 30, 30]) diff --git a/tests/test_models/test_ocr_loss.py b/tests/test_models/test_ocr_loss.py new file mode 100644 index 00000000..fa118f5e --- /dev/null +++ b/tests/test_models/test_ocr_loss.py @@ -0,0 +1,123 @@ +import numpy as np +import pytest +import torch + +from mmdet.core import BitmapMasks +from mmocr.models.common.losses import DiceLoss +from mmocr.models.textrecog.losses import (CAFCNLoss, CELoss, CTCLoss, SARLoss, + TFLoss) + + +def test_ctc_loss(): + # test CTCLoss + ctc_loss = CTCLoss() + outputs = torch.zeros(2, 40, 37) + targets_dict = { + 'flatten_targets': torch.IntTensor([1, 2, 3, 4, 5]), + 'target_lengths': torch.LongTensor([2, 3]) + } + + losses = ctc_loss(outputs, targets_dict) + assert isinstance(losses, dict) + assert 'loss_ctc' in losses + assert torch.allclose(losses['loss_ctc'], + torch.tensor(losses['loss_ctc'].item()).float()) + + +def test_ce_loss(): + with pytest.raises(AssertionError): + CELoss(ignore_index='ignore') + with pytest.raises(AssertionError): + CELoss(reduction=1) + with pytest.raises(AssertionError): + CELoss(reduction='avg') + + ce_loss = CELoss(ignore_index=0) + outputs = torch.rand(1, 10, 37) + targets_dict = { + 'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]]) + } + losses = ce_loss(outputs, targets_dict) + assert isinstance(losses, dict) + assert 'loss_ce' in losses + print(losses['loss_ce'].size()) + assert losses['loss_ce'].size(1) == 10 + + +def test_sar_loss(): + outputs = torch.rand(1, 10, 37) + targets_dict = { + 'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]]) + } + sar_loss = SARLoss() + new_output, new_target = sar_loss.format(outputs, targets_dict) + assert new_output.shape == torch.Size([1, 37, 9]) + assert new_target.shape == torch.Size([1, 9]) + + +def test_tf_loss(): + with pytest.raises(AssertionError): + TFLoss(flatten=1.0) + + outputs = torch.rand(1, 10, 37) + targets_dict = { + 'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]]) + } + tf_loss = TFLoss(flatten=False) + new_output, new_target = tf_loss.format(outputs, targets_dict) + assert new_output.shape == torch.Size([1, 37, 9]) + assert new_target.shape == torch.Size([1, 9]) + + +def test_cafcn_loss(): + with pytest.raises(AssertionError): + CAFCNLoss(alpha='1') + with pytest.raises(AssertionError): + CAFCNLoss(attn_s2_downsample_ratio='2') + with pytest.raises(AssertionError): + CAFCNLoss(attn_s3_downsample_ratio='1.5') + with pytest.raises(AssertionError): + CAFCNLoss(seg_downsample_ratio='1.5') + with pytest.raises(AssertionError): + CAFCNLoss(attn_s2_downsample_ratio=2) + with pytest.raises(AssertionError): + CAFCNLoss(attn_s3_downsample_ratio=1.5) + with pytest.raises(AssertionError): + CAFCNLoss(seg_downsample_ratio=1.5) + + bsz = 1 + H = W = 64 + out_neck = (torch.ones(bsz, 1, H // 4, W // 4) * 0.5, + torch.ones(bsz, 1, H // 8, W // 8) * 0.5, + torch.ones(bsz, 1, H // 8, W // 8) * 0.5, + torch.ones(bsz, 1, H // 8, W // 8) * 0.5, + torch.ones(bsz, 1, H // 2, W // 2) * 0.5) + out_head = torch.rand(bsz, 37, H // 2, W // 2) + + attn_tgt = np.zeros((H, W), dtype=np.float32) + segm_tgt = np.zeros((H, W), dtype=np.float32) + mask = np.ones((H, W), dtype=np.float32) + gt_kernels = BitmapMasks([attn_tgt, segm_tgt, mask], H, W) + + cafcn_loss = CAFCNLoss() + losses = cafcn_loss(out_neck, out_head, [gt_kernels]) + assert isinstance(losses, dict) + assert 'loss_seg' in losses + assert torch.allclose(losses['loss_seg'], + torch.tensor(losses['loss_seg'].item()).float()) + + +def test_dice_loss(): + with pytest.raises(AssertionError): + DiceLoss(eps='1') + + dice_loss = DiceLoss() + pred = torch.rand(1, 1, 32, 32) + gt = torch.rand(1, 1, 32, 32) + + loss = dice_loss(pred, gt, None) + assert isinstance(loss, torch.Tensor) + + mask = torch.rand(1, 1, 1, 1) + loss = dice_loss(pred, gt, mask) + assert isinstance(loss, torch.Tensor) diff --git a/tests/test_models/test_ocr_neck.py b/tests/test_models/test_ocr_neck.py new file mode 100644 index 00000000..8af3e971 --- /dev/null +++ b/tests/test_models/test_ocr_neck.py @@ -0,0 +1,48 @@ +import pytest +import torch + +from mmocr.models.textrecog.necks.cafcn_neck import (CAFCNNeck, CharAttn, + FeatGenerator) + + +def test_char_attn(): + with pytest.raises(AssertionError): + CharAttn(in_channels=5.0) + with pytest.raises(AssertionError): + CharAttn(deformable='deformabel') + + in_feat = torch.rand(1, 128, 32, 32) + char_attn = CharAttn() + out_feat_map, attn_map = char_attn(in_feat) + assert attn_map.shape == torch.Size([1, 1, 32, 32]) + assert out_feat_map.shape == torch.Size([1, 128, 32, 32]) + + +@pytest.mark.skip(reason='TODO: re-enable after CI support pytorch>1.4') +def test_feat_generator(): + in_feat = torch.rand(1, 128, 32, 32) + feat_generator = FeatGenerator(in_channels=128, out_channels=128) + + attn_map, feat_map = feat_generator(in_feat) + assert attn_map.shape == torch.Size([1, 1, 32, 32]) + assert feat_map.shape == torch.Size([1, 128, 32, 32]) + + +@pytest.mark.skip(reason='TODO: re-enable after CI support pytorch>1.4') +def test_cafcn_neck(): + in_s1 = torch.rand(1, 64, 64, 64) + in_s2 = torch.rand(1, 128, 32, 32) + in_s3 = torch.rand(1, 256, 16, 16) + in_s4 = torch.rand(1, 512, 16, 16) + in_s5 = torch.rand(1, 512, 16, 16) + + cafcn_neck = CAFCNNeck() + cafcn_neck.init_weights() + cafcn_neck.train() + + out_neck = cafcn_neck((in_s1, in_s2, in_s3, in_s4, in_s5)) + assert out_neck[0].shape == torch.Size([1, 1, 32, 32]) + assert out_neck[1].shape == torch.Size([1, 1, 16, 16]) + assert out_neck[2].shape == torch.Size([1, 1, 16, 16]) + assert out_neck[3].shape == torch.Size([1, 1, 16, 16]) + assert out_neck[4].shape == torch.Size([1, 128, 64, 64]) diff --git a/tests/test_models/test_recognizer.py b/tests/test_models/test_recognizer.py new file mode 100644 index 00000000..54f66ca6 --- /dev/null +++ b/tests/test_models/test_recognizer.py @@ -0,0 +1,156 @@ +import os.path as osp +import tempfile + +import numpy as np +import pytest +import torch + +from mmdet.core import BitmapMasks +from mmocr.models.textrecog.recognizer import (EncodeDecodeRecognizer, + SegRecognizer) + + +def _create_dummy_dict_file(dict_file): + chars = list('helowrd') + with open(dict_file, 'w') as fw: + for char in chars: + fw.write(char + '\n') + + +def test_base_recognizer(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + _create_dummy_dict_file(dict_file) + + label_convertor = dict( + type='CTCConvertor', dict_file=dict_file, with_unknown=False) + + preprocessor = None + backbone = dict(type='VeryDeepVgg', leakyRelu=False) + encoder = None + decoder = dict(type='CRNNDecoder', in_channels=512, rnn_flag=True) + loss = dict(type='CTCLoss') + + with pytest.raises(AssertionError): + EncodeDecodeRecognizer(backbone=None) + with pytest.raises(AssertionError): + EncodeDecodeRecognizer(decoder=None) + with pytest.raises(AssertionError): + EncodeDecodeRecognizer(loss=None) + with pytest.raises(AssertionError): + EncodeDecodeRecognizer(label_convertor=None) + + recognizer = EncodeDecodeRecognizer( + preprocessor=preprocessor, + backbone=backbone, + encoder=encoder, + decoder=decoder, + loss=loss, + label_convertor=label_convertor) + + recognizer.init_weights() + recognizer.train() + + imgs = torch.rand(1, 3, 32, 160) + + # test extract feat + feat = recognizer.extract_feat(imgs) + assert feat.shape == torch.Size([1, 512, 1, 41]) + + # test forward train + img_metas = [{'text': 'hello', 'valid_ratio': 1.0}] + losses = recognizer.forward_train(imgs, img_metas) + assert isinstance(losses, dict) + assert 'loss_ctc' in losses + + # test simple test + results = recognizer.simple_test(imgs, img_metas) + assert isinstance(results, list) + assert isinstance(results[0], dict) + assert 'text' in results[0] + assert 'score' in results[0] + + # test aug_test + aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas]) + assert isinstance(aug_results, list) + assert isinstance(aug_results[0], dict) + assert 'text' in aug_results[0] + assert 'score' in aug_results[0] + + tmp_dir.cleanup() + + +@pytest.mark.skip(reason='TODO: re-enable after CI support pytorch>1.4') +def test_seg_recognizer(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + _create_dummy_dict_file(dict_file) + + label_convertor = dict( + type='SegConvertor', dict_file=dict_file, with_unknown=False) + + preprocessor = None + backbone = dict(type='ResNet31OCR') + neck = dict(type='FPNOCR') + head = dict(type='SegHead') + loss = dict(type='SegLoss') + + with pytest.raises(AssertionError): + SegRecognizer(backbone=None) + with pytest.raises(AssertionError): + SegRecognizer(neck=None) + with pytest.raises(AssertionError): + SegRecognizer(head=None) + with pytest.raises(AssertionError): + SegRecognizer(loss=None) + with pytest.raises(AssertionError): + SegRecognizer(label_convertor=None) + + recognizer = SegRecognizer( + preprocessor=preprocessor, + backbone=backbone, + neck=neck, + head=head, + loss=loss, + label_convertor=label_convertor) + + recognizer.init_weights() + recognizer.train() + + imgs = torch.rand(1, 3, 64, 256) + + # test extract feat + feats = recognizer.extract_feat(imgs) + assert len(feats) == 5 + assert feats[0].shape == torch.Size([1, 64, 32, 128]) + assert feats[1].shape == torch.Size([1, 128, 16, 64]) + assert feats[2].shape == torch.Size([1, 256, 8, 32]) + assert feats[3].shape == torch.Size([1, 512, 8, 32]) + assert feats[4].shape == torch.Size([1, 512, 8, 32]) + + attn_tgt = np.zeros((64, 256), dtype=np.float32) + segm_tgt = np.zeros((64, 256), dtype=np.float32) + gt_kernels = BitmapMasks([attn_tgt, segm_tgt], 64, 256) + + # test forward train + img_metas = [{'text': 'hello', 'valid_ratio': 1.0}] + losses = recognizer.forward_train(imgs, img_metas, gt_kernels=[gt_kernels]) + assert isinstance(losses, dict) + + # test simple test + results = recognizer.simple_test(imgs, img_metas) + assert isinstance(results, list) + assert isinstance(results[0], dict) + assert 'text' in results[0] + assert 'score' in results[0] + + # test aug_test + aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas]) + assert isinstance(aug_results, list) + assert isinstance(aug_results[0], dict) + assert 'text' in aug_results[0] + assert 'score' in aug_results[0] + + tmp_dir.cleanup() diff --git a/tests/test_utils/test_text/test_text_utils.py b/tests/test_utils/test_text/test_text_utils.py new file mode 100644 index 00000000..a1d4a50f --- /dev/null +++ b/tests/test_utils/test_text/test_text_utils.py @@ -0,0 +1,66 @@ +"""Test text label visualize.""" +import os.path as osp +import random +import tempfile +from unittest import mock + +import numpy as np +import pytest + +import mmocr.core.visualize as visualize_utils + + +def test_tile_image(): + dummp_imgs, heights, widths = [], [], [] + for _ in range(3): + h = random.randint(100, 300) + w = random.randint(100, 300) + heights.append(h) + widths.append(w) + # dummy_img = Image.new('RGB', (w, h), Image.ANTIALIAS) + dummy_img = np.ones((h, w, 3), dtype=np.uint8) + dummp_imgs.append(dummy_img) + joint_img = visualize_utils.tile_image(dummp_imgs) + assert joint_img.shape[0] == sum(heights) + assert joint_img.shape[1] == max(widths) + + # test invalid arguments + with pytest.raises(AssertionError): + visualize_utils.tile_image(dummp_imgs[0]) + with pytest.raises(AssertionError): + visualize_utils.tile_image([]) + + +@mock.patch('%s.visualize_utils.mmcv.imread' % __name__) +@mock.patch('%s.visualize_utils.mmcv.imshow' % __name__) +@mock.patch('%s.visualize_utils.mmcv.imwrite' % __name__) +def test_show_text_label(mock_imwrite, mock_imshow, mock_imread): + img = np.ones((32, 160), dtype=np.uint8) + pred_label = 'hello' + gt_label = 'world' + + tmp_dir = tempfile.TemporaryDirectory() + out_file = osp.join(tmp_dir.name, 'tmp.jpg') + + # test invalid arguments + with pytest.raises(AssertionError): + visualize_utils.imshow_text_label(5, pred_label, gt_label) + with pytest.raises(AssertionError): + visualize_utils.imshow_text_label(img, pred_label, 4) + with pytest.raises(AssertionError): + visualize_utils.imshow_text_label(img, 3, gt_label) + with pytest.raises(AssertionError): + visualize_utils.imshow_text_label( + img, pred_label, gt_label, show=True, wait_time=0.1) + + mock_imread.side_effect = [img, img] + visualize_utils.imshow_text_label( + img, pred_label, gt_label, out_file=out_file) + visualize_utils.imshow_text_label( + img, pred_label, gt_label, out_file=None, show=True) + + # test showing img + mock_imshow.assert_called_once() + mock_imwrite.assert_called_once() + + tmp_dir.cleanup() diff --git a/tools/data/textrecog/seg_synthtext_converter.py b/tools/data/textrecog/seg_synthtext_converter.py new file mode 100644 index 00000000..64ed9701 --- /dev/null +++ b/tools/data/textrecog/seg_synthtext_converter.py @@ -0,0 +1,93 @@ +import argparse +import codecs +import json +import os.path as osp + +import cv2 + + +def read_json(fpath): + with codecs.open(fpath, 'r', 'utf-8') as f: + obj = json.load(f) + return obj + + +def parse_old_label(img_prefix, in_path): + imgid2imgname = {} + imgid2anno = {} + idx = 0 + with open(in_path, 'r') as fr: + for line in fr: + line = line.strip().split() + img_full_path = osp.join(img_prefix, line[0]) + if not osp.exists(img_full_path): + continue + img = cv2.imread(img_full_path) + h, w = img.shape[:2] + img_info = {} + img_info['file_name'] = line[0] + img_info['height'] = h + img_info['width'] = w + imgid2imgname[idx] = img_info + imgid2anno[idx] = [] + for i in range(len(line[1:]) // 8): + seg = [int(x) for x in line[(1 + i * 8):(1 + (i + 1) * 8)]] + points_x = seg[0:2:8] + points_y = seg[1:2:9] + box = [ + min(points_x), + min(points_y), + max(points_x), + max(points_y) + ] + new_anno = {} + new_anno['iscrowd'] = 0 + new_anno['category_id'] = 1 + new_anno['bbox'] = box + new_anno['segmentation'] = [seg] + imgid2anno[idx].append(new_anno) + idx += 1 + + return imgid2imgname, imgid2anno + + +def gen_line_dict_file(out_path, imgid2imgname, imgid2anno): + # import pdb; pdb.set_trace() + with codecs.open(out_path, 'w', 'utf-8') as fw: + for key, value in imgid2imgname.items(): + if key in imgid2anno: + anno = imgid2anno[key] + line_dict = {} + line_dict['file_name'] = value['file_name'] + line_dict['height'] = value['height'] + line_dict['width'] = value['width'] + line_dict['annotations'] = anno + line_dict_str = json.dumps(line_dict) + fw.write(line_dict_str + '\n') + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--img-prefix', + help='image prefix, to generate full image path with "image_name"') + parser.add_argument( + '--in-path', + help='mapping file of image_name and ann_file,' + ' "image_name ann_file" in each line') + parser.add_argument( + '--out-path', help='output txt path with line-json format') + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + imgid2imgname, imgid2anno = parse_old_label(args.img_prefix, args.in_path) + gen_line_dict_file(args.out_path, imgid2imgname, imgid2anno) + print('finish') + + +if __name__ == '__main__': + main() diff --git a/tools/data/utils/txt2lmdb.py b/tools/data/utils/txt2lmdb.py new file mode 100644 index 00000000..e7554ccf --- /dev/null +++ b/tools/data/utils/txt2lmdb.py @@ -0,0 +1,74 @@ +import argparse +import shutil +import sys +import time +from pathlib import Path + +import lmdb + + +def converter(imglist, output, batch_size=1000, coding='utf-8'): + # read imglist + with open(imglist) as f: + lines = f.readlines() + + # create lmdb database + if Path(output).is_dir(): + while True: + print('%s already exist, delete or not? [Y/n]' % output) + Yn = input().strip() + if Yn in ['Y', 'y']: + shutil.rmtree(output) + break + elif Yn in ['N', 'n']: + return + print('create database %s' % output) + Path(output).mkdir(parents=True, exist_ok=False) + env = lmdb.open(output, map_size=1099511627776) + + # build lmdb + beg_time = time.strftime('%H:%M:%S') + for beg_index in range(0, len(lines), batch_size): + end_index = min(beg_index + batch_size, len(lines)) + sys.stdout.write('\r[%s-%s], processing [%d-%d] / %d' % + (beg_time, time.strftime('%H:%M:%S'), beg_index, + end_index, len(lines))) + sys.stdout.flush() + batch = [(str(index).encode(coding), lines[index].encode(coding)) + for index in range(beg_index, end_index)] + with env.begin(write=True) as txn: + cursor = txn.cursor() + cursor.putmulti(batch, dupdata=False, overwrite=True) + sys.stdout.write('\n') + with env.begin(write=True) as txn: + key = 'total_number'.encode(coding) + value = str(len(lines)).encode(coding) + txn.put(key, value) + print('done', flush=True) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--imglist', '-i', required=True, help='input imglist path') + parser.add_argument( + '--output', '-o', required=True, help='output lmdb path') + parser.add_argument( + '--batch_size', + '-b', + type=int, + default=10000, + help='processing batch size, default 10000') + parser.add_argument( + '--coding', + '-c', + default='utf8', + help='bytes coding scheme, default utf8') + opt = parser.parse_args() + + converter( + opt.imglist, opt.output, batch_size=opt.batch_size, coding=opt.coding) + + +if __name__ == '__main__': + main() diff --git a/tools/ocr_test_imgs.py b/tools/ocr_test_imgs.py new file mode 100644 index 00000000..a2d642b1 --- /dev/null +++ b/tools/ocr_test_imgs.py @@ -0,0 +1,132 @@ +import os.path as osp +import shutil +import time +from argparse import ArgumentParser + +import mmcv +import torch +from mmcv.utils import ProgressBar + +from mmdet.apis import init_detector +from mmdet.utils import get_root_logger +from mmocr.apis import model_inference +from mmocr.core.evaluation.ocr_metric import eval_ocr_metric +from mmocr.datasets import build_dataset # noqa: F401 +from mmocr.models import build_detector # noqa: F401 + + +def save_results(img_paths, pred_labels, gt_labels, res_dir): + """Save predicted results to txt file. + + Args: + img_paths (list[str]) + pred_labels (list[str]) + gt_labels (list[str]) + res_dir (str) + """ + assert len(img_paths) == len(pred_labels) == len(gt_labels) + res_file = osp.join(res_dir, 'results.txt') + correct_file = osp.join(res_dir, 'correct.txt') + wrong_file = osp.join(res_dir, 'wrong.txt') + with open(res_file, 'w') as fw, \ + open(correct_file, 'w') as fw_correct, \ + open(wrong_file, 'w') as fw_wrong: + for img_path, pred_label, gt_label in zip(img_paths, pred_labels, + gt_labels): + fw.write(img_path + ' ' + pred_label + ' ' + gt_label + '\n') + if pred_label == gt_label: + fw_correct.write(img_path + ' ' + pred_label + ' ' + gt_label + + '\n') + else: + fw_wrong.write(img_path + ' ' + pred_label + ' ' + gt_label + + '\n') + + +def main(): + parser = ArgumentParser() + parser.add_argument('--img_root_path', type=str, help='Image root path') + parser.add_argument('--img_list', type=str, help='Image path list file') + parser.add_argument('--config', type=str, help='Config file') + parser.add_argument('--checkpoint', type=str, help='Checkpoint file') + parser.add_argument( + '--out_dir', type=str, default='./results', help='Dir to save results') + parser.add_argument( + '--show', action='store_true', help='show image or save') + args = parser.parse_args() + + # init the logger before other steps + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + log_file = osp.join(args.out_dir, f'{timestamp}.log') + logger = get_root_logger(log_file=log_file, log_level='INFO') + + # build the model from a config file and a checkpoint file + device = 'cuda:' + str(torch.cuda.current_device()) + model = init_detector(args.config, args.checkpoint, device=device) + if hasattr(model, 'module'): + model = model.module + if model.cfg.data.test['type'] == 'ConcatDataset': + model.cfg.data.test.pipeline = \ + model.cfg.data.test['datasets'][0].pipeline + + # Start Inference + out_vis_dir = osp.join(args.out_dir, 'out_vis_dir') + mmcv.mkdir_or_exist(out_vis_dir) + correct_vis_dir = osp.join(args.out_dir, 'correct') + mmcv.mkdir_or_exist(correct_vis_dir) + wrong_vis_dir = osp.join(args.out_dir, 'wrong') + mmcv.mkdir_or_exist(wrong_vis_dir) + img_paths, pred_labels, gt_labels = [], [], [] + total_img_num = sum([1 for _ in open(args.img_list)]) + progressbar = ProgressBar(task_num=total_img_num) + num_gt_label = 0 + with open(args.img_list, 'r') as fr: + for line in fr: + progressbar.update() + item_list = line.strip().split() + img_file = item_list[0] + gt_label = '' + if len(item_list) >= 2: + gt_label = item_list[1] + num_gt_label += 1 + img_path = osp.join(args.img_root_path, img_file) + if not osp.exists(img_path): + raise FileNotFoundError(img_path) + # Test a single image + result = model_inference(model, img_path) + pred_label = result['text'] + + out_img_name = '_'.join(img_file.split('/')) + out_file = osp.join(out_vis_dir, out_img_name) + kwargs_dict = { + 'gt_label': gt_label, + 'show': args.show, + 'out_file': '' if args.show else out_file + } + model.show_result(img_path, result, **kwargs_dict) + if gt_label != '': + if gt_label == pred_label: + dst_file = osp.join(correct_vis_dir, out_img_name) + else: + dst_file = osp.join(wrong_vis_dir, out_img_name) + shutil.copy(out_file, dst_file) + img_paths.append(img_path) + gt_labels.append(gt_label) + pred_labels.append(pred_label) + + # Save results + save_results(img_paths, pred_labels, gt_labels, args.out_dir) + + if num_gt_label == len(pred_labels): + # eval + eval_results = eval_ocr_metric(pred_labels, gt_labels) + logger.info('\n' + '-' * 100) + info = 'eval on testset with img_root_path ' + \ + f'{args.img_root_path} and img_list {args.img_list}\n' + logger.info(info) + logger.info(eval_results) + + print(f'\nInference done, and results saved in {args.out_dir}\n') + + +if __name__ == '__main__': + main() diff --git a/tools/ocr_test_imgs.sh b/tools/ocr_test_imgs.sh new file mode 100644 index 00000000..69d719a1 --- /dev/null +++ b/tools/ocr_test_imgs.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +DATE=`date +%Y-%m-%d` +TIME=`date +"%H-%M-%S"` + +if [ $# -lt 5 ] +then + echo "Usage: bash $0 CONFIG CHECKPOINT IMG_PREFIX IMG_LIST RESULTS_DIR" + exit +fi + +CONFIG_FILE=$1 +CHECKPOINT=$2 +IMG_ROOT_PATH=$3 +IMG_LIST=$4 +OUT_DIR=$5_${DATE}_${TIME} + +mkdir ${OUT_DIR} -p && + +python tools/ocr_test_imgs.py \ + --img_root_path ${IMG_ROOT_PATH} \ + --img_list ${IMG_LIST} \ + --config ${CONFIG_FILE} \ + --checkpoint ${CHECKPOINT} \ + --out_dir ${OUT_DIR}