mirror of https://github.com/open-mmlab/mmocr.git
Feature/iss 33 (#34)
* fix #33: update dataset.md * fix #33: pytest for transformer relatedpull/2/head
parent
ff1fc429cd
commit
0f00378f9a
|
@ -48,7 +48,7 @@ v1.0 was released on 07/04/2021.
|
|||
|
||||
## Benchmark and Model Zoo
|
||||
|
||||
Please refer to [MODEL_ZOO.md](MODEL_ZOO.md) for more details.
|
||||
Please refer to [modelzoo.md](modelzoo.md) for more details.
|
||||
|
||||
## Installation
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ label_convertor = dict(
|
|||
type='AttnConvertor', dict_type='DICT36', with_unknown=True, lower=True)
|
||||
|
||||
model = dict(
|
||||
type='TransformerNet',
|
||||
type='NRTR',
|
||||
backbone=dict(type='NRTRModalityTransform'),
|
||||
encoder=dict(type='TFEncoder'),
|
||||
decoder=dict(type='TFDecoder'),
|
||||
|
|
|
@ -1,32 +0,0 @@
|
|||
## Introduction
|
||||
|
||||
[ALGORITHM]
|
||||
|
||||
### Train Dataset
|
||||
|
||||
| trainset | instance_num | repeat_num | note |
|
||||
| :--------: | :----------: | :--------: | :---: |
|
||||
| icdar_2011 | 3567 | 20 | real |
|
||||
| icdar_2013 | 848 | 20 | real |
|
||||
| icdar2015 | 4468 | 20 | real |
|
||||
| coco_text | 42142 | 20 | real |
|
||||
| IIIT5K | 2000 | 20 | real |
|
||||
| SynthText | 2400000 | 1 | synth |
|
||||
|
||||
### Test Dataset
|
||||
|
||||
| testset | instance_num | note |
|
||||
| :-----: | :----------: | :-------------------------: |
|
||||
| IIIT5K | 3000 | regular |
|
||||
| SVT | 647 | regular |
|
||||
| IC13 | 1015 | regular |
|
||||
| IC15 | 2077 | irregular |
|
||||
| SVTP | 645 | irregular, 639 in [[1]](#1) |
|
||||
| CT80 | 288 | irregular |
|
||||
|
||||
## Results and models
|
||||
|
||||
| methods | | Regular Text | | | | Irregular Text | | download |
|
||||
| :---------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :------------------: |
|
||||
| | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
|
||||
| Transformer | 93.3 | 85.8 | 91.3 | | 73.2 | 76.6 | 87.8 | [model]() \| [log]() |
|
|
@ -1,12 +0,0 @@
|
|||
_base_ = [
|
||||
'../../_base_/default_runtime.py',
|
||||
'../../_base_/recog_models/transformer.py',
|
||||
'../../_base_/recog_datasets/toy_dataset.py'
|
||||
]
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='Adadelta', lr=1)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[3, 4])
|
||||
total_epochs = 5
|
|
@ -92,10 +92,12 @@ This page lists the datasets which are commonly used in text detection, text rec
|
|||
│ │ ├── image
|
||||
│ ├── Synth90k
|
||||
│ │ ├── shuffle_labels.txt
|
||||
│ │ ├── label.lmdb
|
||||
│ │ ├── mnt
|
||||
│ ├── SynthText
|
||||
│ │ ├── shuffle_labels.txt
|
||||
│ │ ├── instances_train.txt
|
||||
│ │ ├── label.lmdb
|
||||
│ │ ├── synthtext
|
||||
│ ├── SynthAdd
|
||||
│ │ ├── label.txt
|
||||
|
@ -113,9 +115,9 @@ This page lists the datasets which are commonly used in text detection, text rec
|
|||
| ct80 | | - |-|[test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/ct80/test_label.txt)||
|
||||
| svt | | [homepage](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset) | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svt/test_label.txt) | |
|
||||
| svtp | | - | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svtp/test_label.txt) | |
|
||||
| Synth90k | | [homepage](https://www.robots.ox.ac.uk/~vgg/data/text/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/Synth90k/shuffle_labels.txt) | - | |
|
||||
| SynthText | | [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) \| [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt) | - | |
|
||||
| SynthAdd | | [SynthText_Add.zip](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/SynthText_Add.zip) | [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/label.txt)|- | |
|
||||
| Synth90k | | [homepage](https://www.robots.ox.ac.uk/~vgg/data/text/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/Synth90k/shuffle_labels.txt) \| [label.lmdb](https://download.openmmlab.com/mmocr/data/mixture/Synth90k/label.lmdb) | - | |
|
||||
| SynthText | | [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) \| [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt) \| [label.lmdb](https://download.openmmlab.com/mmocr/data/mixture/SynthText/label.lmdb) | - | |
|
||||
| SynthAdd | | [SynthText_Add.zip](https://pan.baidu.com/s/1uV0LtoNmcxbO-0YA7Ch4dg) (code:627x) | [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/label.txt)|- | |
|
||||
|
||||
- For `icdar_2013`:
|
||||
- Step1: Download `Challenge2_Test_Task3_Images.zip` and `Challenge2_Training_Task3_Images_GT.zip` from [homepage](https://rrc.cvc.uab.es/?ch=2&com=downloads)
|
||||
|
@ -173,7 +175,7 @@ This page lists the datasets which are commonly used in text detection, text rec
|
|||
ln -s /path/to/SynthText SynthText
|
||||
```
|
||||
- For `SynthAdd`:
|
||||
- Step1: Download `SynthText_Add.zip` from [SynthText_Add.zip](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/SynthText_Add.zip)
|
||||
- Step1: Download `SynthText_Add.zip` from [SynthAdd](https://pan.baidu.com/s/1uV0LtoNmcxbO-0YA7Ch4dg) (code:627x))
|
||||
- Step2: Download [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/label.txt)
|
||||
- Step3:
|
||||
```bash
|
||||
|
|
|
@ -25,9 +25,6 @@ def model_inference(model, img):
|
|||
data = test_pipeline(data)
|
||||
data = collate([data], samples_per_gpu=1)
|
||||
|
||||
# just get the actual data from DataContainer
|
||||
# data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
|
||||
|
||||
if next(model.parameters()).is_cuda:
|
||||
# scatter to specified GPU
|
||||
data = scatter(data, [device])[0]
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import mmocr.utils as utils
|
||||
import numpy as np
|
||||
from shapely.geometry import LineString, Point, Polygon
|
||||
|
||||
import mmocr.utils as utils
|
||||
|
||||
|
||||
def sort_vertex(points_x, points_y):
|
||||
"""Sort box vertices in clockwise order from left-top first.
|
||||
|
||||
|
|
|
@ -72,9 +72,11 @@ class TFDecoder(BaseDecoder):
|
|||
return output
|
||||
|
||||
def forward_train(self, feat, out_enc, targets_dict, img_metas):
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
]
|
||||
valid_ratios = None
|
||||
if img_metas is not None:
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
]
|
||||
n, c, h, w = out_enc.size()
|
||||
src_mask = None
|
||||
if valid_ratios is not None:
|
||||
|
@ -91,9 +93,11 @@ class TFDecoder(BaseDecoder):
|
|||
return outputs
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
]
|
||||
valid_ratios = None
|
||||
if img_metas is not None:
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
]
|
||||
n, c, h, w = out_enc.size()
|
||||
src_mask = None
|
||||
if valid_ratios is not None:
|
||||
|
|
|
@ -30,9 +30,11 @@ class TFEncoder(BaseEncoder):
|
|||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, feat, img_metas=None):
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
]
|
||||
valid_ratios = [1.0 for _ in range(feat.size(0))]
|
||||
if img_metas is not None:
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
]
|
||||
n, c, h, w = feat.size()
|
||||
mask = feat.new_zeros((n, h, w))
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
|
|
|
@ -138,7 +138,6 @@ class MultiHeadAttention(nn.Module):
|
|||
self.proj_drop = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
assert mask.dim() in [2, 3, 4]
|
||||
batch_size, len_q, _ = q.size()
|
||||
_, len_k, _ = k.size()
|
||||
|
||||
|
|
|
@ -2,12 +2,12 @@ from .base import BaseRecognizer
|
|||
from .cafcn import CAFCNNet
|
||||
from .crnn import CRNNNet
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
from .nrtr import NRTR
|
||||
from .robust_scanner import RobustScanner
|
||||
from .sar import SARNet
|
||||
from .seg_recognizer import SegRecognizer
|
||||
from .transformer import TransformerNet
|
||||
|
||||
__all__ = [
|
||||
'BaseRecognizer', 'EncodeDecodeRecognizer', 'CRNNNet', 'SARNet',
|
||||
'TransformerNet', 'SegRecognizer', 'RobustScanner', 'CAFCNNet'
|
||||
'BaseRecognizer', 'EncodeDecodeRecognizer', 'CRNNNet', 'SARNet', 'NRTR',
|
||||
'SegRecognizer', 'RobustScanner', 'CAFCNNet'
|
||||
]
|
||||
|
|
|
@ -3,5 +3,5 @@ from .encode_decode_recognizer import EncodeDecodeRecognizer
|
|||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class TransformerNet(EncodeDecodeRecognizer):
|
||||
"""Implementation of Transformer based OCR."""
|
||||
class NRTR(EncodeDecodeRecognizer):
|
||||
"""Implementation of `NRTR <https://arxiv.org/pdf/1806.00926.pdf>`_"""
|
|
@ -3,8 +3,9 @@ import math
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mmocr.datasets.pipelines.crop import (box_jitter, convert_canonical,
|
||||
crop_img, sort_vertex, warp_img)
|
||||
from mmocr.datasets.pipelines.box_utils import convert_canonical
|
||||
from mmocr.datasets.pipelines.crop import (box_jitter, crop_img, sort_vertex,
|
||||
warp_img)
|
||||
|
||||
|
||||
def test_order_vertex():
|
||||
|
|
|
@ -100,7 +100,7 @@ def test_transformer_decoder():
|
|||
decoder.init_weights()
|
||||
decoder.train()
|
||||
|
||||
out_enc = torch.rand(1, 128, 512)
|
||||
out_enc = torch.rand(1, 512, 1, 25)
|
||||
tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])}
|
||||
img_metas = [{'valid_ratio': 1.0}]
|
||||
tgt_dict['padded_targets'] = tgt_dict['padded_targets']
|
||||
|
|
|
@ -38,9 +38,10 @@ def test_transformer_encoder():
|
|||
tf_encoder.init_weights()
|
||||
tf_encoder.train()
|
||||
|
||||
feat = torch.randn(1, 512, 4, 40)
|
||||
feat = torch.randn(1, 512, 1, 25)
|
||||
out_enc = tf_encoder(feat)
|
||||
assert out_enc.shape == torch.Size([1, 160, 512])
|
||||
print('hello', out_enc.size())
|
||||
assert out_enc.shape == torch.Size([1, 512, 1, 25])
|
||||
|
||||
|
||||
def test_base_encoder():
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import torch
|
||||
|
||||
from mmocr.models.textrecog.layers import (BasicBlock, Bottleneck,
|
||||
DecoderLayer, PositionalEncoding,
|
||||
PositionalEncoding,
|
||||
TransformerDecoderLayer,
|
||||
get_pad_mask, get_subsequent_mask)
|
||||
from mmocr.models.textrecog.layers.conv_layer import conv3x3
|
||||
|
||||
|
@ -32,7 +33,7 @@ def test_conv_layer():
|
|||
|
||||
def test_transformer_layer():
|
||||
# test decoder_layer
|
||||
decoder_layer = DecoderLayer()
|
||||
decoder_layer = TransformerDecoderLayer()
|
||||
in_dec = torch.rand(1, 30, 512)
|
||||
out_enc = torch.rand(1, 128, 512)
|
||||
out_dec = decoder_layer(in_dec, out_enc)
|
||||
|
|
Loading…
Reference in New Issue