mirror of https://github.com/open-mmlab/mmocr.git
[Model] Add SVTR framework and configs (#1621)
* [Model] Add SVTR framework and configs * update * update transform names * update base config * fix cfg * update cfgs * fix * update cfg * update decoder * fix encoder * fix encoder * fix * update cfg * update namepull/1663/head
parent
b0557c2c55
commit
0aa5d7be6d
|
@ -147,6 +147,7 @@ Supported algorithms:
|
|||
- [x] [RobustScanner](configs/textrecog/robust_scanner/README.md) (ECCV'2020)
|
||||
- [x] [SAR](configs/textrecog/sar/README.md) (AAAI'2019)
|
||||
- [x] [SATRN](configs/textrecog/satrn/README.md) (CVPR'2020 Workshop on Text and Documents in the Deep Learning Era)
|
||||
- [x] [SVTR](configs/textrecog/svtr/README.md) (IJCAI'2022)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -147,6 +147,7 @@ pip3 install -e .
|
|||
- [x] [RobustScanner](configs/textrecog/robust_scanner/README.md) (ECCV'2020)
|
||||
- [x] [SAR](configs/textrecog/sar/README.md) (AAAI'2019)
|
||||
- [x] [SATRN](configs/textrecog/satrn/README.md) (CVPR'2020 Workshop on Text and Documents in the Deep Learning Era)
|
||||
- [x] [SVTR](configs/textrecog/svtr/README.md) (IJCAI'2022)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
# SVTR
|
||||
|
||||
> [SVTR: Scene Text Recognition with a Single Visual Model](https://arxiv.org/abs/2205.00159)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Dominant scene text recognition models commonly contain two building blocks, a visual model for feature extraction and a sequence model for text transcription. This hybrid architecture, although accurate, is complex and less efficient. In this study, we propose a Single Visual model for Scene Text recognition within the patch-wise image tokenization framework, which dispenses with the sequential modeling entirely. The method, termed SVTR, firstly decomposes an image text into small patches named character components. Afterward, hierarchical stages are recurrently carried out by component-level mixing, merging and/or combining. Global and local mixing blocks are devised to perceive the inter-character and intra-character patterns, leading to a multi-grained character component perception. Thus, characters are recognized by a simple linear prediction. Experimental results on both English and Chinese scene text recognition tasks demonstrate the effectiveness of SVTR. SVTR-L (Large) achieves highly competitive accuracy in English and outperforms existing methods by a large margin in Chinese, while running faster. In addition, SVTR-T (Tiny) is an effective and much smaller model, which shows appealing speed at inference.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/22607038/210541576-025df5d5-f4d2-4037-82e0-246cf8cd3c25.png"/>
|
||||
</div>
|
||||
|
||||
## Dataset
|
||||
|
||||
### Train Dataset
|
||||
|
||||
| trainset | instance_num | repeat_num | source |
|
||||
| :-------: | :----------: | :--------: | :----: |
|
||||
| SynthText | 7266686 | 1 | synth |
|
||||
| Syn90k | 8919273 | 1 | synth |
|
||||
|
||||
### Test Dataset
|
||||
|
||||
| testset | instance_num | type |
|
||||
| :-----: | :----------: | :-------: |
|
||||
| IIIT5K | 3000 | regular |
|
||||
| SVT | 647 | regular |
|
||||
| IC13 | 1015 | regular |
|
||||
| IC15 | 2077 | irregular |
|
||||
| SVTP | 645 | irregular |
|
||||
| CT80 | 288 | irregular |
|
||||
|
||||
## Results and Models
|
||||
|
||||
| Methods | | Regular Text | | | | Irregular Text | | download |
|
||||
| :-----------------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :------------------------------------------------------------------------------: |
|
||||
| | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
|
||||
| [SVTR-tiny](/configs/textrecog/svtr/svtr-tiny_20e_st_mj.py) | - | - | - | | - | - | - | [model](<>) \| [log](<>) |
|
||||
| [SVTR-small](/configs/textrecog/svtr/svtr-small_20e_st_mj.py) | 0.8553 | 0.9026 | 0.9448 | | 0.7496 | 0.8496 | 0.8854 | [model](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/svtr-small_20e_st_mj-35d800d6.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/20230105_184454.log) |
|
||||
| [SVTR-base](/configs/textrecog/svtr/svtr-base_20e_st_mj.py) | 0.8570 | 0.9181 | 0.9438 | | 0.7448 | 0.8388 | 0.9028 | [model](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/svtr-base_20e_st_mj-ea500101.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/20221227_175415.log) |
|
||||
| [SVTR-large](/configs/textrecog/svtr/svtr-large_20e_st_mj.py) | - | - | - | | - | - | - | [model](<>) \| [log](<>) |
|
||||
|
||||
```{note}
|
||||
The implementation and configuration follow the original code and paper, but there is still a gap between the reproduced results and the official ones. We appreciate any suggestions to improve its performance.
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{ijcai2022p124,
|
||||
title = {SVTR: Scene Text Recognition with a Single Visual Model},
|
||||
author = {Du, Yongkun and Chen, Zhineng and Jia, Caiyan and Yin, Xiaoting and Zheng, Tianlun and Li, Chenxia and Du, Yuning and Jiang, Yu-Gang},
|
||||
booktitle = {Proceedings of the Thirty-First International Joint Conference on
|
||||
Artificial Intelligence, {IJCAI-22}},
|
||||
publisher = {International Joint Conferences on Artificial Intelligence Organization},
|
||||
editor = {Lud De Raedt},
|
||||
pages = {884--890},
|
||||
year = {2022},
|
||||
month = {7},
|
||||
note = {Main Track},
|
||||
doi = {10.24963/ijcai.2022/124},
|
||||
url = {https://doi.org/10.24963/ijcai.2022/124},
|
||||
}
|
||||
|
||||
```
|
|
@ -0,0 +1,38 @@
|
|||
dictionary = dict(
|
||||
type='Dictionary',
|
||||
dict_file='{{ fileDirname }}/../../../dicts/lower_english_digits.txt',
|
||||
with_padding=True,
|
||||
with_unknown=True,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type='SVTR',
|
||||
preprocessor=dict(
|
||||
type='STN',
|
||||
in_channels=3,
|
||||
resized_image_size=(32, 64),
|
||||
output_image_size=(32, 100),
|
||||
num_control_points=20,
|
||||
margins=[0.05, 0.05]),
|
||||
encoder=dict(
|
||||
type='SVTREncoder',
|
||||
img_size=[32, 100],
|
||||
in_channels=3,
|
||||
out_channels=192,
|
||||
embed_dims=[64, 128, 256],
|
||||
depth=[3, 6, 3],
|
||||
num_heads=[2, 4, 8],
|
||||
mixer_types=['Local'] * 6 + ['Global'] * 6,
|
||||
window_size=[[7, 11], [7, 11], [7, 11]],
|
||||
merging_types='Conv',
|
||||
prenorm=False,
|
||||
max_seq_len=25),
|
||||
decoder=dict(
|
||||
type='SVTRDecoder',
|
||||
in_channels=192,
|
||||
module_loss=dict(
|
||||
type='CTCModuleLoss', letter_case='lower', zero_infinity=True),
|
||||
postprocessor=dict(type='CTCPostProcessor'),
|
||||
dictionary=dictionary),
|
||||
data_preprocessor=dict(
|
||||
type='TextRecogDataPreprocessor', mean=[127.5], std=[127.5]))
|
|
@ -0,0 +1,89 @@
|
|||
Collections:
|
||||
- Name: SVTR
|
||||
Metadata:
|
||||
Training Data: OCRDataset
|
||||
Training Techniques:
|
||||
- AdamW
|
||||
Training Resources: 4x Tesla A100
|
||||
Epochs: 20
|
||||
Batch Size: 2048
|
||||
Architecture:
|
||||
- STN
|
||||
- SVTREncoder
|
||||
- SVTRDecoder
|
||||
Paper:
|
||||
URL: https://arxiv.org/pdf/2205.00159.pdf
|
||||
Title: 'SVTR: Scene Text Recognition with a Single Visual Model'
|
||||
README: configs/textrecog/svtr/README.md
|
||||
|
||||
Models:
|
||||
- Name: svtr-small_20e_st_mj
|
||||
Alias: svtr-small
|
||||
In Collection: SVTR
|
||||
Config: configs/textrecog/svtr/svtr-small_20e_st_mj.py
|
||||
Metadata:
|
||||
Training Data:
|
||||
- SynthText
|
||||
- Syn90k
|
||||
Results:
|
||||
- Task: Text Recognition
|
||||
Dataset: IIIT5K
|
||||
Metrics:
|
||||
word_acc: 0.8553
|
||||
- Task: Text Recognition
|
||||
Dataset: SVT
|
||||
Metrics:
|
||||
word_acc: 0.9026
|
||||
- Task: Text Recognition
|
||||
Dataset: ICDAR2013
|
||||
Metrics:
|
||||
word_acc: 0.9448
|
||||
- Task: Text Recognition
|
||||
Dataset: ICDAR2015
|
||||
Metrics:
|
||||
word_acc: 0.7496
|
||||
- Task: Text Recognition
|
||||
Dataset: SVTP
|
||||
Metrics:
|
||||
word_acc: 0.8496
|
||||
- Task: Text Recognition
|
||||
Dataset: CT80
|
||||
Metrics:
|
||||
word_acc: 0.8854
|
||||
Weights: https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/svtr-small_20e_st_mj-35d800d6.pth
|
||||
|
||||
- Name: svtr-base_20e_st_mj
|
||||
Alias: svtr-base
|
||||
Batch Size: 1024
|
||||
In Collection: SVTR
|
||||
Config: configs/textrecog/svtr/svtr-base_20e_st_mj.py
|
||||
Metadata:
|
||||
Training Data:
|
||||
- SynthText
|
||||
- Syn90k
|
||||
Results:
|
||||
- Task: Text Recognition
|
||||
Dataset: IIIT5K
|
||||
Metrics:
|
||||
word_acc: 0.8570
|
||||
- Task: Text Recognition
|
||||
Dataset: SVT
|
||||
Metrics:
|
||||
word_acc: 0.9181
|
||||
- Task: Text Recognition
|
||||
Dataset: ICDAR2013
|
||||
Metrics:
|
||||
word_acc: 0.9438
|
||||
- Task: Text Recognition
|
||||
Dataset: ICDAR2015
|
||||
Metrics:
|
||||
word_acc: 0.7448
|
||||
- Task: Text Recognition
|
||||
Dataset: SVTP
|
||||
Metrics:
|
||||
word_acc: 0.8388
|
||||
- Task: Text Recognition
|
||||
Dataset: CT80
|
||||
Metrics:
|
||||
word_acc: 0.9028
|
||||
Weights: https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/svtr-base_20e_st_mj-ea500101.pth
|
|
@ -0,0 +1,17 @@
|
|||
_base_ = [
|
||||
'svtr-tiny_20e_st_mj.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
preprocessor=dict(output_image_size=(48, 160), ),
|
||||
encoder=dict(
|
||||
img_size=[48, 160],
|
||||
max_seq_len=40,
|
||||
out_channels=256,
|
||||
embed_dims=[128, 256, 384],
|
||||
depth=[3, 6, 9],
|
||||
num_heads=[4, 8, 12],
|
||||
mixer_types=['Local'] * 8 + ['Global'] * 10),
|
||||
decoder=dict(in_channels=256))
|
||||
|
||||
train_dataloader = dict(batch_size=256, )
|
|
@ -0,0 +1,19 @@
|
|||
_base_ = [
|
||||
'svtr-tiny_20e_st_mj.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
preprocessor=dict(output_image_size=(48, 160), ),
|
||||
encoder=dict(
|
||||
img_size=[48, 160],
|
||||
max_seq_len=40,
|
||||
out_channels=384,
|
||||
embed_dims=[192, 256, 512],
|
||||
depth=[3, 9, 9],
|
||||
num_heads=[6, 8, 16],
|
||||
mixer_types=['Local'] * 10 + ['Global'] * 11),
|
||||
decoder=dict(in_channels=384))
|
||||
|
||||
train_dataloader = dict(batch_size=128, )
|
||||
|
||||
optim_wrapper = dict(optimizer=dict(lr=2.5 / (10**4)))
|
|
@ -0,0 +1,10 @@
|
|||
_base_ = [
|
||||
'svtr-tiny_20e_st_mj.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
encoder=dict(
|
||||
embed_dims=[96, 192, 256],
|
||||
depth=[3, 6, 6],
|
||||
num_heads=[3, 6, 8],
|
||||
mixer_types=['Local'] * 8 + ['Global'] * 7))
|
|
@ -0,0 +1,162 @@
|
|||
_base_ = [
|
||||
'_base_svtr-tiny.py',
|
||||
'../_base_/default_runtime.py',
|
||||
'../_base_/datasets/mjsynth.py',
|
||||
'../_base_/datasets/synthtext.py',
|
||||
'../_base_/datasets/cute80.py',
|
||||
'../_base_/datasets/iiit5k.py',
|
||||
'../_base_/datasets/svt.py',
|
||||
'../_base_/datasets/svtp.py',
|
||||
'../_base_/datasets/icdar2013.py',
|
||||
'../_base_/datasets/icdar2015.py',
|
||||
'../_base_/schedules/schedule_adam_base.py',
|
||||
]
|
||||
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=20, val_interval=1)
|
||||
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=5 / (10**4) * 2048 / 2048,
|
||||
betas=(0.9, 0.99),
|
||||
eps=8e-8,
|
||||
weight_decay=0.05))
|
||||
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=0.5,
|
||||
end_factor=1.,
|
||||
end=2,
|
||||
verbose=False,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=19,
|
||||
begin=2,
|
||||
end=20,
|
||||
verbose=False,
|
||||
convert_to_iter_based=True),
|
||||
]
|
||||
|
||||
file_client_args = dict(backend='disk')
|
||||
|
||||
train_pipeline = [
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
file_client_args=file_client_args,
|
||||
ignore_empty=True,
|
||||
min_size=5),
|
||||
dict(type='LoadOCRAnnotations', with_text=True),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(type='TextRecogGeneralAug', ),
|
||||
],
|
||||
),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(type='CropHeight', ),
|
||||
],
|
||||
),
|
||||
dict(
|
||||
type='ConditionApply',
|
||||
condition='min(results["img_shape"])>10',
|
||||
true_transforms=dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(
|
||||
type='TorchVisionWrapper',
|
||||
op='GaussianBlur',
|
||||
kernel_size=5,
|
||||
sigma=1,
|
||||
),
|
||||
],
|
||||
)),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(
|
||||
type='TorchVisionWrapper',
|
||||
op='ColorJitter',
|
||||
brightness=0.5,
|
||||
saturation=0.5,
|
||||
contrast=0.5,
|
||||
hue=0.1),
|
||||
]),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(type='ImageContentJitter', ),
|
||||
],
|
||||
),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(
|
||||
type='ImgAugWrapper',
|
||||
args=[dict(cls='AdditiveGaussianNoise', scale=0.1**0.5)]),
|
||||
],
|
||||
),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
prob=0.4,
|
||||
transforms=[
|
||||
dict(type='ReversePixels', ),
|
||||
],
|
||||
),
|
||||
dict(type='Resize', scale=(256, 64)),
|
||||
dict(
|
||||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='Resize', scale=(256, 64)),
|
||||
dict(type='LoadOCRAnnotations', with_text=True),
|
||||
dict(
|
||||
type='PackTextRecogInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
train_list = [_base_.mjsynth_textrecog_test, _base_.synthtext_textrecog_train]
|
||||
test_list = [
|
||||
_base_.cute80_textrecog_test, _base_.iiit5k_textrecog_test,
|
||||
_base_.svt_textrecog_test, _base_.svtp_textrecog_test,
|
||||
_base_.icdar2013_textrecog_test, _base_.icdar2015_textrecog_test
|
||||
]
|
||||
|
||||
val_evaluator = dict(
|
||||
dataset_prefixes=['CUTE80', 'IIIT5K', 'SVT', 'SVTP', 'IC13', 'IC15'])
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=512,
|
||||
num_workers=24,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type='ConcatDataset', datasets=train_list, pipeline=train_pipeline))
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=128,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type='ConcatDataset', datasets=test_list, pipeline=test_pipeline))
|
||||
|
||||
test_dataloader = val_dataloader
|
|
@ -56,9 +56,9 @@ class SVTRDecoder(BaseDecoder):
|
|||
"""Forward for training.
|
||||
|
||||
Args:
|
||||
feat (torch.Tensor, optional): The feature map from backbone of
|
||||
shape :math:`(N, E, H, W)`. Defaults to None.
|
||||
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
|
||||
feat (torch.Tensor, optional): The feature map. Defaults to None.
|
||||
out_enc (torch.Tensor, optional): Encoder output from encoder of
|
||||
shape :math:`(N, 1, H, W)`. Defaults to None.
|
||||
data_samples (Sequence[TextRecogDataSample]): Batch of
|
||||
TextRecogDataSample, containing gt_text information. Defaults
|
||||
to None.
|
||||
|
@ -67,8 +67,8 @@ class SVTRDecoder(BaseDecoder):
|
|||
Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where
|
||||
:math:`C` is ``num_classes``.
|
||||
"""
|
||||
assert feat.size(2) == 1, 'feature height must be 1'
|
||||
x = feat.squeeze(2)
|
||||
assert out_enc.size(2) == 1, 'feature height must be 1'
|
||||
x = out_enc.squeeze(2)
|
||||
x = x.permute(0, 2, 1)
|
||||
predicts = self.decoder(x)
|
||||
return predicts
|
||||
|
@ -82,9 +82,9 @@ class SVTRDecoder(BaseDecoder):
|
|||
"""Forward for testing.
|
||||
|
||||
Args:
|
||||
feat (torch.Tensor, optional): The feature map from backbone of
|
||||
shape :math:`(N, E, H, W)`. Defaults to None.
|
||||
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
|
||||
feat (torch.Tensor, optional): The feature map. Defaults to None.
|
||||
out_enc (torch.Tensor, optional): Encoder output from encoder of
|
||||
shape :math:`(N, 1, H, W)`. Defaults to None.
|
||||
data_samples (Sequence[TextRecogDataSample]): Batch of
|
||||
TextRecogDataSample, containing gt_text information. Defaults
|
||||
to None.
|
||||
|
|
|
@ -11,6 +11,7 @@ from mmengine.model import BaseModule
|
|||
from mmengine.model.weight_init import trunc_normal_init
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
|
||||
|
||||
class OverlapPatchEmbed(BaseModule):
|
||||
|
@ -612,11 +613,16 @@ class SVTREncoder(BaseModule):
|
|||
x = self.layer_norm(x)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
data_samples: List[TextRecogDataSample] = None
|
||||
) -> torch.Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): A Tensor of shape :math:`(N, H/16, W/4, 256)`.
|
||||
data_samples (list[TextRecogDataSample]): Batch of
|
||||
TextRecogDataSample. Defaults to None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A Tensor of shape :math:`(N, 1, W/4, 192)`.
|
||||
|
|
|
@ -9,8 +9,9 @@ from .nrtr import NRTR
|
|||
from .robust_scanner import RobustScanner
|
||||
from .sar import SARNet
|
||||
from .satrn import SATRN
|
||||
from .svtr import SVTR
|
||||
|
||||
__all__ = [
|
||||
'BaseRecognizer', 'EncoderDecoderRecognizer', 'CRNN', 'SARNet', 'NRTR',
|
||||
'RobustScanner', 'SATRN', 'ABINet', 'MASTER', 'ASTER'
|
||||
'RobustScanner', 'SATRN', 'ABINet', 'MASTER', 'SVTR', 'ASTER'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.registry import MODELS
|
||||
from .encoder_decoder_recognizer import EncoderDecoderRecognizer
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SVTR(EncoderDecoderRecognizer):
|
||||
"""A PyTorch implementation of : `SVTR: Scene Text Recognition with a
|
||||
Single Visual Model <https://arxiv.org/abs/2205.00159>`_"""
|
|
@ -12,6 +12,7 @@ Import:
|
|||
- configs/textrecog/crnn/metafile.yml
|
||||
- configs/textrecog/master/metafile.yml
|
||||
- configs/textrecog/nrtr/metafile.yml
|
||||
- configs/textrecog/svtr/metafile.yml
|
||||
- configs/textrecog/robust_scanner/metafile.yml
|
||||
- configs/textrecog/sar/metafile.yml
|
||||
- configs/textrecog/satrn/metafile.yml
|
||||
|
|
|
@ -46,7 +46,7 @@ class TestSVTRDecoder(TestCase):
|
|||
in_channels=192, dictionary=dict_cfg, module_loss=loss_cfg)
|
||||
|
||||
def test_forward_train(self):
|
||||
feat = torch.randn(1, 192, 1, 25)
|
||||
out_enc = torch.randn(1, 192, 1, 25)
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
max_seq_len = 25
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
|
@ -67,11 +67,12 @@ class TestSVTRDecoder(TestCase):
|
|||
max_seq_len=max_seq_len,
|
||||
)
|
||||
data_samples = decoder.module_loss.get_targets(self.data_info)
|
||||
output = decoder.forward_train(feat=feat, data_samples=data_samples)
|
||||
output = decoder.forward_train(
|
||||
out_enc=out_enc, data_samples=data_samples)
|
||||
self.assertTupleEqual(tuple(output.shape), (1, max_seq_len, 39))
|
||||
|
||||
def test_forward_test(self):
|
||||
feat = torch.randn(1, 192, 1, 25)
|
||||
out_enc = torch.randn(1, 192, 1, 25)
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
create_dummy_dict_file(dict_file)
|
||||
|
@ -90,5 +91,6 @@ class TestSVTRDecoder(TestCase):
|
|||
dictionary=dict_cfg,
|
||||
module_loss=loss_cfg,
|
||||
max_seq_len=25)
|
||||
output = decoder.forward_test(feat=feat, data_samples=self.data_info)
|
||||
output = decoder.forward_test(
|
||||
out_enc=out_enc, data_samples=self.data_info)
|
||||
self.assertTupleEqual(tuple(output.shape), (1, 25, 39))
|
||||
|
|
Loading…
Reference in New Issue