Feature/iss 33 (#34)

* fix #33: update dataset.md

* fix #33: pytest for transformer related
pull/2/head
Hongbin Sun 2021-04-05 23:54:57 +08:00 committed by GitHub
parent ff1fc429cd
commit 0f00378f9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 41 additions and 76 deletions

View File

@ -48,7 +48,7 @@ v1.0 was released on 07/04/2021.
## Benchmark and Model Zoo
Please refer to [MODEL_ZOO.md](MODEL_ZOO.md) for more details.
Please refer to [modelzoo.md](modelzoo.md) for more details.
## Installation

View File

@ -2,7 +2,7 @@ label_convertor = dict(
type='AttnConvertor', dict_type='DICT36', with_unknown=True, lower=True)
model = dict(
type='TransformerNet',
type='NRTR',
backbone=dict(type='NRTRModalityTransform'),
encoder=dict(type='TFEncoder'),
decoder=dict(type='TFDecoder'),

View File

@ -1,32 +0,0 @@
## Introduction
[ALGORITHM]
### Train Dataset
| trainset | instance_num | repeat_num | note |
| :--------: | :----------: | :--------: | :---: |
| icdar_2011 | 3567 | 20 | real |
| icdar_2013 | 848 | 20 | real |
| icdar2015 | 4468 | 20 | real |
| coco_text | 42142 | 20 | real |
| IIIT5K | 2000 | 20 | real |
| SynthText | 2400000 | 1 | synth |
### Test Dataset
| testset | instance_num | note |
| :-----: | :----------: | :-------------------------: |
| IIIT5K | 3000 | regular |
| SVT | 647 | regular |
| IC13 | 1015 | regular |
| IC15 | 2077 | irregular |
| SVTP | 645 | irregular, 639 in [[1]](#1) |
| CT80 | 288 | irregular |
## Results and models
| methods | | Regular Text | | | | Irregular Text | | download |
| :---------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :------------------: |
| | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
| Transformer | 93.3 | 85.8 | 91.3 | | 73.2 | 76.6 | 87.8 | [model]() \| [log]() |

View File

@ -1,12 +0,0 @@
_base_ = [
'../../_base_/default_runtime.py',
'../../_base_/recog_models/transformer.py',
'../../_base_/recog_datasets/toy_dataset.py'
]
# optimizer
optimizer = dict(type='Adadelta', lr=1)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[3, 4])
total_epochs = 5

View File

@ -92,10 +92,12 @@ This page lists the datasets which are commonly used in text detection, text rec
│ │ ├── image
│   ├── Synth90k
│ │ ├── shuffle_labels.txt
│ │ ├── label.lmdb
│ │ ├── mnt
│   ├── SynthText
│ │ ├── shuffle_labels.txt
│ │ ├── instances_train.txt
│ │ ├── label.lmdb
│ │ ├── synthtext
│   ├── SynthAdd
│ │ ├── label.txt
@ -113,9 +115,9 @@ This page lists the datasets which are commonly used in text detection, text rec
| ct80 | | - |-|[test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/ct80/test_label.txt)||
| svt | | [homepage](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset) | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svt/test_label.txt) | |
| svtp | | - | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svtp/test_label.txt) | |
| Synth90k | | [homepage](https://www.robots.ox.ac.uk/~vgg/data/text/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/Synth90k/shuffle_labels.txt) | - | |
| SynthText | | [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) \| [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt) | - | |
| SynthAdd | | [SynthText_Add.zip](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/SynthText_Add.zip) | [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/label.txt)|- | |
| Synth90k | | [homepage](https://www.robots.ox.ac.uk/~vgg/data/text/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/Synth90k/shuffle_labels.txt) \| [label.lmdb](https://download.openmmlab.com/mmocr/data/mixture/Synth90k/label.lmdb) | - | |
| SynthText | | [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) \| [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt) \| [label.lmdb](https://download.openmmlab.com/mmocr/data/mixture/SynthText/label.lmdb) | - | |
| SynthAdd | | [SynthText_Add.zip](https://pan.baidu.com/s/1uV0LtoNmcxbO-0YA7Ch4dg) (code:627x) | [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/label.txt)|- | |
- For `icdar_2013`:
- Step1: Download `Challenge2_Test_Task3_Images.zip` and `Challenge2_Training_Task3_Images_GT.zip` from [homepage](https://rrc.cvc.uab.es/?ch=2&com=downloads)
@ -173,7 +175,7 @@ This page lists the datasets which are commonly used in text detection, text rec
ln -s /path/to/SynthText SynthText
```
- For `SynthAdd`:
- Step1: Download `SynthText_Add.zip` from [SynthText_Add.zip](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/SynthText_Add.zip)
- Step1: Download `SynthText_Add.zip` from [SynthAdd](https://pan.baidu.com/s/1uV0LtoNmcxbO-0YA7Ch4dg) (code:627x))
- Step2: Download [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/label.txt)
- Step3:
```bash

View File

@ -25,9 +25,6 @@ def model_inference(model, img):
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
# just get the actual data from DataContainer
# data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]

View File

@ -1,7 +1,9 @@
import mmocr.utils as utils
import numpy as np
from shapely.geometry import LineString, Point, Polygon
import mmocr.utils as utils
def sort_vertex(points_x, points_y):
"""Sort box vertices in clockwise order from left-top first.

View File

@ -72,9 +72,11 @@ class TFDecoder(BaseDecoder):
return output
def forward_train(self, feat, out_enc, targets_dict, img_metas):
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
]
valid_ratios = None
if img_metas is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
]
n, c, h, w = out_enc.size()
src_mask = None
if valid_ratios is not None:
@ -91,9 +93,11 @@ class TFDecoder(BaseDecoder):
return outputs
def forward_test(self, feat, out_enc, img_metas):
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
]
valid_ratios = None
if img_metas is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
]
n, c, h, w = out_enc.size()
src_mask = None
if valid_ratios is not None:

View File

@ -30,9 +30,11 @@ class TFEncoder(BaseEncoder):
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, feat, img_metas=None):
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
]
valid_ratios = [1.0 for _ in range(feat.size(0))]
if img_metas is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
]
n, c, h, w = feat.size()
mask = feat.new_zeros((n, h, w))
for i, valid_ratio in enumerate(valid_ratios):

View File

@ -138,7 +138,6 @@ class MultiHeadAttention(nn.Module):
self.proj_drop = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
assert mask.dim() in [2, 3, 4]
batch_size, len_q, _ = q.size()
_, len_k, _ = k.size()

View File

@ -2,12 +2,12 @@ from .base import BaseRecognizer
from .cafcn import CAFCNNet
from .crnn import CRNNNet
from .encode_decode_recognizer import EncodeDecodeRecognizer
from .nrtr import NRTR
from .robust_scanner import RobustScanner
from .sar import SARNet
from .seg_recognizer import SegRecognizer
from .transformer import TransformerNet
__all__ = [
'BaseRecognizer', 'EncodeDecodeRecognizer', 'CRNNNet', 'SARNet',
'TransformerNet', 'SegRecognizer', 'RobustScanner', 'CAFCNNet'
'BaseRecognizer', 'EncodeDecodeRecognizer', 'CRNNNet', 'SARNet', 'NRTR',
'SegRecognizer', 'RobustScanner', 'CAFCNNet'
]

View File

@ -3,5 +3,5 @@ from .encode_decode_recognizer import EncodeDecodeRecognizer
@DETECTORS.register_module()
class TransformerNet(EncodeDecodeRecognizer):
"""Implementation of Transformer based OCR."""
class NRTR(EncodeDecodeRecognizer):
"""Implementation of `NRTR <https://arxiv.org/pdf/1806.00926.pdf>`_"""

View File

@ -3,8 +3,9 @@ import math
import numpy as np
import pytest
from mmocr.datasets.pipelines.crop import (box_jitter, convert_canonical,
crop_img, sort_vertex, warp_img)
from mmocr.datasets.pipelines.box_utils import convert_canonical
from mmocr.datasets.pipelines.crop import (box_jitter, crop_img, sort_vertex,
warp_img)
def test_order_vertex():

View File

@ -100,7 +100,7 @@ def test_transformer_decoder():
decoder.init_weights()
decoder.train()
out_enc = torch.rand(1, 128, 512)
out_enc = torch.rand(1, 512, 1, 25)
tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])}
img_metas = [{'valid_ratio': 1.0}]
tgt_dict['padded_targets'] = tgt_dict['padded_targets']

View File

@ -38,9 +38,10 @@ def test_transformer_encoder():
tf_encoder.init_weights()
tf_encoder.train()
feat = torch.randn(1, 512, 4, 40)
feat = torch.randn(1, 512, 1, 25)
out_enc = tf_encoder(feat)
assert out_enc.shape == torch.Size([1, 160, 512])
print('hello', out_enc.size())
assert out_enc.shape == torch.Size([1, 512, 1, 25])
def test_base_encoder():

View File

@ -1,7 +1,8 @@
import torch
from mmocr.models.textrecog.layers import (BasicBlock, Bottleneck,
DecoderLayer, PositionalEncoding,
PositionalEncoding,
TransformerDecoderLayer,
get_pad_mask, get_subsequent_mask)
from mmocr.models.textrecog.layers.conv_layer import conv3x3
@ -32,7 +33,7 @@ def test_conv_layer():
def test_transformer_layer():
# test decoder_layer
decoder_layer = DecoderLayer()
decoder_layer = TransformerDecoderLayer()
in_dec = torch.rand(1, 30, 512)
out_enc = torch.rand(1, 128, 512)
out_dec = decoder_layer(in_dec, out_enc)