mirror of https://github.com/open-mmlab/mmocr.git
97 lines
3.1 KiB
Python
97 lines
3.1 KiB
Python
_base_ = [
|
|
'../../_base_/schedules/schedule_1200e.py', '../../_base_/runtime_10e.py'
|
|
]
|
|
model = dict(
|
|
type='DBNet',
|
|
pretrained='torchvision://resnet18',
|
|
backbone=dict(
|
|
type='ResNet',
|
|
depth=18,
|
|
num_stages=4,
|
|
out_indices=(0, 1, 2, 3),
|
|
frozen_stages=-1,
|
|
norm_cfg=dict(type='BN', requires_grad=True),
|
|
norm_eval=False,
|
|
style='caffe'),
|
|
neck=dict(
|
|
type='FPNC', in_channels=[64, 128, 256, 512], lateral_channels=256),
|
|
bbox_head=dict(
|
|
type='DBHead',
|
|
text_repr_type='quad',
|
|
in_channels=256,
|
|
loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True)),
|
|
train_cfg=None,
|
|
test_cfg=None)
|
|
|
|
dataset_type = 'IcdarDataset'
|
|
data_root = 'data/icdar2015/'
|
|
img_norm_cfg = dict(
|
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
|
# for visualizing img, pls uncomment it.
|
|
# img_norm_cfg = dict(mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
|
|
|
|
train_pipeline = [
|
|
dict(type='LoadImageFromFile'),
|
|
dict(
|
|
type='LoadTextAnnotations',
|
|
with_bbox=True,
|
|
with_mask=True,
|
|
poly2mask=False),
|
|
dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
|
|
dict(type='Normalize', **img_norm_cfg),
|
|
# img aug
|
|
dict(
|
|
type='ImgAug',
|
|
args=[['Fliplr', 0.5],
|
|
dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]),
|
|
# random crop
|
|
dict(type='EastRandomCrop', target_size=(640, 640)),
|
|
dict(type='DBNetTargets', shrink_ratio=0.4),
|
|
dict(type='Pad', size_divisor=32),
|
|
# for visualizing img and gts, pls set visualize = True
|
|
dict(
|
|
type='CustomFormatBundle',
|
|
keys=['gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'],
|
|
visualize=dict(flag=False, boundary_key='gt_shrink')),
|
|
dict(
|
|
type='Collect',
|
|
keys=['img', 'gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'])
|
|
]
|
|
test_pipeline = [
|
|
dict(type='LoadImageFromFile'),
|
|
dict(
|
|
type='MultiScaleFlipAug',
|
|
img_scale=(1333, 736),
|
|
flip=False,
|
|
transforms=[
|
|
dict(type='Resize', img_scale=(2944, 736), keep_ratio=True),
|
|
dict(type='Normalize', **img_norm_cfg),
|
|
dict(type='Pad', size_divisor=32),
|
|
dict(type='ImageToTensor', keys=['img']),
|
|
dict(type='Collect', keys=['img']),
|
|
])
|
|
]
|
|
data = dict(
|
|
samples_per_gpu=16,
|
|
workers_per_gpu=8,
|
|
train=dict(
|
|
type=dataset_type,
|
|
ann_file=data_root + '/instances_training.json',
|
|
# for debugging top k imgs
|
|
# select_first_k=200,
|
|
img_prefix=data_root + '/imgs',
|
|
pipeline=train_pipeline),
|
|
val=dict(
|
|
type=dataset_type,
|
|
ann_file=data_root + '/instances_test.json',
|
|
img_prefix=data_root + '/imgs',
|
|
# select_first_k=100,
|
|
pipeline=test_pipeline),
|
|
test=dict(
|
|
type=dataset_type,
|
|
ann_file=data_root + '/instances_test.json',
|
|
img_prefix=data_root + '/imgs',
|
|
# select_first_k=100,
|
|
pipeline=test_pipeline))
|
|
evaluation = dict(interval=100, metric='hmean-iou')
|