_base_ = [ 'satrn.py', '../../_base_/recog_datasets/ST_MJ_train.py', '../../_base_/recog_datasets/academic_test.py', '../../_base_/default_runtime.py', '../../_base_/schedules/schedule_adam_step_5e.py', ] # dataset settings train_list = {{_base_.train_list}} file_client_args = dict(backend='disk') default_hooks = dict(logger=dict(type='LoggerHook', interval=50)) # optimizer optim_wrapper = dict(type='OptimWrapper', optimizer=dict(type='Adam', lr=3e-4)) model = dict( type='SATRN', backbone=dict(type='ShallowCNN', input_channels=3, hidden_dim=512), encoder=dict( type='SATRNEncoder', n_layers=12, n_head=8, d_k=512 // 8, d_v=512 // 8, d_model=512, n_position=100, d_inner=512 * 4, dropout=0.1), decoder=dict( type='NRTRDecoder', n_layers=6, d_embedding=512, n_head=8, d_model=512, d_inner=512 * 4, d_k=512 // 8, d_v=512 // 8, module_loss=dict( type='CEModuleLoss', flatten=True, ignore_first_char=True), max_seq_len=25, postprocessor=dict(type='AttentionPostprocessor'))) train_pipeline = [ dict(type='LoadImageFromFile', file_client_args=file_client_args), dict(type='LoadOCRAnnotations', with_text=True), dict(type='Resize', scale=(100, 32), keep_ratio=False), dict( type='PackTextRecogInputs', meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] # TODO Add Test Time Augmentation `MultiRotateAugOCR` test_pipeline = [ dict(type='LoadImageFromFile', file_client_args=file_client_args), dict(type='Resize', scale=(100, 32), keep_ratio=False), dict( type='PackTextRecogInputs', meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio', 'instances')) ] train_dataloader = dict( batch_size=64, num_workers=8, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), dataset=dict( type='ConcatDataset', datasets=train_list, pipeline=train_pipeline)) test_cfg = dict(type='MultiTestLoop') val_cfg = dict(type='MultiValLoop') val_dataloader = _base_.val_dataloader test_dataloader = _base_.test_dataloader for dataloader in test_dataloader: dataloader['dataset']['pipeline'] = test_pipeline for dataloader in val_dataloader: dataloader['dataset']['pipeline'] = test_pipeline visualizer = dict(type='TextRecogLocalVisualizer', name='visualizer')