mirror of https://github.com/open-mmlab/mmocr.git
113 lines
3.0 KiB
Python
113 lines
3.0 KiB
Python
_base_ = [
|
|
'../../_base_/default_runtime.py',
|
|
'../../_base_/recog_models/nrtr.py',
|
|
]
|
|
|
|
# optimizer
|
|
optimizer = dict(type='Adam', lr=1e-3)
|
|
optimizer_config = dict(grad_clip=None)
|
|
# learning policy
|
|
lr_config = dict(policy='step', step=[3, 4])
|
|
total_epochs = 6
|
|
|
|
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
train_pipeline = [
|
|
dict(type='LoadImageFromFile'),
|
|
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
|
|
dict(
|
|
type='ResizeOCR',
|
|
height=32,
|
|
min_width=32,
|
|
max_width=100,
|
|
keep_aspect_ratio=False),
|
|
dict(type='ToTensorOCR'),
|
|
dict(type='NormalizeOCR', **img_norm_cfg),
|
|
dict(
|
|
type='Collect',
|
|
keys=['img'],
|
|
meta_keys=[
|
|
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
|
|
]),
|
|
]
|
|
test_pipeline = [
|
|
dict(type='LoadImageFromFile'),
|
|
dict(
|
|
type='MultiRotateAugOCR',
|
|
rotate_degrees=[0, 90, 270],
|
|
transforms=[
|
|
dict(
|
|
type='ResizeOCR',
|
|
height=32,
|
|
min_width=32,
|
|
max_width=100,
|
|
keep_aspect_ratio=False),
|
|
dict(type='ToTensorOCR'),
|
|
dict(type='NormalizeOCR', **img_norm_cfg),
|
|
dict(
|
|
type='Collect',
|
|
keys=['img'],
|
|
meta_keys=[
|
|
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
|
|
]),
|
|
])
|
|
]
|
|
|
|
dataset_type = 'OCRDataset'
|
|
img_prefix = 'tests/data/ocr_toy_dataset/imgs'
|
|
train_anno_file1 = 'tests/data/ocr_toy_dataset/label.txt'
|
|
train1 = dict(
|
|
type=dataset_type,
|
|
img_prefix=img_prefix,
|
|
ann_file=train_anno_file1,
|
|
loader=dict(
|
|
type='HardDiskLoader',
|
|
repeat=100,
|
|
parser=dict(
|
|
type='LineStrParser',
|
|
keys=['filename', 'text'],
|
|
keys_idx=[0, 1],
|
|
separator=' ')),
|
|
pipeline=train_pipeline,
|
|
test_mode=False)
|
|
|
|
train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb'
|
|
train2 = dict(
|
|
type=dataset_type,
|
|
img_prefix=img_prefix,
|
|
ann_file=train_anno_file2,
|
|
loader=dict(
|
|
type='LmdbLoader',
|
|
repeat=100,
|
|
parser=dict(
|
|
type='LineStrParser',
|
|
keys=['filename', 'text'],
|
|
keys_idx=[0, 1],
|
|
separator=' ')),
|
|
pipeline=train_pipeline,
|
|
test_mode=False)
|
|
|
|
test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb'
|
|
test = dict(
|
|
type=dataset_type,
|
|
img_prefix=img_prefix,
|
|
ann_file=test_anno_file1,
|
|
loader=dict(
|
|
type='LmdbLoader',
|
|
repeat=1,
|
|
parser=dict(
|
|
type='LineStrParser',
|
|
keys=['filename', 'text'],
|
|
keys_idx=[0, 1],
|
|
separator=' ')),
|
|
pipeline=test_pipeline,
|
|
test_mode=True)
|
|
|
|
data = dict(
|
|
samples_per_gpu=16,
|
|
workers_per_gpu=2,
|
|
train=dict(type='ConcatDataset', datasets=[train1, train2]),
|
|
val=dict(type='ConcatDataset', datasets=[test]),
|
|
test=dict(type='ConcatDataset', datasets=[test]))
|
|
|
|
evaluation = dict(interval=1, metric='acc')
|