mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
add fce ic15 (#258)
This commit is contained in:
parent
4882c8a317
commit
18e7ecc379
@ -17,6 +17,12 @@
|
|||||||
|
|
||||||
### CTW1500
|
### CTW1500
|
||||||
|
|
||||||
| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
|
| Method | Backbone | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
|
||||||
| :--------------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
| :--------------------------------------------------------------------: |:----------------:| :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||||
| [FCENet](/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 1500 |(736, 1080)| 0.828 | 0.875 | 0.851 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/20210511_181328.log.json) |
|
| [FCENet](/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py) | ResNet50 + DCNv2 | ImageNet | CTW1500 Train | CTW1500 Test | 1500 |(736, 1080)| 0.828 | 0.875 | 0.851 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/20210511_181328.log.json) |
|
||||||
|
|
||||||
|
### ICDAR2015
|
||||||
|
|
||||||
|
| Method | Backbone | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
|
||||||
|
| :--------------------------------------------------------------------: | :--------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||||
|
| [FCENet](/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py) | ResNet50 | ImageNet | IC15 Train | IC15 Test | 1500 |(2260, 2260)| 0.819 | 0.880 | 0.849 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/20210511_181328.log.json) |
|
||||||
|
135
configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py
Normal file
135
configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
fourier_degree = 5
|
||||||
|
model = dict(
|
||||||
|
type='FCENet',
|
||||||
|
pretrained='torchvision://resnet50',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNet',
|
||||||
|
depth=50,
|
||||||
|
num_stages=4,
|
||||||
|
out_indices=(1, 2, 3),
|
||||||
|
frozen_stages=-1,
|
||||||
|
norm_cfg=dict(type='BN', requires_grad=True),
|
||||||
|
norm_eval=False,
|
||||||
|
style='pytorch'),
|
||||||
|
neck=dict(
|
||||||
|
type='FPN',
|
||||||
|
in_channels=[512, 1024, 2048],
|
||||||
|
out_channels=256,
|
||||||
|
add_extra_convs=True,
|
||||||
|
extra_convs_on_inputs=False, # use P5
|
||||||
|
num_outs=3,
|
||||||
|
relu_before_extra_convs=True,
|
||||||
|
act_cfg=None),
|
||||||
|
bbox_head=dict(
|
||||||
|
type='FCEHead',
|
||||||
|
in_channels=256,
|
||||||
|
scales=(8, 16, 32),
|
||||||
|
loss=dict(type='FCELoss'),
|
||||||
|
alpha=1.2,
|
||||||
|
beta=1.0,
|
||||||
|
text_repr_type='quad',
|
||||||
|
fourier_degree=fourier_degree,
|
||||||
|
))
|
||||||
|
|
||||||
|
train_cfg = None
|
||||||
|
test_cfg = None
|
||||||
|
|
||||||
|
dataset_type = 'IcdarDataset'
|
||||||
|
data_root = 'data/icdar2015/'
|
||||||
|
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='LoadTextAnnotations',
|
||||||
|
with_bbox=True,
|
||||||
|
with_mask=True,
|
||||||
|
poly2mask=False),
|
||||||
|
dict(
|
||||||
|
type='ColorJitter',
|
||||||
|
brightness=32.0 / 255,
|
||||||
|
saturation=0.5,
|
||||||
|
contrast=0.5),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='RandomScaling', size=800, scale=(3. / 4, 5. / 2)),
|
||||||
|
dict(
|
||||||
|
type='RandomCropFlip', crop_ratio=0.5, iter_num=1, min_area_ratio=0.2),
|
||||||
|
dict(
|
||||||
|
type='RandomCropPolyInstances',
|
||||||
|
instance_key='gt_masks',
|
||||||
|
crop_ratio=0.8,
|
||||||
|
min_side_ratio=0.3),
|
||||||
|
dict(
|
||||||
|
type='RandomRotatePolyInstances',
|
||||||
|
rotate_ratio=0.5,
|
||||||
|
max_angle=30,
|
||||||
|
pad_with_fixed_color=False),
|
||||||
|
dict(type='SquareResizePad', target_size=800, pad_ratio=0.6),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(
|
||||||
|
type='FCENetTargets',
|
||||||
|
fourier_degree=fourier_degree,
|
||||||
|
level_proportion_range=((0, 0.4), (0.3, 0.7), (0.6, 1.0))),
|
||||||
|
dict(
|
||||||
|
type='CustomFormatBundle',
|
||||||
|
keys=['p3_maps', 'p4_maps', 'p5_maps'],
|
||||||
|
visualize=dict(flag=False, boundary_key=None)),
|
||||||
|
dict(type='Collect', keys=['img', 'p3_maps', 'p4_maps', 'p5_maps'])
|
||||||
|
]
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='MultiScaleFlipAug',
|
||||||
|
img_scale=(2260, 2260),
|
||||||
|
flip=False,
|
||||||
|
transforms=[
|
||||||
|
dict(type='Resize', img_scale=(1280, 800), keep_ratio=True),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img']),
|
||||||
|
])
|
||||||
|
]
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=8,
|
||||||
|
workers_per_gpu=2,
|
||||||
|
train=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
ann_file=data_root + '/instances_training.json',
|
||||||
|
img_prefix=data_root + '/imgs',
|
||||||
|
pipeline=train_pipeline),
|
||||||
|
val=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
ann_file=data_root + '/instances_test.json',
|
||||||
|
img_prefix=data_root + '/imgs',
|
||||||
|
pipeline=test_pipeline),
|
||||||
|
test=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
ann_file=data_root + '/instances_test.json',
|
||||||
|
img_prefix=data_root + '/imgs',
|
||||||
|
pipeline=test_pipeline))
|
||||||
|
evaluation = dict(interval=5, metric='hmean-iou')
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
optimizer = dict(type='SGD', lr=1e-3, momentum=0.90, weight_decay=5e-4)
|
||||||
|
optimizer_config = dict(grad_clip=None)
|
||||||
|
lr_config = dict(policy='poly', power=0.9, min_lr=1e-7, by_epoch=True)
|
||||||
|
total_epochs = 1500
|
||||||
|
|
||||||
|
checkpoint_config = dict(interval=5)
|
||||||
|
# yapf:disable
|
||||||
|
log_config = dict(
|
||||||
|
interval=20,
|
||||||
|
hooks=[
|
||||||
|
dict(type='TextLoggerHook')
|
||||||
|
|
||||||
|
])
|
||||||
|
# yapf:enable
|
||||||
|
dist_params = dict(backend='nccl')
|
||||||
|
log_level = 'INFO'
|
||||||
|
load_from = None
|
||||||
|
resume_from = None
|
||||||
|
workflow = [('train', 1)]
|
@ -300,7 +300,7 @@ class FCENetTargets(TextSnakeTargets):
|
|||||||
|
|
||||||
for ind, proportion_range in enumerate(lv_proportion_range):
|
for ind, proportion_range in enumerate(lv_proportion_range):
|
||||||
if proportion_range[0] < proportion < proportion_range[1]:
|
if proportion_range[0] < proportion < proportion_range[1]:
|
||||||
lv_text_polys[ind].append(
|
lv_ignore_polys[ind].append(
|
||||||
[ignore_poly[0] / lv_size_divs[ind]])
|
[ignore_poly[0] / lv_size_divs[ind]])
|
||||||
|
|
||||||
for ind, size_divisor in enumerate(lv_size_divs):
|
for ind, size_divisor in enumerate(lv_size_divs):
|
||||||
|
@ -43,6 +43,7 @@ class FCEHead(HeadMixin, nn.Module):
|
|||||||
nms_thr=0.1,
|
nms_thr=0.1,
|
||||||
alpha=1.0,
|
alpha=1.0,
|
||||||
beta=1.0,
|
beta=1.0,
|
||||||
|
text_repr_type='poly',
|
||||||
train_cfg=None,
|
train_cfg=None,
|
||||||
test_cfg=None):
|
test_cfg=None):
|
||||||
|
|
||||||
@ -63,6 +64,7 @@ class FCEHead(HeadMixin, nn.Module):
|
|||||||
self.nms_thr = nms_thr
|
self.nms_thr = nms_thr
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
|
self.text_repr_type = text_repr_type
|
||||||
self.train_cfg = train_cfg
|
self.train_cfg = train_cfg
|
||||||
self.test_cfg = test_cfg
|
self.test_cfg = test_cfg
|
||||||
self.out_channels_cls = 4
|
self.out_channels_cls = 4
|
||||||
@ -129,6 +131,6 @@ class FCEHead(HeadMixin, nn.Module):
|
|||||||
scale=scale,
|
scale=scale,
|
||||||
alpha=self.alpha,
|
alpha=self.alpha,
|
||||||
beta=self.beta,
|
beta=self.beta,
|
||||||
text_repr_type='poly',
|
text_repr_type=self.text_repr_type,
|
||||||
score_thr=self.score_thr,
|
score_thr=self.score_thr,
|
||||||
nms_thr=self.nms_thr)
|
nms_thr=self.nms_thr)
|
||||||
|
@ -419,7 +419,7 @@ def fcenet_decode(preds,
|
|||||||
"""
|
"""
|
||||||
assert isinstance(preds, list)
|
assert isinstance(preds, list)
|
||||||
assert len(preds) == 2
|
assert len(preds) == 2
|
||||||
assert text_repr_type == 'poly'
|
assert text_repr_type in ['poly', 'quad']
|
||||||
|
|
||||||
cls_pred = preds[0][0]
|
cls_pred = preds[0][0]
|
||||||
tr_pred = cls_pred[0:2].softmax(dim=0).data.cpu().numpy()
|
tr_pred = cls_pred[0:2].softmax(dim=0).data.cpu().numpy()
|
||||||
@ -460,6 +460,16 @@ def fcenet_decode(preds,
|
|||||||
boundaries = boundaries + polygons
|
boundaries = boundaries + polygons
|
||||||
|
|
||||||
boundaries = poly_nms(boundaries, nms_thr)
|
boundaries = poly_nms(boundaries, nms_thr)
|
||||||
|
|
||||||
|
if text_repr_type == 'quad':
|
||||||
|
new_boundaries = []
|
||||||
|
for boundary in boundaries:
|
||||||
|
poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32)
|
||||||
|
score = boundary[-1]
|
||||||
|
points = cv2.boxPoints(cv2.minAreaRect(poly))
|
||||||
|
points = np.int0(points)
|
||||||
|
new_boundaries.append(points.reshape(-1).tolist() + [score])
|
||||||
|
|
||||||
return boundaries
|
return boundaries
|
||||||
|
|
||||||
|
|
||||||
|
@ -372,8 +372,10 @@ def test_textsnake(cfg_file):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize('cfg_file', [
|
||||||
'cfg_file', ['textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py'])
|
'textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py',
|
||||||
|
'textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py'
|
||||||
|
])
|
||||||
def test_fcenet(cfg_file):
|
def test_fcenet(cfg_file):
|
||||||
model = _get_detector_cfg(cfg_file)
|
model = _get_detector_cfg(cfg_file)
|
||||||
model['pretrained'] = None
|
model['pretrained'] = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user