[Fix] base decoder forget passing dictionary

This commit is contained in:
liukuikun 2022-05-25 02:46:34 +00:00 committed by gaotongxiao
parent e2577741dd
commit 05e31e09bc
2 changed files with 93 additions and 10 deletions

View File

@ -5,7 +5,8 @@ import torch
from mmcv.runner import BaseModule
from mmocr.core.data_structures import TextRecogDataSample
from mmocr.registry import MODELS
from mmocr.models.textrecog.dictionary import Dictionary
from mmocr.registry import MODELS, TASK_UTILS
@MODELS.register_module()
@ -13,27 +14,45 @@ class BaseDecoder(BaseModule):
"""Base decoder for text recognition, build the loss and postprocessor.
Args:
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
loss (dict, optional): Config to build loss. Defaults to None.
postprocessor (dict, optional): Config to build postprocessor.
Defaults to None.
max_seq_len (int): Maximum sequence length. The
sequence is usually generated from decoder. Defaults to 40.
init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to None.
"""
def __init__(self,
dictionary: Union[Dict, Dictionary],
loss: Optional[Dict] = None,
postprocessor: Optional[Dict] = None,
max_seq_len: int = 40,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
if isinstance(dictionary, dict):
self.dictionary = TASK_UTILS.build(dictionary)
elif isinstance(dictionary, Dictionary):
self.dictionary = dictionary
else:
raise TypeError(
'The type of dictionary should be `Dictionary` or dict, '
f'but got {type(dictionary)}')
self.loss = None
self.postprocessor = None
if loss is not None:
assert isinstance(loss, dict)
loss.update(dictionary=dictionary)
loss.update(max_seq_len=max_seq_len)
self.loss = MODELS.build(loss)
if postprocessor is not None:
assert isinstance(postprocessor, dict)
postprocessor.update(dictionary=dictionary)
postprocessor.update(max_seq_len=max_seq_len)
self.postprocessor = MODELS.build(postprocessor)
def forward_train(

View File

@ -1,40 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase, mock
from mmocr.models.textrecog.decoders import BaseDecoder
from mmocr.models.textrecog.dictionary.dictionary import Dictionary
from mmocr.registry import MODELS
@MODELS.register_module()
class Tmp:
def __init__(self, max_seq_len, dictionary) -> None:
pass
class TestBaseDecoder(TestCase):
def _create_dummy_dict_file(
self, dict_file,
chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA
with open(dict_file, 'w') as f:
for char in chars:
f.write(char + '\n')
def test_init(self):
cfg = dict(type='Tmp')
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
# test diction cfg
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=False,
with_padding=True,
with_unknown=True)
with self.assertRaises(AssertionError):
BaseDecoder([], cfg)
BaseDecoder(dict_cfg, [], cfg)
with self.assertRaises(AssertionError):
BaseDecoder(cfg, [])
decoder = BaseDecoder()
BaseDecoder(dict_cfg, cfg, [])
with self.assertRaises(TypeError):
BaseDecoder([], cfg, cfg)
decoder = BaseDecoder(dictionary=dict_cfg)
self.assertIsNone(decoder.loss)
self.assertIsNone(decoder.postprocessor)
decoder = BaseDecoder(cfg, cfg)
self.assertIsInstance(decoder.dictionary, Dictionary)
decoder = BaseDecoder(dict_cfg, cfg, cfg)
self.assertIsInstance(decoder.loss, Tmp)
self.assertIsInstance(decoder.postprocessor, Tmp)
tmp_dir.cleanup()
def test_forward_train(self):
decoder = BaseDecoder()
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
# test diction cfg
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=False,
with_padding=True,
with_unknown=True)
decoder = BaseDecoder(dictionary=dict_cfg)
with self.assertRaises(NotImplementedError):
decoder.forward_train(None, None, None)
tmp_dir.cleanup()
def test_forward_test(self):
decoder = BaseDecoder()
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=False,
with_padding=True,
with_unknown=True)
decoder = BaseDecoder(dictionary=dict_cfg)
with self.assertRaises(NotImplementedError):
decoder.forward_test(None, None, None)
tmp_dir.cleanup()
@mock.patch(f'{__name__}.BaseDecoder.forward_test')
@mock.patch(f'{__name__}.BaseDecoder.forward_train')
@ -46,10 +98,22 @@ class TestBaseDecoder(TestCase):
def mock_func_test(feat, out_enc, datasamples):
return False
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=False,
with_padding=True,
with_unknown=True)
mock_forward_train.side_effect = mock_func_train
mock_forward_test.side_effect = mock_func_test
cfg = dict(type='Tmp')
decoder = BaseDecoder(cfg, cfg)
decoder = BaseDecoder(dict_cfg, cfg, cfg)
self.assertTrue(decoder(None, None, None, True))
self.assertFalse(decoder(None, None, None, False))
tmp_dir.cleanup()