mirror of https://github.com/open-mmlab/mmocr.git
[Refactor] BaseDecoder
parent
6cd38a038f
commit
0b5d2df310
|
@ -1,30 +1,98 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.core.data_structures import TextRecogDataSample
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BaseDecoder(BaseModule):
|
||||
"""Base decoder class for text recognition."""
|
||||
"""Base decoder for text recognition, build the loss and postprocessor.
|
||||
|
||||
def __init__(self, init_cfg=None, **kwargs):
|
||||
Args:
|
||||
loss (dict, optional): Config to build loss. Defaults to None.
|
||||
postprocessor (dict, optional): Config to build postprocessor.
|
||||
Defaults to None.
|
||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss: Optional[Dict] = None,
|
||||
postprocessor: Optional[Dict] = None,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.loss = None
|
||||
self.postprocessor = None
|
||||
|
||||
def forward_train(self, feat, out_enc, targets_dict, img_metas):
|
||||
if loss is not None:
|
||||
assert isinstance(loss, dict)
|
||||
self.loss = MODELS.build(loss)
|
||||
|
||||
if postprocessor is not None:
|
||||
assert isinstance(postprocessor, dict)
|
||||
self.postprocessor = MODELS.build(postprocessor)
|
||||
|
||||
def forward_train(
|
||||
self,
|
||||
feat: Optional[torch.Tensor] = None,
|
||||
out_enc: Optional[torch.Tensor] = None,
|
||||
datasamples: Optional[Sequence[TextRecogDataSample]] = None
|
||||
) -> torch.Tensor:
|
||||
"""Forward for training.
|
||||
|
||||
Args:
|
||||
feat (torch.Tensor, optional): The feature map from backbone of
|
||||
shape :math:`(N, E, H, W)`. Defaults to None.
|
||||
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
|
||||
datasamples (Sequence[TextRecogDataSample]): Batch of
|
||||
TextRecogDataSample, containing gt_text information. Defaults
|
||||
to None.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
def forward_test(
|
||||
self,
|
||||
feat: Optional[torch.Tensor] = None,
|
||||
out_enc: Optional[torch.Tensor] = None,
|
||||
datasamples: Optional[Sequence[TextRecogDataSample]] = None
|
||||
) -> torch.Tensor:
|
||||
"""Forward for testing.
|
||||
|
||||
Args:
|
||||
feat (torch.Tensor, optional): The feature map from backbone of
|
||||
shape :math:`(N, E, H, W)`. Defaults to None.
|
||||
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
|
||||
datasamples (Sequence[TextRecogDataSample]): Batch of
|
||||
TextRecogDataSample, containing gt_text information. Defaults
|
||||
to None.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self,
|
||||
feat,
|
||||
out_enc,
|
||||
targets_dict=None,
|
||||
img_metas=None,
|
||||
train_mode=True):
|
||||
feat: Optional[torch.Tensor] = None,
|
||||
out_enc: Optional[torch.Tensor] = None,
|
||||
datasamples: Optional[Sequence[TextRecogDataSample]] = None,
|
||||
train_mode: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
feat (torch.Tensor, optional): The feature map from backbone of
|
||||
shape :math:`(N, E, H, W)`. Defaults to None.
|
||||
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
|
||||
datasamples (Sequence[TextRecogDataSample]): Batch of
|
||||
TextRecogDataSample, containing gt_text information. Defaults
|
||||
to None.
|
||||
train_mode (bool): Train or test. Defaults to True.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoder output
|
||||
"""
|
||||
self.train_mode = train_mode
|
||||
if train_mode:
|
||||
return self.forward_train(feat, out_enc, targets_dict, img_metas)
|
||||
return self.forward_train(feat, out_enc, datasamples)
|
||||
|
||||
return self.forward_test(feat, out_enc, img_metas)
|
||||
return self.forward_test(feat, out_enc, datasamples)
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase, mock
|
||||
|
||||
from mmocr.models.textrecog.decoders import BaseDecoder
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class Tmp:
|
||||
pass
|
||||
|
||||
|
||||
class TestBaseDecoder(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
cfg = dict(type='Tmp')
|
||||
with self.assertRaises(AssertionError):
|
||||
BaseDecoder([], cfg)
|
||||
with self.assertRaises(AssertionError):
|
||||
BaseDecoder(cfg, [])
|
||||
decoder = BaseDecoder()
|
||||
self.assertIsNone(decoder.loss)
|
||||
self.assertIsNone(decoder.postprocessor)
|
||||
|
||||
decoder = BaseDecoder(cfg, cfg)
|
||||
self.assertIsInstance(decoder.loss, Tmp)
|
||||
self.assertIsInstance(decoder.postprocessor, Tmp)
|
||||
|
||||
def test_forward_train(self):
|
||||
decoder = BaseDecoder()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
decoder.forward_train(None, None, None)
|
||||
|
||||
def test_forward_test(self):
|
||||
decoder = BaseDecoder()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
decoder.forward_test(None, None, None)
|
||||
|
||||
@mock.patch(f'{__name__}.BaseDecoder.forward_test')
|
||||
@mock.patch(f'{__name__}.BaseDecoder.forward_train')
|
||||
def test_forward(self, mock_forward_train, mock_forward_test):
|
||||
|
||||
def mock_func_train(feat, out_enc, datasamples):
|
||||
return True
|
||||
|
||||
def mock_func_test(feat, out_enc, datasamples):
|
||||
return False
|
||||
|
||||
mock_forward_train.side_effect = mock_func_train
|
||||
mock_forward_test.side_effect = mock_func_test
|
||||
cfg = dict(type='Tmp')
|
||||
decoder = BaseDecoder(cfg, cfg)
|
||||
|
||||
self.assertTrue(decoder(None, None, None, True))
|
||||
self.assertFalse(decoder(None, None, None, False))
|
Loading…
Reference in New Issue