refactor tps config ()

* refactor tps config

* recover tps

* add ckpt of tps
pull/175/head
Hongbin Sun 2021-05-12 14:14:24 +08:00 committed by GitHub
parent 18c54afbdc
commit df5493a79e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 156 additions and 55 deletions
mmocr/models/textrecog/preprocessor
tests/test_models

View File

@ -28,10 +28,13 @@
| IIIT5K | 3000 | regular |
| SVT | 647 | regular |
| IC13 | 1015 | regular |
| IC15 | 2077 |irregular|
| SVTP | 645 |irregular|
| CT80 | 288 |irregular|
## Results and models
| methods | | Regular Text | | | | Irregular Text | | download |
| :------------------------------------------------------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| methods | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
| [CRNN](/configs/textrecog/crnn/crnn_academic_dataset.py) | 80.5 | 81.5 | 86.5 | | - | - | - | [model](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_academic-a723a1c5.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/crnn/20210326_111035.log.json) |
| [CRNN](/configs/textrecog/crnn/crnn_academic_dataset.py) | 80.5 | 81.5 | 86.5 | | 54.1 | 59.1 | 55.6 | [model](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_academic-a723a1c5.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/crnn/20210326_111035.log.json) |

View File

@ -95,13 +95,20 @@ train1 = dict(
test_mode=False)
test_prefix = 'data/mixture/'
test_img_prefix1 = test_prefix + 'icdar_2013/'
test_img_prefix2 = test_prefix + 'IIIT5K/'
test_img_prefix3 = test_prefix + 'svt/'
test_ann_file1 = test_prefix + 'icdar_2013/test_label_1015.txt'
test_ann_file2 = test_prefix + 'IIIT5K/test_label.txt'
test_ann_file3 = test_prefix + 'svt/test_label.txt'
test_img_prefix1 = test_prefix + 'IIIT5K/'
test_img_prefix2 = test_prefix + 'svt/'
test_img_prefix3 = test_prefix + 'icdar_2013/'
test_img_prefix4 = test_prefix + 'icdar_2015/'
test_img_prefix5 = test_prefix + 'svtp/'
test_img_prefix6 = test_prefix + 'ct80/'
test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
test_ann_file2 = test_prefix + 'svt/test_label.txt'
test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt'
test_ann_file5 = test_prefix + 'svtp/test_label.txt'
test_ann_file6 = test_prefix + 'ct80/test_label.txt'
test1 = dict(
type=dataset_type,
@ -126,12 +133,28 @@ test3 = {key: value for key, value in test1.items()}
test3['img_prefix'] = test_img_prefix3
test3['ann_file'] = test_ann_file3
test4 = {key: value for key, value in test1.items()}
test4['img_prefix'] = test_img_prefix4
test4['ann_file'] = test_ann_file4
test5 = {key: value for key, value in test1.items()}
test5['img_prefix'] = test_img_prefix5
test5['ann_file'] = test_ann_file5
test6 = {key: value for key, value in test1.items()}
test6['img_prefix'] = test_img_prefix6
test6['ann_file'] = test_ann_file6
data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
train=dict(type='ConcatDataset', datasets=[train1]),
val=dict(type='ConcatDataset', datasets=[test1, test2, test3]),
test=dict(type='ConcatDataset', datasets=[test1, test2, test3]))
val=dict(
type='ConcatDataset',
datasets=[test1, test2, test3, test4, test5, test6]),
test=dict(
type='ConcatDataset',
datasets=[test1, test2, test3, test4, test5, test6]))
evaluation = dict(interval=1, metric='acc')

View File

@ -1,9 +1,20 @@
# Thin-Plate-Spline (TPS) transformation
# CRNN with TPS based STN
## Introduction
[ALGORITHM]
```bibtex
@article{shi2016end,
title={An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition},
author={Shi, Baoguang and Bai, Xiang and Yao, Cong},
journal={IEEE transactions on pattern analysis and machine intelligence},
year={2016}
}
```
[PREPROCESSOR]
```bibtex
@article{shi2016robust,
title={Robust Scene Text Recognition with Automatic Rectification},
@ -13,14 +24,28 @@
}
```
## About using TPS in other models
## Results and Models
- Simply change `cfg.model.preprocessor` from `None` to
```python
dict(
type='TPSPreprocessor',
num_fiducial=20,
img_size=(32, 100),
rectified_img_size=(32, 100),
num_img_channel=1
)
### Train Dataset
| trainset | instance_num | repeat_num | note |
| :------: | :----------: | :--------: | :---: |
| Syn90k | 8919273 | 1 | synth |
### Test Dataset
| testset | instance_num | note |
| :-----: | :----------: | :-----: |
| IIIT5K | 3000 | regular |
| SVT | 647 | regular |
| IC13 | 1015 | regular |
| IC15 | 2077 |irregular|
| SVTP | 645 |irregular|
| CT80 | 288 |irregular|
## Results and models
| methods | | Regular Text | | | | Irregular Text | | download |
| :------------------------------------------------------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
| [CRNN-STN](/configs/textrecog/tps/crnn_tps_academic_dataset.py) | 80.8 | 81.3 | 85.0 | | 59.6 | 68.1 | 53.8 | [model](https://download.openmmlab.com/mmocr/textrecog/tps/crnn_tps_academic_dataset_20210510-d221a905.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/tps/20210510_204353.log.json) |

View File

@ -26,7 +26,7 @@ model = dict(
img_size=(32, 100),
rectified_img_size=(32, 100),
num_img_channel=1),
backbone=dict(type='VeryDeepVgg', leakyRelu=False, input_channels=1),
backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1),
encoder=None,
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
loss=dict(type='CTCLoss'),
@ -68,9 +68,9 @@ test_pipeline = [
dict(
type='ResizeOCR',
height=32,
min_width=4,
max_width=None,
keep_aspect_ratio=True),
min_width=32,
max_width=100,
keep_aspect_ratio=False),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
@ -100,13 +100,20 @@ train1 = dict(
test_mode=False)
test_prefix = 'data/mixture/'
test_img_prefix1 = test_prefix + 'icdar_2013/'
test_img_prefix2 = test_prefix + 'IIIT5K/'
test_img_prefix3 = test_prefix + 'svt/'
test_ann_file1 = test_prefix + 'icdar_2013/test_label_1015.txt'
test_ann_file2 = test_prefix + 'IIIT5K/test_label.txt'
test_ann_file3 = test_prefix + 'svt/test_label.txt'
test_img_prefix1 = test_prefix + 'IIIT5K/'
test_img_prefix2 = test_prefix + 'svt/'
test_img_prefix3 = test_prefix + 'icdar_2013/'
test_img_prefix4 = test_prefix + 'icdar_2015/'
test_img_prefix5 = test_prefix + 'svtp/'
test_img_prefix6 = test_prefix + 'ct80/'
test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
test_ann_file2 = test_prefix + 'svt/test_label.txt'
test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt'
test_ann_file5 = test_prefix + 'svtp/test_label.txt'
test_ann_file6 = test_prefix + 'ct80/test_label.txt'
test1 = dict(
type=dataset_type,
@ -131,12 +138,28 @@ test3 = {key: value for key, value in test1.items()}
test3['img_prefix'] = test_img_prefix3
test3['ann_file'] = test_ann_file3
test4 = {key: value for key, value in test1.items()}
test4['img_prefix'] = test_img_prefix4
test4['ann_file'] = test_ann_file4
test5 = {key: value for key, value in test1.items()}
test5['img_prefix'] = test_img_prefix5
test5['ann_file'] = test_ann_file5
test6 = {key: value for key, value in test1.items()}
test6['img_prefix'] = test_img_prefix6
test6['ann_file'] = test_ann_file6
data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
train=dict(type='ConcatDataset', datasets=[train1]),
val=dict(type='ConcatDataset', datasets=[test1, test2, test3]),
test=dict(type='ConcatDataset', datasets=[test1, test2, test3]))
val=dict(
type='ConcatDataset',
datasets=[test1, test2, test3, test4, test5, test6]),
test=dict(
type='ConcatDataset',
datasets=[test1, test2, test3, test4, test5, test6]))
evaluation = dict(interval=1, metric='acc')

View File

@ -23,26 +23,35 @@ from .base_preprocessor import BasePreprocessor
@PREPROCESSOR.register_module()
class TPSPreprocessor(BasePreprocessor):
"""Rectification Network of RARE, namely TPS based STN."""
"""Rectification Network of RARE, namely TPS based STN in.
<https://arxiv.org/pdf/1603.03915.pdf>`_.
Args:
num_fiducial (int): Number of fiducial points of TPS-STN.
img_size (tuple(int, int)): Size (height, width) of the input image.
rectified_img_size (tuple(int, int))::
Size (height, width) of the rectified image.
num_img_channel (int): Number of channels of the input image.
Output:
batch_rectified_img: Rectified image with size
[batch_size x num_img_channel x rectified_img_height
x rectified_img_width]
"""
def __init__(self,
num_fiducial,
img_size,
rectified_img_size,
num_fiducial=20,
img_size=(32, 100),
rectified_img_size=(32, 100),
num_img_channel=1):
""" Based on RARE TPS
Args:
num_fiducial (int): number of fiducial points of TPS-STN
img_size (int, int): (height, width) of the input image
rectified_img_size (int, int):
(height, width) of the rectified image
num_img_channel (int): the number of channels of the input image
output:
batch_rectified_img: rectified image
[batch_size x num_img_channel x rectified_img_height
x rectified_img_width]
"""
super().__init__()
assert isinstance(num_fiducial, int)
assert num_fiducial > 0
assert isinstance(img_size, tuple)
assert isinstance(rectified_img_size, tuple)
assert isinstance(num_img_channel, int)
self.num_fiducial = num_fiducial
self.img_size = img_size
self.rectified_img_size = rectified_img_size
@ -71,13 +80,15 @@ class TPSPreprocessor(BasePreprocessor):
return batch_rectified_img
def init_weights(self):
pass
class LocalizationNetwork(nn.Module):
"""Localization Network of RARE, which predicts C' (K x 2) from input
(img_width x img_height)"""
(img_width x img_height)
Args:
num_fiducial (int): Number of fiducial points of TPS-STN.
num_img_channel (int): Number of channels of the input image.
"""
def __init__(self, num_fiducial, num_img_channel):
super().__init__()
@ -128,7 +139,8 @@ class LocalizationNetwork(nn.Module):
Args:
batch_img (tensor): Batch Input Image
[batch_size x num_img_channel x img_height x img_width]
output:
Output:
batch_C_prime : Predicted coordinates of fiducial points for
input batch [batch_size x num_fiducial x 2]
"""
@ -141,8 +153,13 @@ class LocalizationNetwork(nn.Module):
class GridGenerator(nn.Module):
"""Grid Generator of RARE, which produces P_prime by multipling T with
P."""
"""Grid Generator of RARE, which produces P_prime by multipling T with P.
Args:
num_fiducial (int): Number of fiducial points of TPS-STN.
rectified_img_size (tuple(int, int)):
Size (height, width) of the rectified image.
"""
def __init__(self, num_fiducial, rectified_img_size):
"""Generate P_hat and inv_delta_C for later."""

View File

@ -1,3 +1,4 @@
import pytest
import torch
from mmocr.models.textrecog.preprocessor import (BasePreprocessor,
@ -5,6 +6,15 @@ from mmocr.models.textrecog.preprocessor import (BasePreprocessor,
def test_tps_preprocessor():
with pytest.raises(AssertionError):
TPSPreprocessor(num_fiducial=-1)
with pytest.raises(AssertionError):
TPSPreprocessor(img_size=32)
with pytest.raises(AssertionError):
TPSPreprocessor(rectified_img_size=100)
with pytest.raises(AssertionError):
TPSPreprocessor(num_img_channel='bgr')
tps_preprocessor = TPSPreprocessor(
num_fiducial=20,
img_size=(32, 100),