mirror of https://github.com/open-mmlab/mmocr.git
[Feature] Rec TTA (#1401)
* Support TTA for recognition * updata readme * updata abinet readme * updata train_test doc for ttapull/1731/head
parent
7cea6a6419
commit
f820470415
|
@ -46,3 +46,5 @@ visualizer = dict(
|
|||
type='TextRecogLocalVisualizer',
|
||||
name='visualizer',
|
||||
vis_backends=vis_backends)
|
||||
|
||||
tta_model = dict(type='EncoderDecoderRecognizerTTAModel')
|
||||
|
|
|
@ -38,7 +38,9 @@ Linguistic knowledge is of great benefit to scene text recognition. However, how
|
|||
| :--------------------------------------------: | :------------------------------------------------: | :----: | :----------: | :-------: | :-------: | :------------: | :----: | :----------------------------------------------- |
|
||||
| | | IIIT5K | SVT | IC13-1015 | IC15-2077 | SVTP | CT80 | |
|
||||
| [ABINet-Vision](/configs/textrecog/abinet/abinet-vision_20e_st-an_mj.py) | - | 0.9523 | 0.9196 | 0.9369 | 0.7896 | 0.8403 | 0.8437 | [model](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet-vision_20e_st-an_mj/abinet-vision_20e_st-an_mj_20220915_152445-85cfb03d.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet-vision_20e_st-an_mj/20220915_152445.log) |
|
||||
| [ABINet-Vision-TTA](/configs/textrecog/abinet/abinet-vision_20e_st-an_mj.py) | - | 0.9523 | 0.9196 | 0.9360 | 0.8175 | 0.8450 | 0.8542 | |
|
||||
| [ABINet](/configs/textrecog/abinet/abinet_20e_st-an_mj.py) | [Pretrained](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_pretrain-45deac15.pth) | 0.9603 | 0.9397 | 0.9557 | 0.8146 | 0.8868 | 0.8785 | [model](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_20e_st-an_mj/abinet_20e_st-an_mj_20221005_012617-ead8c139.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_20e_st-an_mj/20221005_012617.log) |
|
||||
| [ABINet-TTA](/configs/textrecog/abinet/abinet_20e_st-an_mj.py) | [Pretrained](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_pretrain-45deac15.pth) | 0.9597 | 0.9397 | 0.9527 | 0.8426 | 0.8930 | 0.8854 | |
|
||||
|
||||
```{note}
|
||||
1. ABINet allows its encoder to run and be trained without decoder and fuser. Its encoder is designed to recognize texts as a stand-alone model and therefore can work as an independent text recognizer. We release it as ABINet-Vision.
|
||||
|
|
|
@ -116,3 +116,50 @@ test_pipeline = [
|
|||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
||||
]
|
||||
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=0, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=1, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=3, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
],
|
||||
[dict(type='Resize', scale=(128, 32))],
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
[dict(type='LoadOCRAnnotations', with_text=True)],
|
||||
[
|
||||
dict(
|
||||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape',
|
||||
'valid_ratio'))
|
||||
]
|
||||
])
|
||||
]
|
||||
|
|
|
@ -35,9 +35,10 @@ A challenging aspect of scene text recognition is to handle text with distortion
|
|||
## Results and models
|
||||
|
||||
| Methods | Backbone | | Regular Text | | | | Irregular Text | | download |
|
||||
| :----------------------------------------------------------: | :------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-----------------------------------------------------------------------: |
|
||||
| :--------------------------------------------------------------: | :------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-------------------------------------------------------------------: |
|
||||
| | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
|
||||
| [ASTER](/configs/textrecog/aster/aster_resnet45_6e_st_mj.py) | ResNet45 | 0.9357 | 0.8949 | 0.9281 | | 0.7665 | 0.8062 | 0.8507 | [model](https://download.openmmlab.com/mmocr/textrecog/aster/aster_resnet45_6e_st_mj/aster_resnet45_6e_st_mj-cc56eca4.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/aster/aster_resnet45_6e_st_mj/20221214_232605.log) |
|
||||
| [ASTER-TTA](/configs/textrecog/aster/aster_resnet45_6e_st_mj.py) | ResNet45 | 0.9337 | 0.8949 | 0.9251 | | 0.7925 | 0.8109 | 0.8507 | |
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -69,3 +69,42 @@ test_pipeline = [
|
|||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio',
|
||||
'instances'))
|
||||
]
|
||||
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[[
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=0, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=1, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=3, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"),
|
||||
], [dict(type='Resize', scale=(256, 64))],
|
||||
[dict(type='LoadOCRAnnotations', with_text=True)],
|
||||
[
|
||||
dict(
|
||||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape',
|
||||
'valid_ratio', 'instances'))
|
||||
]])
|
||||
]
|
||||
|
|
|
@ -34,9 +34,10 @@ Image-based sequence recognition has been a long-standing research topic in comp
|
|||
## Results and models
|
||||
|
||||
| methods | | Regular Text | | | | Irregular Text | | download |
|
||||
| :----------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-------------------------------------------------------------------------------------: |
|
||||
| :--------------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :---------------------------------------------------------------------------------: |
|
||||
| methods | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
|
||||
| [CRNN](/configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py) | 0.8053 | 0.7991 | 0.8739 | | 0.5571 | 0.6093 | 0.5694 | [model](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/20220826_224120.log) |
|
||||
| [CRNN-TTA](/configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py) | 0.8013 | 0.7975 | 0.8631 | | 0.5763 | 0.6093 | 0.5764 | [model](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/20220826_224120.log) |
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -51,3 +51,60 @@ test_pipeline = [
|
|||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
||||
]
|
||||
|
||||
tta_pipeline = [
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
color_type='grayscale',
|
||||
file_client_args=file_client_args),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=0, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=1, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=3, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
],
|
||||
[
|
||||
dict(
|
||||
type='RescaleToHeight',
|
||||
height=32,
|
||||
min_width=32,
|
||||
max_width=None,
|
||||
width_divisor=16)
|
||||
],
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
[dict(type='LoadOCRAnnotations', with_text=True)],
|
||||
[
|
||||
dict(
|
||||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape',
|
||||
'valid_ratio'))
|
||||
]
|
||||
])
|
||||
]
|
||||
|
|
|
@ -39,6 +39,7 @@ Attention-based scene text recognizers have gained huge success, which leverages
|
|||
| :-------------------------------------------------------------: | :-----------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :---------------------------------------------------------------: |
|
||||
| | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
|
||||
| [MASTER](/configs/textrecog/master/master_resnet31_12e_st_mj_sa.py) | R31-GCAModule | 0.9490 | 0.8887 | 0.9517 | | 0.7650 | 0.8465 | 0.8889 | [model](https://download.openmmlab.com/mmocr/textrecog/master/master_resnet31_12e_st_mj_sa/master_resnet31_12e_st_mj_sa_20220915_152443-f4a5cabc.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/master/master_resnet31_12e_st_mj_sa/20220915_152443.log) |
|
||||
| [MASTER-TTA](/configs/textrecog/master/master_resnet31_12e_st_mj_sa.py) | R31-GCAModule | 0.9450 | 0.8887 | 0.9478 | | 0.7906 | 0.8481 | 0.8958 | |
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -109,3 +109,58 @@ test_pipeline = [
|
|||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
||||
]
|
||||
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=0, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=1, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=3, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
],
|
||||
[
|
||||
dict(
|
||||
type='RescaleToHeight',
|
||||
height=48,
|
||||
min_width=48,
|
||||
max_width=160,
|
||||
width_divisor=16)
|
||||
],
|
||||
[dict(type='PadToWidth', width=160)],
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
[dict(type='LoadOCRAnnotations', with_text=True)],
|
||||
[
|
||||
dict(
|
||||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape',
|
||||
'valid_ratio'))
|
||||
]
|
||||
])
|
||||
]
|
||||
|
|
|
@ -38,8 +38,11 @@ Scene text recognition has attracted a great many researches due to its importan
|
|||
| :---------------------------------------------------------: | :-------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-----------------------------------------------------------: |
|
||||
| | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
|
||||
| [NRTR](/configs/textrecog/nrtr/nrtr_modality-transform_6e_st_mj.py) | NRTRModalityTransform | 0.9147 | 0.8841 | 0.9369 | | 0.7246 | 0.7783 | 0.7500 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_modality-transform_6e_st_mj/nrtr_modality-transform_6e_st_mj_20220916_103322-bd9425be.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_modality-transform_6e_st_mj/20220916_103322.log) |
|
||||
| [NRTR-TTA](/configs/textrecog/nrtr/nrtr_modality-transform_6e_st_mj.py) | NRTRModalityTransform | 0.9123 | 0.8825 | 0.9310 | | 0.7492 | 0.7798 | 0.7535 | |
|
||||
| [NRTR](/configs/textrecog/nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj.py) | R31-1/8-1/4 | 0.9483 | 0.8918 | 0.9507 | | 0.7578 | 0.8016 | 0.8889 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj/nrtr_resnet31-1by8-1by4_6e_st_mj_20220916_103322-a6a2a123.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj/20220916_103322.log) |
|
||||
| [NRTR-TTA](/configs/textrecog/nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj.py) | R31-1/8-1/4 | 0.9443 | 0.8903 | 0.9478 | | 0.7790 | 0.8078 | 0.8854 | |
|
||||
| [NRTR](/configs/textrecog/nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj.py) | R31-1/16-1/8 | 0.9470 | 0.8918 | 0.9399 | | 0.7376 | 0.7969 | 0.8854 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj/nrtr_resnet31-1by16-1by8_6e_st_mj_20220920_143358-43767036.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj/20220920_143358.log) |
|
||||
| [NRTR-TTA](/configs/textrecog/nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj.py) | R31-1/16-1/8 | 0.9423 | 0.8903 | 0.9360 | | 0.7641 | 0.8016 | 0.8854 | |
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -60,3 +60,58 @@ test_pipeline = [
|
|||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
||||
]
|
||||
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=0, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=1, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=3, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
],
|
||||
[
|
||||
dict(
|
||||
type='RescaleToHeight',
|
||||
height=32,
|
||||
min_width=32,
|
||||
max_width=160,
|
||||
width_divisor=16)
|
||||
],
|
||||
[dict(type='PadToWidth', width=160)],
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
[dict(type='LoadOCRAnnotations', with_text=True)],
|
||||
[
|
||||
dict(
|
||||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape',
|
||||
'valid_ratio'))
|
||||
]
|
||||
])
|
||||
]
|
||||
|
|
|
@ -66,3 +66,58 @@ test_pipeline = [
|
|||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
||||
]
|
||||
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=0, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=1, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=3, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
],
|
||||
[
|
||||
dict(
|
||||
type='RescaleToHeight',
|
||||
height=32,
|
||||
min_width=32,
|
||||
max_width=160,
|
||||
width_divisor=16)
|
||||
],
|
||||
[dict(type='PadToWidth', width=160)],
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
[dict(type='LoadOCRAnnotations', with_text=True)],
|
||||
[
|
||||
dict(
|
||||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape',
|
||||
'valid_ratio'))
|
||||
]
|
||||
])
|
||||
]
|
||||
|
|
|
@ -44,6 +44,7 @@ The attention-based encoder-decoder framework has recently achieved impressive r
|
|||
| :------------------------------------------------------------------: | :--: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :-------------------------------------------------------------------: |
|
||||
| | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
|
||||
| [RobustScanner](/configs/textrecog/robust_scanner/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real.py) | 4 | 0.9510 | 0.9011 | 0.9320 | | 0.7578 | 0.8078 | 0.8750 | [model](https://download.openmmlab.com/mmocr/textrecog/robust_scanner/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real_20220915_152447-7fc35929.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/robust_scanner/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real/20220915_152447.log) |
|
||||
| [RobustScanner-TTA](/configs/textrecog/robust_scanner/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real.py) | 4 | 0.9487 | 0.9011 | 0.9261 | | 0.7805 | 0.8124 | 0.8819 | |
|
||||
|
||||
## References
|
||||
|
||||
|
|
|
@ -66,3 +66,58 @@ test_pipeline = [
|
|||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
||||
]
|
||||
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=0, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=1, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=3, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
],
|
||||
[
|
||||
dict(
|
||||
type='RescaleToHeight',
|
||||
height=48,
|
||||
min_width=48,
|
||||
max_width=160,
|
||||
width_divisor=4),
|
||||
],
|
||||
[dict(type='PadToWidth', width=160)],
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
[dict(type='LoadOCRAnnotations', with_text=True)],
|
||||
[
|
||||
dict(
|
||||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape',
|
||||
'valid_ratio'))
|
||||
]
|
||||
])
|
||||
]
|
||||
|
|
|
@ -44,7 +44,9 @@ Recognizing irregular text in natural scene images is challenging due to the lar
|
|||
| :----------------------------------------------------: | :---------: | :------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :------------------------------------------------------: |
|
||||
| | | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
|
||||
| [SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py) | R31-1/8-1/4 | ParallelSARDecoder | 0.9533 | 0.8964 | 0.9369 | | 0.7602 | 0.8326 | 0.9062 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real/sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real_20220915_171910-04eb4e75.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real/20220915_171910.log) |
|
||||
| [SAR-TTA](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py) | R31-1/8-1/4 | ParallelSARDecoder | 0.9510 | 0.8964 | 0.9340 | | 0.7862 | 0.8372 | 0.9132 | |
|
||||
| [SAR](/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py) | R31-1/8-1/4 | SequentialSARDecoder | 0.9553 | 0.9073 | 0.9409 | | 0.7761 | 0.8093 | 0.8958 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_sequential-decoder_5e_st-sub_mj-sub_sa_real/sar_resnet31_sequential-decoder_5e_st-sub_mj-sub_sa_real_20220915_185451-1fd6b1fc.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_sequential-decoder_5e_st-sub_mj-sub_sa_real/20220915_185451.log) |
|
||||
| [SAR-TTA](/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py) | R31-1/8-1/4 | SequentialSARDecoder | 0.9530 | 0.9073 | 0.9389 | | 0.8002 | 0.8124 | 0.9028 | |
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -71,3 +71,58 @@ test_pipeline = [
|
|||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
||||
]
|
||||
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=0, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=1, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=3, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
],
|
||||
[
|
||||
dict(
|
||||
type='RescaleToHeight',
|
||||
height=48,
|
||||
min_width=48,
|
||||
max_width=160,
|
||||
width_divisor=4)
|
||||
],
|
||||
[dict(type='PadToWidth', width=160)],
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
[dict(type='LoadOCRAnnotations', with_text=True)],
|
||||
[
|
||||
dict(
|
||||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape',
|
||||
'valid_ratio'))
|
||||
]
|
||||
])
|
||||
]
|
||||
|
|
|
@ -38,7 +38,9 @@ Scene text recognition (STR) is the task of recognizing character sequences in n
|
|||
| :--------------------------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :---------------------------------------------------------------------: |
|
||||
| | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
|
||||
| [Satrn](/configs/textrecog/satrn/satrn_shallow_5e_st_mj.py) | 0.9600 | 0.9181 | 0.9606 | | 0.8045 | 0.8837 | 0.8993 | [model](https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow_5e_st_mj/satrn_shallow_5e_st_mj_20220915_152443-5fd04a4c.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow_5e_st_mj/20220915_152443.log) |
|
||||
| [Satrn-TTA](/configs/textrecog/satrn/satrn_shallow_5e_st_mj.py) | 0.9530 | 0.9181 | 0.9527 | | 0.8276 | 0.8884 | 0.9028 | |
|
||||
| [Satrn_small](/configs/textrecog/satrn/satrn_shallow-small_5e_st_mj.py) | 0.9423 | 0.9011 | 0.9567 | | 0.7886 | 0.8574 | 0.8472 | [model](https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow-small_5e_st_mj/satrn_shallow-small_5e_st_mj_20220915_152442-5591bf27.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow-small_5e_st_mj/20220915_152442.log) |
|
||||
| [Satrn_small-TTA](/configs/textrecog/satrn/satrn_shallow-small_5e_st_mj.py) | 0.9380 | 0.8995 | 0.9488 | | 0.8122 | 0.8620 | 0.8507 | |
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -54,7 +54,6 @@ train_pipeline = [
|
|||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
||||
]
|
||||
|
||||
# TODO Add Test Time Augmentation `MultiRotateAugOCR`
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='Resize', scale=(100, 32), keep_ratio=False),
|
||||
|
@ -65,3 +64,50 @@ test_pipeline = [
|
|||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
||||
]
|
||||
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=0, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=1, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=3, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
),
|
||||
],
|
||||
[dict(type='Resize', scale=(100, 32), keep_ratio=False)],
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
[dict(type='LoadOCRAnnotations', with_text=True)],
|
||||
[
|
||||
dict(
|
||||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape',
|
||||
'valid_ratio'))
|
||||
]
|
||||
])
|
||||
]
|
||||
|
|
|
@ -35,11 +35,13 @@ Dominant scene text recognition models commonly contain two building blocks, a v
|
|||
## Results and Models
|
||||
|
||||
| Methods | | Regular Text | | | | Irregular Text | | download |
|
||||
| :-----------------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :------------------------------------------------------------------------------: |
|
||||
| :---------------------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :--------------------------------------------------------------------------: |
|
||||
| | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
|
||||
| [SVTR-tiny](/configs/textrecog/svtr/svtr-tiny_20e_st_mj.py) | - | - | - | | - | - | - | [model](<>) \| [log](<>) |
|
||||
| [SVTR-small](/configs/textrecog/svtr/svtr-small_20e_st_mj.py) | 0.8553 | 0.9026 | 0.9448 | | 0.7496 | 0.8496 | 0.8854 | [model](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/svtr-small_20e_st_mj-35d800d6.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/20230105_184454.log) |
|
||||
| [SVTR-small-TTA](/configs/textrecog/svtr/svtr-small_20e_st_mj.py) | 0.8397 | 0.8964 | 0.9241 | | 0.7597 | 0.8124 | 0.8646 | |
|
||||
| [SVTR-base](/configs/textrecog/svtr/svtr-base_20e_st_mj.py) | 0.8570 | 0.9181 | 0.9438 | | 0.7448 | 0.8388 | 0.9028 | [model](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/svtr-base_20e_st_mj-ea500101.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/20221227_175415.log) |
|
||||
| [SVTR-base-TTA](/configs/textrecog/svtr/svtr-base_20e_st_mj.py) | 0.8517 | 0.9011 | 0.9379 | | 0.7569 | 0.8279 | 0.8819 | |
|
||||
| [SVTR-large](/configs/textrecog/svtr/svtr-large_20e_st_mj.py) | - | - | - | | - | - | - | [model](<>) \| [log](<>) |
|
||||
|
||||
```{note}
|
||||
|
|
|
@ -36,3 +36,130 @@ model = dict(
|
|||
dictionary=dictionary),
|
||||
data_preprocessor=dict(
|
||||
type='TextRecogDataPreprocessor', mean=[127.5], std=[127.5]))
|
||||
|
||||
file_client_args = dict(backend='disk')
|
||||
|
||||
train_pipeline = [
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
file_client_args=file_client_args,
|
||||
ignore_empty=True,
|
||||
min_size=5),
|
||||
dict(type='LoadOCRAnnotations', with_text=True),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(type='TextRecogGeneralAug', ),
|
||||
],
|
||||
),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(type='CropHeight', ),
|
||||
],
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
condition='min(results["img_shape"])>10',
|
||||
true_transforms=dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(
|
||||
type='TorchVisionWrapper',
|
||||
op='GaussianBlur',
|
||||
kernel_size=5,
|
||||
sigma=1,
|
||||
),
|
||||
],
|
||||
)),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(
|
||||
type='TorchVisionWrapper',
|
||||
op='ColorJitter',
|
||||
brightness=0.5,
|
||||
saturation=0.5,
|
||||
contrast=0.5,
|
||||
hue=0.1),
|
||||
]),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(type='ImageContentJitter', ),
|
||||
],
|
||||
),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='AdditiveGaussianNoise', scale=0.1**0.5)]),
|
||||
],
|
||||
),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(type='ReversePixels', ),
|
||||
],
|
||||
),
|
||||
dict(type='Resize', scale=(256, 64)),
|
||||
dict(
|
||||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='Resize', scale=(256, 64)),
|
||||
dict(type='LoadOCRAnnotations', with_text=True),
|
||||
dict(
|
||||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
||||
]
|
||||
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[[
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=0, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=1, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
true_transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='Rot90', k=3, keep_size=False)])
|
||||
],
|
||||
condition="results['img_shape'][1]<results['img_shape'][0]"),
|
||||
], [dict(type='Resize', scale=(256, 64))],
|
||||
[dict(type='LoadOCRAnnotations', with_text=True)],
|
||||
[
|
||||
dict(
|
||||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape',
|
||||
'valid_ratio'))
|
||||
]])
|
||||
]
|
||||
|
|
|
@ -40,94 +40,6 @@ param_scheduler = [
|
|||
convert_to_iter_based=True),
|
||||
]
|
||||
|
||||
file_client_args = dict(backend='disk')
|
||||
|
||||
train_pipeline = [
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
file_client_args=file_client_args,
|
||||
ignore_empty=True,
|
||||
min_size=5),
|
||||
dict(type='LoadOCRAnnotations', with_text=True),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(type='TextRecogGeneralAug', ),
|
||||
],
|
||||
),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(type='CropHeight', ),
|
||||
],
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
condition='min(results["img_shape"])>10',
|
||||
true_transforms=dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(
|
||||
type='TorchVisionWrapper',
|
||||
op='GaussianBlur',
|
||||
kernel_size=5,
|
||||
sigma=1,
|
||||
),
|
||||
],
|
||||
)),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(
|
||||
type='TorchVisionWrapper',
|
||||
op='ColorJitter',
|
||||
brightness=0.5,
|
||||
saturation=0.5,
|
||||
contrast=0.5,
|
||||
hue=0.1),
|
||||
]),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(type='ImageContentJitter', ),
|
||||
],
|
||||
),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='AdditiveGaussianNoise', scale=0.1**0.5)]),
|
||||
],
|
||||
),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(type='ReversePixels', ),
|
||||
],
|
||||
),
|
||||
dict(type='Resize', scale=(256, 64)),
|
||||
dict(
|
||||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='Resize', scale=(256, 64)),
|
||||
dict(type='LoadOCRAnnotations', with_text=True),
|
||||
dict(
|
||||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
train_list = [_base_.mjsynth_textrecog_test, _base_.synthtext_textrecog_train]
|
||||
test_list = [
|
||||
|
@ -147,7 +59,9 @@ train_dataloader = dict(
|
|||
pin_memory=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type='ConcatDataset', datasets=train_list, pipeline=train_pipeline))
|
||||
type='ConcatDataset',
|
||||
datasets=train_list,
|
||||
pipeline=_base_.train_pipeline))
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=128,
|
||||
|
@ -157,6 +71,8 @@ val_dataloader = dict(
|
|||
drop_last=False,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type='ConcatDataset', datasets=test_list, pipeline=test_pipeline))
|
||||
type='ConcatDataset',
|
||||
datasets=test_list,
|
||||
pipeline=_base_.test_pipeline))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
|
|
@ -36,6 +36,7 @@ The following table lists all the arguments supported by `train.py`. Args withou
|
|||
| --cfg-options | str | Override some settings in the configs. [Example](<>) |
|
||||
| --launcher | str | Option for launcher,\['none', 'pytorch', 'slurm', 'mpi'\]. |
|
||||
| --local_rank | int | Rank of local machine,used for distributed training,defaults to 0。 |
|
||||
| --tta | bool | Whether to use test time augmentation. |
|
||||
|
||||
### Test
|
||||
|
||||
|
@ -308,3 +309,15 @@ The visualization-related parameters in `tools/test.py` are described as follows
|
|||
| --show | bool | Whether to show the visualization results. |
|
||||
| --show-dir | str | Path to save the visualization results. |
|
||||
| --wait-time | float | Interval of visualization (s), defaults to 2. |
|
||||
|
||||
### Test Time Augmentation
|
||||
|
||||
Test time augmentation (TTA) is a technique that is used to improve the performance of a model by performing data augmentation on the input image at test time. It is a simple yet effective method to improve the performance of a model. In MMOCR, we support TTA in the following ways:
|
||||
|
||||
```{note}
|
||||
TTA is only supported for text recognition models.
|
||||
```
|
||||
|
||||
```bash
|
||||
python tools/test.py configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py checkpoints/crnn_mini-vgg_5e_mj.pth --tta
|
||||
```
|
||||
|
|
|
@ -66,6 +66,7 @@ CUDA_VISIBLE_DEVICES=0 python tools/test.py configs/textdet/dbnet/dbnet_resnet50
|
|||
| --cfg-options | str | 用于覆写配置文件中的指定参数。[示例](#添加示例) |
|
||||
| --launcher | str | 启动器选项,可选项目为 \['none', 'pytorch', 'slurm', 'mpi'\]。 |
|
||||
| --local_rank | int | 本地机器编号,用于多机多卡分布式训练,默认为 0。 |
|
||||
| --tta | bool | 是否使用测试时数据增强 |
|
||||
|
||||
## 多卡机器训练及测试
|
||||
|
||||
|
@ -308,3 +309,16 @@ python tools/test.py configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.p
|
|||
| --show | bool | 是否绘制可视化结果。 |
|
||||
| --show-dir | str | 可视化图片存储路径。 |
|
||||
| --wait-time | float | 可视化间隔时间(秒),默认为 2。 |
|
||||
|
||||
### 测试时数据增强
|
||||
|
||||
测试时增强,指的是在推理(预测)阶段,将原始图片进行水平翻转、垂直翻转、对角线翻转、旋转角度等数据增强操作,得到多张图,分别进行推理,再对多个结果进行综合分析,得到最终输出结果。
|
||||
为此,MMOCR 提供了一键式测试时数据增强,仅需在测试时添加 `--tta` 参数即可。
|
||||
|
||||
```{note}
|
||||
TTA 仅支持文本识别模型。
|
||||
```
|
||||
|
||||
```bash
|
||||
python tools/test.py configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py checkpoints/crnn_mini-vgg_5e_mj.pth --tta
|
||||
```
|
||||
|
|
|
@ -4,6 +4,7 @@ from .aster import ASTER
|
|||
from .base import BaseRecognizer
|
||||
from .crnn import CRNN
|
||||
from .encoder_decoder_recognizer import EncoderDecoderRecognizer
|
||||
from .encoder_decoder_recognizer_tta import EncoderDecoderRecognizerTTAModel
|
||||
from .master import MASTER
|
||||
from .nrtr import NRTR
|
||||
from .robust_scanner import RobustScanner
|
||||
|
@ -13,5 +14,6 @@ from .svtr import SVTR
|
|||
|
||||
__all__ = [
|
||||
'BaseRecognizer', 'EncoderDecoderRecognizer', 'CRNN', 'SARNet', 'NRTR',
|
||||
'RobustScanner', 'SATRN', 'ABINet', 'MASTER', 'SVTR', 'ASTER'
|
||||
'RobustScanner', 'SATRN', 'ABINet', 'MASTER', 'SVTR', 'ASTER',
|
||||
'EncoderDecoderRecognizerTTAModel'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from mmengine.model import BaseTTAModel
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils.typing_utils import RecSampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class EncoderDecoderRecognizerTTAModel(BaseTTAModel):
|
||||
"""Merge augmented recognition results. It will select the best result
|
||||
according average scores from all augmented results.
|
||||
|
||||
Examples:
|
||||
>>> tta_model = dict(
|
||||
>>> type='EncoderDecoderRecognizerTTAModel')
|
||||
>>>
|
||||
>>> tta_pipeline = [
|
||||
>>> dict(
|
||||
>>> type='LoadImageFromFile',
|
||||
>>> color_type='grayscale',
|
||||
>>> file_client_args=file_client_args),
|
||||
>>> dict(
|
||||
>>> type='TestTimeAug',
|
||||
>>> transforms=[
|
||||
>>> [
|
||||
>>> dict(
|
||||
>>> type='ConditionApply',
|
||||
>>> true_transforms=[
|
||||
>>> dict(
|
||||
>>> type='ImgAugWrapper',
|
||||
>>> args=[dict(cls='Rot90', k=0, keep_size=False)]) # noqa: E501
|
||||
>>> ],
|
||||
>>> condition="results['img_shape'][1]<results['img_shape'][0]" # noqa: E501
|
||||
>>> ),
|
||||
>>> dict(
|
||||
>>> type='ConditionApply',
|
||||
>>> true_transforms=[
|
||||
>>> dict(
|
||||
>>> type='ImgAugWrapper',
|
||||
>>> args=[dict(cls='Rot90', k=1, keep_size=False)]) # noqa: E501
|
||||
>>> ],
|
||||
>>> condition="results['img_shape'][1]<results['img_shape'][0]" # noqa: E501
|
||||
>>> ),
|
||||
>>> dict(
|
||||
>>> type='ConditionApply',
|
||||
>>> true_transforms=[
|
||||
>>> dict(
|
||||
>>> type='ImgAugWrapper',
|
||||
>>> args=[dict(cls='Rot90', k=3, keep_size=False)])
|
||||
>>> ],
|
||||
>>> condition="results['img_shape'][1]<results['img_shape'][0]"
|
||||
>>> ),
|
||||
>>> ],
|
||||
>>> [
|
||||
>>> dict(
|
||||
>>> type='RescaleToHeight',
|
||||
>>> height=32,
|
||||
>>> min_width=32,
|
||||
>>> max_width=None,
|
||||
>>> width_divisor=16)
|
||||
>>> ],
|
||||
>>> # add loading annotation after ``Resize`` because ground truth
|
||||
>>> # does not need to do resize data transform
|
||||
>>> [dict(type='LoadOCRAnnotations', with_text=True)],
|
||||
>>> [
|
||||
>>> dict(
|
||||
>>> type='PackTextRecogInputs',
|
||||
>>> meta_keys=('img_path', 'ori_shape', 'img_shape',
|
||||
>>> 'valid_ratio'))
|
||||
>>> ]
|
||||
>>> ])
|
||||
>>> ]
|
||||
"""
|
||||
|
||||
def merge_preds(self,
|
||||
data_samples_list: List[RecSampleList]) -> RecSampleList:
|
||||
"""Merge predictions of enhanced data to one prediction.
|
||||
|
||||
Args:
|
||||
data_samples_list (List[RecSampleList]): List of predictions of
|
||||
all enhanced data. The shape of data_samples_list is (B, M),
|
||||
where B is the batch size and M is the number of augmented
|
||||
data.
|
||||
|
||||
Returns:
|
||||
RecSampleList: Merged prediction.
|
||||
"""
|
||||
predictions = list()
|
||||
for data_samples in data_samples_list:
|
||||
scores = [
|
||||
data_sample.pred_text.score for data_sample in data_samples
|
||||
]
|
||||
average_scores = np.array(
|
||||
[sum(score) / max(1, len(score)) for score in scores])
|
||||
max_idx = np.argmax(average_scores)
|
||||
predictions.append(data_samples[max_idx])
|
||||
return predictions
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmocr.models.textrecog.recognizers import EncoderDecoderRecognizerTTAModel
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
def test_step(self, x):
|
||||
return self.forward(x)
|
||||
|
||||
|
||||
class TestEncoderDecoderRecognizerTTAModel(TestCase):
|
||||
|
||||
def test_merge_preds(self):
|
||||
|
||||
data_sample1 = TextRecogDataSample(
|
||||
pred_text=LabelData(
|
||||
score=torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]), text='abcde'))
|
||||
data_sample2 = TextRecogDataSample(
|
||||
pred_text=LabelData(
|
||||
score=torch.tensor([0.2, 0.3, 0.4, 0.5, 0.6]), text='bcdef'))
|
||||
data_sample3 = TextRecogDataSample(
|
||||
pred_text=LabelData(
|
||||
score=torch.tensor([0.3, 0.4, 0.5, 0.6, 0.7]), text='cdefg'))
|
||||
aug_data_samples = [data_sample1, data_sample2, data_sample3]
|
||||
batch_aug_data_samples = [aug_data_samples] * 3
|
||||
model = EncoderDecoderRecognizerTTAModel(module=DummyModel())
|
||||
preds = model.merge_preds(batch_aug_data_samples)
|
||||
for pred in preds:
|
||||
self.assertEqual(pred.pred_text.text, 'cdefg')
|
|
@ -45,6 +45,8 @@ def parse_args():
|
|||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='Job launcher')
|
||||
parser.add_argument(
|
||||
'--tta', action='store_true', help='Test time augmentation')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
|
@ -107,6 +109,11 @@ def main():
|
|||
if args.show or args.show_dir:
|
||||
cfg = trigger_visualization_hook(cfg, args)
|
||||
|
||||
if args.tta:
|
||||
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
|
||||
cfg.tta_model.module = cfg.model
|
||||
cfg.model = cfg.tta_model
|
||||
|
||||
# save predictions
|
||||
if args.save_preds:
|
||||
dump_metric = dict(
|
||||
|
|
Loading…
Reference in New Issue