[Config] Rename textsnake (#1297)

* [Config] Rename textsnake

* Update configs/textdet/textsnake/textsnake_resnet50_fpn-unet_1200e_ctw1500.py

Co-authored-by: Xinyu Wang <45810070+xinke-wang@users.noreply.github.com>

* update metafile

* fix linting

Co-authored-by: Xinyu Wang <45810070+xinke-wang@users.noreply.github.com>
pull/1307/head
Tong Gao 2022-08-22 15:02:21 +08:00 committed by GitHub
parent b2e06c04f5
commit 908ebf1bcf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 62 deletions

View File

@ -1,17 +1,31 @@
_base_ = [
'textsnake_r50_fpn_unet.py',
'../../_base_/det_datasets/ctw1500.py',
'../../_base_/default_runtime.py',
'../../_base_/schedules/schedule_sgd_1200e.py',
]
# dataset settings
train_list = {{_base_.train_list}}
test_list = {{_base_.test_list}}
file_client_args = dict(backend='disk')
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=20),
logger=dict(type='LoggerHook', interval=20))
model = dict(
type='TextSnake',
backbone=dict(
type='mmdet.ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
norm_cfg=dict(type='BN', requires_grad=True),
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
norm_eval=True,
style='caffe'),
neck=dict(
type='FPN_UNet', in_channels=[256, 512, 1024, 2048], out_channels=32),
det_head=dict(
type='TextSnakeHead',
in_channels=32,
module_loss=dict(type='TextSnakeModuleLoss'),
postprocessor=dict(
type='TextSnakePostprocessor', text_repr_type='poly')),
data_preprocessor=dict(
type='TextDetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=32))
train_pipeline = [
dict(
@ -74,24 +88,3 @@ test_pipeline = [
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='ConcatDataset', datasets=train_list, pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='ConcatDataset', datasets=test_list, pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='HmeanIOUMetric')
test_evaluator = val_evaluator
visualizer = dict(type='TextDetLocalVisualizer', name='visualizer')

View File

@ -14,9 +14,9 @@ Collections:
README: configs/textdet/textsnake/README.md
Models:
- Name: textsnake_r50_fpn_unet_1200e_ctw1500
- Name: textsnake_resnet50_fpn-unet_1200e_ctw1500
In Collection: TextSnake
Config: configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py
Config: configs/textdet/textsnake/textsnake_resnet50_fpn-unet_1200e_ctw1500.py
Metadata:
Training Data: CTW1500
Results:

View File

@ -1,26 +0,0 @@
model = dict(
type='TextSnake',
backbone=dict(
type='mmdet.ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
norm_cfg=dict(type='BN', requires_grad=True),
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
norm_eval=True,
style='caffe'),
neck=dict(
type='FPN_UNet', in_channels=[256, 512, 1024, 2048], out_channels=32),
det_head=dict(
type='TextSnakeHead',
in_channels=32,
module_loss=dict(type='TextSnakeModuleLoss'),
postprocessor=dict(
type='TextSnakePostprocessor', text_repr_type='poly')),
data_preprocessor=dict(
type='TextDetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=32))

View File

@ -0,0 +1,28 @@
_base_ = [
'_base_textsnake_resnet50_fpn-unet.py',
'../../_base_/det_datasets/ctw1500.py',
'../../_base_/textdet_default_runtime.py',
'../../_base_/schedules/schedule_sgd_1200e.py',
]
# dataset settings
ctw_det_train = _base_.ctw_det_train
ctw_det_train.pipeline = _base_.train_pipeline
ctw_det_test = _base_.ctw_det_test
ctw_det_test.pipeline = _base_.test_pipeline
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=ctw_det_train)
val_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=ctw_det_test)
test_dataloader = val_dataloader