mirror of https://github.com/open-mmlab/mmocr.git
160 lines
5.0 KiB
Python
160 lines
5.0 KiB
Python
dictionary = dict(
|
|
type='Dictionary',
|
|
dict_file='{{ fileDirname }}/../../../dicts/lower_english_digits.txt',
|
|
with_start=True,
|
|
with_end=True,
|
|
same_start_end=True,
|
|
with_padding=False,
|
|
with_unknown=False)
|
|
|
|
model = dict(
|
|
type='ABINet',
|
|
backbone=dict(type='ResNetABI'),
|
|
encoder=dict(
|
|
type='ABIEncoder',
|
|
n_layers=3,
|
|
n_head=8,
|
|
d_model=512,
|
|
d_inner=2048,
|
|
dropout=0.1,
|
|
max_len=8 * 32,
|
|
),
|
|
decoder=dict(
|
|
type='ABIFuser',
|
|
vision_decoder=dict(
|
|
type='ABIVisionDecoder',
|
|
in_channels=512,
|
|
num_channels=64,
|
|
attn_height=8,
|
|
attn_width=32,
|
|
attn_mode='nearest',
|
|
init_cfg=dict(type='Xavier', layer='Conv2d')),
|
|
module_loss=dict(type='ABIModuleLoss', letter_case='lower'),
|
|
postprocessor=dict(type='AttentionPostprocessor'),
|
|
dictionary=dictionary,
|
|
max_seq_len=26,
|
|
),
|
|
data_preprocessor=dict(
|
|
type='TextRecogDataPreprocessor',
|
|
mean=[123.675, 116.28, 103.53],
|
|
std=[58.395, 57.12, 57.375]))
|
|
|
|
train_pipeline = [
|
|
dict(type='LoadImageFromFile', ignore_empty=True, min_size=2),
|
|
dict(type='LoadOCRAnnotations', with_text=True),
|
|
dict(type='Resize', scale=(128, 32)),
|
|
dict(
|
|
type='RandomApply',
|
|
prob=0.5,
|
|
transforms=[
|
|
dict(
|
|
type='RandomChoice',
|
|
transforms=[
|
|
dict(
|
|
type='RandomRotate',
|
|
max_angle=15,
|
|
),
|
|
dict(
|
|
type='TorchVisionWrapper',
|
|
op='RandomAffine',
|
|
degrees=15,
|
|
translate=(0.3, 0.3),
|
|
scale=(0.5, 2.),
|
|
shear=(-45, 45),
|
|
),
|
|
dict(
|
|
type='TorchVisionWrapper',
|
|
op='RandomPerspective',
|
|
distortion_scale=0.5,
|
|
p=1,
|
|
),
|
|
])
|
|
],
|
|
),
|
|
dict(
|
|
type='RandomApply',
|
|
prob=0.25,
|
|
transforms=[
|
|
dict(type='PyramidRescale'),
|
|
dict(
|
|
type='mmdet.Albu',
|
|
transforms=[
|
|
dict(type='GaussNoise', var_limit=(20, 20), p=0.5),
|
|
dict(type='MotionBlur', blur_limit=7, p=0.5),
|
|
]),
|
|
]),
|
|
dict(
|
|
type='RandomApply',
|
|
prob=0.25,
|
|
transforms=[
|
|
dict(
|
|
type='TorchVisionWrapper',
|
|
op='ColorJitter',
|
|
brightness=0.5,
|
|
saturation=0.5,
|
|
contrast=0.5,
|
|
hue=0.1),
|
|
]),
|
|
dict(
|
|
type='PackTextRecogInputs',
|
|
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
|
]
|
|
|
|
test_pipeline = [
|
|
dict(type='LoadImageFromFile'),
|
|
dict(type='Resize', scale=(128, 32)),
|
|
# add loading annotation after ``Resize`` because ground truth
|
|
# does not need to do resize data transform
|
|
dict(type='LoadOCRAnnotations', with_text=True),
|
|
dict(
|
|
type='PackTextRecogInputs',
|
|
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
|
]
|
|
|
|
tta_pipeline = [
|
|
dict(type='LoadImageFromFile'),
|
|
dict(
|
|
type='TestTimeAug',
|
|
transforms=[
|
|
[
|
|
dict(
|
|
type='ConditionApply',
|
|
true_transforms=[
|
|
dict(
|
|
type='ImgAugWrapper',
|
|
args=[dict(cls='Rot90', k=0, keep_size=False)])
|
|
],
|
|
condition="results['img_shape'][1]<results['img_shape'][0]"
|
|
),
|
|
dict(
|
|
type='ConditionApply',
|
|
true_transforms=[
|
|
dict(
|
|
type='ImgAugWrapper',
|
|
args=[dict(cls='Rot90', k=1, keep_size=False)])
|
|
],
|
|
condition="results['img_shape'][1]<results['img_shape'][0]"
|
|
),
|
|
dict(
|
|
type='ConditionApply',
|
|
true_transforms=[
|
|
dict(
|
|
type='ImgAugWrapper',
|
|
args=[dict(cls='Rot90', k=3, keep_size=False)])
|
|
],
|
|
condition="results['img_shape'][1]<results['img_shape'][0]"
|
|
),
|
|
],
|
|
[dict(type='Resize', scale=(128, 32))],
|
|
# add loading annotation after ``Resize`` because ground truth
|
|
# does not need to do resize data transform
|
|
[dict(type='LoadOCRAnnotations', with_text=True)],
|
|
[
|
|
dict(
|
|
type='PackTextRecogInputs',
|
|
meta_keys=('img_path', 'ori_shape', 'img_shape',
|
|
'valid_ratio'))
|
|
]
|
|
])
|
|
]
|