[Refactor] BaseDecoder

pull/1178/head
liukuikun 2022-05-19 09:41:10 +00:00 committed by gaotongxiao
parent 6cd38a038f
commit 0b5d2df310
2 changed files with 134 additions and 11 deletions

View File

@ -1,30 +1,98 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Union
import torch
from mmcv.runner import BaseModule
from mmocr.core.data_structures import TextRecogDataSample
from mmocr.registry import MODELS
@MODELS.register_module()
class BaseDecoder(BaseModule):
"""Base decoder class for text recognition."""
"""Base decoder for text recognition, build the loss and postprocessor.
def __init__(self, init_cfg=None, **kwargs):
Args:
loss (dict, optional): Config to build loss. Defaults to None.
postprocessor (dict, optional): Config to build postprocessor.
Defaults to None.
init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to None.
"""
def __init__(self,
loss: Optional[Dict] = None,
postprocessor: Optional[Dict] = None,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.loss = None
self.postprocessor = None
def forward_train(self, feat, out_enc, targets_dict, img_metas):
if loss is not None:
assert isinstance(loss, dict)
self.loss = MODELS.build(loss)
if postprocessor is not None:
assert isinstance(postprocessor, dict)
self.postprocessor = MODELS.build(postprocessor)
def forward_train(
self,
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,
datasamples: Optional[Sequence[TextRecogDataSample]] = None
) -> torch.Tensor:
"""Forward for training.
Args:
feat (torch.Tensor, optional): The feature map from backbone of
shape :math:`(N, E, H, W)`. Defaults to None.
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
datasamples (Sequence[TextRecogDataSample]): Batch of
TextRecogDataSample, containing gt_text information. Defaults
to None.
"""
raise NotImplementedError
def forward_test(self, feat, out_enc, img_metas):
def forward_test(
self,
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,
datasamples: Optional[Sequence[TextRecogDataSample]] = None
) -> torch.Tensor:
"""Forward for testing.
Args:
feat (torch.Tensor, optional): The feature map from backbone of
shape :math:`(N, E, H, W)`. Defaults to None.
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
datasamples (Sequence[TextRecogDataSample]): Batch of
TextRecogDataSample, containing gt_text information. Defaults
to None.
"""
raise NotImplementedError
def forward(self,
feat,
out_enc,
targets_dict=None,
img_metas=None,
train_mode=True):
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,
datasamples: Optional[Sequence[TextRecogDataSample]] = None,
train_mode: bool = True) -> torch.Tensor:
"""
Args:
feat (torch.Tensor, optional): The feature map from backbone of
shape :math:`(N, E, H, W)`. Defaults to None.
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
datasamples (Sequence[TextRecogDataSample]): Batch of
TextRecogDataSample, containing gt_text information. Defaults
to None.
train_mode (bool): Train or test. Defaults to True.
Returns:
torch.Tensor: Decoder output
"""
self.train_mode = train_mode
if train_mode:
return self.forward_train(feat, out_enc, targets_dict, img_metas)
return self.forward_train(feat, out_enc, datasamples)
return self.forward_test(feat, out_enc, img_metas)
return self.forward_test(feat, out_enc, datasamples)

View File

@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase, mock
from mmocr.models.textrecog.decoders import BaseDecoder
from mmocr.registry import MODELS
@MODELS.register_module()
class Tmp:
pass
class TestBaseDecoder(TestCase):
def test_init(self):
cfg = dict(type='Tmp')
with self.assertRaises(AssertionError):
BaseDecoder([], cfg)
with self.assertRaises(AssertionError):
BaseDecoder(cfg, [])
decoder = BaseDecoder()
self.assertIsNone(decoder.loss)
self.assertIsNone(decoder.postprocessor)
decoder = BaseDecoder(cfg, cfg)
self.assertIsInstance(decoder.loss, Tmp)
self.assertIsInstance(decoder.postprocessor, Tmp)
def test_forward_train(self):
decoder = BaseDecoder()
with self.assertRaises(NotImplementedError):
decoder.forward_train(None, None, None)
def test_forward_test(self):
decoder = BaseDecoder()
with self.assertRaises(NotImplementedError):
decoder.forward_test(None, None, None)
@mock.patch(f'{__name__}.BaseDecoder.forward_test')
@mock.patch(f'{__name__}.BaseDecoder.forward_train')
def test_forward(self, mock_forward_train, mock_forward_test):
def mock_func_train(feat, out_enc, datasamples):
return True
def mock_func_test(feat, out_enc, datasamples):
return False
mock_forward_train.side_effect = mock_func_train
mock_forward_test.side_effect = mock_func_test
cfg = dict(type='Tmp')
decoder = BaseDecoder(cfg, cfg)
self.assertTrue(decoder(None, None, None, True))
self.assertFalse(decoder(None, None, None, False))