From 3992f0d78e4fda57b08de51cb7a29745f0bf9dfe Mon Sep 17 00:00:00 2001 From: wangxinyu Date: Mon, 13 Jun 2022 07:00:39 +0000 Subject: [PATCH] [SATRN] SATRN Config --- configs/_base_/recog_models/satrn.py | 20 +++-- configs/textrecog/satrn/satrn_academic.py | 102 ++++++++++++++-------- configs/textrecog/satrn/satrn_small.py | 51 +---------- 3 files changed, 81 insertions(+), 92 deletions(-) diff --git a/configs/_base_/recog_models/satrn.py b/configs/_base_/recog_models/satrn.py index f7a6de86..a9d6195e 100644 --- a/configs/_base_/recog_models/satrn.py +++ b/configs/_base_/recog_models/satrn.py @@ -1,11 +1,17 @@ -label_convertor = dict( - type='AttnConvertor', dict_type='DICT36', with_unknown=True, lower=True) +dictionary = dict( + type='Dictionary', + dict_file='dicts/english_digits_symbols.txt', + with_padding=True, + with_unknown=True, + same_start_end=True, + with_start=True, + with_end=True) model = dict( type='SATRN', backbone=dict(type='ShallowCNN'), - encoder=dict(type='SatrnEncoder'), - decoder=dict(type='TFDecoder'), - loss=dict(type='TFLoss'), - label_convertor=label_convertor, - max_seq_len=40) + encoder=dict(type='SATRNEncoder'), + decoder=dict(type='NRTRDecoder', loss=dict(type='CELoss')), + dictionary=dictionary, + preprocess_cfg=dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])) diff --git a/configs/textrecog/satrn/satrn_academic.py b/configs/textrecog/satrn/satrn_academic.py index 00a664e2..0045e24f 100644 --- a/configs/textrecog/satrn/satrn_academic.py +++ b/configs/textrecog/satrn/satrn_academic.py @@ -1,24 +1,21 @@ _base_ = [ '../../_base_/default_runtime.py', - '../../_base_/recog_pipelines/satrn_pipeline.py', - '../../_base_/recog_datasets/ST_MJ_train.py', - '../../_base_/recog_datasets/academic_test.py' + '../../_base_/schedules/schedule_adam_step_5e.py', + '../../_base_/recog_models/satrn.py' ] -train_list = {{_base_.train_list}} -test_list = {{_base_.test_list}} +default_hooks = dict(logger=dict(type='LoggerHook', interval=50)) -train_pipeline = {{_base_.train_pipeline}} -test_pipeline = {{_base_.test_pipeline}} - -label_convertor = dict( - type='AttnConvertor', dict_type='DICT90', with_unknown=True) +# dataset settings +dataset_type = 'OCRDataset' +data_root = 'tests/data/ocr_toy_dataset' +file_client_args = dict(backend='petrel') model = dict( type='SATRN', backbone=dict(type='ShallowCNN', input_channels=3, hidden_dim=512), encoder=dict( - type='SatrnEncoder', + type='SATRNEncoder', n_layers=12, n_head=8, d_k=512 // 8, @@ -35,34 +32,65 @@ model = dict( d_model=512, d_inner=512 * 4, d_k=512 // 8, - d_v=512 // 8), - loss=dict(type='TFLoss'), - label_convertor=label_convertor, - max_seq_len=25) + d_v=512 // 8, + loss=dict(type='CELoss', flatten=True, ignore_first_char=True), + max_seq_len=25, + postprocessor=dict(type='AttentionPostprocessor'))) # optimizer -optimizer = dict(type='Adam', lr=3e-4) -optimizer_config = dict(grad_clip=None) -# learning policy -lr_config = dict(policy='step', step=[3, 4]) -total_epochs = 6 +optim_wrapper = dict(type='OptimWrapper', optimizer=dict(type='Adam', lr=3e-4)) -data = dict( - samples_per_gpu=64, - workers_per_gpu=4, - val_dataloader=dict(samples_per_gpu=1), - test_dataloader=dict(samples_per_gpu=1), - train=dict( - type='UniformConcatDataset', - datasets=train_list, - pipeline=train_pipeline), - val=dict( - type='UniformConcatDataset', - datasets=test_list, - pipeline=test_pipeline), - test=dict( - type='UniformConcatDataset', - datasets=test_list, +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='LoadOCRAnnotations', with_text=True), + dict(type='Resize', scale=(100, 32), keep_ratio=False), + dict( + type='PackTextRecogInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) +] + +# TODO Add Test Time Augmentation `MultiRotateAugOCR` +test_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='Resize', scale=(100, 32), keep_ratio=False), + dict( + type='PackTextRecogInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio', + 'instances')) +] + +train_dataloader = dict( + batch_size=64, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(img_path=None), + ann_file='train_label.json', + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=64, + num_workers=4, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(img_path=None), + ann_file='test_label.json', + test_mode=True, pipeline=test_pipeline)) +test_dataloader = val_dataloader -evaluation = dict(interval=1, metric='acc') +val_evaluator = [ + dict( + type='WordMetric', mode=['exact', 'ignore_case', + 'ignore_case_symbol']), + dict(type='CharMetric') +] +test_evaluator = val_evaluator +visualizer = dict(type='TextRecogLocalVisualizer', name='visualizer') diff --git a/configs/textrecog/satrn/satrn_small.py b/configs/textrecog/satrn/satrn_small.py index 96f86797..1798e2f1 100644 --- a/configs/textrecog/satrn/satrn_small.py +++ b/configs/textrecog/satrn/satrn_small.py @@ -1,24 +1,9 @@ -_base_ = [ - '../../_base_/default_runtime.py', - '../../_base_/recog_pipelines/satrn_pipeline.py', - '../../_base_/recog_datasets/ST_MJ_train.py', - '../../_base_/recog_datasets/academic_test.py' -] - -train_list = {{_base_.train_list}} -test_list = {{_base_.test_list}} - -train_pipeline = {{_base_.train_pipeline}} -test_pipeline = {{_base_.test_pipeline}} - -label_convertor = dict( - type='AttnConvertor', dict_type='DICT90', with_unknown=True) +_base_ = ['satrn_academic.py'] model = dict( - type='SATRN', backbone=dict(type='ShallowCNN', input_channels=3, hidden_dim=256), encoder=dict( - type='SatrnEncoder', + type='SATRNEncoder', n_layers=6, n_head=8, d_k=256 // 8, @@ -35,34 +20,4 @@ model = dict( d_model=256, d_inner=256 * 4, d_k=256 // 8, - d_v=256 // 8), - loss=dict(type='TFLoss'), - label_convertor=label_convertor, - max_seq_len=25) - -# optimizer -optimizer = dict(type='Adam', lr=3e-4) -optimizer_config = dict(grad_clip=None) -# learning policy -lr_config = dict(policy='step', step=[3, 4]) -total_epochs = 6 - -data = dict( - samples_per_gpu=64, - workers_per_gpu=4, - val_dataloader=dict(samples_per_gpu=1), - test_dataloader=dict(samples_per_gpu=1), - train=dict( - type='UniformConcatDataset', - datasets=train_list, - pipeline=train_pipeline), - val=dict( - type='UniformConcatDataset', - datasets=test_list, - pipeline=test_pipeline), - test=dict( - type='UniformConcatDataset', - datasets=test_list, - pipeline=test_pipeline)) - -evaluation = dict(interval=1, metric='acc') + d_v=256 // 8))