mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Fix] base decoder forget passing dictionary
This commit is contained in:
parent
e2577741dd
commit
05e31e09bc
@ -5,7 +5,8 @@ import torch
|
|||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
from mmocr.core.data_structures import TextRecogDataSample
|
from mmocr.core.data_structures import TextRecogDataSample
|
||||||
from mmocr.registry import MODELS
|
from mmocr.models.textrecog.dictionary import Dictionary
|
||||||
|
from mmocr.registry import MODELS, TASK_UTILS
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
@ -13,27 +14,45 @@ class BaseDecoder(BaseModule):
|
|||||||
"""Base decoder for text recognition, build the loss and postprocessor.
|
"""Base decoder for text recognition, build the loss and postprocessor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
|
||||||
|
the instance of `Dictionary`.
|
||||||
loss (dict, optional): Config to build loss. Defaults to None.
|
loss (dict, optional): Config to build loss. Defaults to None.
|
||||||
postprocessor (dict, optional): Config to build postprocessor.
|
postprocessor (dict, optional): Config to build postprocessor.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
max_seq_len (int): Maximum sequence length. The
|
||||||
|
sequence is usually generated from decoder. Defaults to 40.
|
||||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
dictionary: Union[Dict, Dictionary],
|
||||||
loss: Optional[Dict] = None,
|
loss: Optional[Dict] = None,
|
||||||
postprocessor: Optional[Dict] = None,
|
postprocessor: Optional[Dict] = None,
|
||||||
|
max_seq_len: int = 40,
|
||||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||||
super().__init__(init_cfg=init_cfg)
|
super().__init__(init_cfg=init_cfg)
|
||||||
|
if isinstance(dictionary, dict):
|
||||||
|
self.dictionary = TASK_UTILS.build(dictionary)
|
||||||
|
elif isinstance(dictionary, Dictionary):
|
||||||
|
self.dictionary = dictionary
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'The type of dictionary should be `Dictionary` or dict, '
|
||||||
|
f'but got {type(dictionary)}')
|
||||||
self.loss = None
|
self.loss = None
|
||||||
self.postprocessor = None
|
self.postprocessor = None
|
||||||
|
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
assert isinstance(loss, dict)
|
assert isinstance(loss, dict)
|
||||||
|
loss.update(dictionary=dictionary)
|
||||||
|
loss.update(max_seq_len=max_seq_len)
|
||||||
self.loss = MODELS.build(loss)
|
self.loss = MODELS.build(loss)
|
||||||
|
|
||||||
if postprocessor is not None:
|
if postprocessor is not None:
|
||||||
assert isinstance(postprocessor, dict)
|
assert isinstance(postprocessor, dict)
|
||||||
|
postprocessor.update(dictionary=dictionary)
|
||||||
|
postprocessor.update(max_seq_len=max_seq_len)
|
||||||
self.postprocessor = MODELS.build(postprocessor)
|
self.postprocessor = MODELS.build(postprocessor)
|
||||||
|
|
||||||
def forward_train(
|
def forward_train(
|
||||||
|
@ -1,40 +1,92 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os.path as osp
|
||||||
|
import tempfile
|
||||||
from unittest import TestCase, mock
|
from unittest import TestCase, mock
|
||||||
|
|
||||||
from mmocr.models.textrecog.decoders import BaseDecoder
|
from mmocr.models.textrecog.decoders import BaseDecoder
|
||||||
|
from mmocr.models.textrecog.dictionary.dictionary import Dictionary
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class Tmp:
|
class Tmp:
|
||||||
|
|
||||||
|
def __init__(self, max_seq_len, dictionary) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestBaseDecoder(TestCase):
|
class TestBaseDecoder(TestCase):
|
||||||
|
|
||||||
|
def _create_dummy_dict_file(
|
||||||
|
self, dict_file,
|
||||||
|
chars=list('0123456789abcdefghijklmnopqrstuvwxyz')): # NOQA
|
||||||
|
with open(dict_file, 'w') as f:
|
||||||
|
for char in chars:
|
||||||
|
f.write(char + '\n')
|
||||||
|
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
cfg = dict(type='Tmp')
|
cfg = dict(type='Tmp')
|
||||||
|
tmp_dir = tempfile.TemporaryDirectory()
|
||||||
|
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||||
|
self._create_dummy_dict_file(dict_file)
|
||||||
|
# test diction cfg
|
||||||
|
dict_cfg = dict(
|
||||||
|
type='Dictionary',
|
||||||
|
dict_file=dict_file,
|
||||||
|
with_start=True,
|
||||||
|
with_end=True,
|
||||||
|
same_start_end=False,
|
||||||
|
with_padding=True,
|
||||||
|
with_unknown=True)
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
BaseDecoder([], cfg)
|
BaseDecoder(dict_cfg, [], cfg)
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
BaseDecoder(cfg, [])
|
BaseDecoder(dict_cfg, cfg, [])
|
||||||
decoder = BaseDecoder()
|
with self.assertRaises(TypeError):
|
||||||
|
BaseDecoder([], cfg, cfg)
|
||||||
|
decoder = BaseDecoder(dictionary=dict_cfg)
|
||||||
self.assertIsNone(decoder.loss)
|
self.assertIsNone(decoder.loss)
|
||||||
self.assertIsNone(decoder.postprocessor)
|
self.assertIsNone(decoder.postprocessor)
|
||||||
|
self.assertIsInstance(decoder.dictionary, Dictionary)
|
||||||
decoder = BaseDecoder(cfg, cfg)
|
decoder = BaseDecoder(dict_cfg, cfg, cfg)
|
||||||
self.assertIsInstance(decoder.loss, Tmp)
|
self.assertIsInstance(decoder.loss, Tmp)
|
||||||
self.assertIsInstance(decoder.postprocessor, Tmp)
|
self.assertIsInstance(decoder.postprocessor, Tmp)
|
||||||
|
tmp_dir.cleanup()
|
||||||
|
|
||||||
def test_forward_train(self):
|
def test_forward_train(self):
|
||||||
decoder = BaseDecoder()
|
tmp_dir = tempfile.TemporaryDirectory()
|
||||||
|
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||||
|
self._create_dummy_dict_file(dict_file)
|
||||||
|
# test diction cfg
|
||||||
|
dict_cfg = dict(
|
||||||
|
type='Dictionary',
|
||||||
|
dict_file=dict_file,
|
||||||
|
with_start=True,
|
||||||
|
with_end=True,
|
||||||
|
same_start_end=False,
|
||||||
|
with_padding=True,
|
||||||
|
with_unknown=True)
|
||||||
|
decoder = BaseDecoder(dictionary=dict_cfg)
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
decoder.forward_train(None, None, None)
|
decoder.forward_train(None, None, None)
|
||||||
|
tmp_dir.cleanup()
|
||||||
|
|
||||||
def test_forward_test(self):
|
def test_forward_test(self):
|
||||||
decoder = BaseDecoder()
|
tmp_dir = tempfile.TemporaryDirectory()
|
||||||
|
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||||
|
self._create_dummy_dict_file(dict_file)
|
||||||
|
dict_cfg = dict(
|
||||||
|
type='Dictionary',
|
||||||
|
dict_file=dict_file,
|
||||||
|
with_start=True,
|
||||||
|
with_end=True,
|
||||||
|
same_start_end=False,
|
||||||
|
with_padding=True,
|
||||||
|
with_unknown=True)
|
||||||
|
decoder = BaseDecoder(dictionary=dict_cfg)
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
decoder.forward_test(None, None, None)
|
decoder.forward_test(None, None, None)
|
||||||
|
tmp_dir.cleanup()
|
||||||
|
|
||||||
@mock.patch(f'{__name__}.BaseDecoder.forward_test')
|
@mock.patch(f'{__name__}.BaseDecoder.forward_test')
|
||||||
@mock.patch(f'{__name__}.BaseDecoder.forward_train')
|
@mock.patch(f'{__name__}.BaseDecoder.forward_train')
|
||||||
@ -46,10 +98,22 @@ class TestBaseDecoder(TestCase):
|
|||||||
def mock_func_test(feat, out_enc, datasamples):
|
def mock_func_test(feat, out_enc, datasamples):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
tmp_dir = tempfile.TemporaryDirectory()
|
||||||
|
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||||
|
self._create_dummy_dict_file(dict_file)
|
||||||
|
dict_cfg = dict(
|
||||||
|
type='Dictionary',
|
||||||
|
dict_file=dict_file,
|
||||||
|
with_start=True,
|
||||||
|
with_end=True,
|
||||||
|
same_start_end=False,
|
||||||
|
with_padding=True,
|
||||||
|
with_unknown=True)
|
||||||
mock_forward_train.side_effect = mock_func_train
|
mock_forward_train.side_effect = mock_func_train
|
||||||
mock_forward_test.side_effect = mock_func_test
|
mock_forward_test.side_effect = mock_func_test
|
||||||
cfg = dict(type='Tmp')
|
cfg = dict(type='Tmp')
|
||||||
decoder = BaseDecoder(cfg, cfg)
|
decoder = BaseDecoder(dict_cfg, cfg, cfg)
|
||||||
|
|
||||||
self.assertTrue(decoder(None, None, None, True))
|
self.assertTrue(decoder(None, None, None, True))
|
||||||
self.assertFalse(decoder(None, None, None, False))
|
self.assertFalse(decoder(None, None, None, False))
|
||||||
|
tmp_dir.cleanup()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user