mirror of https://github.com/open-mmlab/mmocr.git
100 lines
2.7 KiB
Python
100 lines
2.7 KiB
Python
|
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||
|
train_pipeline = [
|
||
|
dict(type='LoadImageFromFile'),
|
||
|
dict(
|
||
|
type='ResizeOCR',
|
||
|
height=32,
|
||
|
min_width=32,
|
||
|
max_width=160,
|
||
|
keep_aspect_ratio=True),
|
||
|
dict(type='ToTensorOCR'),
|
||
|
dict(type='NormalizeOCR', **img_norm_cfg),
|
||
|
dict(
|
||
|
type='Collect',
|
||
|
keys=['img'],
|
||
|
meta_keys=[
|
||
|
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
|
||
|
]),
|
||
|
]
|
||
|
test_pipeline = [
|
||
|
dict(type='LoadImageFromFile'),
|
||
|
dict(
|
||
|
type='MultiRotateAugOCR',
|
||
|
rotate_degrees=[0, 90, 270],
|
||
|
transforms=[
|
||
|
dict(
|
||
|
type='ResizeOCR',
|
||
|
height=32,
|
||
|
min_width=32,
|
||
|
max_width=160,
|
||
|
keep_aspect_ratio=True),
|
||
|
dict(type='ToTensorOCR'),
|
||
|
dict(type='NormalizeOCR', **img_norm_cfg),
|
||
|
dict(
|
||
|
type='Collect',
|
||
|
keys=['img'],
|
||
|
meta_keys=[
|
||
|
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
|
||
|
]),
|
||
|
])
|
||
|
]
|
||
|
|
||
|
dataset_type = 'OCRDataset'
|
||
|
img_prefix = 'tests/data/ocr_toy_dataset/imgs'
|
||
|
train_anno_file1 = 'tests/data/ocr_toy_dataset/label.txt'
|
||
|
train1 = dict(
|
||
|
type=dataset_type,
|
||
|
img_prefix=img_prefix,
|
||
|
ann_file=train_anno_file1,
|
||
|
loader=dict(
|
||
|
type='HardDiskLoader',
|
||
|
repeat=100,
|
||
|
parser=dict(
|
||
|
type='LineStrParser',
|
||
|
keys=['filename', 'text'],
|
||
|
keys_idx=[0, 1],
|
||
|
separator=' ')),
|
||
|
pipeline=train_pipeline,
|
||
|
test_mode=False)
|
||
|
|
||
|
train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb'
|
||
|
train2 = dict(
|
||
|
type=dataset_type,
|
||
|
img_prefix=img_prefix,
|
||
|
ann_file=train_anno_file2,
|
||
|
loader=dict(
|
||
|
type='LmdbLoader',
|
||
|
repeat=100,
|
||
|
parser=dict(
|
||
|
type='LineStrParser',
|
||
|
keys=['filename', 'text'],
|
||
|
keys_idx=[0, 1],
|
||
|
separator=' ')),
|
||
|
pipeline=train_pipeline,
|
||
|
test_mode=False)
|
||
|
|
||
|
test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb'
|
||
|
test = dict(
|
||
|
type=dataset_type,
|
||
|
img_prefix=img_prefix,
|
||
|
ann_file=test_anno_file1,
|
||
|
loader=dict(
|
||
|
type='LmdbLoader',
|
||
|
repeat=1,
|
||
|
parser=dict(
|
||
|
type='LineStrParser',
|
||
|
keys=['filename', 'text'],
|
||
|
keys_idx=[0, 1],
|
||
|
separator=' ')),
|
||
|
pipeline=test_pipeline,
|
||
|
test_mode=True)
|
||
|
|
||
|
data = dict(
|
||
|
samples_per_gpu=16,
|
||
|
workers_per_gpu=2,
|
||
|
train=dict(type='ConcatDataset', datasets=[train1, train2]),
|
||
|
val=dict(type='ConcatDataset', datasets=[test]),
|
||
|
test=dict(type='ConcatDataset', datasets=[test]))
|
||
|
|
||
|
evaluation = dict(interval=1, metric='acc')
|