From 6ca74049253e8382191b86e7181ac3ae5ee9d290 Mon Sep 17 00:00:00 2001 From: Xinyu Wang <45810070+xinke-wang@users.noreply.github.com> Date: Mon, 22 Aug 2022 12:45:00 +0800 Subject: [PATCH] [Config] Update satrn config (#1300) * [Config] Add textrec_default_runtime * [Config] Add textrec_default_runtime * add vis hook * update satrn cfg * update * update Co-authored-by: gaotongxiao --- .../textrecog/satrn/_base_satrn_shallow.py | 67 +++++++++++ configs/textrecog/satrn/metafile.yml | 8 +- configs/textrecog/satrn/satrn.py | 22 ---- configs/textrecog/satrn/satrn_academic.py | 107 ------------------ ...all.py => satrn_shallow-small_5e_st_mj.py} | 2 +- .../textrecog/satrn/satrn_shallow_5e_st_mj.py | 45 ++++++++ 6 files changed, 117 insertions(+), 134 deletions(-) create mode 100644 configs/textrecog/satrn/_base_satrn_shallow.py delete mode 100644 configs/textrecog/satrn/satrn.py delete mode 100644 configs/textrecog/satrn/satrn_academic.py rename configs/textrecog/satrn/{satrn_small.py => satrn_shallow-small_5e_st_mj.py} (92%) create mode 100644 configs/textrecog/satrn/satrn_shallow_5e_st_mj.py diff --git a/configs/textrecog/satrn/_base_satrn_shallow.py b/configs/textrecog/satrn/_base_satrn_shallow.py new file mode 100644 index 00000000..7c76ca92 --- /dev/null +++ b/configs/textrecog/satrn/_base_satrn_shallow.py @@ -0,0 +1,67 @@ +file_client_args = dict(backend='disk') + +dictionary = dict( + type='Dictionary', + dict_file='dicts/english_digits_symbols.txt', + with_padding=True, + with_unknown=True, + same_start_end=True, + with_start=True, + with_end=True) + +model = dict( + type='SATRN', + backbone=dict(type='ShallowCNN', input_channels=3, hidden_dim=512), + encoder=dict( + type='SATRNEncoder', + n_layers=12, + n_head=8, + d_k=512 // 8, + d_v=512 // 8, + d_model=512, + n_position=100, + d_inner=512 * 4, + dropout=0.1), + decoder=dict( + type='NRTRDecoder', + n_layers=6, + d_embedding=512, + n_head=8, + d_model=512, + d_inner=512 * 4, + d_k=512 // 8, + d_v=512 // 8, + module_loss=dict( + type='CEModuleLoss', flatten=True, ignore_first_char=True), + dictionary=dictionary, + max_seq_len=25, + postprocessor=dict(type='AttentionPostprocessor')), + data_preprocessor=dict( + type='TextRecogDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375])) + +train_pipeline = [ + dict( + type='LoadImageFromFile', + file_client_args=file_client_args, + ignore_empty=True, + min_size=5), + dict(type='LoadOCRAnnotations', with_text=True), + dict(type='Resize', scale=(100, 32), keep_ratio=False), + dict( + type='PackTextRecogInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) +] + +# TODO Add Test Time Augmentation `MultiRotateAugOCR` +test_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='Resize', scale=(100, 32), keep_ratio=False), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadOCRAnnotations', with_text=True), + dict( + type='PackTextRecogInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) +] diff --git a/configs/textrecog/satrn/metafile.yml b/configs/textrecog/satrn/metafile.yml index d7ed7a7c..8961fbcc 100644 --- a/configs/textrecog/satrn/metafile.yml +++ b/configs/textrecog/satrn/metafile.yml @@ -17,9 +17,9 @@ Collections: README: configs/textrecog/satrn/README.md Models: - - Name: satrn_academic + - Name: satrn_shallow_5e_st_mj In Collection: SATRN - Config: configs/textrecog/satrn/satrn_academic.py + Config: configs/textrecog/satrn/satrn_shallow_5e_st_mj.py Metadata: Training Data: - SynthText @@ -51,9 +51,9 @@ Models: word_acc: 90.6 Weights: https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_academic_20211009-cb8b1580.pth - - Name: satrn_small + - Name: satrn_shallow-small_5e_st_mj In Collection: SATRN - Config: configs/textrecog/satrn/satrn_small.py + Config: configs/textrecog/satrn/satrn_shallow-small_5e_st_mj.py Metadata: Training Data: - SynthText diff --git a/configs/textrecog/satrn/satrn.py b/configs/textrecog/satrn/satrn.py deleted file mode 100644 index 5549d210..00000000 --- a/configs/textrecog/satrn/satrn.py +++ /dev/null @@ -1,22 +0,0 @@ -dictionary = dict( - type='Dictionary', - dict_file='dicts/english_digits_symbols.txt', - with_padding=True, - with_unknown=True, - same_start_end=True, - with_start=True, - with_end=True) - -model = dict( - type='SATRN', - backbone=dict(type='ShallowCNN'), - encoder=dict(type='SATRNEncoder'), - decoder=dict( - type='NRTRDecoder', - module_loss=dict(type='CEModuleLoss'), - dictionary=dictionary, - max_seq_len=40), - data_preprocessor=dict( - type='TextRecogDataPreprocessor', - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375])) diff --git a/configs/textrecog/satrn/satrn_academic.py b/configs/textrecog/satrn/satrn_academic.py deleted file mode 100644 index 1fb7854b..00000000 --- a/configs/textrecog/satrn/satrn_academic.py +++ /dev/null @@ -1,107 +0,0 @@ -_base_ = [ - '../../_base_/recog_datasets/mjsynth.py', - '../../_base_/recog_datasets/synthtext.py', - '../../_base_/recog_datasets/cute80.py', - '../../_base_/recog_datasets/iiit5k.py', - '../../_base_/recog_datasets/svt.py', - '../../_base_/recog_datasets/svtp.py', - '../../_base_/recog_datasets/icdar2013.py', - '../../_base_/recog_datasets/icdar2015.py', - '../../_base_/default_runtime.py', - '../../_base_/schedules/schedule_adam_step_5e.py', - 'satrn.py', -] - -# dataset settings -train_list = [_base_.mj_rec_train, _base_.st_rec_train] -test_list = [ - _base_.cute80_rec_test, _base_.iiit5k_rec_test, _base_.svt_rec_test, - _base_.svtp_rec_test, _base_.ic13_rec_test, _base_.ic15_rec_test -] -file_client_args = dict(backend='disk') -default_hooks = dict(logger=dict(type='LoggerHook', interval=50)) - -# optimizer -optim_wrapper = dict(type='OptimWrapper', optimizer=dict(type='Adam', lr=3e-4)) - -model = dict( - type='SATRN', - backbone=dict(type='ShallowCNN', input_channels=3, hidden_dim=512), - encoder=dict( - type='SATRNEncoder', - n_layers=12, - n_head=8, - d_k=512 // 8, - d_v=512 // 8, - d_model=512, - n_position=100, - d_inner=512 * 4, - dropout=0.1), - decoder=dict( - type='NRTRDecoder', - n_layers=6, - d_embedding=512, - n_head=8, - d_model=512, - d_inner=512 * 4, - d_k=512 // 8, - d_v=512 // 8, - module_loss=dict( - type='CEModuleLoss', flatten=True, ignore_first_char=True), - max_seq_len=25, - postprocessor=dict(type='AttentionPostprocessor'))) - -train_pipeline = [ - dict( - type='LoadImageFromFile', - file_client_args=file_client_args, - ignore_empty=True, - min_size=5), - dict(type='LoadOCRAnnotations', with_text=True), - dict(type='Resize', scale=(100, 32), keep_ratio=False), - dict( - type='PackTextRecogInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) -] - -# TODO Add Test Time Augmentation `MultiRotateAugOCR` -test_pipeline = [ - dict(type='LoadImageFromFile', file_client_args=file_client_args), - dict(type='Resize', scale=(100, 32), keep_ratio=False), - # add loading annotation after ``Resize`` because ground truth - # does not need to do resize data transform - dict(type='LoadOCRAnnotations', with_text=True), - dict( - type='PackTextRecogInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) -] - -train_dataloader = dict( - batch_size=64, - num_workers=8, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - dataset=dict( - type='ConcatDataset', datasets=train_list, pipeline=train_pipeline)) -test_dataloader = dict( - batch_size=1, - num_workers=4, - persistent_workers=True, - drop_last=False, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=dict( - type='ConcatDataset', datasets=test_list, pipeline=test_pipeline)) -val_dataloader = test_dataloader - -val_evaluator = dict( - type='MultiDatasetsEvaluator', - metrics=[ - dict( - type='WordMetric', - mode=['exact', 'ignore_case', 'ignore_case_symbol']), - dict(type='CharMetric') - ], - dataset_prefixes=['CUTE80', 'IIIT5K', 'SVT', 'SVTP', 'IC13', 'IC15']) -test_evaluator = val_evaluator - -visualizer = dict(type='TextRecogLocalVisualizer', name='visualizer') diff --git a/configs/textrecog/satrn/satrn_small.py b/configs/textrecog/satrn/satrn_shallow-small_5e_st_mj.py similarity index 92% rename from configs/textrecog/satrn/satrn_small.py rename to configs/textrecog/satrn/satrn_shallow-small_5e_st_mj.py index 1798e2f1..a72201c9 100644 --- a/configs/textrecog/satrn/satrn_small.py +++ b/configs/textrecog/satrn/satrn_shallow-small_5e_st_mj.py @@ -1,4 +1,4 @@ -_base_ = ['satrn_academic.py'] +_base_ = ['satrn_shallow_5e_st_mj.py'] model = dict( backbone=dict(type='ShallowCNN', input_channels=3, hidden_dim=256), diff --git a/configs/textrecog/satrn/satrn_shallow_5e_st_mj.py b/configs/textrecog/satrn/satrn_shallow_5e_st_mj.py new file mode 100644 index 00000000..4832bcb3 --- /dev/null +++ b/configs/textrecog/satrn/satrn_shallow_5e_st_mj.py @@ -0,0 +1,45 @@ +_base_ = [ + '../../_base_/recog_datasets/mjsynth.py', + '../../_base_/recog_datasets/synthtext.py', + '../../_base_/recog_datasets/cute80.py', + '../../_base_/recog_datasets/iiit5k.py', + '../../_base_/recog_datasets/svt.py', + '../../_base_/recog_datasets/svtp.py', + '../../_base_/recog_datasets/icdar2013.py', + '../../_base_/recog_datasets/icdar2015.py', + '../../_base_/textrec_default_runtime.py', + '../../_base_/schedules/schedule_adam_step_5e.py', + '_base_satrn_shallow.py', +] + +# dataset settings +train_list = [_base_.mj_rec_train, _base_.st_rec_train] +test_list = [ + _base_.cute80_rec_test, _base_.iiit5k_rec_test, _base_.svt_rec_test, + _base_.svtp_rec_test, _base_.ic13_rec_test, _base_.ic15_rec_test +] + +train_dataset = dict( + type='ConcatDataset', datasets=train_list, pipeline=_base_.train_pipeline) +test_dataset = dict( + type='ConcatDataset', datasets=test_list, pipeline=_base_.test_pipeline) + +# optimizer +optim_wrapper = dict(type='OptimWrapper', optimizer=dict(type='Adam', lr=3e-4)) + +train_dataloader = dict( + batch_size=64, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=train_dataset) + +test_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=test_dataset) + +val_dataloader = test_dataloader