From f82047041578d1629d53083780c2c7928ca7b96c Mon Sep 17 00:00:00 2001 From: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Date: Thu, 16 Feb 2023 10:27:07 +0800 Subject: [PATCH] [Feature] Rec TTA (#1401) * Support TTA for recognition * updata readme * updata abinet readme * updata train_test doc for tta --- configs/textrecog/_base_/default_runtime.py | 2 + configs/textrecog/abinet/README.md | 2 + .../textrecog/abinet/_base_abinet-vision.py | 47 +++++++ configs/textrecog/aster/README.md | 9 +- configs/textrecog/aster/_base_aster.py | 39 ++++++ configs/textrecog/crnn/README.md | 9 +- configs/textrecog/crnn/_base_crnn_mini-vgg.py | 57 ++++++++ configs/textrecog/master/README.md | 1 + .../textrecog/master/_base_master_resnet31.py | 55 ++++++++ configs/textrecog/nrtr/README.md | 3 + .../nrtr/_base_nrtr_modality-transform.py | 55 ++++++++ configs/textrecog/nrtr/_base_nrtr_resnet31.py | 55 ++++++++ configs/textrecog/robust_scanner/README.md | 1 + .../_base_robustscanner_resnet31.py | 55 ++++++++ configs/textrecog/sar/README.md | 2 + .../_base_sar_resnet31_parallel-decoder.py | 55 ++++++++ configs/textrecog/satrn/README.md | 2 + .../textrecog/satrn/_base_satrn_shallow.py | 48 ++++++- configs/textrecog/svtr/README.md | 16 ++- configs/textrecog/svtr/_base_svtr-tiny.py | 127 ++++++++++++++++++ configs/textrecog/svtr/svtr-tiny_20e_st_mj.py | 96 +------------ docs/en/user_guides/train_test.md | 13 ++ docs/zh_cn/user_guides/train_test.md | 14 ++ .../models/textrecog/recognizers/__init__.py | 4 +- .../encoder_decoder_recognizer_tta.py | 100 ++++++++++++++ .../test_encoder_decoder_recognizer_tta.py | 42 ++++++ tools/test.py | 7 + 27 files changed, 809 insertions(+), 107 deletions(-) create mode 100644 mmocr/models/textrecog/recognizers/encoder_decoder_recognizer_tta.py create mode 100644 tests/test_models/test_textrecog/test_recognizers/test_encoder_decoder_recognizer_tta.py diff --git a/configs/textrecog/_base_/default_runtime.py b/configs/textrecog/_base_/default_runtime.py index 3dcd1051..f3ce4e1a 100644 --- a/configs/textrecog/_base_/default_runtime.py +++ b/configs/textrecog/_base_/default_runtime.py @@ -46,3 +46,5 @@ visualizer = dict( type='TextRecogLocalVisualizer', name='visualizer', vis_backends=vis_backends) + +tta_model = dict(type='EncoderDecoderRecognizerTTAModel') diff --git a/configs/textrecog/abinet/README.md b/configs/textrecog/abinet/README.md index 8920c0f3..6a7faadb 100644 --- a/configs/textrecog/abinet/README.md +++ b/configs/textrecog/abinet/README.md @@ -38,7 +38,9 @@ Linguistic knowledge is of great benefit to scene text recognition. However, how | :--------------------------------------------: | :------------------------------------------------: | :----: | :----------: | :-------: | :-------: | :------------: | :----: | :----------------------------------------------- | | | | IIIT5K | SVT | IC13-1015 | IC15-2077 | SVTP | CT80 | | | [ABINet-Vision](/configs/textrecog/abinet/abinet-vision_20e_st-an_mj.py) | - | 0.9523 | 0.9196 | 0.9369 | 0.7896 | 0.8403 | 0.8437 | [model](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet-vision_20e_st-an_mj/abinet-vision_20e_st-an_mj_20220915_152445-85cfb03d.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet-vision_20e_st-an_mj/20220915_152445.log) | +| [ABINet-Vision-TTA](/configs/textrecog/abinet/abinet-vision_20e_st-an_mj.py) | - | 0.9523 | 0.9196 | 0.9360 | 0.8175 | 0.8450 | 0.8542 | | | [ABINet](/configs/textrecog/abinet/abinet_20e_st-an_mj.py) | [Pretrained](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_pretrain-45deac15.pth) | 0.9603 | 0.9397 | 0.9557 | 0.8146 | 0.8868 | 0.8785 | [model](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_20e_st-an_mj/abinet_20e_st-an_mj_20221005_012617-ead8c139.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_20e_st-an_mj/20221005_012617.log) | +| [ABINet-TTA](/configs/textrecog/abinet/abinet_20e_st-an_mj.py) | [Pretrained](https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_pretrain-45deac15.pth) | 0.9597 | 0.9397 | 0.9527 | 0.8426 | 0.8930 | 0.8854 | | ```{note} 1. ABINet allows its encoder to run and be trained without decoder and fuser. Its encoder is designed to recognize texts as a stand-alone model and therefore can work as an independent text recognizer. We release it as ABINet-Vision. diff --git a/configs/textrecog/abinet/_base_abinet-vision.py b/configs/textrecog/abinet/_base_abinet-vision.py index ef9a482f..66954ff8 100644 --- a/configs/textrecog/abinet/_base_abinet-vision.py +++ b/configs/textrecog/abinet/_base_abinet-vision.py @@ -116,3 +116,50 @@ test_pipeline = [ type='PackTextRecogInputs', meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] + +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='TestTimeAug', + transforms=[ + [ + dict( + type='ConditionApply', + true_transforms=[ + dict( + type='ImgAugWrapper', + args=[dict(cls='Rot90', k=0, keep_size=False)]) + ], + condition="results['img_shape'][1]) \| [log](<>) | -| [SVTR-small](/configs/textrecog/svtr/svtr-small_20e_st_mj.py) | 0.8553 | 0.9026 | 0.9448 | | 0.7496 | 0.8496 | 0.8854 | [model](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/svtr-small_20e_st_mj-35d800d6.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/20230105_184454.log) | -| [SVTR-base](/configs/textrecog/svtr/svtr-base_20e_st_mj.py) | 0.8570 | 0.9181 | 0.9438 | | 0.7448 | 0.8388 | 0.9028 | [model](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/svtr-base_20e_st_mj-ea500101.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/20221227_175415.log) | -| [SVTR-large](/configs/textrecog/svtr/svtr-large_20e_st_mj.py) | - | - | - | | - | - | - | [model](<>) \| [log](<>) | +| Methods | | Regular Text | | | | Irregular Text | | download | +| :---------------------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :--------------------------------------------------------------------------: | +| | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | | +| [SVTR-tiny](/configs/textrecog/svtr/svtr-tiny_20e_st_mj.py) | - | - | - | | - | - | - | [model](<>) \| [log](<>) | +| [SVTR-small](/configs/textrecog/svtr/svtr-small_20e_st_mj.py) | 0.8553 | 0.9026 | 0.9448 | | 0.7496 | 0.8496 | 0.8854 | [model](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/svtr-small_20e_st_mj-35d800d6.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/20230105_184454.log) | +| [SVTR-small-TTA](/configs/textrecog/svtr/svtr-small_20e_st_mj.py) | 0.8397 | 0.8964 | 0.9241 | | 0.7597 | 0.8124 | 0.8646 | | +| [SVTR-base](/configs/textrecog/svtr/svtr-base_20e_st_mj.py) | 0.8570 | 0.9181 | 0.9438 | | 0.7448 | 0.8388 | 0.9028 | [model](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/svtr-base_20e_st_mj-ea500101.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/20221227_175415.log) | +| [SVTR-base-TTA](/configs/textrecog/svtr/svtr-base_20e_st_mj.py) | 0.8517 | 0.9011 | 0.9379 | | 0.7569 | 0.8279 | 0.8819 | | +| [SVTR-large](/configs/textrecog/svtr/svtr-large_20e_st_mj.py) | - | - | - | | - | - | - | [model](<>) \| [log](<>) | ```{note} The implementation and configuration follow the original code and paper, but there is still a gap between the reproduced results and the official ones. We appreciate any suggestions to improve its performance. diff --git a/configs/textrecog/svtr/_base_svtr-tiny.py b/configs/textrecog/svtr/_base_svtr-tiny.py index dcfd7867..024d36ed 100644 --- a/configs/textrecog/svtr/_base_svtr-tiny.py +++ b/configs/textrecog/svtr/_base_svtr-tiny.py @@ -36,3 +36,130 @@ model = dict( dictionary=dictionary), data_preprocessor=dict( type='TextRecogDataPreprocessor', mean=[127.5], std=[127.5])) + +file_client_args = dict(backend='disk') + +train_pipeline = [ + dict( + type='LoadImageFromFile', + file_client_args=file_client_args, + ignore_empty=True, + min_size=5), + dict(type='LoadOCRAnnotations', with_text=True), + dict( + type='RandomApply', + prob=0.4, + transforms=[ + dict(type='TextRecogGeneralAug', ), + ], + ), + dict( + type='RandomApply', + prob=0.4, + transforms=[ + dict(type='CropHeight', ), + ], + ), + dict( + type='ConditionApply', + condition='min(results["img_shape"])>10', + true_transforms=dict( + type='RandomApply', + prob=0.4, + transforms=[ + dict( + type='TorchVisionWrapper', + op='GaussianBlur', + kernel_size=5, + sigma=1, + ), + ], + )), + dict( + type='RandomApply', + prob=0.4, + transforms=[ + dict( + type='TorchVisionWrapper', + op='ColorJitter', + brightness=0.5, + saturation=0.5, + contrast=0.5, + hue=0.1), + ]), + dict( + type='RandomApply', + prob=0.4, + transforms=[ + dict(type='ImageContentJitter', ), + ], + ), + dict( + type='RandomApply', + prob=0.4, + transforms=[ + dict( + type='ImgAugWrapper', + args=[dict(cls='AdditiveGaussianNoise', scale=0.1**0.5)]), + ], + ), + dict( + type='RandomApply', + prob=0.4, + transforms=[ + dict(type='ReversePixels', ), + ], + ), + dict(type='Resize', scale=(256, 64)), + dict( + type='PackTextRecogInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) +] + +test_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='Resize', scale=(256, 64)), + dict(type='LoadOCRAnnotations', with_text=True), + dict( + type='PackTextRecogInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) +] + +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='TestTimeAug', + transforms=[[ + dict( + type='ConditionApply', + true_transforms=[ + dict( + type='ImgAugWrapper', + args=[dict(cls='Rot90', k=0, keep_size=False)]) + ], + condition="results['img_shape'][1]) | | --launcher | str | Option for launcher,\['none', 'pytorch', 'slurm', 'mpi'\]. | | --local_rank | int | Rank of local machine,used for distributed training,defaults to 0。 | +| --tta | bool | Whether to use test time augmentation. | ### Test @@ -308,3 +309,15 @@ The visualization-related parameters in `tools/test.py` are described as follows | --show | bool | Whether to show the visualization results. | | --show-dir | str | Path to save the visualization results. | | --wait-time | float | Interval of visualization (s), defaults to 2. | + +### Test Time Augmentation + +Test time augmentation (TTA) is a technique that is used to improve the performance of a model by performing data augmentation on the input image at test time. It is a simple yet effective method to improve the performance of a model. In MMOCR, we support TTA in the following ways: + +```{note} +TTA is only supported for text recognition models. +``` + +```bash +python tools/test.py configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py checkpoints/crnn_mini-vgg_5e_mj.pth --tta +``` diff --git a/docs/zh_cn/user_guides/train_test.md b/docs/zh_cn/user_guides/train_test.md index 0a24cd7f..7ab634f8 100644 --- a/docs/zh_cn/user_guides/train_test.md +++ b/docs/zh_cn/user_guides/train_test.md @@ -66,6 +66,7 @@ CUDA_VISIBLE_DEVICES=0 python tools/test.py configs/textdet/dbnet/dbnet_resnet50 | --cfg-options | str | 用于覆写配置文件中的指定参数。[示例](#添加示例) | | --launcher | str | 启动器选项,可选项目为 \['none', 'pytorch', 'slurm', 'mpi'\]。 | | --local_rank | int | 本地机器编号,用于多机多卡分布式训练,默认为 0。 | +| --tta | bool | 是否使用测试时数据增强 | ## 多卡机器训练及测试 @@ -308,3 +309,16 @@ python tools/test.py configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.p | --show | bool | 是否绘制可视化结果。 | | --show-dir | str | 可视化图片存储路径。 | | --wait-time | float | 可视化间隔时间(秒),默认为 2。 | + +### 测试时数据增强 + +测试时增强,指的是在推理(预测)阶段,将原始图片进行水平翻转、垂直翻转、对角线翻转、旋转角度等数据增强操作,得到多张图,分别进行推理,再对多个结果进行综合分析,得到最终输出结果。 +为此,MMOCR 提供了一键式测试时数据增强,仅需在测试时添加 `--tta` 参数即可。 + +```{note} +TTA 仅支持文本识别模型。 +``` + +```bash +python tools/test.py configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py checkpoints/crnn_mini-vgg_5e_mj.pth --tta +``` diff --git a/mmocr/models/textrecog/recognizers/__init__.py b/mmocr/models/textrecog/recognizers/__init__.py index a2f81941..d9016492 100644 --- a/mmocr/models/textrecog/recognizers/__init__.py +++ b/mmocr/models/textrecog/recognizers/__init__.py @@ -4,6 +4,7 @@ from .aster import ASTER from .base import BaseRecognizer from .crnn import CRNN from .encoder_decoder_recognizer import EncoderDecoderRecognizer +from .encoder_decoder_recognizer_tta import EncoderDecoderRecognizerTTAModel from .master import MASTER from .nrtr import NRTR from .robust_scanner import RobustScanner @@ -13,5 +14,6 @@ from .svtr import SVTR __all__ = [ 'BaseRecognizer', 'EncoderDecoderRecognizer', 'CRNN', 'SARNet', 'NRTR', - 'RobustScanner', 'SATRN', 'ABINet', 'MASTER', 'SVTR', 'ASTER' + 'RobustScanner', 'SATRN', 'ABINet', 'MASTER', 'SVTR', 'ASTER', + 'EncoderDecoderRecognizerTTAModel' ] diff --git a/mmocr/models/textrecog/recognizers/encoder_decoder_recognizer_tta.py b/mmocr/models/textrecog/recognizers/encoder_decoder_recognizer_tta.py new file mode 100644 index 00000000..b73db22d --- /dev/null +++ b/mmocr/models/textrecog/recognizers/encoder_decoder_recognizer_tta.py @@ -0,0 +1,100 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import numpy as np +from mmengine.model import BaseTTAModel + +from mmocr.registry import MODELS +from mmocr.utils.typing_utils import RecSampleList + + +@MODELS.register_module() +class EncoderDecoderRecognizerTTAModel(BaseTTAModel): + """Merge augmented recognition results. It will select the best result + according average scores from all augmented results. + + Examples: + >>> tta_model = dict( + >>> type='EncoderDecoderRecognizerTTAModel') + >>> + >>> tta_pipeline = [ + >>> dict( + >>> type='LoadImageFromFile', + >>> color_type='grayscale', + >>> file_client_args=file_client_args), + >>> dict( + >>> type='TestTimeAug', + >>> transforms=[ + >>> [ + >>> dict( + >>> type='ConditionApply', + >>> true_transforms=[ + >>> dict( + >>> type='ImgAugWrapper', + >>> args=[dict(cls='Rot90', k=0, keep_size=False)]) # noqa: E501 + >>> ], + >>> condition="results['img_shape'][1]>> ), + >>> dict( + >>> type='ConditionApply', + >>> true_transforms=[ + >>> dict( + >>> type='ImgAugWrapper', + >>> args=[dict(cls='Rot90', k=1, keep_size=False)]) # noqa: E501 + >>> ], + >>> condition="results['img_shape'][1]>> ), + >>> dict( + >>> type='ConditionApply', + >>> true_transforms=[ + >>> dict( + >>> type='ImgAugWrapper', + >>> args=[dict(cls='Rot90', k=3, keep_size=False)]) + >>> ], + >>> condition="results['img_shape'][1]>> ), + >>> ], + >>> [ + >>> dict( + >>> type='RescaleToHeight', + >>> height=32, + >>> min_width=32, + >>> max_width=None, + >>> width_divisor=16) + >>> ], + >>> # add loading annotation after ``Resize`` because ground truth + >>> # does not need to do resize data transform + >>> [dict(type='LoadOCRAnnotations', with_text=True)], + >>> [ + >>> dict( + >>> type='PackTextRecogInputs', + >>> meta_keys=('img_path', 'ori_shape', 'img_shape', + >>> 'valid_ratio')) + >>> ] + >>> ]) + >>> ] + """ + + def merge_preds(self, + data_samples_list: List[RecSampleList]) -> RecSampleList: + """Merge predictions of enhanced data to one prediction. + + Args: + data_samples_list (List[RecSampleList]): List of predictions of + all enhanced data. The shape of data_samples_list is (B, M), + where B is the batch size and M is the number of augmented + data. + + Returns: + RecSampleList: Merged prediction. + """ + predictions = list() + for data_samples in data_samples_list: + scores = [ + data_sample.pred_text.score for data_sample in data_samples + ] + average_scores = np.array( + [sum(score) / max(1, len(score)) for score in scores]) + max_idx = np.argmax(average_scores) + predictions.append(data_samples[max_idx]) + return predictions diff --git a/tests/test_models/test_textrecog/test_recognizers/test_encoder_decoder_recognizer_tta.py b/tests/test_models/test_textrecog/test_recognizers/test_encoder_decoder_recognizer_tta.py new file mode 100644 index 00000000..2c2da3f8 --- /dev/null +++ b/tests/test_models/test_textrecog/test_recognizers/test_encoder_decoder_recognizer_tta.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +import torch.nn as nn +from mmengine.structures import LabelData + +from mmocr.models.textrecog.recognizers import EncoderDecoderRecognizerTTAModel +from mmocr.structures import TextRecogDataSample + + +class DummyModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + return x + + def test_step(self, x): + return self.forward(x) + + +class TestEncoderDecoderRecognizerTTAModel(TestCase): + + def test_merge_preds(self): + + data_sample1 = TextRecogDataSample( + pred_text=LabelData( + score=torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]), text='abcde')) + data_sample2 = TextRecogDataSample( + pred_text=LabelData( + score=torch.tensor([0.2, 0.3, 0.4, 0.5, 0.6]), text='bcdef')) + data_sample3 = TextRecogDataSample( + pred_text=LabelData( + score=torch.tensor([0.3, 0.4, 0.5, 0.6, 0.7]), text='cdefg')) + aug_data_samples = [data_sample1, data_sample2, data_sample3] + batch_aug_data_samples = [aug_data_samples] * 3 + model = EncoderDecoderRecognizerTTAModel(module=DummyModel()) + preds = model.merge_preds(batch_aug_data_samples) + for pred in preds: + self.assertEqual(pred.pred_text.text, 'cdefg') diff --git a/tools/test.py b/tools/test.py index 3699e99a..555867b5 100755 --- a/tools/test.py +++ b/tools/test.py @@ -45,6 +45,8 @@ def parse_args(): choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='Job launcher') + parser.add_argument( + '--tta', action='store_true', help='Test time augmentation') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: @@ -107,6 +109,11 @@ def main(): if args.show or args.show_dir: cfg = trigger_visualization_hook(cfg, args) + if args.tta: + cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline + cfg.tta_model.module = cfg.model + cfg.model = cfg.tta_model + # save predictions if args.save_preds: dump_metric = dict(