diff --git a/configs/_base_/recog_datasets/ST_SA_MJ_real_train.py b/configs/_base_/recog_datasets/ST_SA_MJ_real_train.py index b0eaeddd..9a4ecdbc 100644 --- a/configs/_base_/recog_datasets/ST_SA_MJ_real_train.py +++ b/configs/_base_/recog_datasets/ST_SA_MJ_real_train.py @@ -6,7 +6,7 @@ data_root = 'data/rec' train_img_prefix1 = 'icdar_2011' train_img_prefix2 = 'icdar_2013' train_img_prefix3 = 'icdar_2015' -train_img_prefix4 = 'coco_text' +train_img_prefix4 = 'coco_text_v1' train_img_prefix5 = 'IIIT5K' train_img_prefix6 = 'synthtext_add' train_img_prefix7 = 'SynthText/synthtext/SynthText_patch_horizontal' diff --git a/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py b/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py index 50ba3b24..b706eb7e 100644 --- a/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py +++ b/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py @@ -7,13 +7,15 @@ _base_ = [ ] # dataset settings -train_list = {{_base_.train_list}} -test_list = {{_base_.test_list}} file_client_args = dict(backend='disk') default_hooks = dict(logger=dict(type='LoggerHook', interval=100)) train_pipeline = [ - dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='LoadImageFromFile', + file_client_args=file_client_args, + ignore_empty=True, + min_size=5), dict(type='LoadOCRAnnotations', with_text=True), dict(type='Resize', scale=(160, 48), keep_ratio=False), dict( @@ -36,32 +38,52 @@ test_pipeline = [ 'instances')) ] +# dataset settings +ic11_rec_train = _base_.ic11_rec_train +ic13_rec_train = _base_.ic13_rec_train +ic15_rec_train = _base_.ic15_rec_train +cocov1_rec_train = _base_.cocov1_rec_train +iiit5k_rec_train = _base_.iiit5k_rec_train +st_add_rec_train = _base_.st_add_rec_train +st_rec_train = _base_.st_rec_train +mj_rec_trian = _base_.mj_rec_trian + +ic11_rec_train.pipeline = train_pipeline +ic13_rec_train.pipeline = train_pipeline +ic15_rec_train.pipeline = train_pipeline +cocov1_rec_train.pipeline = train_pipeline +iiit5k_rec_train.pipeline = train_pipeline +st_add_rec_train.pipeline = train_pipeline +st_rec_train.pipeline = train_pipeline +mj_rec_trian.pipeline = train_pipeline +repeat_ic11 = dict(type='RepeatDataset', dataset=ic11_rec_train, times=20) +repeat_ic13 = dict(type='RepeatDataset', dataset=ic13_rec_train, times=20) +repeat_ic15 = dict(type='RepeatDataset', dataset=ic15_rec_train, times=20) +repeat_cocov1 = dict(type='RepeatDataset', dataset=cocov1_rec_train, times=20) +repeat_iiit5k = dict(type='RepeatDataset', dataset=iiit5k_rec_train, times=20) + train_dataloader = dict( batch_size=64, - num_workers=8, + num_workers=16, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), dataset=dict( - type='ConcatDataset', datasets=train_list, pipeline=train_pipeline)) + type='ConcatDataset', + datasets=[ + repeat_ic11, repeat_ic13, repeat_ic15, repeat_cocov1, + repeat_iiit5k, st_add_rec_train, st_rec_train, mj_rec_trian + ])) -val_dataloader = dict( - batch_size=1, - num_workers=4, - persistent_workers=True, - drop_last=False, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=dict( - type='ConcatDataset', datasets=test_list, pipeline=test_pipeline)) -test_dataloader = val_dataloader - -val_evaluator = [ - dict( - type='WordMetric', mode=['exact', 'ignore_case', - 'ignore_case_symbol']), - dict(type='CharMetric') -] -test_evaluator = val_evaluator +test_cfg = dict(type='MultiTestLoop') +val_cfg = dict(type='MultiValLoop') +val_dataloader = _base_.val_dataloader +test_dataloader = _base_.test_dataloader +for dataloader in test_dataloader: + dataloader['dataset']['pipeline'] = test_pipeline +for dataloader in val_dataloader: + dataloader['dataset']['pipeline'] = test_pipeline visualizer = dict(type='TextRecogLocalVisualizer', name='visualizer') + dictionary = dict( type='Dictionary', dict_file='dicts/english_digits_symbols.txt', @@ -90,6 +112,6 @@ model = dict( pred_concat=True, postprocessor=dict(type='AttentionPostprocessor'), module_loss=dict( - type='CEModuleLoss', ignore_first_char=True, reduction='mean')), - dictionary=dictionary, - max_seq_len=30) + type='CEModuleLoss', ignore_first_char=True, reduction='mean'), + dictionary=dictionary, + max_seq_len=30))