[Fix] base decoder forget passing dictionary

This commit is contained in:
liukuikun 2022-05-25 02:46:34 +00:00 committed by gaotongxiao
parent e2577741dd
commit 05e31e09bc
2 changed files with 93 additions and 10 deletions

View File

@ -5,7 +5,8 @@ import torch
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmocr.core.data_structures import TextRecogDataSample from mmocr.core.data_structures import TextRecogDataSample
from mmocr.registry import MODELS from mmocr.models.textrecog.dictionary import Dictionary
from mmocr.registry import MODELS, TASK_UTILS
@MODELS.register_module() @MODELS.register_module()
@ -13,27 +14,45 @@ class BaseDecoder(BaseModule):
"""Base decoder for text recognition, build the loss and postprocessor. """Base decoder for text recognition, build the loss and postprocessor.
Args: Args:
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
loss (dict, optional): Config to build loss. Defaults to None. loss (dict, optional): Config to build loss. Defaults to None.
postprocessor (dict, optional): Config to build postprocessor. postprocessor (dict, optional): Config to build postprocessor.
Defaults to None. Defaults to None.
max_seq_len (int): Maximum sequence length. The
sequence is usually generated from decoder. Defaults to 40.
init_cfg (dict or list[dict], optional): Initialization configs. init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to None. Defaults to None.
""" """
def __init__(self, def __init__(self,
dictionary: Union[Dict, Dictionary],
loss: Optional[Dict] = None, loss: Optional[Dict] = None,
postprocessor: Optional[Dict] = None, postprocessor: Optional[Dict] = None,
max_seq_len: int = 40,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None: init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
if isinstance(dictionary, dict):
self.dictionary = TASK_UTILS.build(dictionary)
elif isinstance(dictionary, Dictionary):
self.dictionary = dictionary
else:
raise TypeError(
'The type of dictionary should be `Dictionary` or dict, '
f'but got {type(dictionary)}')
self.loss = None self.loss = None
self.postprocessor = None self.postprocessor = None
if loss is not None: if loss is not None:
assert isinstance(loss, dict) assert isinstance(loss, dict)
loss.update(dictionary=dictionary)
loss.update(max_seq_len=max_seq_len)
self.loss = MODELS.build(loss) self.loss = MODELS.build(loss)
if postprocessor is not None: if postprocessor is not None:
assert isinstance(postprocessor, dict) assert isinstance(postprocessor, dict)
postprocessor.update(dictionary=dictionary)
postprocessor.update(max_seq_len=max_seq_len)
self.postprocessor = MODELS.build(postprocessor) self.postprocessor = MODELS.build(postprocessor)
def forward_train( def forward_train(

View File

@ -1,40 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase, mock from unittest import TestCase, mock
from mmocr.models.textrecog.decoders import BaseDecoder from mmocr.models.textrecog.decoders import BaseDecoder
from mmocr.models.textrecog.dictionary.dictionary import Dictionary
from mmocr.registry import MODELS from mmocr.registry import MODELS
@MODELS.register_module() @MODELS.register_module()
class Tmp: class Tmp:
def __init__(self, max_seq_len, dictionary) -> None:
pass pass
class TestBaseDecoder(TestCase): class TestBaseDecoder(TestCase):
def _create_dummy_dict_file(
self, dict_file,
chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA
with open(dict_file, 'w') as f:
for char in chars:
f.write(char + '\n')
def test_init(self): def test_init(self):
cfg = dict(type='Tmp') cfg = dict(type='Tmp')
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
# test diction cfg
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=False,
with_padding=True,
with_unknown=True)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
BaseDecoder([], cfg) BaseDecoder(dict_cfg, [], cfg)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
BaseDecoder(cfg, []) BaseDecoder(dict_cfg, cfg, [])
decoder = BaseDecoder() with self.assertRaises(TypeError):
BaseDecoder([], cfg, cfg)
decoder = BaseDecoder(dictionary=dict_cfg)
self.assertIsNone(decoder.loss) self.assertIsNone(decoder.loss)
self.assertIsNone(decoder.postprocessor) self.assertIsNone(decoder.postprocessor)
self.assertIsInstance(decoder.dictionary, Dictionary)
decoder = BaseDecoder(cfg, cfg) decoder = BaseDecoder(dict_cfg, cfg, cfg)
self.assertIsInstance(decoder.loss, Tmp) self.assertIsInstance(decoder.loss, Tmp)
self.assertIsInstance(decoder.postprocessor, Tmp) self.assertIsInstance(decoder.postprocessor, Tmp)
tmp_dir.cleanup()
def test_forward_train(self): def test_forward_train(self):
decoder = BaseDecoder() tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
# test diction cfg
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=False,
with_padding=True,
with_unknown=True)
decoder = BaseDecoder(dictionary=dict_cfg)
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
decoder.forward_train(None, None, None) decoder.forward_train(None, None, None)
tmp_dir.cleanup()
def test_forward_test(self): def test_forward_test(self):
decoder = BaseDecoder() tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=False,
with_padding=True,
with_unknown=True)
decoder = BaseDecoder(dictionary=dict_cfg)
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
decoder.forward_test(None, None, None) decoder.forward_test(None, None, None)
tmp_dir.cleanup()
@mock.patch(f'{__name__}.BaseDecoder.forward_test') @mock.patch(f'{__name__}.BaseDecoder.forward_test')
@mock.patch(f'{__name__}.BaseDecoder.forward_train') @mock.patch(f'{__name__}.BaseDecoder.forward_train')
@ -46,10 +98,22 @@ class TestBaseDecoder(TestCase):
def mock_func_test(feat, out_enc, datasamples): def mock_func_test(feat, out_enc, datasamples):
return False return False
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
self._create_dummy_dict_file(dict_file)
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=False,
with_padding=True,
with_unknown=True)
mock_forward_train.side_effect = mock_func_train mock_forward_train.side_effect = mock_func_train
mock_forward_test.side_effect = mock_func_test mock_forward_test.side_effect = mock_func_test
cfg = dict(type='Tmp') cfg = dict(type='Tmp')
decoder = BaseDecoder(cfg, cfg) decoder = BaseDecoder(dict_cfg, cfg, cfg)
self.assertTrue(decoder(None, None, None, True)) self.assertTrue(decoder(None, None, None, True))
self.assertFalse(decoder(None, None, None, False)) self.assertFalse(decoder(None, None, None, False))
tmp_dir.cleanup()