diff --git a/README.md b/README.md
index 77f155f5..e466530b 100644
--- a/README.md
+++ b/README.md
@@ -147,6 +147,7 @@ Supported algorithms:
- [x] [RobustScanner](configs/textrecog/robust_scanner/README.md) (ECCV'2020)
- [x] [SAR](configs/textrecog/sar/README.md) (AAAI'2019)
- [x] [SATRN](configs/textrecog/satrn/README.md) (CVPR'2020 Workshop on Text and Documents in the Deep Learning Era)
+- [x] [SVTR](configs/textrecog/svtr/README.md) (IJCAI'2022)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index f19319a1..e681e8be 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -147,6 +147,7 @@ pip3 install -e .
- [x] [RobustScanner](configs/textrecog/robust_scanner/README.md) (ECCV'2020)
- [x] [SAR](configs/textrecog/sar/README.md) (AAAI'2019)
- [x] [SATRN](configs/textrecog/satrn/README.md) (CVPR'2020 Workshop on Text and Documents in the Deep Learning Era)
+- [x] [SVTR](configs/textrecog/svtr/README.md) (IJCAI'2022)
diff --git a/configs/textrecog/svtr/README.md b/configs/textrecog/svtr/README.md
new file mode 100644
index 00000000..80c8317f
--- /dev/null
+++ b/configs/textrecog/svtr/README.md
@@ -0,0 +1,67 @@
+# SVTR
+
+> [SVTR: Scene Text Recognition with a Single Visual Model](https://arxiv.org/abs/2205.00159)
+
+
+
+## Abstract
+
+Dominant scene text recognition models commonly contain two building blocks, a visual model for feature extraction and a sequence model for text transcription. This hybrid architecture, although accurate, is complex and less efficient. In this study, we propose a Single Visual model for Scene Text recognition within the patch-wise image tokenization framework, which dispenses with the sequential modeling entirely. The method, termed SVTR, firstly decomposes an image text into small patches named character components. Afterward, hierarchical stages are recurrently carried out by component-level mixing, merging and/or combining. Global and local mixing blocks are devised to perceive the inter-character and intra-character patterns, leading to a multi-grained character component perception. Thus, characters are recognized by a simple linear prediction. Experimental results on both English and Chinese scene text recognition tasks demonstrate the effectiveness of SVTR. SVTR-L (Large) achieves highly competitive accuracy in English and outperforms existing methods by a large margin in Chinese, while running faster. In addition, SVTR-T (Tiny) is an effective and much smaller model, which shows appealing speed at inference.
+
+
+

+
+
+## Dataset
+
+### Train Dataset
+
+| trainset | instance_num | repeat_num | source |
+| :-------: | :----------: | :--------: | :----: |
+| SynthText | 7266686 | 1 | synth |
+| Syn90k | 8919273 | 1 | synth |
+
+### Test Dataset
+
+| testset | instance_num | type |
+| :-----: | :----------: | :-------: |
+| IIIT5K | 3000 | regular |
+| SVT | 647 | regular |
+| IC13 | 1015 | regular |
+| IC15 | 2077 | irregular |
+| SVTP | 645 | irregular |
+| CT80 | 288 | irregular |
+
+## Results and Models
+
+| Methods | | Regular Text | | | | Irregular Text | | download |
+| :-----------------------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :----: | :------------------------------------------------------------------------------: |
+| | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
+| [SVTR-tiny](/configs/textrecog/svtr/svtr-tiny_20e_st_mj.py) | - | - | - | | - | - | - | [model](<>) \| [log](<>) |
+| [SVTR-small](/configs/textrecog/svtr/svtr-small_20e_st_mj.py) | 0.8553 | 0.9026 | 0.9448 | | 0.7496 | 0.8496 | 0.8854 | [model](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/svtr-small_20e_st_mj-35d800d6.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/20230105_184454.log) |
+| [SVTR-base](/configs/textrecog/svtr/svtr-base_20e_st_mj.py) | 0.8570 | 0.9181 | 0.9438 | | 0.7448 | 0.8388 | 0.9028 | [model](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/svtr-base_20e_st_mj-ea500101.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/20221227_175415.log) |
+| [SVTR-large](/configs/textrecog/svtr/svtr-large_20e_st_mj.py) | - | - | - | | - | - | - | [model](<>) \| [log](<>) |
+
+```{note}
+The implementation and configuration follow the original code and paper, but there is still a gap between the reproduced results and the official ones. We appreciate any suggestions to improve its performance.
+```
+
+## Citation
+
+```bibtex
+@inproceedings{ijcai2022p124,
+ title = {SVTR: Scene Text Recognition with a Single Visual Model},
+ author = {Du, Yongkun and Chen, Zhineng and Jia, Caiyan and Yin, Xiaoting and Zheng, Tianlun and Li, Chenxia and Du, Yuning and Jiang, Yu-Gang},
+ booktitle = {Proceedings of the Thirty-First International Joint Conference on
+ Artificial Intelligence, {IJCAI-22}},
+ publisher = {International Joint Conferences on Artificial Intelligence Organization},
+ editor = {Lud De Raedt},
+ pages = {884--890},
+ year = {2022},
+ month = {7},
+ note = {Main Track},
+ doi = {10.24963/ijcai.2022/124},
+ url = {https://doi.org/10.24963/ijcai.2022/124},
+}
+
+```
diff --git a/configs/textrecog/svtr/_base_svtr-tiny.py b/configs/textrecog/svtr/_base_svtr-tiny.py
new file mode 100644
index 00000000..dcfd7867
--- /dev/null
+++ b/configs/textrecog/svtr/_base_svtr-tiny.py
@@ -0,0 +1,38 @@
+dictionary = dict(
+ type='Dictionary',
+ dict_file='{{ fileDirname }}/../../../dicts/lower_english_digits.txt',
+ with_padding=True,
+ with_unknown=True,
+)
+
+model = dict(
+ type='SVTR',
+ preprocessor=dict(
+ type='STN',
+ in_channels=3,
+ resized_image_size=(32, 64),
+ output_image_size=(32, 100),
+ num_control_points=20,
+ margins=[0.05, 0.05]),
+ encoder=dict(
+ type='SVTREncoder',
+ img_size=[32, 100],
+ in_channels=3,
+ out_channels=192,
+ embed_dims=[64, 128, 256],
+ depth=[3, 6, 3],
+ num_heads=[2, 4, 8],
+ mixer_types=['Local'] * 6 + ['Global'] * 6,
+ window_size=[[7, 11], [7, 11], [7, 11]],
+ merging_types='Conv',
+ prenorm=False,
+ max_seq_len=25),
+ decoder=dict(
+ type='SVTRDecoder',
+ in_channels=192,
+ module_loss=dict(
+ type='CTCModuleLoss', letter_case='lower', zero_infinity=True),
+ postprocessor=dict(type='CTCPostProcessor'),
+ dictionary=dictionary),
+ data_preprocessor=dict(
+ type='TextRecogDataPreprocessor', mean=[127.5], std=[127.5]))
diff --git a/configs/textrecog/svtr/metafile.yml b/configs/textrecog/svtr/metafile.yml
new file mode 100644
index 00000000..ff690156
--- /dev/null
+++ b/configs/textrecog/svtr/metafile.yml
@@ -0,0 +1,89 @@
+Collections:
+- Name: SVTR
+ Metadata:
+ Training Data: OCRDataset
+ Training Techniques:
+ - AdamW
+ Training Resources: 4x Tesla A100
+ Epochs: 20
+ Batch Size: 2048
+ Architecture:
+ - STN
+ - SVTREncoder
+ - SVTRDecoder
+ Paper:
+ URL: https://arxiv.org/pdf/2205.00159.pdf
+ Title: 'SVTR: Scene Text Recognition with a Single Visual Model'
+ README: configs/textrecog/svtr/README.md
+
+Models:
+ - Name: svtr-small_20e_st_mj
+ Alias: svtr-small
+ In Collection: SVTR
+ Config: configs/textrecog/svtr/svtr-small_20e_st_mj.py
+ Metadata:
+ Training Data:
+ - SynthText
+ - Syn90k
+ Results:
+ - Task: Text Recognition
+ Dataset: IIIT5K
+ Metrics:
+ word_acc: 0.8553
+ - Task: Text Recognition
+ Dataset: SVT
+ Metrics:
+ word_acc: 0.9026
+ - Task: Text Recognition
+ Dataset: ICDAR2013
+ Metrics:
+ word_acc: 0.9448
+ - Task: Text Recognition
+ Dataset: ICDAR2015
+ Metrics:
+ word_acc: 0.7496
+ - Task: Text Recognition
+ Dataset: SVTP
+ Metrics:
+ word_acc: 0.8496
+ - Task: Text Recognition
+ Dataset: CT80
+ Metrics:
+ word_acc: 0.8854
+ Weights: https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-small_20e_st_mj/svtr-small_20e_st_mj-35d800d6.pth
+
+ - Name: svtr-base_20e_st_mj
+ Alias: svtr-base
+ Batch Size: 1024
+ In Collection: SVTR
+ Config: configs/textrecog/svtr/svtr-base_20e_st_mj.py
+ Metadata:
+ Training Data:
+ - SynthText
+ - Syn90k
+ Results:
+ - Task: Text Recognition
+ Dataset: IIIT5K
+ Metrics:
+ word_acc: 0.8570
+ - Task: Text Recognition
+ Dataset: SVT
+ Metrics:
+ word_acc: 0.9181
+ - Task: Text Recognition
+ Dataset: ICDAR2013
+ Metrics:
+ word_acc: 0.9438
+ - Task: Text Recognition
+ Dataset: ICDAR2015
+ Metrics:
+ word_acc: 0.7448
+ - Task: Text Recognition
+ Dataset: SVTP
+ Metrics:
+ word_acc: 0.8388
+ - Task: Text Recognition
+ Dataset: CT80
+ Metrics:
+ word_acc: 0.9028
+ Weights: https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/svtr-base_20e_st_mj-ea500101.pth
diff --git a/configs/textrecog/svtr/svtr-base_20e_st_mj.py b/configs/textrecog/svtr/svtr-base_20e_st_mj.py
new file mode 100644
index 00000000..7dda8501
--- /dev/null
+++ b/configs/textrecog/svtr/svtr-base_20e_st_mj.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'svtr-tiny_20e_st_mj.py',
+]
+
+model = dict(
+ preprocessor=dict(output_image_size=(48, 160), ),
+ encoder=dict(
+ img_size=[48, 160],
+ max_seq_len=40,
+ out_channels=256,
+ embed_dims=[128, 256, 384],
+ depth=[3, 6, 9],
+ num_heads=[4, 8, 12],
+ mixer_types=['Local'] * 8 + ['Global'] * 10),
+ decoder=dict(in_channels=256))
+
+train_dataloader = dict(batch_size=256, )
diff --git a/configs/textrecog/svtr/svtr-large_20e_st_mj.py b/configs/textrecog/svtr/svtr-large_20e_st_mj.py
new file mode 100644
index 00000000..1082d761
--- /dev/null
+++ b/configs/textrecog/svtr/svtr-large_20e_st_mj.py
@@ -0,0 +1,19 @@
+_base_ = [
+ 'svtr-tiny_20e_st_mj.py',
+]
+
+model = dict(
+ preprocessor=dict(output_image_size=(48, 160), ),
+ encoder=dict(
+ img_size=[48, 160],
+ max_seq_len=40,
+ out_channels=384,
+ embed_dims=[192, 256, 512],
+ depth=[3, 9, 9],
+ num_heads=[6, 8, 16],
+ mixer_types=['Local'] * 10 + ['Global'] * 11),
+ decoder=dict(in_channels=384))
+
+train_dataloader = dict(batch_size=128, )
+
+optim_wrapper = dict(optimizer=dict(lr=2.5 / (10**4)))
diff --git a/configs/textrecog/svtr/svtr-small_20e_st_mj.py b/configs/textrecog/svtr/svtr-small_20e_st_mj.py
new file mode 100644
index 00000000..bd73e46f
--- /dev/null
+++ b/configs/textrecog/svtr/svtr-small_20e_st_mj.py
@@ -0,0 +1,10 @@
+_base_ = [
+ 'svtr-tiny_20e_st_mj.py',
+]
+
+model = dict(
+ encoder=dict(
+ embed_dims=[96, 192, 256],
+ depth=[3, 6, 6],
+ num_heads=[3, 6, 8],
+ mixer_types=['Local'] * 8 + ['Global'] * 7))
diff --git a/configs/textrecog/svtr/svtr-tiny_20e_st_mj.py b/configs/textrecog/svtr/svtr-tiny_20e_st_mj.py
new file mode 100644
index 00000000..6f7dcdda
--- /dev/null
+++ b/configs/textrecog/svtr/svtr-tiny_20e_st_mj.py
@@ -0,0 +1,162 @@
+_base_ = [
+ '_base_svtr-tiny.py',
+ '../_base_/default_runtime.py',
+ '../_base_/datasets/mjsynth.py',
+ '../_base_/datasets/synthtext.py',
+ '../_base_/datasets/cute80.py',
+ '../_base_/datasets/iiit5k.py',
+ '../_base_/datasets/svt.py',
+ '../_base_/datasets/svtp.py',
+ '../_base_/datasets/icdar2013.py',
+ '../_base_/datasets/icdar2015.py',
+ '../_base_/schedules/schedule_adam_base.py',
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=20, val_interval=1)
+
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW',
+ lr=5 / (10**4) * 2048 / 2048,
+ betas=(0.9, 0.99),
+ eps=8e-8,
+ weight_decay=0.05))
+
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=0.5,
+ end_factor=1.,
+ end=2,
+ verbose=False,
+ convert_to_iter_based=True),
+ dict(
+ type='CosineAnnealingLR',
+ T_max=19,
+ begin=2,
+ end=20,
+ verbose=False,
+ convert_to_iter_based=True),
+]
+
+file_client_args = dict(backend='disk')
+
+train_pipeline = [
+ dict(
+ type='LoadImageFromFile',
+ file_client_args=file_client_args,
+ ignore_empty=True,
+ min_size=5),
+ dict(type='LoadOCRAnnotations', with_text=True),
+ dict(
+ type='RandomApply',
+ prob=0.4,
+ transforms=[
+ dict(type='TextRecogGeneralAug', ),
+ ],
+ ),
+ dict(
+ type='RandomApply',
+ prob=0.4,
+ transforms=[
+ dict(type='CropHeight', ),
+ ],
+ ),
+ dict(
+ type='ConditionApply',
+ condition='min(results["img_shape"])>10',
+ true_transforms=dict(
+ type='RandomApply',
+ prob=0.4,
+ transforms=[
+ dict(
+ type='TorchVisionWrapper',
+ op='GaussianBlur',
+ kernel_size=5,
+ sigma=1,
+ ),
+ ],
+ )),
+ dict(
+ type='RandomApply',
+ prob=0.4,
+ transforms=[
+ dict(
+ type='TorchVisionWrapper',
+ op='ColorJitter',
+ brightness=0.5,
+ saturation=0.5,
+ contrast=0.5,
+ hue=0.1),
+ ]),
+ dict(
+ type='RandomApply',
+ prob=0.4,
+ transforms=[
+ dict(type='ImageContentJitter', ),
+ ],
+ ),
+ dict(
+ type='RandomApply',
+ prob=0.4,
+ transforms=[
+ dict(
+ type='ImgAugWrapper',
+ args=[dict(cls='AdditiveGaussianNoise', scale=0.1**0.5)]),
+ ],
+ ),
+ dict(
+ type='RandomApply',
+ prob=0.4,
+ transforms=[
+ dict(type='ReversePixels', ),
+ ],
+ ),
+ dict(type='Resize', scale=(256, 64)),
+ dict(
+ type='PackTextRecogInputs',
+ meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile', file_client_args=file_client_args),
+ dict(type='Resize', scale=(256, 64)),
+ dict(type='LoadOCRAnnotations', with_text=True),
+ dict(
+ type='PackTextRecogInputs',
+ meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
+]
+
+# dataset settings
+train_list = [_base_.mjsynth_textrecog_test, _base_.synthtext_textrecog_train]
+test_list = [
+ _base_.cute80_textrecog_test, _base_.iiit5k_textrecog_test,
+ _base_.svt_textrecog_test, _base_.svtp_textrecog_test,
+ _base_.icdar2013_textrecog_test, _base_.icdar2015_textrecog_test
+]
+
+val_evaluator = dict(
+ dataset_prefixes=['CUTE80', 'IIIT5K', 'SVT', 'SVTP', 'IC13', 'IC15'])
+test_evaluator = val_evaluator
+
+train_dataloader = dict(
+ batch_size=512,
+ num_workers=24,
+ persistent_workers=True,
+ pin_memory=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ dataset=dict(
+ type='ConcatDataset', datasets=train_list, pipeline=train_pipeline))
+
+val_dataloader = dict(
+ batch_size=128,
+ num_workers=8,
+ persistent_workers=True,
+ pin_memory=True,
+ drop_last=False,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type='ConcatDataset', datasets=test_list, pipeline=test_pipeline))
+
+test_dataloader = val_dataloader
diff --git a/mmocr/models/textrecog/decoders/svtr_decoder.py b/mmocr/models/textrecog/decoders/svtr_decoder.py
index a4df37a7..122a51dc 100644
--- a/mmocr/models/textrecog/decoders/svtr_decoder.py
+++ b/mmocr/models/textrecog/decoders/svtr_decoder.py
@@ -56,9 +56,9 @@ class SVTRDecoder(BaseDecoder):
"""Forward for training.
Args:
- feat (torch.Tensor, optional): The feature map from backbone of
- shape :math:`(N, E, H, W)`. Defaults to None.
- out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
+ feat (torch.Tensor, optional): The feature map. Defaults to None.
+ out_enc (torch.Tensor, optional): Encoder output from encoder of
+ shape :math:`(N, 1, H, W)`. Defaults to None.
data_samples (Sequence[TextRecogDataSample]): Batch of
TextRecogDataSample, containing gt_text information. Defaults
to None.
@@ -67,8 +67,8 @@ class SVTRDecoder(BaseDecoder):
Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where
:math:`C` is ``num_classes``.
"""
- assert feat.size(2) == 1, 'feature height must be 1'
- x = feat.squeeze(2)
+ assert out_enc.size(2) == 1, 'feature height must be 1'
+ x = out_enc.squeeze(2)
x = x.permute(0, 2, 1)
predicts = self.decoder(x)
return predicts
@@ -82,9 +82,9 @@ class SVTRDecoder(BaseDecoder):
"""Forward for testing.
Args:
- feat (torch.Tensor, optional): The feature map from backbone of
- shape :math:`(N, E, H, W)`. Defaults to None.
- out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
+ feat (torch.Tensor, optional): The feature map. Defaults to None.
+ out_enc (torch.Tensor, optional): Encoder output from encoder of
+ shape :math:`(N, 1, H, W)`. Defaults to None.
data_samples (Sequence[TextRecogDataSample]): Batch of
TextRecogDataSample, containing gt_text information. Defaults
to None.
diff --git a/mmocr/models/textrecog/encoders/svtr_encoder.py b/mmocr/models/textrecog/encoders/svtr_encoder.py
index f97550c1..aa27f422 100644
--- a/mmocr/models/textrecog/encoders/svtr_encoder.py
+++ b/mmocr/models/textrecog/encoders/svtr_encoder.py
@@ -11,6 +11,7 @@ from mmengine.model import BaseModule
from mmengine.model.weight_init import trunc_normal_init
from mmocr.registry import MODELS
+from mmocr.structures import TextRecogDataSample
class OverlapPatchEmbed(BaseModule):
@@ -612,11 +613,16 @@ class SVTREncoder(BaseModule):
x = self.layer_norm(x)
return x
- def forward(self, x: torch.Tensor) -> torch.Tensor:
+ def forward(self,
+ x: torch.Tensor,
+ data_samples: List[TextRecogDataSample] = None
+ ) -> torch.Tensor:
"""Forward function.
Args:
x (torch.Tensor): A Tensor of shape :math:`(N, H/16, W/4, 256)`.
+ data_samples (list[TextRecogDataSample]): Batch of
+ TextRecogDataSample. Defaults to None.
Returns:
torch.Tensor: A Tensor of shape :math:`(N, 1, W/4, 192)`.
diff --git a/mmocr/models/textrecog/recognizers/__init__.py b/mmocr/models/textrecog/recognizers/__init__.py
index dc0ee711..a2f81941 100644
--- a/mmocr/models/textrecog/recognizers/__init__.py
+++ b/mmocr/models/textrecog/recognizers/__init__.py
@@ -9,8 +9,9 @@ from .nrtr import NRTR
from .robust_scanner import RobustScanner
from .sar import SARNet
from .satrn import SATRN
+from .svtr import SVTR
__all__ = [
'BaseRecognizer', 'EncoderDecoderRecognizer', 'CRNN', 'SARNet', 'NRTR',
- 'RobustScanner', 'SATRN', 'ABINet', 'MASTER', 'ASTER'
+ 'RobustScanner', 'SATRN', 'ABINet', 'MASTER', 'SVTR', 'ASTER'
]
diff --git a/mmocr/models/textrecog/recognizers/svtr.py b/mmocr/models/textrecog/recognizers/svtr.py
new file mode 100644
index 00000000..6fc42b85
--- /dev/null
+++ b/mmocr/models/textrecog/recognizers/svtr.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmocr.registry import MODELS
+from .encoder_decoder_recognizer import EncoderDecoderRecognizer
+
+
+@MODELS.register_module()
+class SVTR(EncoderDecoderRecognizer):
+ """A PyTorch implementation of : `SVTR: Scene Text Recognition with a
+ Single Visual Model `_"""
diff --git a/model-index.yml b/model-index.yml
index efa682fd..563372c2 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -12,6 +12,7 @@ Import:
- configs/textrecog/crnn/metafile.yml
- configs/textrecog/master/metafile.yml
- configs/textrecog/nrtr/metafile.yml
+ - configs/textrecog/svtr/metafile.yml
- configs/textrecog/robust_scanner/metafile.yml
- configs/textrecog/sar/metafile.yml
- configs/textrecog/satrn/metafile.yml
diff --git a/tests/test_models/test_textrecog/test_decoders/test_svtr_decoder.py b/tests/test_models/test_textrecog/test_decoders/test_svtr_decoder.py
index cb475cf0..63396511 100644
--- a/tests/test_models/test_textrecog/test_decoders/test_svtr_decoder.py
+++ b/tests/test_models/test_textrecog/test_decoders/test_svtr_decoder.py
@@ -46,7 +46,7 @@ class TestSVTRDecoder(TestCase):
in_channels=192, dictionary=dict_cfg, module_loss=loss_cfg)
def test_forward_train(self):
- feat = torch.randn(1, 192, 1, 25)
+ out_enc = torch.randn(1, 192, 1, 25)
tmp_dir = tempfile.TemporaryDirectory()
max_seq_len = 25
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
@@ -67,11 +67,12 @@ class TestSVTRDecoder(TestCase):
max_seq_len=max_seq_len,
)
data_samples = decoder.module_loss.get_targets(self.data_info)
- output = decoder.forward_train(feat=feat, data_samples=data_samples)
+ output = decoder.forward_train(
+ out_enc=out_enc, data_samples=data_samples)
self.assertTupleEqual(tuple(output.shape), (1, max_seq_len, 39))
def test_forward_test(self):
- feat = torch.randn(1, 192, 1, 25)
+ out_enc = torch.randn(1, 192, 1, 25)
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
create_dummy_dict_file(dict_file)
@@ -90,5 +91,6 @@ class TestSVTRDecoder(TestCase):
dictionary=dict_cfg,
module_loss=loss_cfg,
max_seq_len=25)
- output = decoder.forward_test(feat=feat, data_samples=self.data_info)
+ output = decoder.forward_test(
+ out_enc=out_enc, data_samples=self.data_info)
self.assertTupleEqual(tuple(output.shape), (1, 25, 39))