mirror of https://github.com/open-mmlab/mmocr.git
[Configs] Totaltext cfgs for DB and FCE (#1633)
* fcenet configs * dbnet config * update fcenet config * update dbnet config * Add readme and metafilepull/1647/head
parent
e1aa1f6f42
commit
89606a1cf1
|
@ -23,6 +23,12 @@ Recently, segmentation-based methods are quite popular in scene text detection,
|
|||
| [DBNet_r50dcn](/configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py) | ResNet50-DCN | [Synthtext](https://download.openmmlab.com/mmocr/textdet/dbnet/tmp_1.0_pretrain/dbnet_r50dcnv2_fpnc_sbn_2e_synthtext_20210325-ed322016.pth) | ICDAR2015 Train | ICDAR2015 Test | 1200 | 1024 | 0.8784 | 0.8315 | 0.8543 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015_20220828_124917-452c443c.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015/20220828_124917.log) |
|
||||
| [DBNet_r50-oclip](/configs/textdet/dbnet/dbnet_resnet50-oclip_fpnc_1200e_icdar2015.py) | [ResNet50-oCLIP](https://download.openmmlab.com/mmocr/backbone/resnet50-oclip-7ba0c533.pth) | - | ICDAR2015 Train | ICDAR2015 Test | 1200 | 1024 | 0.9052 | 0.8272 | 0.8644 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_resnet50-oclip_1200e_icdar2015/dbnet_resnet50-oclip_1200e_icdar2015_20221102_115917-bde8c87a.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_resnet50-oclip_1200e_icdar2015/20221102_115917.log) |
|
||||
|
||||
### Total Text
|
||||
|
||||
| Method | Backbone | Pretrained Model | Training set | Test set | #epochs | Test size | Precision | Recall | Hmean | Download |
|
||||
| :----------------------------------------------------: | :------: | :--------------: | :-------------: | :------------: | :-----: | :-------: | :-------: | :----: | :----: | :------------------------------------------------------: |
|
||||
| [DBNet_r18](/configs/textdet/dbnet/dbnet_resnet18_fpnc_1200e_totaltext.py) | ResNet18 | - | Totaltext Train | Totaltext Test | 1200 | 736 | 0.8640 | 0.7770 | 0.8182 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_resnet18_fpnc_1200e_totaltext/dbnet_resnet18_fpnc_1200e_totaltext-3ed3233c.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_resnet18_fpnc_1200e_totaltext/20221219_201038.log) |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
_base_ = [
|
||||
'_base_dbnet_resnet18_fpnc.py',
|
||||
'../_base_/datasets/totaltext.py',
|
||||
'../_base_/default_runtime.py',
|
||||
'../_base_/schedules/schedule_sgd_1200e.py',
|
||||
]
|
||||
|
||||
file_client_args = dict(backend='disk')
|
||||
|
||||
train_pipeline = [
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
file_client_args=file_client_args,
|
||||
color_type='color_ignore_orientation'),
|
||||
dict(
|
||||
type='LoadOCRAnnotations',
|
||||
with_polygon=True,
|
||||
with_bbox=True,
|
||||
with_label=True,
|
||||
),
|
||||
dict(type='FixInvalidPolygon', min_poly_points=4),
|
||||
dict(
|
||||
type='TorchVisionWrapper',
|
||||
op='ColorJitter',
|
||||
brightness=32.0 / 255,
|
||||
saturation=0.5),
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[['Fliplr', 0.5],
|
||||
dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]),
|
||||
dict(type='RandomCrop', min_side_ratio=0.1),
|
||||
dict(type='Resize', scale=(640, 640), keep_ratio=True),
|
||||
dict(type='Pad', size=(640, 640)),
|
||||
dict(
|
||||
type='PackTextDetInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape'))
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
file_client_args=file_client_args,
|
||||
color_type='color_ignore_orientation'),
|
||||
dict(type='Resize', scale=(1333, 736), keep_ratio=True),
|
||||
dict(
|
||||
type='LoadOCRAnnotations',
|
||||
with_polygon=True,
|
||||
with_bbox=True,
|
||||
with_label=True,
|
||||
),
|
||||
dict(type='FixInvalidPolygon', min_poly_points=4),
|
||||
dict(
|
||||
type='PackTextDetInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
totaltext_textdet_train = _base_.totaltext_textdet_train
|
||||
totaltext_textdet_test = _base_.totaltext_textdet_test
|
||||
totaltext_textdet_train.pipeline = train_pipeline
|
||||
totaltext_textdet_test.pipeline = test_pipeline
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=16,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
dataset=totaltext_textdet_train)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=1,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=totaltext_textdet_test)
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
auto_scale_lr = dict(base_batch_size=16)
|
|
@ -62,3 +62,15 @@ Models:
|
|||
Metrics:
|
||||
hmean-iou: 0.8644
|
||||
Weights: https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_resnet50-oclip_1200e_icdar2015/dbnet_resnet50-oclip_1200e_icdar2015_20221102_115917-bde8c87a.pth
|
||||
|
||||
- Name: dbnet_resnet18_fpnc_1200e_totaltext
|
||||
In Collection: DBNet
|
||||
Config: configs/textdet/dbnet/dbnet_resnet18_fpnc_1200e_totaltext.py
|
||||
Metadata:
|
||||
Training Data: Totaltext
|
||||
Results:
|
||||
- Task: Text Detection
|
||||
Dataset: Totaltext
|
||||
Metrics:
|
||||
hmean-iou: 0.8182
|
||||
Weights: https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_resnet18_fpnc_1200e_totaltext/dbnet_resnet18_fpnc_1200e_totaltext-3ed3233c.pth
|
||||
|
|
|
@ -18,16 +18,22 @@ One of the main challenges for arbitrary-shaped text detection is to design a go
|
|||
|
||||
| Method | Backbone | Pretrained Model | Training set | Test set | #epochs | Test size | Precision | Recall | Hmean | Download |
|
||||
| :------------------------------------: | :---------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :---------: | :-------: | :----: | :----: | :---------------------------------------: |
|
||||
| [FCENet](/configs/textdet/fcenet/fcenet_resnet50-dcnv2_fpn_1500e_ctw1500.py) | ResNet50 + DCNv2 | - | CTW1500 Train | CTW1500 Test | 1500 | (736, 1080) | 0.8689 | 0.8296 | 0.8488 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_resnet50-dcnv2_fpn_1500e_ctw1500/fcenet_resnet50-dcnv2_fpn_1500e_ctw1500_20220825_221510-4d705392.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_resnet50-dcnv2_fpn_1500e_ctw1500/20220825_221510.log) |
|
||||
| [FCENet_r50dcn](/configs/textdet/fcenet/fcenet_resnet50-dcnv2_fpn_1500e_ctw1500.py) | ResNet50 + DCNv2 | - | CTW1500 Train | CTW1500 Test | 1500 | (736, 1080) | 0.8689 | 0.8296 | 0.8488 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_resnet50-dcnv2_fpn_1500e_ctw1500/fcenet_resnet50-dcnv2_fpn_1500e_ctw1500_20220825_221510-4d705392.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_resnet50-dcnv2_fpn_1500e_ctw1500/20220825_221510.log) |
|
||||
| [FCENet_r50-oclip](/configs/textdet/fcenet/fcenet_resnet50-oclip-dcnv2_fpn_1500e_ctw1500.py) | [ResNet50-oCLIP](https://download.openmmlab.com/mmocr/backbone/resnet50-oclip-7ba0c533.pth) | - | CTW1500 Train | CTW1500 Test | 1500 | (736, 1080) | 0.8383 | 0.801 | 0.8192 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_resnet50-oclip_fpn_1500e_ctw1500/fcenet_resnet50-oclip_fpn_1500e_ctw1500_20221102_121909-101df7e6.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_resnet50-oclip_fpn_1500e_ctw1500/20221102_121909.log) |
|
||||
|
||||
### ICDAR2015
|
||||
|
||||
| Method | Backbone | Pretrained Model | Training set | Test set | #epochs | Test size | Precision | Recall | Hmean | Download |
|
||||
| :---------------------------------------------------: | :------------: | :--------------: | :----------: | :-------: | :-----: | :----------: | :-------: | :----: | :----: | :------------------------------------------------------: |
|
||||
| [FCENet](/configs/textdet/fcenet/fcenet_resnet50_fpn_1500e_icdar2015.py) | ResNet50 | - | IC15 Train | IC15 Test | 1500 | (2260, 2260) | 0.8243 | 0.8834 | 0.8528 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_resnet50_fpn_1500e_icdar2015/fcenet_resnet50_fpn_1500e_icdar2015_20220826_140941-167d9042.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_resnet50_fpn_1500e_icdar2015/20220826_140941.log) |
|
||||
| [FCENet_r50](/configs/textdet/fcenet/fcenet_resnet50_fpn_1500e_icdar2015.py) | ResNet50 | - | IC15 Train | IC15 Test | 1500 | (2260, 2260) | 0.8243 | 0.8834 | 0.8528 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_resnet50_fpn_1500e_icdar2015/fcenet_resnet50_fpn_1500e_icdar2015_20220826_140941-167d9042.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_resnet50_fpn_1500e_icdar2015/20220826_140941.log) |
|
||||
| [FCENet_r50-oclip](/configs/textdet/fcenet/fcenet_resnet50-oclip_fpn_1500e_icdar2015.py) | ResNet50-oCLIP | - | IC15 Train | IC15 Test | 1500 | (2260, 2260) | 0.9176 | 0.8098 | 0.8604 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_resnet50-oclip_fpn_1500e_icdar2015/fcenet_resnet50-oclip_fpn_1500e_icdar2015_20221101_150145-5a6fc412.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_resnet50-oclip_fpn_1500e_icdar2015/20221101_150145.log) |
|
||||
|
||||
### Total Text
|
||||
|
||||
| Method | Backbone | Pretrained Model | Training set | Test set | #epochs | Test size | Precision | Recall | Hmean | Download |
|
||||
| :---------------------------------------------------: | :------: | :--------------: | :-------------: | :------------: | :-----: | :---------: | :-------: | :----: | :----: | :-----------------------------------------------------: |
|
||||
| [FCENet_r50](/configs/textdet/fcenet/fcenet_resnet50_fpn_1500e_totaltext.py) | ResNet50 | - | Totaltext Train | Totaltext Test | 1500 | (1280, 960) | 0.8485 | 0.7810 | 0.8134 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_resnet50_fpn_1500e_totaltext/fcenet_resnet50_fpn_1500e_totaltext-91bd37af.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_resnet50_fpn_1500e_totaltext/20221219_201107.log) |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
_base_ = [
|
||||
'_base_fcenet_resnet50_fpn.py',
|
||||
'../_base_/datasets/totaltext.py',
|
||||
'../_base_/default_runtime.py',
|
||||
'../_base_/schedules/schedule_sgd_base.py',
|
||||
]
|
||||
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook',
|
||||
save_best='icdar/hmean',
|
||||
rule='greater',
|
||||
_delete_=True))
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile', color_type='color_ignore_orientation'),
|
||||
dict(
|
||||
type='LoadOCRAnnotations',
|
||||
with_polygon=True,
|
||||
with_bbox=True,
|
||||
with_label=True,
|
||||
),
|
||||
dict(type='FixInvalidPolygon'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(800, 800),
|
||||
ratio_range=(0.75, 2.5),
|
||||
keep_ratio=True),
|
||||
dict(
|
||||
type='TextDetRandomCropFlip',
|
||||
crop_ratio=0.5,
|
||||
iter_num=1,
|
||||
min_area_ratio=0.2),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
transforms=[dict(type='RandomCrop', min_side_ratio=0.3)],
|
||||
prob=0.8),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
transforms=[
|
||||
dict(
|
||||
type='RandomRotate',
|
||||
max_angle=30,
|
||||
pad_with_fixed_color=False,
|
||||
use_canvas=True)
|
||||
],
|
||||
prob=0.5),
|
||||
dict(
|
||||
type='RandomChoice',
|
||||
transforms=[[
|
||||
dict(type='Resize', scale=800, keep_ratio=True),
|
||||
dict(type='SourceImagePad', target_scale=800)
|
||||
],
|
||||
dict(type='Resize', scale=800, keep_ratio=False)],
|
||||
prob=[0.6, 0.4]),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='TorchVisionWrapper',
|
||||
op='ColorJitter',
|
||||
brightness=32.0 / 255,
|
||||
saturation=0.5,
|
||||
contrast=0.5),
|
||||
dict(
|
||||
type='PackTextDetInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile', color_type='color_ignore_orientation'),
|
||||
dict(type='Resize', scale=(1280, 960), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(
|
||||
type='LoadOCRAnnotations',
|
||||
with_polygon=True,
|
||||
with_bbox=True,
|
||||
with_label=True),
|
||||
dict(type='FixInvalidPolygon'),
|
||||
dict(
|
||||
type='PackTextDetInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
|
||||
]
|
||||
|
||||
optim_wrapper = dict(optimizer=dict(lr=1e-3, weight_decay=5e-4))
|
||||
train_cfg = dict(max_epochs=1500)
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(type='StepLR', gamma=0.8, step_size=200, end=1200),
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
totaltext_textdet_train = _base_.totaltext_textdet_train
|
||||
totaltext_textdet_test = _base_.totaltext_textdet_test
|
||||
totaltext_textdet_train.pipeline = train_pipeline
|
||||
totaltext_textdet_test.pipeline = test_pipeline
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=16,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
dataset=totaltext_textdet_train)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=1,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=totaltext_textdet_test)
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
auto_scale_lr = dict(base_batch_size=16)
|
||||
|
||||
find_unused_parameters = True
|
|
@ -62,3 +62,15 @@ Models:
|
|||
Metrics:
|
||||
hmean-iou: 0.8604
|
||||
Weights: https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_resnet50-oclip_fpn_1500e_icdar2015/fcenet_resnet50-oclip_fpn_1500e_icdar2015_20221101_150145-5a6fc412.pth
|
||||
|
||||
- Name: fcenet_resnet50_fpn_1500e_totaltext
|
||||
In Collection: FCENet
|
||||
Config: configs/textdet/fcenet/fcenet_resnet50_fpn_1500e_totaltext.py
|
||||
Metadata:
|
||||
Training Data: Totaltext
|
||||
Results:
|
||||
- Task: Text Detection
|
||||
Dataset: Totaltext
|
||||
Metrics:
|
||||
hmean-iou: 0.8134
|
||||
Weights: https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_resnet50_fpn_1500e_totaltext/fcenet_resnet50_fpn_1500e_totaltext-91bd37af.pth
|
||||
|
|
Loading…
Reference in New Issue