mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Fix] base decoder forget passing dictionary
This commit is contained in:
parent
e2577741dd
commit
05e31e09bc
@ -5,7 +5,8 @@ import torch
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.core.data_structures import TextRecogDataSample
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.models.textrecog.dictionary import Dictionary
|
||||
from mmocr.registry import MODELS, TASK_UTILS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
@ -13,27 +14,45 @@ class BaseDecoder(BaseModule):
|
||||
"""Base decoder for text recognition, build the loss and postprocessor.
|
||||
|
||||
Args:
|
||||
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
|
||||
the instance of `Dictionary`.
|
||||
loss (dict, optional): Config to build loss. Defaults to None.
|
||||
postprocessor (dict, optional): Config to build postprocessor.
|
||||
Defaults to None.
|
||||
max_seq_len (int): Maximum sequence length. The
|
||||
sequence is usually generated from decoder. Defaults to 40.
|
||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dictionary: Union[Dict, Dictionary],
|
||||
loss: Optional[Dict] = None,
|
||||
postprocessor: Optional[Dict] = None,
|
||||
max_seq_len: int = 40,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
if isinstance(dictionary, dict):
|
||||
self.dictionary = TASK_UTILS.build(dictionary)
|
||||
elif isinstance(dictionary, Dictionary):
|
||||
self.dictionary = dictionary
|
||||
else:
|
||||
raise TypeError(
|
||||
'The type of dictionary should be `Dictionary` or dict, '
|
||||
f'but got {type(dictionary)}')
|
||||
self.loss = None
|
||||
self.postprocessor = None
|
||||
|
||||
if loss is not None:
|
||||
assert isinstance(loss, dict)
|
||||
loss.update(dictionary=dictionary)
|
||||
loss.update(max_seq_len=max_seq_len)
|
||||
self.loss = MODELS.build(loss)
|
||||
|
||||
if postprocessor is not None:
|
||||
assert isinstance(postprocessor, dict)
|
||||
postprocessor.update(dictionary=dictionary)
|
||||
postprocessor.update(max_seq_len=max_seq_len)
|
||||
self.postprocessor = MODELS.build(postprocessor)
|
||||
|
||||
def forward_train(
|
||||
|
@ -1,40 +1,92 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase, mock
|
||||
|
||||
from mmocr.models.textrecog.decoders import BaseDecoder
|
||||
from mmocr.models.textrecog.dictionary.dictionary import Dictionary
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class Tmp:
|
||||
pass
|
||||
|
||||
def __init__(self, max_seq_len, dictionary) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class TestBaseDecoder(TestCase):
|
||||
|
||||
def _create_dummy_dict_file(
|
||||
self, dict_file,
|
||||
chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA
|
||||
with open(dict_file, 'w') as f:
|
||||
for char in chars:
|
||||
f.write(char + '\n')
|
||||
|
||||
def test_init(self):
|
||||
cfg = dict(type='Tmp')
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
# test diction cfg
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=False,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
with self.assertRaises(AssertionError):
|
||||
BaseDecoder([], cfg)
|
||||
BaseDecoder(dict_cfg, [], cfg)
|
||||
with self.assertRaises(AssertionError):
|
||||
BaseDecoder(cfg, [])
|
||||
decoder = BaseDecoder()
|
||||
BaseDecoder(dict_cfg, cfg, [])
|
||||
with self.assertRaises(TypeError):
|
||||
BaseDecoder([], cfg, cfg)
|
||||
decoder = BaseDecoder(dictionary=dict_cfg)
|
||||
self.assertIsNone(decoder.loss)
|
||||
self.assertIsNone(decoder.postprocessor)
|
||||
|
||||
decoder = BaseDecoder(cfg, cfg)
|
||||
self.assertIsInstance(decoder.dictionary, Dictionary)
|
||||
decoder = BaseDecoder(dict_cfg, cfg, cfg)
|
||||
self.assertIsInstance(decoder.loss, Tmp)
|
||||
self.assertIsInstance(decoder.postprocessor, Tmp)
|
||||
tmp_dir.cleanup()
|
||||
|
||||
def test_forward_train(self):
|
||||
decoder = BaseDecoder()
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
# test diction cfg
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=False,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
decoder = BaseDecoder(dictionary=dict_cfg)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
decoder.forward_train(None, None, None)
|
||||
tmp_dir.cleanup()
|
||||
|
||||
def test_forward_test(self):
|
||||
decoder = BaseDecoder()
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=False,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
decoder = BaseDecoder(dictionary=dict_cfg)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
decoder.forward_test(None, None, None)
|
||||
tmp_dir.cleanup()
|
||||
|
||||
@mock.patch(f'{__name__}.BaseDecoder.forward_test')
|
||||
@mock.patch(f'{__name__}.BaseDecoder.forward_train')
|
||||
@ -46,10 +98,22 @@ class TestBaseDecoder(TestCase):
|
||||
def mock_func_test(feat, out_enc, datasamples):
|
||||
return False
|
||||
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
self._create_dummy_dict_file(dict_file)
|
||||
dict_cfg = dict(
|
||||
type='Dictionary',
|
||||
dict_file=dict_file,
|
||||
with_start=True,
|
||||
with_end=True,
|
||||
same_start_end=False,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
mock_forward_train.side_effect = mock_func_train
|
||||
mock_forward_test.side_effect = mock_func_test
|
||||
cfg = dict(type='Tmp')
|
||||
decoder = BaseDecoder(cfg, cfg)
|
||||
decoder = BaseDecoder(dict_cfg, cfg, cfg)
|
||||
|
||||
self.assertTrue(decoder(None, None, None, True))
|
||||
self.assertFalse(decoder(None, None, None, False))
|
||||
tmp_dir.cleanup()
|
||||
|
Loading…
x
Reference in New Issue
Block a user