[Feature] Rec TTA (#1401)

* Support TTA for recognition

* updata readme

* updata abinet readme

* updata train_test doc for tta
This commit is contained in:
liukuikun 2023-02-16 10:27:07 +08:00 committed by GitHub
parent 7cea6a6419
commit f820470415
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 809 additions and 107 deletions

View File

@ -46,3 +46,5 @@ visualizer = dict(
type='TextRecogLocalVisualizer', type='TextRecogLocalVisualizer',
name='visualizer', name='visualizer',
vis_backends=vis_backends) vis_backends=vis_backends)
tta_model = dict(type='EncoderDecoderRecognizerTTAModel')

View File

@ -38,7 +38,9 @@ Linguistic knowledge is of great benefit to scene text recognition. However, how
| :--------------------------------------------: | :------------------------------------------------: | :----: | :----------: | :-------: | :-------: | :------------: | :----: | :----------------------------------------------- | | :--------------------------------------------: | :------------------------------------------------: | :----: | :----------: | :-------: | :-------: | :------------: | :----: | :----------------------------------------------- |
| | | IIIT5K | SVT | IC13-1015 | IC15-2077 | SVTP | CT80 | | | | | IIIT5K | SVT | IC13-1015 | IC15-2077 | SVTP | CT80 | |
| [ABINet-Vision](/configs/textrecog/abinet/abinet-vision_20e_st-an_mj.py) | - | 0.9523 | 0.9196 | 0.9369 | 0.7896 | 0.8403 | 0.8437 | [model](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet-vision_20e_st-an_mj/abinet-vision_20e_st-an_mj_20220915_152445-85cfb03d.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet-vision_20e_st-an_mj/20220915_152445.log) | | [ABINet-Vision](/configs/textrecog/abinet/abinet-vision_20e_st-an_mj.py) | - | 0.9523 | 0.9196 | 0.9369 | 0.7896 | 0.8403 | 0.8437 | [model](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet-vision_20e_st-an_mj/abinet-vision_20e_st-an_mj_20220915_152445-85cfb03d.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet-vision_20e_st-an_mj/20220915_152445.log) |
| [ABINet-Vision-TTA](/configs/textrecog/abinet/abinet-vision_20e_st-an_mj.py) | - | 0.9523 | 0.9196 | 0.9360 | 0.8175 | 0.8450 | 0.8542 | |
| [ABINet](/configs/textrecog/abinet/abinet_20e_st-an_mj.py) | [Pretrained](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_pretrain-45deac15.pth) | 0.9603 | 0.9397 | 0.9557 | 0.8146 | 0.8868 | 0.8785 | [model](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_20e_st-an_mj/abinet_20e_st-an_mj_20221005_012617-ead8c139.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_20e_st-an_mj/20221005_012617.log) | | [ABINet](/configs/textrecog/abinet/abinet_20e_st-an_mj.py) | [Pretrained](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_pretrain-45deac15.pth) | 0.9603 | 0.9397 | 0.9557 | 0.8146 | 0.8868 | 0.8785 | [model](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_20e_st-an_mj/abinet_20e_st-an_mj_20221005_012617-ead8c139.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_20e_st-an_mj/20221005_012617.log) |
| [ABINet-TTA](/configs/textrecog/abinet/abinet_20e_st-an_mj.py) | [Pretrained](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_pretrain-45deac15.pth) | 0.9597 | 0.9397 | 0.9527 | 0.8426 | 0.8930 | 0.8854 | |
```{note} ```{note}
1. ABINet allows its encoder to run and be trained without decoder and fuser. Its encoder is designed to recognize texts as a stand-alone model and therefore can work as an independent text recognizer. We release it as ABINet-Vision. 1. ABINet allows its encoder to run and be trained without decoder and fuser. Its encoder is designed to recognize texts as a stand-alone model and therefore can work as an independent text recognizer. We release it as ABINet-Vision.

View File

@ -116,3 +116,50 @@ test_pipeline = [
type='PackTextRecogInputs', type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
] ]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[dict(type='Resize', scale=(128, 32))],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]

View File

@ -34,10 +34,11 @@ A challenging aspect of scene text recognition is to handle text with distortion
## Results and models ## Results and models
| Methods | Backbone | | Regular Text | | | | Irregular Text | | download | | Methods | Backbone | | Regular Text | | | | Irregular Text | | download |
| :----------------------------------------------------------: | :------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-----------------------------------------------------------------------: | | :--------------------------------------------------------------: | :------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-------------------------------------------------------------------: |
| | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | | | | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
| [ASTER](/configs/textrecog/aster/aster_resnet45_6e_st_mj.py) | ResNet45 | 0.9357 | 0.8949 | 0.9281 | | 0.7665 | 0.8062 | 0.8507 | [model](https://download.openmmlab.com/mmocr/textrecog/aster/aster_resnet45_6e_st_mj/aster_resnet45_6e_st_mj-cc56eca4.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/aster/aster_resnet45_6e_st_mj/20221214_232605.log) | | [ASTER](/configs/textrecog/aster/aster_resnet45_6e_st_mj.py) | ResNet45 | 0.9357 | 0.8949 | 0.9281 | | 0.7665 | 0.8062 | 0.8507 | [model](https://download.openmmlab.com/mmocr/textrecog/aster/aster_resnet45_6e_st_mj/aster_resnet45_6e_st_mj-cc56eca4.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/aster/aster_resnet45_6e_st_mj/20221214_232605.log) |
| [ASTER-TTA](/configs/textrecog/aster/aster_resnet45_6e_st_mj.py) | ResNet45 | 0.9337 | 0.8949 | 0.9251 | | 0.7925 | 0.8109 | 0.8507 | |
## Citation ## Citation

View File

@ -69,3 +69,42 @@ test_pipeline = [
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio', meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio',
'instances')) 'instances'))
] ]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"),
], [dict(type='Resize', scale=(256, 64))],
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio', 'instances'))
]])
]

View File

@ -33,10 +33,11 @@ Image-based sequence recognition has been a long-standing research topic in comp
## Results and models ## Results and models
| methods | | Regular Text | | | | Irregular Text | | download | | methods | | Regular Text | | | | Irregular Text | | download |
| :----------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-------------------------------------------------------------------------------------: | | :--------------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :---------------------------------------------------------------------------------: |
| methods | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | | | methods | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
| [CRNN](/configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py) | 0.8053 | 0.7991 | 0.8739 | | 0.5571 | 0.6093 | 0.5694 | [model](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/20220826_224120.log) | | [CRNN](/configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py) | 0.8053 | 0.7991 | 0.8739 | | 0.5571 | 0.6093 | 0.5694 | [model](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/20220826_224120.log) |
| [CRNN-TTA](/configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py) | 0.8013 | 0.7975 | 0.8631 | | 0.5763 | 0.6093 | 0.5764 | [model](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/20220826_224120.log) |
## Citation ## Citation

View File

@ -51,3 +51,60 @@ test_pipeline = [
type='PackTextRecogInputs', type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
] ]
tta_pipeline = [
dict(
type='LoadImageFromFile',
color_type='grayscale',
file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[
dict(
type='RescaleToHeight',
height=32,
min_width=32,
max_width=None,
width_divisor=16)
],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]

View File

@ -39,6 +39,7 @@ Attention-based scene text recognizers have gained huge success, which leverages
| :-------------------------------------------------------------: | :-----------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :---------------------------------------------------------------: | | :-------------------------------------------------------------: | :-----------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :---------------------------------------------------------------: |
| | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | | | | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
| [MASTER](/configs/textrecog/master/master_resnet31_12e_st_mj_sa.py) | R31-GCAModule | 0.9490 | 0.8887 | 0.9517 | | 0.7650 | 0.8465 | 0.8889 | [model](https://download.openmmlab.com/mmocr/textrecog/master/master_resnet31_12e_st_mj_sa/master_resnet31_12e_st_mj_sa_20220915_152443-f4a5cabc.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/master/master_resnet31_12e_st_mj_sa/20220915_152443.log) | | [MASTER](/configs/textrecog/master/master_resnet31_12e_st_mj_sa.py) | R31-GCAModule | 0.9490 | 0.8887 | 0.9517 | | 0.7650 | 0.8465 | 0.8889 | [model](https://download.openmmlab.com/mmocr/textrecog/master/master_resnet31_12e_st_mj_sa/master_resnet31_12e_st_mj_sa_20220915_152443-f4a5cabc.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/master/master_resnet31_12e_st_mj_sa/20220915_152443.log) |
| [MASTER-TTA](/configs/textrecog/master/master_resnet31_12e_st_mj_sa.py) | R31-GCAModule | 0.9450 | 0.8887 | 0.9478 | | 0.7906 | 0.8481 | 0.8958 | |
## Citation ## Citation

View File

@ -109,3 +109,58 @@ test_pipeline = [
type='PackTextRecogInputs', type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
] ]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[
dict(
type='RescaleToHeight',
height=48,
min_width=48,
max_width=160,
width_divisor=16)
],
[dict(type='PadToWidth', width=160)],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]

View File

@ -38,8 +38,11 @@ Scene text recognition has attracted a great many researches due to its importan
| :---------------------------------------------------------: | :-------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-----------------------------------------------------------: | | :---------------------------------------------------------: | :-------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-----------------------------------------------------------: |
| | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | | | | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
| [NRTR](/configs/textrecog/nrtr/nrtr_modality-transform_6e_st_mj.py) | NRTRModalityTransform | 0.9147 | 0.8841 | 0.9369 | | 0.7246 | 0.7783 | 0.7500 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_modality-transform_6e_st_mj/nrtr_modality-transform_6e_st_mj_20220916_103322-bd9425be.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_modality-transform_6e_st_mj/20220916_103322.log) | | [NRTR](/configs/textrecog/nrtr/nrtr_modality-transform_6e_st_mj.py) | NRTRModalityTransform | 0.9147 | 0.8841 | 0.9369 | | 0.7246 | 0.7783 | 0.7500 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_modality-transform_6e_st_mj/nrtr_modality-transform_6e_st_mj_20220916_103322-bd9425be.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_modality-transform_6e_st_mj/20220916_103322.log) |
| [NRTR-TTA](/configs/textrecog/nrtr/nrtr_modality-transform_6e_st_mj.py) | NRTRModalityTransform | 0.9123 | 0.8825 | 0.9310 | | 0.7492 | 0.7798 | 0.7535 | |
| [NRTR](/configs/textrecog/nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj.py) | R31-1/8-1/4 | 0.9483 | 0.8918 | 0.9507 | | 0.7578 | 0.8016 | 0.8889 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj/nrtr_resnet31-1by8-1by4_6e_st_mj_20220916_103322-a6a2a123.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj/20220916_103322.log) | | [NRTR](/configs/textrecog/nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj.py) | R31-1/8-1/4 | 0.9483 | 0.8918 | 0.9507 | | 0.7578 | 0.8016 | 0.8889 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj/nrtr_resnet31-1by8-1by4_6e_st_mj_20220916_103322-a6a2a123.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj/20220916_103322.log) |
| [NRTR-TTA](/configs/textrecog/nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj.py) | R31-1/8-1/4 | 0.9443 | 0.8903 | 0.9478 | | 0.7790 | 0.8078 | 0.8854 | |
| [NRTR](/configs/textrecog/nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj.py) | R31-1/16-1/8 | 0.9470 | 0.8918 | 0.9399 | | 0.7376 | 0.7969 | 0.8854 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj/nrtr_resnet31-1by16-1by8_6e_st_mj_20220920_143358-43767036.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj/20220920_143358.log) | | [NRTR](/configs/textrecog/nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj.py) | R31-1/16-1/8 | 0.9470 | 0.8918 | 0.9399 | | 0.7376 | 0.7969 | 0.8854 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj/nrtr_resnet31-1by16-1by8_6e_st_mj_20220920_143358-43767036.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj/20220920_143358.log) |
| [NRTR-TTA](/configs/textrecog/nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj.py) | R31-1/16-1/8 | 0.9423 | 0.8903 | 0.9360 | | 0.7641 | 0.8016 | 0.8854 | |
## Citation ## Citation

View File

@ -60,3 +60,58 @@ test_pipeline = [
type='PackTextRecogInputs', type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
] ]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[
dict(
type='RescaleToHeight',
height=32,
min_width=32,
max_width=160,
width_divisor=16)
],
[dict(type='PadToWidth', width=160)],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]

View File

@ -66,3 +66,58 @@ test_pipeline = [
type='PackTextRecogInputs', type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
] ]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[
dict(
type='RescaleToHeight',
height=32,
min_width=32,
max_width=160,
width_divisor=16)
],
[dict(type='PadToWidth', width=160)],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]

View File

@ -44,6 +44,7 @@ The attention-based encoder-decoder framework has recently achieved impressive r
| :------------------------------------------------------------------: | :--: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-------------------------------------------------------------------: | | :------------------------------------------------------------------: | :--: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-------------------------------------------------------------------: |
| | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | | | | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
| [RobustScanner](/configs/textrecog/robust_scanner/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real.py) | 4 | 0.9510 | 0.9011 | 0.9320 | | 0.7578 | 0.8078 | 0.8750 | [model](https://download.openmmlab.com/mmocr/textrecog/robust_scanner/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real_20220915_152447-7fc35929.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/robust_scanner/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real/20220915_152447.log) | | [RobustScanner](/configs/textrecog/robust_scanner/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real.py) | 4 | 0.9510 | 0.9011 | 0.9320 | | 0.7578 | 0.8078 | 0.8750 | [model](https://download.openmmlab.com/mmocr/textrecog/robust_scanner/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real_20220915_152447-7fc35929.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/robust_scanner/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real/20220915_152447.log) |
| [RobustScanner-TTA](/configs/textrecog/robust_scanner/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real.py) | 4 | 0.9487 | 0.9011 | 0.9261 | | 0.7805 | 0.8124 | 0.8819 | |
## References ## References

View File

@ -66,3 +66,58 @@ test_pipeline = [
type='PackTextRecogInputs', type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
] ]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[
dict(
type='RescaleToHeight',
height=48,
min_width=48,
max_width=160,
width_divisor=4),
],
[dict(type='PadToWidth', width=160)],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]

View File

@ -44,7 +44,9 @@ Recognizing irregular text in natural scene images is challenging due to the lar
| :----------------------------------------------------: | :---------: | :------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :------------------------------------------------------: | | :----------------------------------------------------: | :---------: | :------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :------------------------------------------------------: |
| | | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | | | | | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
| [SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py) | R31-1/8-1/4 | ParallelSARDecoder | 0.9533 | 0.8964 | 0.9369 | | 0.7602 | 0.8326 | 0.9062 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real/sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real_20220915_171910-04eb4e75.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real/20220915_171910.log) | | [SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py) | R31-1/8-1/4 | ParallelSARDecoder | 0.9533 | 0.8964 | 0.9369 | | 0.7602 | 0.8326 | 0.9062 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real/sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real_20220915_171910-04eb4e75.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real/20220915_171910.log) |
| [SAR-TTA](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py) | R31-1/8-1/4 | ParallelSARDecoder | 0.9510 | 0.8964 | 0.9340 | | 0.7862 | 0.8372 | 0.9132 | |
| [SAR](/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py) | R31-1/8-1/4 | SequentialSARDecoder | 0.9553 | 0.9073 | 0.9409 | | 0.7761 | 0.8093 | 0.8958 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_sequential-decoder_5e_st-sub_mj-sub_sa_real/sar_resnet31_sequential-decoder_5e_st-sub_mj-sub_sa_real_20220915_185451-1fd6b1fc.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_sequential-decoder_5e_st-sub_mj-sub_sa_real/20220915_185451.log) | | [SAR](/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py) | R31-1/8-1/4 | SequentialSARDecoder | 0.9553 | 0.9073 | 0.9409 | | 0.7761 | 0.8093 | 0.8958 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_sequential-decoder_5e_st-sub_mj-sub_sa_real/sar_resnet31_sequential-decoder_5e_st-sub_mj-sub_sa_real_20220915_185451-1fd6b1fc.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_sequential-decoder_5e_st-sub_mj-sub_sa_real/20220915_185451.log) |
| [SAR-TTA](/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py) | R31-1/8-1/4 | SequentialSARDecoder | 0.9530 | 0.9073 | 0.9389 | | 0.8002 | 0.8124 | 0.9028 | |
## Citation ## Citation

View File

@ -71,3 +71,58 @@ test_pipeline = [
type='PackTextRecogInputs', type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
] ]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[
dict(
type='RescaleToHeight',
height=48,
min_width=48,
max_width=160,
width_divisor=4)
],
[dict(type='PadToWidth', width=160)],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]

View File

@ -38,7 +38,9 @@ Scene text recognition (STR) is the task of recognizing character sequences in n
| :--------------------------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :---------------------------------------------------------------------: | | :--------------------------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :---------------------------------------------------------------------: |
| | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | | | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
| [Satrn](/configs/textrecog/satrn/satrn_shallow_5e_st_mj.py) | 0.9600 | 0.9181 | 0.9606 | | 0.8045 | 0.8837 | 0.8993 | [model](https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow_5e_st_mj/satrn_shallow_5e_st_mj_20220915_152443-5fd04a4c.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow_5e_st_mj/20220915_152443.log) | | [Satrn](/configs/textrecog/satrn/satrn_shallow_5e_st_mj.py) | 0.9600 | 0.9181 | 0.9606 | | 0.8045 | 0.8837 | 0.8993 | [model](https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow_5e_st_mj/satrn_shallow_5e_st_mj_20220915_152443-5fd04a4c.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow_5e_st_mj/20220915_152443.log) |
| [Satrn-TTA](/configs/textrecog/satrn/satrn_shallow_5e_st_mj.py) | 0.9530 | 0.9181 | 0.9527 | | 0.8276 | 0.8884 | 0.9028 | |
| [Satrn_small](/configs/textrecog/satrn/satrn_shallow-small_5e_st_mj.py) | 0.9423 | 0.9011 | 0.9567 | | 0.7886 | 0.8574 | 0.8472 | [model](https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow-small_5e_st_mj/satrn_shallow-small_5e_st_mj_20220915_152442-5591bf27.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow-small_5e_st_mj/20220915_152442.log) | | [Satrn_small](/configs/textrecog/satrn/satrn_shallow-small_5e_st_mj.py) | 0.9423 | 0.9011 | 0.9567 | | 0.7886 | 0.8574 | 0.8472 | [model](https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow-small_5e_st_mj/satrn_shallow-small_5e_st_mj_20220915_152442-5591bf27.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow-small_5e_st_mj/20220915_152442.log) |
| [Satrn_small-TTA](/configs/textrecog/satrn/satrn_shallow-small_5e_st_mj.py) | 0.9380 | 0.8995 | 0.9488 | | 0.8122 | 0.8620 | 0.8507 | |
## Citation ## Citation

View File

@ -54,7 +54,6 @@ train_pipeline = [
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
] ]
# TODO Add Test Time Augmentation `MultiRotateAugOCR`
test_pipeline = [ test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args), dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(100, 32), keep_ratio=False), dict(type='Resize', scale=(100, 32), keep_ratio=False),
@ -65,3 +64,50 @@ test_pipeline = [
type='PackTextRecogInputs', type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
] ]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[dict(type='Resize', scale=(100, 32), keep_ratio=False)],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]

View File

@ -34,13 +34,15 @@ Dominant scene text recognition models commonly contain two building blocks, a v
## Results and Models ## Results and Models
| Methods | | Regular Text | | | | Irregular Text | | download | | Methods | | Regular Text | | | | Irregular Text | | download |
| :-----------------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :------------------------------------------------------------------------------: | | :---------------------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :--------------------------------------------------------------------------: |
| | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | | | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
| [SVTR-tiny](/configs/textrecog/svtr/svtr-tiny_20e_st_mj.py) | - | - | - | | - | - | - | [model](<>) \| [log](<>) | | [SVTR-tiny](/configs/textrecog/svtr/svtr-tiny_20e_st_mj.py) | - | - | - | | - | - | - | [model](<>) \| [log](<>) |
| [SVTR-small](/configs/textrecog/svtr/svtr-small_20e_st_mj.py) | 0.8553 | 0.9026 | 0.9448 | | 0.7496 | 0.8496 | 0.8854 | [model](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/svtr-small_20e_st_mj-35d800d6.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/20230105_184454.log) | | [SVTR-small](/configs/textrecog/svtr/svtr-small_20e_st_mj.py) | 0.8553 | 0.9026 | 0.9448 | | 0.7496 | 0.8496 | 0.8854 | [model](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/svtr-small_20e_st_mj-35d800d6.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/20230105_184454.log) |
| [SVTR-base](/configs/textrecog/svtr/svtr-base_20e_st_mj.py) | 0.8570 | 0.9181 | 0.9438 | | 0.7448 | 0.8388 | 0.9028 | [model](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/svtr-base_20e_st_mj-ea500101.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/20221227_175415.log) | | [SVTR-small-TTA](/configs/textrecog/svtr/svtr-small_20e_st_mj.py) | 0.8397 | 0.8964 | 0.9241 | | 0.7597 | 0.8124 | 0.8646 | |
| [SVTR-large](/configs/textrecog/svtr/svtr-large_20e_st_mj.py) | - | - | - | | - | - | - | [model](<>) \| [log](<>) | | [SVTR-base](/configs/textrecog/svtr/svtr-base_20e_st_mj.py) | 0.8570 | 0.9181 | 0.9438 | | 0.7448 | 0.8388 | 0.9028 | [model](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/svtr-base_20e_st_mj-ea500101.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/20221227_175415.log) |
| [SVTR-base-TTA](/configs/textrecog/svtr/svtr-base_20e_st_mj.py) | 0.8517 | 0.9011 | 0.9379 | | 0.7569 | 0.8279 | 0.8819 | |
| [SVTR-large](/configs/textrecog/svtr/svtr-large_20e_st_mj.py) | - | - | - | | - | - | - | [model](<>) \| [log](<>) |
```{note} ```{note}
The implementation and configuration follow the original code and paper, but there is still a gap between the reproduced results and the official ones. We appreciate any suggestions to improve its performance. The implementation and configuration follow the original code and paper, but there is still a gap between the reproduced results and the official ones. We appreciate any suggestions to improve its performance.

View File

@ -36,3 +36,130 @@ model = dict(
dictionary=dictionary), dictionary=dictionary),
data_preprocessor=dict( data_preprocessor=dict(
type='TextRecogDataPreprocessor', mean=[127.5], std=[127.5])) type='TextRecogDataPreprocessor', mean=[127.5], std=[127.5]))
file_client_args = dict(backend='disk')
train_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args=file_client_args,
ignore_empty=True,
min_size=5),
dict(type='LoadOCRAnnotations', with_text=True),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(type='TextRecogGeneralAug', ),
],
),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(type='CropHeight', ),
],
),
dict(
type='ConditionApply',
condition='min(results["img_shape"])>10',
true_transforms=dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(
type='TorchVisionWrapper',
op='GaussianBlur',
kernel_size=5,
sigma=1,
),
],
)),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(
type='TorchVisionWrapper',
op='ColorJitter',
brightness=0.5,
saturation=0.5,
contrast=0.5,
hue=0.1),
]),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(type='ImageContentJitter', ),
],
),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='AdditiveGaussianNoise', scale=0.1**0.5)]),
],
),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(type='ReversePixels', ),
],
),
dict(type='Resize', scale=(256, 64)),
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(256, 64)),
dict(type='LoadOCRAnnotations', with_text=True),
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[[
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"),
dict(
type='ConditionApply',
true_transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"),
], [dict(type='Resize', scale=(256, 64))],
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]])
]

View File

@ -40,94 +40,6 @@ param_scheduler = [
convert_to_iter_based=True), convert_to_iter_based=True),
] ]
file_client_args = dict(backend='disk')
train_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args=file_client_args,
ignore_empty=True,
min_size=5),
dict(type='LoadOCRAnnotations', with_text=True),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(type='TextRecogGeneralAug', ),
],
),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(type='CropHeight', ),
],
),
dict(
type='ConditionApply',
condition='min(results["img_shape"])>10',
true_transforms=dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(
type='TorchVisionWrapper',
op='GaussianBlur',
kernel_size=5,
sigma=1,
),
],
)),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(
type='TorchVisionWrapper',
op='ColorJitter',
brightness=0.5,
saturation=0.5,
contrast=0.5,
hue=0.1),
]),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(type='ImageContentJitter', ),
],
),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='AdditiveGaussianNoise', scale=0.1**0.5)]),
],
),
dict(
type='RandomApply',
prob=0.4,
transforms=[
dict(type='ReversePixels', ),
],
),
dict(type='Resize', scale=(256, 64)),
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(256, 64)),
dict(type='LoadOCRAnnotations', with_text=True),
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
# dataset settings # dataset settings
train_list = [_base_.mjsynth_textrecog_test, _base_.synthtext_textrecog_train] train_list = [_base_.mjsynth_textrecog_test, _base_.synthtext_textrecog_train]
test_list = [ test_list = [
@ -147,7 +59,9 @@ train_dataloader = dict(
pin_memory=True, pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=True), sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict( dataset=dict(
type='ConcatDataset', datasets=train_list, pipeline=train_pipeline)) type='ConcatDataset',
datasets=train_list,
pipeline=_base_.train_pipeline))
val_dataloader = dict( val_dataloader = dict(
batch_size=128, batch_size=128,
@ -157,6 +71,8 @@ val_dataloader = dict(
drop_last=False, drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False), sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict( dataset=dict(
type='ConcatDataset', datasets=test_list, pipeline=test_pipeline)) type='ConcatDataset',
datasets=test_list,
pipeline=_base_.test_pipeline))
test_dataloader = val_dataloader test_dataloader = val_dataloader

View File

@ -36,6 +36,7 @@ The following table lists all the arguments supported by `train.py`. Args withou
| --cfg-options | str | Override some settings in the configs. [Example](<>) | | --cfg-options | str | Override some settings in the configs. [Example](<>) |
| --launcher | str | Option for launcher\['none', 'pytorch', 'slurm', 'mpi'\]. | | --launcher | str | Option for launcher\['none', 'pytorch', 'slurm', 'mpi'\]. |
| --local_rank | int | Rank of local machineused for distributed trainingdefaults to 0。 | | --local_rank | int | Rank of local machineused for distributed trainingdefaults to 0。 |
| --tta | bool | Whether to use test time augmentation. |
### Test ### Test
@ -308,3 +309,15 @@ The visualization-related parameters in `tools/test.py` are described as follows
| --show | bool | Whether to show the visualization results. | | --show | bool | Whether to show the visualization results. |
| --show-dir | str | Path to save the visualization results. | | --show-dir | str | Path to save the visualization results. |
| --wait-time | float | Interval of visualization (s), defaults to 2. | | --wait-time | float | Interval of visualization (s), defaults to 2. |
### Test Time Augmentation
Test time augmentation (TTA) is a technique that is used to improve the performance of a model by performing data augmentation on the input image at test time. It is a simple yet effective method to improve the performance of a model. In MMOCR, we support TTA in the following ways:
```{note}
TTA is only supported for text recognition models.
```
```bash
python tools/test.py configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py checkpoints/crnn_mini-vgg_5e_mj.pth --tta
```

View File

@ -66,6 +66,7 @@ CUDA_VISIBLE_DEVICES=0 python tools/test.py configs/textdet/dbnet/dbnet_resnet50
| --cfg-options | str | 用于覆写配置文件中的指定参数。[示例](#添加示例) | | --cfg-options | str | 用于覆写配置文件中的指定参数。[示例](#添加示例) |
| --launcher | str | 启动器选项,可选项目为 \['none', 'pytorch', 'slurm', 'mpi'\]。 | | --launcher | str | 启动器选项,可选项目为 \['none', 'pytorch', 'slurm', 'mpi'\]。 |
| --local_rank | int | 本地机器编号,用于多机多卡分布式训练,默认为 0。 | | --local_rank | int | 本地机器编号,用于多机多卡分布式训练,默认为 0。 |
| --tta | bool | 是否使用测试时数据增强 |
## 多卡机器训练及测试 ## 多卡机器训练及测试
@ -308,3 +309,16 @@ python tools/test.py configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.p
| --show | bool | 是否绘制可视化结果。 | | --show | bool | 是否绘制可视化结果。 |
| --show-dir | str | 可视化图片存储路径。 | | --show-dir | str | 可视化图片存储路径。 |
| --wait-time | float | 可视化间隔时间(秒),默认为 2。 | | --wait-time | float | 可视化间隔时间(秒),默认为 2。 |
### 测试时数据增强
测试时增强,指的是在推理(预测)阶段,将原始图片进行水平翻转、垂直翻转、对角线翻转、旋转角度等数据增强操作,得到多张图,分别进行推理,再对多个结果进行综合分析,得到最终输出结果。
为此MMOCR 提供了一键式测试时数据增强,仅需在测试时添加 `--tta` 参数即可。
```{note}
TTA 仅支持文本识别模型。
```
```bash
python tools/test.py configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py checkpoints/crnn_mini-vgg_5e_mj.pth --tta
```

View File

@ -4,6 +4,7 @@ from .aster import ASTER
from .base import BaseRecognizer from .base import BaseRecognizer
from .crnn import CRNN from .crnn import CRNN
from .encoder_decoder_recognizer import EncoderDecoderRecognizer from .encoder_decoder_recognizer import EncoderDecoderRecognizer
from .encoder_decoder_recognizer_tta import EncoderDecoderRecognizerTTAModel
from .master import MASTER from .master import MASTER
from .nrtr import NRTR from .nrtr import NRTR
from .robust_scanner import RobustScanner from .robust_scanner import RobustScanner
@ -13,5 +14,6 @@ from .svtr import SVTR
__all__ = [ __all__ = [
'BaseRecognizer', 'EncoderDecoderRecognizer', 'CRNN', 'SARNet', 'NRTR', 'BaseRecognizer', 'EncoderDecoderRecognizer', 'CRNN', 'SARNet', 'NRTR',
'RobustScanner', 'SATRN', 'ABINet', 'MASTER', 'SVTR', 'ASTER' 'RobustScanner', 'SATRN', 'ABINet', 'MASTER', 'SVTR', 'ASTER',
'EncoderDecoderRecognizerTTAModel'
] ]

View File

@ -0,0 +1,100 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import numpy as np
from mmengine.model import BaseTTAModel
from mmocr.registry import MODELS
from mmocr.utils.typing_utils import RecSampleList
@MODELS.register_module()
class EncoderDecoderRecognizerTTAModel(BaseTTAModel):
"""Merge augmented recognition results. It will select the best result
according average scores from all augmented results.
Examples:
>>> tta_model = dict(
>>> type='EncoderDecoderRecognizerTTAModel')
>>>
>>> tta_pipeline = [
>>> dict(
>>> type='LoadImageFromFile',
>>> color_type='grayscale',
>>> file_client_args=file_client_args),
>>> dict(
>>> type='TestTimeAug',
>>> transforms=[
>>> [
>>> dict(
>>> type='ConditionApply',
>>> true_transforms=[
>>> dict(
>>> type='ImgAugWrapper',
>>> args=[dict(cls='Rot90', k=0, keep_size=False)]) # noqa: E501
>>> ],
>>> condition="results['img_shape'][1]<results['img_shape'][0]" # noqa: E501
>>> ),
>>> dict(
>>> type='ConditionApply',
>>> true_transforms=[
>>> dict(
>>> type='ImgAugWrapper',
>>> args=[dict(cls='Rot90', k=1, keep_size=False)]) # noqa: E501
>>> ],
>>> condition="results['img_shape'][1]<results['img_shape'][0]" # noqa: E501
>>> ),
>>> dict(
>>> type='ConditionApply',
>>> true_transforms=[
>>> dict(
>>> type='ImgAugWrapper',
>>> args=[dict(cls='Rot90', k=3, keep_size=False)])
>>> ],
>>> condition="results['img_shape'][1]<results['img_shape'][0]"
>>> ),
>>> ],
>>> [
>>> dict(
>>> type='RescaleToHeight',
>>> height=32,
>>> min_width=32,
>>> max_width=None,
>>> width_divisor=16)
>>> ],
>>> # add loading annotation after ``Resize`` because ground truth
>>> # does not need to do resize data transform
>>> [dict(type='LoadOCRAnnotations', with_text=True)],
>>> [
>>> dict(
>>> type='PackTextRecogInputs',
>>> meta_keys=('img_path', 'ori_shape', 'img_shape',
>>> 'valid_ratio'))
>>> ]
>>> ])
>>> ]
"""
def merge_preds(self,
data_samples_list: List[RecSampleList]) -> RecSampleList:
"""Merge predictions of enhanced data to one prediction.
Args:
data_samples_list (List[RecSampleList]): List of predictions of
all enhanced data. The shape of data_samples_list is (B, M),
where B is the batch size and M is the number of augmented
data.
Returns:
RecSampleList: Merged prediction.
"""
predictions = list()
for data_samples in data_samples_list:
scores = [
data_sample.pred_text.score for data_sample in data_samples
]
average_scores = np.array(
[sum(score) / max(1, len(score)) for score in scores])
max_idx = np.argmax(average_scores)
predictions.append(data_samples[max_idx])
return predictions

View File

@ -0,0 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
import torch.nn as nn
from mmengine.structures import LabelData
from mmocr.models.textrecog.recognizers import EncoderDecoderRecognizerTTAModel
from mmocr.structures import TextRecogDataSample
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
def test_step(self, x):
return self.forward(x)
class TestEncoderDecoderRecognizerTTAModel(TestCase):
def test_merge_preds(self):
data_sample1 = TextRecogDataSample(
pred_text=LabelData(
score=torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]), text='abcde'))
data_sample2 = TextRecogDataSample(
pred_text=LabelData(
score=torch.tensor([0.2, 0.3, 0.4, 0.5, 0.6]), text='bcdef'))
data_sample3 = TextRecogDataSample(
pred_text=LabelData(
score=torch.tensor([0.3, 0.4, 0.5, 0.6, 0.7]), text='cdefg'))
aug_data_samples = [data_sample1, data_sample2, data_sample3]
batch_aug_data_samples = [aug_data_samples] * 3
model = EncoderDecoderRecognizerTTAModel(module=DummyModel())
preds = model.merge_preds(batch_aug_data_samples)
for pred in preds:
self.assertEqual(pred.pred_text.text, 'cdefg')

View File

@ -45,6 +45,8 @@ def parse_args():
choices=['none', 'pytorch', 'slurm', 'mpi'], choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none', default='none',
help='Job launcher') help='Job launcher')
parser.add_argument(
'--tta', action='store_true', help='Test time augmentation')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ: if 'LOCAL_RANK' not in os.environ:
@ -107,6 +109,11 @@ def main():
if args.show or args.show_dir: if args.show or args.show_dir:
cfg = trigger_visualization_hook(cfg, args) cfg = trigger_visualization_hook(cfg, args)
if args.tta:
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
cfg.tta_model.module = cfg.model
cfg.model = cfg.tta_model
# save predictions # save predictions
if args.save_preds: if args.save_preds:
dump_metric = dict( dump_metric = dict(