Add ABINet cfg

This commit is contained in:
gaotongxiao 2022-07-12 08:27:55 +00:00
parent b8d472b77b
commit 2d478ea244
7 changed files with 179 additions and 279 deletions

View File

@ -1,70 +1,16 @@
# num_chars depends on the configuration of label_convertor. The actual
# dictionary size is 36 + 1 (<BOS/EOS>).
# TODO: Automatically update num_chars based on the configuration of
# label_convertor
num_chars = 37
max_seq_len = 26
label_convertor = dict(
type='ABIConvertor',
dict_type='DICT36',
with_unknown=False,
with_padding=False,
lower=True,
)
_base_ = 'abinet_vision_only.py'
model = dict(
type='ABINet',
backbone=dict(type='ResNetABI'),
encoder=dict(
type='ABIVisionModel',
encoder=dict(
type='TransformerEncoder',
n_layers=3,
n_head=8,
d_model=512,
d_inner=2048,
dropout=0.1,
max_len=8 * 32,
),
decoder=dict(
type='ABIVisionDecoder',
in_channels=512,
num_channels=64,
attn_height=8,
attn_width=32,
attn_mode='nearest',
use_result='feature',
num_chars=num_chars,
max_seq_len=max_seq_len,
init_cfg=dict(type='Xavier', layer='Conv2d')),
),
decoder=dict(
type='ABILanguageDecoder',
d_model=512,
n_head=8,
d_inner=2048,
n_layers=4,
dropout=0.1,
detach_tokens=True,
use_self_attn=False,
pad_idx=num_chars - 1,
num_chars=num_chars,
max_seq_len=max_seq_len,
init_cfg=None),
fuser=dict(
type='ABIFuser',
d_model=512,
num_chars=num_chars,
init_cfg=None,
max_seq_len=max_seq_len,
),
loss=dict(
type='ABILoss',
enc_weight=1.0,
dec_weight=1.0,
fusion_weight=1.0,
num_classes=num_chars),
label_convertor=label_convertor,
max_seq_len=max_seq_len,
iter_size=3)
num_iters=3,
language_decoder=dict(
type='ABILanguageDecoder',
d_model=512,
n_head=8,
d_inner=2048,
n_layers=4,
dropout=0.1,
detach_tokens=True,
use_self_attn=False,
)), )

View File

@ -0,0 +1,40 @@
dictionary = dict(
type='Dictionary',
dict_file='dicts/lower_english_digits.txt',
with_start=True,
with_end=True,
same_start_end=True,
with_padding=False,
with_unknown=False)
model = dict(
type='ABINet',
backbone=dict(type='ResNetABI'),
encoder=dict(
type='ABIEncoder',
n_layers=3,
n_head=8,
d_model=512,
d_inner=2048,
dropout=0.1,
max_len=8 * 32,
),
decoder=dict(
type='ABIFuser',
vision_decoder=dict(
type='ABIVisionDecoder',
in_channels=512,
num_channels=64,
attn_height=8,
attn_width=32,
attn_mode='nearest',
init_cfg=dict(type='Xavier', layer='Conv2d')),
loss_module=dict(type='ABILoss', letter_case='lower'),
postprocessor=dict(type='AttentionPostprocessor'),
),
dictionary=dictionary,
max_seq_len=26,
data_preprocessor=dict(
type='TextRecogDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375]))

View File

@ -1,97 +0,0 @@
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeOCR',
height=32,
min_width=128,
max_width=128,
keep_aspect_ratio=False,
width_downsample_ratio=0.25),
dict(
type='RandomWrapper',
p=0.5,
transforms=[
dict(
type='OneOfWrapper',
transforms=[
dict(
type='RandomRotate',
rotate_ratio=1.0,
max_angle=15,
),
dict(
type='TorchVisionWrapper',
op='RandomAffine',
degrees=15,
translate=(0.3, 0.3),
scale=(0.5, 2.),
shear=(-45, 45),
),
dict(
type='TorchVisionWrapper',
op='RandomPerspective',
distortion_scale=0.5,
p=1,
),
])
],
),
dict(
type='RandomWrapper',
p=0.25,
transforms=[
dict(type='PyramidRescale'),
dict(
type='Albu',
transforms=[
dict(type='GaussNoise', var_limit=(20, 20), p=0.5),
dict(type='MotionBlur', blur_limit=6, p=0.5),
]),
]),
dict(
type='RandomWrapper',
p=0.25,
transforms=[
dict(
type='TorchVisionWrapper',
op='ColorJitter',
brightness=0.5,
saturation=0.5,
contrast=0.5,
hue=0.1),
]),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio',
'resize_shape'
]),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiRotateAugOCR',
rotate_degrees=[0, 90, 270],
transforms=[
dict(
type='ResizeOCR',
height=32,
min_width=128,
max_width=128,
keep_aspect_ratio=False,
width_downsample_ratio=0.25),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio',
'resize_shape', 'img_norm_cfg', 'ori_filename'
]),
])
]

View File

@ -5,6 +5,8 @@ val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# learning policy
param_scheduler = [
dict(type='LinearLR', end=1, start_factor=0.001),
dict(
type='LinearLR', end=2, start_factor=0.001,
convert_to_iter_based=True),
dict(type='MultiStepLR', milestones=[16, 18], end=20),
]

View File

@ -1,35 +1,3 @@
_base_ = [
'../../_base_/default_runtime.py',
'../../_base_/schedules/schedule_adam_step_20e.py',
'../../_base_/recog_pipelines/abinet_pipeline.py',
'../../_base_/recog_models/abinet.py',
# '../../_base_/recog_datasets/ST_MJ_alphanumeric_train.py',
'../../_base_/recog_datasets/toy_data.py'
# '../../_base_/recog_datasets/academic_test.py'
]
_base_ = ['../../_base_/recog_models/abinet.py', 'base.py']
train_list = {{_base_.train_list}}
test_list = {{_base_.test_list}}
train_pipeline = {{_base_.train_pipeline}}
test_pipeline = {{_base_.test_pipeline}}
data = dict(
samples_per_gpu=192,
workers_per_gpu=8,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='UniformConcatDataset',
datasets=train_list,
pipeline=train_pipeline),
val=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline),
test=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='acc')
load_from = 'abinet_pretrain-1bed979b.pth'

View File

@ -1,81 +1 @@
_base_ = [
'../../_base_/default_runtime.py',
'../../_base_/schedules/schedule_adam_step_20e.py',
'../../_base_/recog_pipelines/abinet_pipeline.py',
'../../_base_/recog_datasets/toy_data.py'
# '../../_base_/recog_datasets/ST_MJ_alphanumeric_train.py',
# '../../_base_/recog_datasets/academic_test.py'
]
train_list = {{_base_.train_list}}
test_list = {{_base_.test_list}}
train_pipeline = {{_base_.train_pipeline}}
test_pipeline = {{_base_.test_pipeline}}
# Model
num_chars = 37
max_seq_len = 26
label_convertor = dict(
type='ABIConvertor',
dict_type='DICT36',
with_unknown=False,
with_padding=False,
lower=True,
)
model = dict(
type='ABINet',
backbone=dict(type='ResNetABI'),
encoder=dict(
type='ABIVisionModel',
encoder=dict(
type='TransformerEncoder',
n_layers=3,
n_head=8,
d_model=512,
d_inner=2048,
dropout=0.1,
max_len=8 * 32,
),
decoder=dict(
type='ABIVisionDecoder',
in_channels=512,
num_channels=64,
attn_height=8,
attn_width=32,
attn_mode='nearest',
use_result='feature',
num_chars=num_chars,
max_seq_len=max_seq_len,
init_cfg=dict(type='Xavier', layer='Conv2d')),
),
loss=dict(
type='ABILoss',
enc_weight=1.0,
dec_weight=1.0,
fusion_weight=1.0,
num_classes=num_chars),
label_convertor=label_convertor,
max_seq_len=max_seq_len,
iter_size=1)
data = dict(
samples_per_gpu=192,
workers_per_gpu=8,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='UniformConcatDataset',
datasets=train_list,
pipeline=train_pipeline),
val=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline),
test=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='acc')
_base_ = ['../../_base_/recog_models/abinet_vision_only.py', 'base.py']

View File

@ -0,0 +1,121 @@
_base_ = [
'../../_base_/default_runtime.py',
'../../_base_/schedules/schedule_adam_step_20e.py',
]
default_hooks = dict(logger=dict(type='LoggerHook', interval=100))
# dataset settings
dataset_type = 'OCRDataset'
data_root = 'data/recog/'
file_client_args = dict(backend='disk')
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='LoadOCRAnnotations', with_text=True),
dict(type='Resize', scale=(128, 32)),
dict(
type='RandomApply',
prob=0.5,
transforms=[
dict(
type='RandomChoice',
transforms=[
dict(
type='RandomRotate',
max_angle=15,
),
dict(
type='TorchVisionWrapper',
op='RandomAffine',
degrees=15,
translate=(0.3, 0.3),
scale=(0.5, 2.),
shear=(-45, 45),
),
dict(
type='TorchVisionWrapper',
op='RandomPerspective',
distortion_scale=0.5,
p=1,
),
])
],
),
dict(
type='RandomApply',
prob=0.25,
transforms=[
dict(type='PyramidRescale'),
dict(
type='mmdet.Albu',
transforms=[
dict(type='GaussNoise', var_limit=(20, 20), p=0.5),
dict(type='MotionBlur', blur_limit=6, p=0.5),
]),
]),
dict(
type='RandomApply',
prob=0.25,
transforms=[
dict(
type='TorchVisionWrapper',
op='ColorJitter',
brightness=0.5,
saturation=0.5,
contrast=0.5,
hue=0.1),
]),
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='LoadOCRAnnotations', with_text=True),
dict(type='Resize', scale=(128, 32)),
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio',
'instances'))
]
dataset_mj = dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img_path='mnt/ramdisk/max/90kDICT32px/'),
ann_file='data/MJ/label.json',
pipeline=train_pipeline)
dataset_st = dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='SynthText/synthtext/SynthText_patch_horizontal/'),
ann_file='data/ST/alphanumeric_labels.json',
pipeline=train_pipeline)
train_dataloader = dict(
batch_size=192 * 4,
num_workers=32,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(type='ConcatDataset', datasets=[dataset_mj, dataset_st]))
val_dataloader = dict(
batch_size=192,
num_workers=16,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img_path='testset/testset/IIIT5K/'),
ann_file='label.json',
test_mode=True,
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='WordMetric', mode=['ignore_case_symbol'])
test_evaluator = val_evaluator
visualizer = dict(type='TextRecogLocalVisualizer', name='visualizer')