diff --git a/configs/_base_/recog_models/abinet.py b/configs/_base_/recog_models/abinet.py index 19c6b667..75d143c0 100644 --- a/configs/_base_/recog_models/abinet.py +++ b/configs/_base_/recog_models/abinet.py @@ -1,70 +1,16 @@ -# num_chars depends on the configuration of label_convertor. The actual -# dictionary size is 36 + 1 (). -# TODO: Automatically update num_chars based on the configuration of -# label_convertor -num_chars = 37 -max_seq_len = 26 - -label_convertor = dict( - type='ABIConvertor', - dict_type='DICT36', - with_unknown=False, - with_padding=False, - lower=True, -) +_base_ = 'abinet_vision_only.py' model = dict( - type='ABINet', - backbone=dict(type='ResNetABI'), - encoder=dict( - type='ABIVisionModel', - encoder=dict( - type='TransformerEncoder', - n_layers=3, - n_head=8, - d_model=512, - d_inner=2048, - dropout=0.1, - max_len=8 * 32, - ), - decoder=dict( - type='ABIVisionDecoder', - in_channels=512, - num_channels=64, - attn_height=8, - attn_width=32, - attn_mode='nearest', - use_result='feature', - num_chars=num_chars, - max_seq_len=max_seq_len, - init_cfg=dict(type='Xavier', layer='Conv2d')), - ), decoder=dict( - type='ABILanguageDecoder', d_model=512, - n_head=8, - d_inner=2048, - n_layers=4, - dropout=0.1, - detach_tokens=True, - use_self_attn=False, - pad_idx=num_chars - 1, - num_chars=num_chars, - max_seq_len=max_seq_len, - init_cfg=None), - fuser=dict( - type='ABIFuser', - d_model=512, - num_chars=num_chars, - init_cfg=None, - max_seq_len=max_seq_len, - ), - loss=dict( - type='ABILoss', - enc_weight=1.0, - dec_weight=1.0, - fusion_weight=1.0, - num_classes=num_chars), - label_convertor=label_convertor, - max_seq_len=max_seq_len, - iter_size=3) + num_iters=3, + language_decoder=dict( + type='ABILanguageDecoder', + d_model=512, + n_head=8, + d_inner=2048, + n_layers=4, + dropout=0.1, + detach_tokens=True, + use_self_attn=False, + )), ) diff --git a/configs/_base_/recog_models/abinet_vision_only.py b/configs/_base_/recog_models/abinet_vision_only.py new file mode 100644 index 00000000..473475f7 --- /dev/null +++ b/configs/_base_/recog_models/abinet_vision_only.py @@ -0,0 +1,40 @@ +dictionary = dict( + type='Dictionary', + dict_file='dicts/lower_english_digits.txt', + with_start=True, + with_end=True, + same_start_end=True, + with_padding=False, + with_unknown=False) + +model = dict( + type='ABINet', + backbone=dict(type='ResNetABI'), + encoder=dict( + type='ABIEncoder', + n_layers=3, + n_head=8, + d_model=512, + d_inner=2048, + dropout=0.1, + max_len=8 * 32, + ), + decoder=dict( + type='ABIFuser', + vision_decoder=dict( + type='ABIVisionDecoder', + in_channels=512, + num_channels=64, + attn_height=8, + attn_width=32, + attn_mode='nearest', + init_cfg=dict(type='Xavier', layer='Conv2d')), + loss_module=dict(type='ABILoss', letter_case='lower'), + postprocessor=dict(type='AttentionPostprocessor'), + ), + dictionary=dictionary, + max_seq_len=26, + data_preprocessor=dict( + type='TextRecogDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375])) diff --git a/configs/_base_/recog_pipelines/abinet_pipeline.py b/configs/_base_/recog_pipelines/abinet_pipeline.py deleted file mode 100644 index ac519f50..00000000 --- a/configs/_base_/recog_pipelines/abinet_pipeline.py +++ /dev/null @@ -1,97 +0,0 @@ -img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -train_pipeline = [ - dict(type='LoadImageFromFile'), - dict( - type='ResizeOCR', - height=32, - min_width=128, - max_width=128, - keep_aspect_ratio=False, - width_downsample_ratio=0.25), - dict( - type='RandomWrapper', - p=0.5, - transforms=[ - dict( - type='OneOfWrapper', - transforms=[ - dict( - type='RandomRotate', - rotate_ratio=1.0, - max_angle=15, - ), - dict( - type='TorchVisionWrapper', - op='RandomAffine', - degrees=15, - translate=(0.3, 0.3), - scale=(0.5, 2.), - shear=(-45, 45), - ), - dict( - type='TorchVisionWrapper', - op='RandomPerspective', - distortion_scale=0.5, - p=1, - ), - ]) - ], - ), - dict( - type='RandomWrapper', - p=0.25, - transforms=[ - dict(type='PyramidRescale'), - dict( - type='Albu', - transforms=[ - dict(type='GaussNoise', var_limit=(20, 20), p=0.5), - dict(type='MotionBlur', blur_limit=6, p=0.5), - ]), - ]), - dict( - type='RandomWrapper', - p=0.25, - transforms=[ - dict( - type='TorchVisionWrapper', - op='ColorJitter', - brightness=0.5, - saturation=0.5, - contrast=0.5, - hue=0.1), - ]), - dict(type='ToTensorOCR'), - dict(type='NormalizeOCR', **img_norm_cfg), - dict( - type='Collect', - keys=['img'], - meta_keys=[ - 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio', - 'resize_shape' - ]), -] -test_pipeline = [ - dict(type='LoadImageFromFile'), - dict( - type='MultiRotateAugOCR', - rotate_degrees=[0, 90, 270], - transforms=[ - dict( - type='ResizeOCR', - height=32, - min_width=128, - max_width=128, - keep_aspect_ratio=False, - width_downsample_ratio=0.25), - dict(type='ToTensorOCR'), - dict(type='NormalizeOCR', **img_norm_cfg), - dict( - type='Collect', - keys=['img'], - meta_keys=[ - 'filename', 'ori_shape', 'img_shape', 'valid_ratio', - 'resize_shape', 'img_norm_cfg', 'ori_filename' - ]), - ]) -] diff --git a/configs/_base_/schedules/schedule_adam_step_20e.py b/configs/_base_/schedules/schedule_adam_step_20e.py index 487836ce..98997fc0 100644 --- a/configs/_base_/schedules/schedule_adam_step_20e.py +++ b/configs/_base_/schedules/schedule_adam_step_20e.py @@ -5,6 +5,8 @@ val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') # learning policy param_scheduler = [ - dict(type='LinearLR', end=1, start_factor=0.001), + dict( + type='LinearLR', end=2, start_factor=0.001, + convert_to_iter_based=True), dict(type='MultiStepLR', milestones=[16, 18], end=20), ] diff --git a/configs/textrecog/abinet/abinet_academic.py b/configs/textrecog/abinet/abinet_academic.py index 4abb87a6..ef7f02dd 100644 --- a/configs/textrecog/abinet/abinet_academic.py +++ b/configs/textrecog/abinet/abinet_academic.py @@ -1,35 +1,3 @@ -_base_ = [ - '../../_base_/default_runtime.py', - '../../_base_/schedules/schedule_adam_step_20e.py', - '../../_base_/recog_pipelines/abinet_pipeline.py', - '../../_base_/recog_models/abinet.py', - # '../../_base_/recog_datasets/ST_MJ_alphanumeric_train.py', - '../../_base_/recog_datasets/toy_data.py' - # '../../_base_/recog_datasets/academic_test.py' -] +_base_ = ['../../_base_/recog_models/abinet.py', 'base.py'] -train_list = {{_base_.train_list}} -test_list = {{_base_.test_list}} - -train_pipeline = {{_base_.train_pipeline}} -test_pipeline = {{_base_.test_pipeline}} - -data = dict( - samples_per_gpu=192, - workers_per_gpu=8, - val_dataloader=dict(samples_per_gpu=1), - test_dataloader=dict(samples_per_gpu=1), - train=dict( - type='UniformConcatDataset', - datasets=train_list, - pipeline=train_pipeline), - val=dict( - type='UniformConcatDataset', - datasets=test_list, - pipeline=test_pipeline), - test=dict( - type='UniformConcatDataset', - datasets=test_list, - pipeline=test_pipeline)) - -evaluation = dict(interval=1, metric='acc') +load_from = 'abinet_pretrain-1bed979b.pth' diff --git a/configs/textrecog/abinet/abinet_vision_only_academic.py b/configs/textrecog/abinet/abinet_vision_only_academic.py index 318144d2..d820c1a7 100644 --- a/configs/textrecog/abinet/abinet_vision_only_academic.py +++ b/configs/textrecog/abinet/abinet_vision_only_academic.py @@ -1,81 +1 @@ -_base_ = [ - '../../_base_/default_runtime.py', - '../../_base_/schedules/schedule_adam_step_20e.py', - '../../_base_/recog_pipelines/abinet_pipeline.py', - '../../_base_/recog_datasets/toy_data.py' - # '../../_base_/recog_datasets/ST_MJ_alphanumeric_train.py', - # '../../_base_/recog_datasets/academic_test.py' -] - -train_list = {{_base_.train_list}} -test_list = {{_base_.test_list}} - -train_pipeline = {{_base_.train_pipeline}} -test_pipeline = {{_base_.test_pipeline}} - -# Model -num_chars = 37 -max_seq_len = 26 -label_convertor = dict( - type='ABIConvertor', - dict_type='DICT36', - with_unknown=False, - with_padding=False, - lower=True, -) - -model = dict( - type='ABINet', - backbone=dict(type='ResNetABI'), - encoder=dict( - type='ABIVisionModel', - encoder=dict( - type='TransformerEncoder', - n_layers=3, - n_head=8, - d_model=512, - d_inner=2048, - dropout=0.1, - max_len=8 * 32, - ), - decoder=dict( - type='ABIVisionDecoder', - in_channels=512, - num_channels=64, - attn_height=8, - attn_width=32, - attn_mode='nearest', - use_result='feature', - num_chars=num_chars, - max_seq_len=max_seq_len, - init_cfg=dict(type='Xavier', layer='Conv2d')), - ), - loss=dict( - type='ABILoss', - enc_weight=1.0, - dec_weight=1.0, - fusion_weight=1.0, - num_classes=num_chars), - label_convertor=label_convertor, - max_seq_len=max_seq_len, - iter_size=1) - -data = dict( - samples_per_gpu=192, - workers_per_gpu=8, - val_dataloader=dict(samples_per_gpu=1), - test_dataloader=dict(samples_per_gpu=1), - train=dict( - type='UniformConcatDataset', - datasets=train_list, - pipeline=train_pipeline), - val=dict( - type='UniformConcatDataset', - datasets=test_list, - pipeline=test_pipeline), - test=dict( - type='UniformConcatDataset', - datasets=test_list, - pipeline=test_pipeline)) - -evaluation = dict(interval=1, metric='acc') +_base_ = ['../../_base_/recog_models/abinet_vision_only.py', 'base.py'] diff --git a/configs/textrecog/abinet/base.py b/configs/textrecog/abinet/base.py new file mode 100644 index 00000000..2d040bf7 --- /dev/null +++ b/configs/textrecog/abinet/base.py @@ -0,0 +1,121 @@ +_base_ = [ + '../../_base_/default_runtime.py', + '../../_base_/schedules/schedule_adam_step_20e.py', +] + +default_hooks = dict(logger=dict(type='LoggerHook', interval=100)) + +# dataset settings +dataset_type = 'OCRDataset' +data_root = 'data/recog/' +file_client_args = dict(backend='disk') + +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='LoadOCRAnnotations', with_text=True), + dict(type='Resize', scale=(128, 32)), + dict( + type='RandomApply', + prob=0.5, + transforms=[ + dict( + type='RandomChoice', + transforms=[ + dict( + type='RandomRotate', + max_angle=15, + ), + dict( + type='TorchVisionWrapper', + op='RandomAffine', + degrees=15, + translate=(0.3, 0.3), + scale=(0.5, 2.), + shear=(-45, 45), + ), + dict( + type='TorchVisionWrapper', + op='RandomPerspective', + distortion_scale=0.5, + p=1, + ), + ]) + ], + ), + dict( + type='RandomApply', + prob=0.25, + transforms=[ + dict(type='PyramidRescale'), + dict( + type='mmdet.Albu', + transforms=[ + dict(type='GaussNoise', var_limit=(20, 20), p=0.5), + dict(type='MotionBlur', blur_limit=6, p=0.5), + ]), + ]), + dict( + type='RandomApply', + prob=0.25, + transforms=[ + dict( + type='TorchVisionWrapper', + op='ColorJitter', + brightness=0.5, + saturation=0.5, + contrast=0.5, + hue=0.1), + ]), + dict( + type='PackTextRecogInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) +] +test_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='LoadOCRAnnotations', with_text=True), + dict(type='Resize', scale=(128, 32)), + dict( + type='PackTextRecogInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio', + 'instances')) +] + +dataset_mj = dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(img_path='mnt/ramdisk/max/90kDICT32px/'), + ann_file='data/MJ/label.json', + pipeline=train_pipeline) +dataset_st = dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='SynthText/synthtext/SynthText_patch_horizontal/'), + ann_file='data/ST/alphanumeric_labels.json', + pipeline=train_pipeline) + +train_dataloader = dict( + batch_size=192 * 4, + num_workers=32, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict(type='ConcatDataset', datasets=[dataset_mj, dataset_st])) + +val_dataloader = dict( + batch_size=192, + num_workers=16, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(img_path='testset/testset/IIIT5K/'), + ann_file='label.json', + test_mode=True, + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict(type='WordMetric', mode=['ignore_case_symbol']) +test_evaluator = val_evaluator +visualizer = dict(type='TextRecogLocalVisualizer', name='visualizer')