diff --git a/configs/textdet/fcenet/README.md b/configs/textdet/fcenet/README.md index 5e111a33..505509ec 100644 --- a/configs/textdet/fcenet/README.md +++ b/configs/textdet/fcenet/README.md @@ -17,6 +17,12 @@ ### CTW1500 -| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | -| :--------------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | -| [FCENet](/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 1500 |(736, 1080)| 0.828 | 0.875 | 0.851 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/20210511_181328.log.json) | +| Method | Backbone | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :--------------------------------------------------------------------: |:----------------:| :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [FCENet](/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py) | ResNet50 + DCNv2 | ImageNet | CTW1500 Train | CTW1500 Test | 1500 |(736, 1080)| 0.828 | 0.875 | 0.851 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/20210511_181328.log.json) | + +### ICDAR2015 + +| Method | Backbone | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :--------------------------------------------------------------------: | :--------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [FCENet](/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py) | ResNet50 | ImageNet | IC15 Train | IC15 Test | 1500 |(2260, 2260)| 0.819 | 0.880 | 0.849 | [model](https://download.openmmlab.com/mmocr/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/fcenet/20210511_181328.log.json) | diff --git a/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py b/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py new file mode 100644 index 00000000..2b4fd6d1 --- /dev/null +++ b/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py @@ -0,0 +1,135 @@ +fourier_degree = 5 +model = dict( + type='FCENet', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[512, 1024, 2048], + out_channels=256, + add_extra_convs=True, + extra_convs_on_inputs=False, # use P5 + num_outs=3, + relu_before_extra_convs=True, + act_cfg=None), + bbox_head=dict( + type='FCEHead', + in_channels=256, + scales=(8, 16, 32), + loss=dict(type='FCELoss'), + alpha=1.2, + beta=1.0, + text_repr_type='quad', + fourier_degree=fourier_degree, + )) + +train_cfg = None +test_cfg = None + +dataset_type = 'IcdarDataset' +data_root = 'data/icdar2015/' + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadTextAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict( + type='ColorJitter', + brightness=32.0 / 255, + saturation=0.5, + contrast=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='RandomScaling', size=800, scale=(3. / 4, 5. / 2)), + dict( + type='RandomCropFlip', crop_ratio=0.5, iter_num=1, min_area_ratio=0.2), + dict( + type='RandomCropPolyInstances', + instance_key='gt_masks', + crop_ratio=0.8, + min_side_ratio=0.3), + dict( + type='RandomRotatePolyInstances', + rotate_ratio=0.5, + max_angle=30, + pad_with_fixed_color=False), + dict(type='SquareResizePad', target_size=800, pad_ratio=0.6), + dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'), + dict(type='Pad', size_divisor=32), + dict( + type='FCENetTargets', + fourier_degree=fourier_degree, + level_proportion_range=((0, 0.4), (0.3, 0.7), (0.6, 1.0))), + dict( + type='CustomFormatBundle', + keys=['p3_maps', 'p4_maps', 'p5_maps'], + visualize=dict(flag=False, boundary_key=None)), + dict(type='Collect', keys=['img', 'p3_maps', 'p4_maps', 'p5_maps']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2260, 2260), + flip=False, + transforms=[ + dict(type='Resize', img_scale=(1280, 800), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=8, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + '/instances_training.json', + img_prefix=data_root + '/imgs', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline)) +evaluation = dict(interval=5, metric='hmean-iou') + +# optimizer +optimizer = dict(type='SGD', lr=1e-3, momentum=0.90, weight_decay=5e-4) +optimizer_config = dict(grad_clip=None) +lr_config = dict(policy='poly', power=0.9, min_lr=1e-7, by_epoch=True) +total_epochs = 1500 + +checkpoint_config = dict(interval=5) +# yapf:disable +log_config = dict( + interval=20, + hooks=[ + dict(type='TextLoggerHook') + + ]) +# yapf:enable +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py b/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py index 6992afa6..3d774312 100644 --- a/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py +++ b/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py @@ -300,7 +300,7 @@ class FCENetTargets(TextSnakeTargets): for ind, proportion_range in enumerate(lv_proportion_range): if proportion_range[0] < proportion < proportion_range[1]: - lv_text_polys[ind].append( + lv_ignore_polys[ind].append( [ignore_poly[0] / lv_size_divs[ind]]) for ind, size_divisor in enumerate(lv_size_divs): diff --git a/mmocr/models/textdet/dense_heads/fce_head.py b/mmocr/models/textdet/dense_heads/fce_head.py index ef737e8d..b178fb4f 100644 --- a/mmocr/models/textdet/dense_heads/fce_head.py +++ b/mmocr/models/textdet/dense_heads/fce_head.py @@ -43,6 +43,7 @@ class FCEHead(HeadMixin, nn.Module): nms_thr=0.1, alpha=1.0, beta=1.0, + text_repr_type='poly', train_cfg=None, test_cfg=None): @@ -63,6 +64,7 @@ class FCEHead(HeadMixin, nn.Module): self.nms_thr = nms_thr self.alpha = alpha self.beta = beta + self.text_repr_type = text_repr_type self.train_cfg = train_cfg self.test_cfg = test_cfg self.out_channels_cls = 4 @@ -129,6 +131,6 @@ class FCEHead(HeadMixin, nn.Module): scale=scale, alpha=self.alpha, beta=self.beta, - text_repr_type='poly', + text_repr_type=self.text_repr_type, score_thr=self.score_thr, nms_thr=self.nms_thr) diff --git a/mmocr/models/textdet/postprocess/wrapper.py b/mmocr/models/textdet/postprocess/wrapper.py index ee9ac9fa..8de62b35 100644 --- a/mmocr/models/textdet/postprocess/wrapper.py +++ b/mmocr/models/textdet/postprocess/wrapper.py @@ -419,7 +419,7 @@ def fcenet_decode(preds, """ assert isinstance(preds, list) assert len(preds) == 2 - assert text_repr_type == 'poly' + assert text_repr_type in ['poly', 'quad'] cls_pred = preds[0][0] tr_pred = cls_pred[0:2].softmax(dim=0).data.cpu().numpy() @@ -460,6 +460,16 @@ def fcenet_decode(preds, boundaries = boundaries + polygons boundaries = poly_nms(boundaries, nms_thr) + + if text_repr_type == 'quad': + new_boundaries = [] + for boundary in boundaries: + poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32) + score = boundary[-1] + points = cv2.boxPoints(cv2.minAreaRect(poly)) + points = np.int0(points) + new_boundaries.append(points.reshape(-1).tolist() + [score]) + return boundaries diff --git a/tests/test_models/test_detector.py b/tests/test_models/test_detector.py index 856515a5..3c5f5ded 100644 --- a/tests/test_models/test_detector.py +++ b/tests/test_models/test_detector.py @@ -372,8 +372,10 @@ def test_textsnake(cfg_file): @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') -@pytest.mark.parametrize( - 'cfg_file', ['textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py']) +@pytest.mark.parametrize('cfg_file', [ + 'textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py', + 'textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py' +]) def test_fcenet(cfg_file): model = _get_detector_cfg(cfg_file) model['pretrained'] = None