[Feature] Rec TTA (#1401)

* Support TTA for recognition

* updata readme

* updata abinet readme

* updata train_test doc for tta
pull/1731/head
liukuikun 2023-02-16 10:27:07 +08:00 committed by GitHub
parent 7cea6a6419
commit f820470415
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 809 additions and 107 deletions

View File

@ -46,3 +46,5 @@ visualizer = dict(
type='TextRecogLocalVisualizer',
name='visualizer',
vis_backends=vis_backends)
tta_model = dict(type='EncoderDecoderRecognizerTTAModel')

View File

@ -38,7 +38,9 @@ Linguistic knowledge is of great benefit to scene text recognition. However, how
| :--------------------------------------------: | :------------------------------------------------: | :----: | :----------: | :-------: | :-------: | :------------: | :----: | :----------------------------------------------- |
| | | IIIT5K | SVT | IC13-1015 | IC15-2077 | SVTP | CT80 | |
| [ABINet-Vision](/configs/textrecog/abinet/abinet-vision_20e_st-an_mj.py) | - | 0.9523 | 0.9196 | 0.9369 | 0.7896 | 0.8403 | 0.8437 | [model](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet-vision_20e_st-an_mj/abinet-vision_20e_st-an_mj_20220915_152445-85cfb03d.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet-vision_20e_st-an_mj/20220915_152445.log) |
| [ABINet-Vision-TTA](/configs/textrecog/abinet/abinet-vision_20e_st-an_mj.py) | - | 0.9523 | 0.9196 | 0.9360 | 0.8175 | 0.8450 | 0.8542 | |
| [ABINet](/configs/textrecog/abinet/abinet_20e_st-an_mj.py) | [Pretrained](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_pretrain-45deac15.pth) | 0.9603 | 0.9397 | 0.9557 | 0.8146 | 0.8868 | 0.8785 | [model](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_20e_st-an_mj/abinet_20e_st-an_mj_20221005_012617-ead8c139.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_20e_st-an_mj/20221005_012617.log) |
| [ABINet-TTA](/configs/textrecog/abinet/abinet_20e_st-an_mj.py) | [Pretrained](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_pretrain-45deac15.pth) | 0.9597 | 0.9397 | 0.9527 | 0.8426 | 0.8930 | 0.8854 | |
```{note}
1. ABINet allows its encoder to run and be trained without decoder and fuser. Its encoder is designed to recognize texts as a stand-alone model and therefore can work as an independent text recognizer. We release it as ABINet-Vision.

View File

@ -116,3 +116,50 @@ test_pipeline = [
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[dict(type='Resize', scale=(128, 32))],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]

View File

@ -35,9 +35,10 @@ A challenging aspect of scene text recognition is to handle text with distortion
## Results and models
| Methods | Backbone | | Regular Text | | | | Irregular Text | | download |
| :----------------------------------------------------------: | :------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-----------------------------------------------------------------------: |
| :--------------------------------------------------------------: | :------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-------------------------------------------------------------------: |
| | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
| [ASTER](/configs/textrecog/aster/aster_resnet45_6e_st_mj.py) | ResNet45 | 0.9357 | 0.8949 | 0.9281 | | 0.7665 | 0.8062 | 0.8507 | [model](https://download.openmmlab.com/mmocr/textrecog/aster/aster_resnet45_6e_st_mj/aster_resnet45_6e_st_mj-cc56eca4.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/aster/aster_resnet45_6e_st_mj/20221214_232605.log) |
| [ASTER-TTA](/configs/textrecog/aster/aster_resnet45_6e_st_mj.py) | ResNet45 | 0.9337 | 0.8949 | 0.9251 | | 0.7925 | 0.8109 | 0.8507 | |
## Citation

View File

@ -69,3 +69,42 @@ test_pipeline = [
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio',
'instances'))
]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"),
], [dict(type='Resize', scale=(256, 64))],
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio', 'instances'))
]])
]

View File

@ -34,9 +34,10 @@ Image-based sequence recognition has been a long-standing research topic in comp
## Results and models
| methods | | Regular Text | | | | Irregular Text | | download |
| :----------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-------------------------------------------------------------------------------------: |
| :--------------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :---------------------------------------------------------------------------------: |
| methods | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
| [CRNN](/configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py) | 0.8053 | 0.7991 | 0.8739 | | 0.5571 | 0.6093 | 0.5694 | [model](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/20220826_224120.log) |
| [CRNN-TTA](/configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py) | 0.8013 | 0.7975 | 0.8631 | | 0.5763 | 0.6093 | 0.5764 | [model](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/20220826_224120.log) |
## Citation

View File

@ -51,3 +51,60 @@ test_pipeline = [
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
tta_pipeline = [
dict(
type='LoadImageFromFile',
color_type='grayscale',
file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[
dict(
type='RescaleToHeight',
height=32,
min_width=32,
max_width=None,
width_divisor=16)
],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]

View File

@ -39,6 +39,7 @@ Attention-based scene text recognizers have gained huge success, which leverages
| :-------------------------------------------------------------: | :-----------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :---------------------------------------------------------------: |
| | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
| [MASTER](/configs/textrecog/master/master_resnet31_12e_st_mj_sa.py) | R31-GCAModule | 0.9490 | 0.8887 | 0.9517 | | 0.7650 | 0.8465 | 0.8889 | [model](https://download.openmmlab.com/mmocr/textrecog/master/master_resnet31_12e_st_mj_sa/master_resnet31_12e_st_mj_sa_20220915_152443-f4a5cabc.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/master/master_resnet31_12e_st_mj_sa/20220915_152443.log) |
| [MASTER-TTA](/configs/textrecog/master/master_resnet31_12e_st_mj_sa.py) | R31-GCAModule | 0.9450 | 0.8887 | 0.9478 | | 0.7906 | 0.8481 | 0.8958 | |
## Citation

View File

@ -109,3 +109,58 @@ test_pipeline = [
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[
dict(
type='RescaleToHeight',
height=48,
min_width=48,
max_width=160,
width_divisor=16)
],
[dict(type='PadToWidth', width=160)],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]

View File

@ -38,8 +38,11 @@ Scene text recognition has attracted a great many researches due to its importan
| :---------------------------------------------------------: | :-------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-----------------------------------------------------------: |
| | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
| [NRTR](/configs/textrecog/nrtr/nrtr_modality-transform_6e_st_mj.py) | NRTRModalityTransform | 0.9147 | 0.8841 | 0.9369 | | 0.7246 | 0.7783 | 0.7500 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_modality-transform_6e_st_mj/nrtr_modality-transform_6e_st_mj_20220916_103322-bd9425be.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_modality-transform_6e_st_mj/20220916_103322.log) |
| [NRTR-TTA](/configs/textrecog/nrtr/nrtr_modality-transform_6e_st_mj.py) | NRTRModalityTransform | 0.9123 | 0.8825 | 0.9310 | | 0.7492 | 0.7798 | 0.7535 | |
| [NRTR](/configs/textrecog/nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj.py) | R31-1/8-1/4 | 0.9483 | 0.8918 | 0.9507 | | 0.7578 | 0.8016 | 0.8889 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj/nrtr_resnet31-1by8-1by4_6e_st_mj_20220916_103322-a6a2a123.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj/20220916_103322.log) |
| [NRTR-TTA](/configs/textrecog/nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj.py) | R31-1/8-1/4 | 0.9443 | 0.8903 | 0.9478 | | 0.7790 | 0.8078 | 0.8854 | |
| [NRTR](/configs/textrecog/nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj.py) | R31-1/16-1/8 | 0.9470 | 0.8918 | 0.9399 | | 0.7376 | 0.7969 | 0.8854 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj/nrtr_resnet31-1by16-1by8_6e_st_mj_20220920_143358-43767036.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj/20220920_143358.log) |
| [NRTR-TTA](/configs/textrecog/nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj.py) | R31-1/16-1/8 | 0.9423 | 0.8903 | 0.9360 | | 0.7641 | 0.8016 | 0.8854 | |
## Citation

View File

@ -60,3 +60,58 @@ test_pipeline = [
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[
dict(
type='RescaleToHeight',
height=32,
min_width=32,
max_width=160,
width_divisor=16)
],
[dict(type='PadToWidth', width=160)],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]

View File

@ -66,3 +66,58 @@ test_pipeline = [
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[
dict(
type='RescaleToHeight',
height=32,
min_width=32,
max_width=160,
width_divisor=16)
],
[dict(type='PadToWidth', width=160)],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]

View File

@ -44,6 +44,7 @@ The attention-based encoder-decoder framework has recently achieved impressive r
| :------------------------------------------------------------------: | :--: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-------------------------------------------------------------------: |
| | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
| [RobustScanner](/configs/textrecog/robust_scanner/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real.py) | 4 | 0.9510 | 0.9011 | 0.9320 | | 0.7578 | 0.8078 | 0.8750 | [model](https://download.openmmlab.com/mmocr/textrecog/robust_scanner/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real_20220915_152447-7fc35929.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/robust_scanner/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real/20220915_152447.log) |
| [RobustScanner-TTA](/configs/textrecog/robust_scanner/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real.py) | 4 | 0.9487 | 0.9011 | 0.9261 | | 0.7805 | 0.8124 | 0.8819 | |
## References

View File

@ -66,3 +66,58 @@ test_pipeline = [
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[
dict(
type='RescaleToHeight',
height=48,
min_width=48,
max_width=160,
width_divisor=4),
],
[dict(type='PadToWidth', width=160)],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]

View File

@ -44,7 +44,9 @@ Recognizing irregular text in natural scene images is challenging due to the lar
| :----------------------------------------------------: | :---------: | :------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :------------------------------------------------------: |
| | | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
| [SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py) | R31-1/8-1/4 | ParallelSARDecoder | 0.9533 | 0.8964 | 0.9369 | | 0.7602 | 0.8326 | 0.9062 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real/sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real_20220915_171910-04eb4e75.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real/20220915_171910.log) |
| [SAR-TTA](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py) | R31-1/8-1/4 | ParallelSARDecoder | 0.9510 | 0.8964 | 0.9340 | | 0.7862 | 0.8372 | 0.9132 | |
| [SAR](/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py) | R31-1/8-1/4 | SequentialSARDecoder | 0.9553 | 0.9073 | 0.9409 | | 0.7761 | 0.8093 | 0.8958 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_sequential-decoder_5e_st-sub_mj-sub_sa_real/sar_resnet31_sequential-decoder_5e_st-sub_mj-sub_sa_real_20220915_185451-1fd6b1fc.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_sequential-decoder_5e_st-sub_mj-sub_sa_real/20220915_185451.log) |
| [SAR-TTA](/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py) | R31-1/8-1/4 | SequentialSARDecoder | 0.9530 | 0.9073 | 0.9389 | | 0.8002 | 0.8124 | 0.9028 | |
## Citation

View File

@ -71,3 +71,58 @@ test_pipeline = [
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[
dict(
type='RescaleToHeight',
height=48,
min_width=48,
max_width=160,
width_divisor=4)
],
[dict(type='PadToWidth', width=160)],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]

View File

@ -38,7 +38,9 @@ Scene text recognition (STR) is the task of recognizing character sequences in n
| :--------------------------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :---------------------------------------------------------------------: |
| | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
| [Satrn](/configs/textrecog/satrn/satrn_shallow_5e_st_mj.py) | 0.9600 | 0.9181 | 0.9606 | | 0.8045 | 0.8837 | 0.8993 | [model](https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow_5e_st_mj/satrn_shallow_5e_st_mj_20220915_152443-5fd04a4c.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow_5e_st_mj/20220915_152443.log) |
| [Satrn-TTA](/configs/textrecog/satrn/satrn_shallow_5e_st_mj.py) | 0.9530 | 0.9181 | 0.9527 | | 0.8276 | 0.8884 | 0.9028 | |
| [Satrn_small](/configs/textrecog/satrn/satrn_shallow-small_5e_st_mj.py) | 0.9423 | 0.9011 | 0.9567 | | 0.7886 | 0.8574 | 0.8472 | [model](https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow-small_5e_st_mj/satrn_shallow-small_5e_st_mj_20220915_152442-5591bf27.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow-small_5e_st_mj/20220915_152442.log) |
| [Satrn_small-TTA](/configs/textrecog/satrn/satrn_shallow-small_5e_st_mj.py) | 0.9380 | 0.8995 | 0.9488 | | 0.8122 | 0.8620 | 0.8507 | |
## Citation

View File

@ -54,7 +54,6 @@ train_pipeline = [
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
# TODO Add Test Time Augmentation `MultiRotateAugOCR`
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(100, 32), keep_ratio=False),
@ -65,3 +64,50 @@ test_pipeline = [
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[dict(type='Resize', scale=(100, 32), keep_ratio=False)],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]

View File

@ -35,11 +35,13 @@ Dominant scene text recognition models commonly contain two building blocks, a v
## Results and Models
| Methods | | Regular Text | | | | Irregular Text | | download |
| :-----------------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :------------------------------------------------------------------------------: |
| :---------------------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :--------------------------------------------------------------------------: |
| | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
| [SVTR-tiny](/configs/textrecog/svtr/svtr-tiny_20e_st_mj.py) | - | - | - | | - | - | - | [model](<>) \| [log](<>) |
| [SVTR-small](/configs/textrecog/svtr/svtr-small_20e_st_mj.py) | 0.8553 | 0.9026 | 0.9448 | | 0.7496 | 0.8496 | 0.8854 | [model](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/svtr-small_20e_st_mj-35d800d6.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/20230105_184454.log) |
| [SVTR-small-TTA](/configs/textrecog/svtr/svtr-small_20e_st_mj.py) | 0.8397 | 0.8964 | 0.9241 | | 0.7597 | 0.8124 | 0.8646 | |
| [SVTR-base](/configs/textrecog/svtr/svtr-base_20e_st_mj.py) | 0.8570 | 0.9181 | 0.9438 | | 0.7448 | 0.8388 | 0.9028 | [model](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/svtr-base_20e_st_mj-ea500101.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/20221227_175415.log) |
| [SVTR-base-TTA](/configs/textrecog/svtr/svtr-base_20e_st_mj.py) | 0.8517 | 0.9011 | 0.9379 | | 0.7569 | 0.8279 | 0.8819 | |
| [SVTR-large](/configs/textrecog/svtr/svtr-large_20e_st_mj.py) | - | - | - | | - | - | - | [model](<>) \| [log](<>) |
```{note}

View File

@ -36,3 +36,130 @@ model = dict(
dictionary=dictionary),
data_preprocessor=dict(
type='TextRecogDataPreprocessor', mean=[127.5], std=[127.5]))
file_client_args = dict(backend='disk')
train_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args=file_client_args,
ignore_empty=True,
min_size=5),
dict(type='LoadOCRAnnotations', with_text=True),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(type='TextRecogGeneralAug', ),
],
),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(type='CropHeight', ),
],
),
dict(
type='ConditionApply',
condition='min(results["img_shape"])>10',
true_transforms=dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(
type='TorchVisionWrapper',
op='GaussianBlur',
kernel_size=5,
sigma=1,
),
],
)),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(
type='TorchVisionWrapper',
op='ColorJitter',
brightness=0.5,
saturation=0.5,
contrast=0.5,
hue=0.1),
]),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(type='ImageContentJitter', ),
],
),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='AdditiveGaussianNoise', scale=0.1**0.5)]),
],
),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(type='ReversePixels', ),
],
),
dict(type='Resize', scale=(256, 64)),
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(256, 64)),
dict(type='LoadOCRAnnotations', with_text=True),
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"),
], [dict(type='Resize', scale=(256, 64))],
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]])
]

View File

@ -40,94 +40,6 @@ param_scheduler = [
convert_to_iter_based=True),
]
file_client_args = dict(backend='disk')
train_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args=file_client_args,
ignore_empty=True,
min_size=5),
dict(type='LoadOCRAnnotations', with_text=True),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(type='TextRecogGeneralAug', ),
],
),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(type='CropHeight', ),
],
),
dict(
type='ConditionApply',
condition='min(results["img_shape"])>10',
true_transforms=dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(
type='TorchVisionWrapper',
op='GaussianBlur',
kernel_size=5,
sigma=1,
),
],
)),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(
type='TorchVisionWrapper',
op='ColorJitter',
brightness=0.5,
saturation=0.5,
contrast=0.5,
hue=0.1),
]),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(type='ImageContentJitter', ),
],
),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='AdditiveGaussianNoise', scale=0.1**0.5)]),
],
),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(type='ReversePixels', ),
],
),
dict(type='Resize', scale=(256, 64)),
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(256, 64)),
dict(type='LoadOCRAnnotations', with_text=True),
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
# dataset settings
train_list = [_base_.mjsynth_textrecog_test, _base_.synthtext_textrecog_train]
test_list = [
@ -147,7 +59,9 @@ train_dataloader = dict(
pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='ConcatDataset', datasets=train_list, pipeline=train_pipeline))
type='ConcatDataset',
datasets=train_list,
pipeline=_base_.train_pipeline))
val_dataloader = dict(
batch_size=128,
@ -157,6 +71,8 @@ val_dataloader = dict(
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='ConcatDataset', datasets=test_list, pipeline=test_pipeline))
type='ConcatDataset',
datasets=test_list,
pipeline=_base_.test_pipeline))
test_dataloader = val_dataloader

View File

@ -36,6 +36,7 @@ The following table lists all the arguments supported by `train.py`. Args withou
| --cfg-options | str | Override some settings in the configs. [Example](<>) |
| --launcher | str | Option for launcher\['none', 'pytorch', 'slurm', 'mpi'\]. |
| --local_rank | int | Rank of local machineused for distributed trainingdefaults to 0。 |
| --tta | bool | Whether to use test time augmentation. |
### Test
@ -308,3 +309,15 @@ The visualization-related parameters in `tools/test.py` are described as follows
| --show | bool | Whether to show the visualization results. |
| --show-dir | str | Path to save the visualization results. |
| --wait-time | float | Interval of visualization (s), defaults to 2. |
### Test Time Augmentation
Test time augmentation (TTA) is a technique that is used to improve the performance of a model by performing data augmentation on the input image at test time. It is a simple yet effective method to improve the performance of a model. In MMOCR, we support TTA in the following ways:
```{note}
TTA is only supported for text recognition models.
```
```bash
python tools/test.py configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py checkpoints/crnn_mini-vgg_5e_mj.pth --tta
```

View File

@ -66,6 +66,7 @@ CUDA_VISIBLE_DEVICES=0 python tools/test.py configs/textdet/dbnet/dbnet_resnet50
| --cfg-options | str | 用于覆写配置文件中的指定参数。[示例](#添加示例) |
| --launcher | str | 启动器选项,可选项目为 \['none', 'pytorch', 'slurm', 'mpi'\]。 |
| --local_rank | int | 本地机器编号,用于多机多卡分布式训练,默认为 0。 |
| --tta | bool | 是否使用测试时数据增强 |
## 多卡机器训练及测试
@ -308,3 +309,16 @@ python tools/test.py configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.p
| --show | bool | 是否绘制可视化结果。 |
| --show-dir | str | 可视化图片存储路径。 |
| --wait-time | float | 可视化间隔时间(秒),默认为 2。 |
### 测试时数据增强
测试时增强,指的是在推理(预测)阶段,将原始图片进行水平翻转、垂直翻转、对角线翻转、旋转角度等数据增强操作,得到多张图,分别进行推理,再对多个结果进行综合分析,得到最终输出结果。
为此MMOCR 提供了一键式测试时数据增强,仅需在测试时添加 `--tta` 参数即可。
```{note}
TTA 仅支持文本识别模型。
```
```bash
python tools/test.py configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py checkpoints/crnn_mini-vgg_5e_mj.pth --tta
```

View File

@ -4,6 +4,7 @@ from .aster import ASTER
from .base import BaseRecognizer
from .crnn import CRNN
from .encoder_decoder_recognizer import EncoderDecoderRecognizer
from .encoder_decoder_recognizer_tta import EncoderDecoderRecognizerTTAModel
from .master import MASTER
from .nrtr import NRTR
from .robust_scanner import RobustScanner
@ -13,5 +14,6 @@ from .svtr import SVTR
__all__ = [
'BaseRecognizer', 'EncoderDecoderRecognizer', 'CRNN', 'SARNet', 'NRTR',
'RobustScanner', 'SATRN', 'ABINet', 'MASTER', 'SVTR', 'ASTER'
'RobustScanner', 'SATRN', 'ABINet', 'MASTER', 'SVTR', 'ASTER',
'EncoderDecoderRecognizerTTAModel'
]

View File

@ -0,0 +1,100 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import numpy as np
from mmengine.model import BaseTTAModel
from mmocr.registry import MODELS
from mmocr.utils.typing_utils import RecSampleList
@MODELS.register_module()
class EncoderDecoderRecognizerTTAModel(BaseTTAModel):
"""Merge augmented recognition results. It will select the best result
according average scores from all augmented results.
Examples:
>>> tta_model = dict(
>>> type='EncoderDecoderRecognizerTTAModel')
>>>
>>> tta_pipeline = [
>>> dict(
>>> type='LoadImageFromFile',
>>> color_type='grayscale',
>>> file_client_args=file_client_args),
>>> dict(
>>> type='TestTimeAug',
>>> transforms=[
>>> [
>>> dict(
>>> type='ConditionApply',
>>> true_transforms=[
>>> dict(
>>> type='ImgAugWrapper',
>>> args=[dict(cls='Rot90', k=0, keep_size=False)]) # noqa: E501
>>> ],
>>> condition="results['img_shape'][1]<results['img_shape'][0]" # noqa: E501
>>> ),
>>> dict(
>>> type='ConditionApply',
>>> true_transforms=[
>>> dict(
>>> type='ImgAugWrapper',
>>> args=[dict(cls='Rot90', k=1, keep_size=False)]) # noqa: E501
>>> ],
>>> condition="results['img_shape'][1]<results['img_shape'][0]" # noqa: E501
>>> ),
>>> dict(
>>> type='ConditionApply',
>>> true_transforms=[
>>> dict(
>>> type='ImgAugWrapper',
>>> args=[dict(cls='Rot90', k=3, keep_size=False)])
>>> ],
>>> condition="results['img_shape'][1]<results['img_shape'][0]"
>>> ),
>>> ],
>>> [
>>> dict(
>>> type='RescaleToHeight',
>>> height=32,
>>> min_width=32,
>>> max_width=None,
>>> width_divisor=16)
>>> ],
>>> # add loading annotation after ``Resize`` because ground truth
>>> # does not need to do resize data transform
>>> [dict(type='LoadOCRAnnotations', with_text=True)],
>>> [
>>> dict(
>>> type='PackTextRecogInputs',
>>> meta_keys=('img_path', 'ori_shape', 'img_shape',
>>> 'valid_ratio'))
>>> ]
>>> ])
>>> ]
"""
def merge_preds(self,
data_samples_list: List[RecSampleList]) -> RecSampleList:
"""Merge predictions of enhanced data to one prediction.
Args:
data_samples_list (List[RecSampleList]): List of predictions of
all enhanced data. The shape of data_samples_list is (B, M),
where B is the batch size and M is the number of augmented
data.
Returns:
RecSampleList: Merged prediction.
"""
predictions = list()
for data_samples in data_samples_list:
scores = [
data_sample.pred_text.score for data_sample in data_samples
]
average_scores = np.array(
[sum(score) / max(1, len(score)) for score in scores])
max_idx = np.argmax(average_scores)
predictions.append(data_samples[max_idx])
return predictions

View File

@ -0,0 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
import torch.nn as nn
from mmengine.structures import LabelData
from mmocr.models.textrecog.recognizers import EncoderDecoderRecognizerTTAModel
from mmocr.structures import TextRecogDataSample
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
def test_step(self, x):
return self.forward(x)
class TestEncoderDecoderRecognizerTTAModel(TestCase):
def test_merge_preds(self):
data_sample1 = TextRecogDataSample(
pred_text=LabelData(
score=torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]), text='abcde'))
data_sample2 = TextRecogDataSample(
pred_text=LabelData(
score=torch.tensor([0.2, 0.3, 0.4, 0.5, 0.6]), text='bcdef'))
data_sample3 = TextRecogDataSample(
pred_text=LabelData(
score=torch.tensor([0.3, 0.4, 0.5, 0.6, 0.7]), text='cdefg'))
aug_data_samples = [data_sample1, data_sample2, data_sample3]
batch_aug_data_samples = [aug_data_samples] * 3
model = EncoderDecoderRecognizerTTAModel(module=DummyModel())
preds = model.merge_preds(batch_aug_data_samples)
for pred in preds:
self.assertEqual(pred.pred_text.text, 'cdefg')

View File

@ -45,6 +45,8 @@ def parse_args():
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='Job launcher')
parser.add_argument(
'--tta', action='store_true', help='Test time augmentation')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
@ -107,6 +109,11 @@ def main():
if args.show or args.show_dir:
cfg = trigger_visualization_hook(cfg, args)
if args.tta:
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
cfg.tta_model.module = cfg.model
cfg.model = cfg.tta_model
# save predictions
if args.save_preds:
dump_metric = dict(