refactor tps config (#135)

* refactor tps config

* recover tps

* add ckpt of tps
pull/175/head
Hongbin Sun 2021-05-12 14:14:24 +08:00 committed by GitHub
parent 18c54afbdc
commit df5493a79e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 156 additions and 55 deletions

View File

@ -28,10 +28,13 @@
| IIIT5K | 3000 | regular | | IIIT5K | 3000 | regular |
| SVT | 647 | regular | | SVT | 647 | regular |
| IC13 | 1015 | regular | | IC13 | 1015 | regular |
| IC15 | 2077 |irregular|
| SVTP | 645 |irregular|
| CT80 | 288 |irregular|
## Results and models ## Results and models
| methods | | Regular Text | | | | Irregular Text | | download | | methods | | Regular Text | | | | Irregular Text | | download |
| :------------------------------------------------------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | :------------------------------------------------------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| methods | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 | | methods | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
| [CRNN](/configs/textrecog/crnn/crnn_academic_dataset.py) | 80.5 | 81.5 | 86.5 | | - | - | - | [model](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_academic-a723a1c5.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/crnn/20210326_111035.log.json) | | [CRNN](/configs/textrecog/crnn/crnn_academic_dataset.py) | 80.5 | 81.5 | 86.5 | | 54.1 | 59.1 | 55.6 | [model](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_academic-a723a1c5.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/crnn/20210326_111035.log.json) |

View File

@ -95,13 +95,20 @@ train1 = dict(
test_mode=False) test_mode=False)
test_prefix = 'data/mixture/' test_prefix = 'data/mixture/'
test_img_prefix1 = test_prefix + 'icdar_2013/'
test_img_prefix2 = test_prefix + 'IIIT5K/'
test_img_prefix3 = test_prefix + 'svt/'
test_ann_file1 = test_prefix + 'icdar_2013/test_label_1015.txt' test_img_prefix1 = test_prefix + 'IIIT5K/'
test_ann_file2 = test_prefix + 'IIIT5K/test_label.txt' test_img_prefix2 = test_prefix + 'svt/'
test_ann_file3 = test_prefix + 'svt/test_label.txt' test_img_prefix3 = test_prefix + 'icdar_2013/'
test_img_prefix4 = test_prefix + 'icdar_2015/'
test_img_prefix5 = test_prefix + 'svtp/'
test_img_prefix6 = test_prefix + 'ct80/'
test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
test_ann_file2 = test_prefix + 'svt/test_label.txt'
test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt'
test_ann_file5 = test_prefix + 'svtp/test_label.txt'
test_ann_file6 = test_prefix + 'ct80/test_label.txt'
test1 = dict( test1 = dict(
type=dataset_type, type=dataset_type,
@ -126,12 +133,28 @@ test3 = {key: value for key, value in test1.items()}
test3['img_prefix'] = test_img_prefix3 test3['img_prefix'] = test_img_prefix3
test3['ann_file'] = test_ann_file3 test3['ann_file'] = test_ann_file3
test4 = {key: value for key, value in test1.items()}
test4['img_prefix'] = test_img_prefix4
test4['ann_file'] = test_ann_file4
test5 = {key: value for key, value in test1.items()}
test5['img_prefix'] = test_img_prefix5
test5['ann_file'] = test_ann_file5
test6 = {key: value for key, value in test1.items()}
test6['img_prefix'] = test_img_prefix6
test6['ann_file'] = test_ann_file6
data = dict( data = dict(
samples_per_gpu=64, samples_per_gpu=64,
workers_per_gpu=4, workers_per_gpu=4,
train=dict(type='ConcatDataset', datasets=[train1]), train=dict(type='ConcatDataset', datasets=[train1]),
val=dict(type='ConcatDataset', datasets=[test1, test2, test3]), val=dict(
test=dict(type='ConcatDataset', datasets=[test1, test2, test3])) type='ConcatDataset',
datasets=[test1, test2, test3, test4, test5, test6]),
test=dict(
type='ConcatDataset',
datasets=[test1, test2, test3, test4, test5, test6]))
evaluation = dict(interval=1, metric='acc') evaluation = dict(interval=1, metric='acc')

View File

@ -1,9 +1,20 @@
# Thin-Plate-Spline (TPS) transformation # CRNN with TPS based STN
## Introduction ## Introduction
[ALGORITHM] [ALGORITHM]
```bibtex
@article{shi2016end,
title={An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition},
author={Shi, Baoguang and Bai, Xiang and Yao, Cong},
journal={IEEE transactions on pattern analysis and machine intelligence},
year={2016}
}
```
[PREPROCESSOR]
```bibtex ```bibtex
@article{shi2016robust, @article{shi2016robust,
title={Robust Scene Text Recognition with Automatic Rectification}, title={Robust Scene Text Recognition with Automatic Rectification},
@ -13,14 +24,28 @@
} }
``` ```
## About using TPS in other models ## Results and Models
- Simply change `cfg.model.preprocessor` from `None` to ### Train Dataset
```python
dict( | trainset | instance_num | repeat_num | note |
type='TPSPreprocessor', | :------: | :----------: | :--------: | :---: |
num_fiducial=20, | Syn90k | 8919273 | 1 | synth |
img_size=(32, 100),
rectified_img_size=(32, 100), ### Test Dataset
num_img_channel=1
) | testset | instance_num | note |
| :-----: | :----------: | :-----: |
| IIIT5K | 3000 | regular |
| SVT | 647 | regular |
| IC13 | 1015 | regular |
| IC15 | 2077 |irregular|
| SVTP | 645 |irregular|
| CT80 | 288 |irregular|
## Results and models
| methods | | Regular Text | | | | Irregular Text | | download |
| :------------------------------------------------------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
| [CRNN-STN](/configs/textrecog/tps/crnn_tps_academic_dataset.py) | 80.8 | 81.3 | 85.0 | | 59.6 | 68.1 | 53.8 | [model](https://download.openmmlab.com/mmocr/textrecog/tps/crnn_tps_academic_dataset_20210510-d221a905.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/tps/20210510_204353.log.json) |

View File

@ -26,7 +26,7 @@ model = dict(
img_size=(32, 100), img_size=(32, 100),
rectified_img_size=(32, 100), rectified_img_size=(32, 100),
num_img_channel=1), num_img_channel=1),
backbone=dict(type='VeryDeepVgg', leakyRelu=False, input_channels=1), backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1),
encoder=None, encoder=None,
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True), decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
loss=dict(type='CTCLoss'), loss=dict(type='CTCLoss'),
@ -68,9 +68,9 @@ test_pipeline = [
dict( dict(
type='ResizeOCR', type='ResizeOCR',
height=32, height=32,
min_width=4, min_width=32,
max_width=None, max_width=100,
keep_aspect_ratio=True), keep_aspect_ratio=False),
dict(type='ToTensorOCR'), dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg), dict(type='NormalizeOCR', **img_norm_cfg),
dict( dict(
@ -100,13 +100,20 @@ train1 = dict(
test_mode=False) test_mode=False)
test_prefix = 'data/mixture/' test_prefix = 'data/mixture/'
test_img_prefix1 = test_prefix + 'icdar_2013/'
test_img_prefix2 = test_prefix + 'IIIT5K/'
test_img_prefix3 = test_prefix + 'svt/'
test_ann_file1 = test_prefix + 'icdar_2013/test_label_1015.txt' test_img_prefix1 = test_prefix + 'IIIT5K/'
test_ann_file2 = test_prefix + 'IIIT5K/test_label.txt' test_img_prefix2 = test_prefix + 'svt/'
test_ann_file3 = test_prefix + 'svt/test_label.txt' test_img_prefix3 = test_prefix + 'icdar_2013/'
test_img_prefix4 = test_prefix + 'icdar_2015/'
test_img_prefix5 = test_prefix + 'svtp/'
test_img_prefix6 = test_prefix + 'ct80/'
test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
test_ann_file2 = test_prefix + 'svt/test_label.txt'
test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt'
test_ann_file5 = test_prefix + 'svtp/test_label.txt'
test_ann_file6 = test_prefix + 'ct80/test_label.txt'
test1 = dict( test1 = dict(
type=dataset_type, type=dataset_type,
@ -131,12 +138,28 @@ test3 = {key: value for key, value in test1.items()}
test3['img_prefix'] = test_img_prefix3 test3['img_prefix'] = test_img_prefix3
test3['ann_file'] = test_ann_file3 test3['ann_file'] = test_ann_file3
test4 = {key: value for key, value in test1.items()}
test4['img_prefix'] = test_img_prefix4
test4['ann_file'] = test_ann_file4
test5 = {key: value for key, value in test1.items()}
test5['img_prefix'] = test_img_prefix5
test5['ann_file'] = test_ann_file5
test6 = {key: value for key, value in test1.items()}
test6['img_prefix'] = test_img_prefix6
test6['ann_file'] = test_ann_file6
data = dict( data = dict(
samples_per_gpu=64, samples_per_gpu=64,
workers_per_gpu=4, workers_per_gpu=4,
train=dict(type='ConcatDataset', datasets=[train1]), train=dict(type='ConcatDataset', datasets=[train1]),
val=dict(type='ConcatDataset', datasets=[test1, test2, test3]), val=dict(
test=dict(type='ConcatDataset', datasets=[test1, test2, test3])) type='ConcatDataset',
datasets=[test1, test2, test3, test4, test5, test6]),
test=dict(
type='ConcatDataset',
datasets=[test1, test2, test3, test4, test5, test6]))
evaluation = dict(interval=1, metric='acc') evaluation = dict(interval=1, metric='acc')

View File

@ -23,26 +23,35 @@ from .base_preprocessor import BasePreprocessor
@PREPROCESSOR.register_module() @PREPROCESSOR.register_module()
class TPSPreprocessor(BasePreprocessor): class TPSPreprocessor(BasePreprocessor):
"""Rectification Network of RARE, namely TPS based STN.""" """Rectification Network of RARE, namely TPS based STN in.
<https://arxiv.org/pdf/1603.03915.pdf>`_.
Args:
num_fiducial (int): Number of fiducial points of TPS-STN.
img_size (tuple(int, int)): Size (height, width) of the input image.
rectified_img_size (tuple(int, int))::
Size (height, width) of the rectified image.
num_img_channel (int): Number of channels of the input image.
Output:
batch_rectified_img: Rectified image with size
[batch_size x num_img_channel x rectified_img_height
x rectified_img_width]
"""
def __init__(self, def __init__(self,
num_fiducial, num_fiducial=20,
img_size, img_size=(32, 100),
rectified_img_size, rectified_img_size=(32, 100),
num_img_channel=1): num_img_channel=1):
""" Based on RARE TPS
Args:
num_fiducial (int): number of fiducial points of TPS-STN
img_size (int, int): (height, width) of the input image
rectified_img_size (int, int):
(height, width) of the rectified image
num_img_channel (int): the number of channels of the input image
output:
batch_rectified_img: rectified image
[batch_size x num_img_channel x rectified_img_height
x rectified_img_width]
"""
super().__init__() super().__init__()
assert isinstance(num_fiducial, int)
assert num_fiducial > 0
assert isinstance(img_size, tuple)
assert isinstance(rectified_img_size, tuple)
assert isinstance(num_img_channel, int)
self.num_fiducial = num_fiducial self.num_fiducial = num_fiducial
self.img_size = img_size self.img_size = img_size
self.rectified_img_size = rectified_img_size self.rectified_img_size = rectified_img_size
@ -71,13 +80,15 @@ class TPSPreprocessor(BasePreprocessor):
return batch_rectified_img return batch_rectified_img
def init_weights(self):
pass
class LocalizationNetwork(nn.Module): class LocalizationNetwork(nn.Module):
"""Localization Network of RARE, which predicts C' (K x 2) from input """Localization Network of RARE, which predicts C' (K x 2) from input
(img_width x img_height)""" (img_width x img_height)
Args:
num_fiducial (int): Number of fiducial points of TPS-STN.
num_img_channel (int): Number of channels of the input image.
"""
def __init__(self, num_fiducial, num_img_channel): def __init__(self, num_fiducial, num_img_channel):
super().__init__() super().__init__()
@ -128,7 +139,8 @@ class LocalizationNetwork(nn.Module):
Args: Args:
batch_img (tensor): Batch Input Image batch_img (tensor): Batch Input Image
[batch_size x num_img_channel x img_height x img_width] [batch_size x num_img_channel x img_height x img_width]
output:
Output:
batch_C_prime : Predicted coordinates of fiducial points for batch_C_prime : Predicted coordinates of fiducial points for
input batch [batch_size x num_fiducial x 2] input batch [batch_size x num_fiducial x 2]
""" """
@ -141,8 +153,13 @@ class LocalizationNetwork(nn.Module):
class GridGenerator(nn.Module): class GridGenerator(nn.Module):
"""Grid Generator of RARE, which produces P_prime by multipling T with """Grid Generator of RARE, which produces P_prime by multipling T with P.
P."""
Args:
num_fiducial (int): Number of fiducial points of TPS-STN.
rectified_img_size (tuple(int, int)):
Size (height, width) of the rectified image.
"""
def __init__(self, num_fiducial, rectified_img_size): def __init__(self, num_fiducial, rectified_img_size):
"""Generate P_hat and inv_delta_C for later.""" """Generate P_hat and inv_delta_C for later."""

View File

@ -1,3 +1,4 @@
import pytest
import torch import torch
from mmocr.models.textrecog.preprocessor import (BasePreprocessor, from mmocr.models.textrecog.preprocessor import (BasePreprocessor,
@ -5,6 +6,15 @@ from mmocr.models.textrecog.preprocessor import (BasePreprocessor,
def test_tps_preprocessor(): def test_tps_preprocessor():
with pytest.raises(AssertionError):
TPSPreprocessor(num_fiducial=-1)
with pytest.raises(AssertionError):
TPSPreprocessor(img_size=32)
with pytest.raises(AssertionError):
TPSPreprocessor(rectified_img_size=100)
with pytest.raises(AssertionError):
TPSPreprocessor(num_img_channel='bgr')
tps_preprocessor = TPSPreprocessor( tps_preprocessor = TPSPreprocessor(
num_fiducial=20, num_fiducial=20,
img_size=(32, 100), img_size=(32, 100),