[Config] fix pan config (#1203)

pull/1204/head
liukuikun 2022-07-25 22:21:58 +08:00 committed by GitHub
parent 870f062394
commit 83ba24cad6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 152 additions and 56 deletions

View File

@ -1,6 +1,5 @@
model = dict(
type='PANet',
pretrained='torchvision://resnet50',
backbone=dict(
type='mmdet.ResNet',
depth=50,
@ -9,13 +8,20 @@ model = dict(
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='caffe'),
style='caffe',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
),
neck=dict(type='FPEM_FFM', in_channels=[256, 512, 1024, 2048]),
bbox_head=dict(
det_head=dict(
type='PANHead',
in_channels=[128, 128, 128, 128],
out_channels=6,
module_loss=dict(type='PANModuleLoss', speedup_bbox_thr=32),
hidden_dim=128,
out_channel=6,
module_loss=dict(
type='PANModuleLoss',
loss_text=dict(type='MaskedSquareDiceLoss'),
loss_kernel=dict(type='MaskedSquareDiceLoss'),
),
postprocessor=dict(type='PANPostprocessor', text_repr_type='poly')),
data_preprocessor=dict(
type='TextDetDataPreprocessor',

View File

@ -1,34 +1,78 @@
_base_ = [
'panet_r18_fpem_ffm.py', '../../_base_/default_runtime.py',
'../../_base_/schedules/schedule_adam_600e.py',
'_base_panet_r18_fpem_ffm.py',
'../../_base_/default_runtime.py',
'../../_base_/det_datasets/ctw1500.py',
'../../_base_/det_pipelines/panet_pipeline.py'
'../../_base_/schedules/schedule_adam_600e.py',
]
model = {{_base_.model_poly}}
# dataset settings
train_list = {{_base_.train_list}}
test_list = {{_base_.test_list}}
file_client_args = dict(backend='disk')
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=20), )
train_pipeline_ctw1500 = {{_base_.train_pipeline_ctw1500}}
test_pipeline_ctw1500 = {{_base_.test_pipeline_ctw1500}}
train_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args=file_client_args,
color_type='color_ignore_orientation'),
dict(
type='LoadOCRAnnotations',
with_polygon=True,
with_bbox=True,
with_label=True,
),
dict(type='ShortScaleAspectJitter', short_size=640, scale_divisor=32),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='RandomRotate', max_angle=10),
dict(type='TextDetRandomCrop', target_size=(640, 640)),
dict(type='Pad', size=(640, 640)),
dict(
type='TorchVisionWrapper',
op='ColorJitter',
brightness=32.0 / 255,
saturation=0.5),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='UniformConcatDataset',
datasets=train_list,
pipeline=train_pipeline_ctw1500),
val=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline_ctw1500),
test=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline_ctw1500))
test_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args=file_client_args,
color_type='color_ignore_orientation'),
# TODO Replace with mmcv.RescaleToShort when it's ready
dict(
type='ShortScaleAspectJitter',
short_size=640,
scale_divisor=1,
ratio_range=(1.0, 1.0),
aspect_ratio_range=(1.0, 1.0)),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor',
'instances'))
]
model = dict(det_head=dict(module_loss=dict(shrink_ratio=(1, 0.7))))
train_dataloader = dict(
batch_size=64,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='ConcatDataset', datasets=train_list, pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='ConcatDataset', datasets=test_list, pipeline=test_pipeline))
test_dataloader = val_dataloader
evaluation = dict(interval=10, metric='hmean-iou')
val_evaluator = dict(
type='HmeanIOUMetric', pred_score_thrs=dict(start=0.3, stop=1, step=0.05))
test_evaluator = val_evaluator
visualizer = dict(type='TextDetLocalVisualizer', name='visualizer')

View File

@ -1,5 +1,5 @@
_base_ = [
'panet_r18_fpem_ffm.py',
'_base_panet_r18_fpem_ffm.py',
'../../_base_/default_runtime.py',
'../../_base_/det_datasets/icdar2015.py',
'../../_base_/schedules/schedule_adam_600e.py',
@ -56,8 +56,8 @@ test_pipeline = [
]
train_dataloader = dict(
batch_size=8,
num_workers=4,
batch_size=64,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(

View File

@ -1,32 +1,78 @@
_base_ = [
'panet_r50_fpem_ffm.py', '../../_base_/default_runtime.py',
'../../_base_/schedules/schedule_adam_600e.py',
'_base_panet_r50_fpem_ffm.py',
'../../_base_/default_runtime.py',
'../../_base_/det_datasets/icdar2017.py',
'../../_base_/det_pipelines/panet_pipeline.py'
'../../_base_/schedules/schedule_adam_600e.py',
]
# dataset settings
train_list = {{_base_.train_list}}
test_list = {{_base_.test_list}}
file_client_args = dict(backend='disk')
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=20), )
train_pipeline_icdar2017 = {{_base_.train_pipeline_icdar2017}}
test_pipeline_icdar2017 = {{_base_.test_pipeline_icdar2017}}
train_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args=file_client_args,
color_type='color_ignore_orientation'),
dict(
type='LoadOCRAnnotations',
with_polygon=True,
with_bbox=True,
with_label=True,
),
dict(type='ShortScaleAspectJitter', short_size=800, scale_divisor=32),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='RandomRotate', max_angle=10),
dict(type='TextDetRandomCrop', target_size=(800, 800)),
dict(type='Pad', size=(800, 800)),
dict(
type='TorchVisionWrapper',
op='ColorJitter',
brightness=32.0 / 255,
saturation=0.5),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='UniformConcatDataset',
datasets=train_list,
pipeline=train_pipeline_icdar2017),
val=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline_icdar2017),
test=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline_icdar2017))
test_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args=file_client_args,
color_type='color_ignore_orientation'),
# TODO Replace with mmcv.RescaleToShort when it's ready
dict(
type='ShortScaleAspectJitter',
short_size=800,
scale_divisor=1,
ratio_range=(1.0, 1.0),
aspect_ratio_range=(1.0, 1.0)),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor',
'instances'))
]
evaluation = dict(interval=10, metric='hmean-iou')
train_dataloader = dict(
batch_size=64,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='ConcatDataset', datasets=train_list, pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='ConcatDataset', datasets=test_list, pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(
type='HmeanIOUMetric', pred_score_thrs=dict(start=0.3, stop=1, step=0.05))
test_evaluator = val_evaluator
visualizer = dict(type='TextDetLocalVisualizer', name='visualizer')