mirror of https://github.com/open-mmlab/mmocr.git
[Refactor] Rename base_xxx.py as base.py (#1322)
parent
b81d58e70c
commit
1b5764b155
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_textdet_head import BaseTextDetHead
|
||||
from .base import BaseTextDetHead
|
||||
from .db_head import DBHead
|
||||
from .drrg_head import DRRGHead
|
||||
from .fce_head import FCEHead
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch.nn as nn
|
|||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextDetDataSample
|
||||
from mmocr.utils import check_argument
|
||||
from .base_textdet_head import BaseTextDetHead
|
||||
from .base import BaseTextDetHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_postprocessor import BaseTextDetPostProcessor
|
||||
from .base import BaseTextDetPostProcessor
|
||||
from .db_postprocessor import DBPostprocessor
|
||||
from .drrg_postprocessor import DRRGPostprocessor
|
||||
from .fce_postprocessor import FCEPostprocessor
|
||||
|
|
|
@ -11,7 +11,7 @@ from torch import Tensor
|
|||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextDetDataSample
|
||||
from mmocr.utils import offset_polygon
|
||||
from .base_postprocessor import BaseTextDetPostProcessor
|
||||
from .base import BaseTextDetPostProcessor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -11,7 +11,7 @@ from numpy import ndarray
|
|||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextDetDataSample
|
||||
from .base_postprocessor import BaseTextDetPostProcessor
|
||||
from .base import BaseTextDetPostProcessor
|
||||
|
||||
|
||||
class Node:
|
||||
|
|
|
@ -10,7 +10,7 @@ from numpy.fft import ifft
|
|||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextDetDataSample
|
||||
from mmocr.utils import fill_hole
|
||||
from .base_postprocessor import BaseTextDetPostProcessor
|
||||
from .base import BaseTextDetPostProcessor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -9,7 +9,7 @@ from mmengine.data import InstanceData
|
|||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextDetDataSample
|
||||
from .base_postprocessor import BaseTextDetPostProcessor
|
||||
from .base import BaseTextDetPostProcessor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -12,7 +12,7 @@ from skimage.morphology import skeletonize
|
|||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextDetDataSample
|
||||
from mmocr.utils import fill_hole
|
||||
from .base_postprocessor import BaseTextDetPostProcessor
|
||||
from .base import BaseTextDetPostProcessor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from .abi_fuser import ABIFuser
|
||||
from .abi_language_decoder import ABILanguageDecoder
|
||||
from .abi_vision_decoder import ABIVisionDecoder
|
||||
from .base_decoder import BaseDecoder
|
||||
from .base import BaseDecoder
|
||||
from .crnn_decoder import CRNNDecoder
|
||||
from .master_decoder import MasterDecoder
|
||||
from .nrtr_decoder import NRTRDecoder
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||
from mmocr.models.common.dictionary import Dictionary
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_decoder import BaseDecoder
|
||||
from .base import BaseDecoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -11,7 +11,7 @@ from mmocr.models.common.dictionary import Dictionary
|
|||
from mmocr.models.common.modules import PositionalEncoding
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_decoder import BaseDecoder
|
||||
from .base import BaseDecoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -9,7 +9,7 @@ from mmocr.models.common.dictionary import Dictionary
|
|||
from mmocr.models.common.modules import PositionalEncoding
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_decoder import BaseDecoder
|
||||
from .base import BaseDecoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -9,7 +9,7 @@ from mmocr.models.common.dictionary import Dictionary
|
|||
from mmocr.models.textrecog.layers import BidirectionalLSTM
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_decoder import BaseDecoder
|
||||
from .base import BaseDecoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -12,7 +12,7 @@ from mmocr.models.common.dictionary import Dictionary
|
|||
from mmocr.models.common.modules import PositionalEncoding
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_decoder import BaseDecoder
|
||||
from .base import BaseDecoder
|
||||
|
||||
|
||||
def clones(module: nn.Module, N: int) -> nn.ModuleList:
|
||||
|
|
|
@ -10,7 +10,7 @@ from mmocr.models.common import PositionalEncoding, TFDecoderLayer
|
|||
from mmocr.models.common.dictionary import Dictionary
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_decoder import BaseDecoder
|
||||
from .base import BaseDecoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -10,7 +10,7 @@ from mmocr.models.textrecog.layers import (DotProductAttentionLayer,
|
|||
PositionAwareLayer)
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_decoder import BaseDecoder
|
||||
from .base import BaseDecoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||
from mmocr.models.common.dictionary import Dictionary
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_decoder import BaseDecoder
|
||||
from .base import BaseDecoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch.nn.functional as F
|
|||
from mmocr.models.common.dictionary import Dictionary
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_decoder import BaseDecoder
|
||||
from .base import BaseDecoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -9,7 +9,7 @@ from mmocr.models.common.dictionary import Dictionary
|
|||
from mmocr.models.textrecog.layers import DotProductAttentionLayer
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_decoder import BaseDecoder
|
||||
from .base import BaseDecoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .abi_encoder import ABIEncoder
|
||||
from .base_encoder import BaseEncoder
|
||||
from .base import BaseEncoder
|
||||
from .channel_reduction_encoder import ChannelReductionEncoder
|
||||
from .nrtr_encoder import NRTREncoder
|
||||
from .sar_encoder import SAREncoder
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_encoder import BaseEncoder
|
||||
from .base import BaseEncoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -9,7 +9,7 @@ from mmengine.model import ModuleList
|
|||
from mmocr.models.common import TFEncoderLayer
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_encoder import BaseEncoder
|
||||
from .base import BaseEncoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch.nn.functional as F
|
|||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_encoder import BaseEncoder
|
||||
from .base import BaseEncoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -11,7 +11,7 @@ from mmocr.models.textrecog.layers import (Adaptive2DPositionalEncoding,
|
|||
SATRNEncoderLayer)
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_encoder import BaseEncoder
|
||||
from .base import BaseEncoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .abi_module_loss import ABIModuleLoss
|
||||
from .base_recog_module_loss import BaseRecogModuleLoss
|
||||
from .base import BaseTextRecogModuleLoss
|
||||
from .ce_module_loss import CEModuleLoss
|
||||
from .ctc_module_loss import CTCModuleLoss
|
||||
|
||||
__all__ = [
|
||||
'BaseRecogModuleLoss', 'CEModuleLoss', 'CTCModuleLoss', 'ABIModuleLoss'
|
||||
'BaseTextRecogModuleLoss', 'CEModuleLoss', 'CTCModuleLoss', 'ABIModuleLoss'
|
||||
]
|
||||
|
|
|
@ -6,12 +6,12 @@ import torch
|
|||
from mmocr.models.common.dictionary import Dictionary
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_recog_module_loss import BaseRecogModuleLoss
|
||||
from .base import BaseTextRecogModuleLoss
|
||||
from .ce_module_loss import CEModuleLoss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ABIModuleLoss(BaseRecogModuleLoss):
|
||||
class ABIModuleLoss(BaseTextRecogModuleLoss):
|
||||
"""Implementation of ABINet multiloss that allows mixing different types of
|
||||
losses with weights.
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from mmocr.registry import TASK_UTILS
|
|||
from mmocr.structures import TextRecogDataSample
|
||||
|
||||
|
||||
class BaseRecogModuleLoss(nn.Module):
|
||||
class BaseTextRecogModuleLoss(nn.Module):
|
||||
"""Base recognition loss.
|
||||
|
||||
Args:
|
|
@ -8,11 +8,11 @@ import torch.nn as nn
|
|||
from mmocr.models.common.dictionary import Dictionary
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_recog_module_loss import BaseRecogModuleLoss
|
||||
from .base import BaseTextRecogModuleLoss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CEModuleLoss(BaseRecogModuleLoss):
|
||||
class CEModuleLoss(BaseTextRecogModuleLoss):
|
||||
"""Implementation of loss module for encoder-decoder based text recognition
|
||||
method with CrossEntropy loss.
|
||||
|
||||
|
|
|
@ -8,11 +8,11 @@ import torch.nn as nn
|
|||
from mmocr.models.common.dictionary import Dictionary
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_recog_module_loss import BaseRecogModuleLoss
|
||||
from .base import BaseTextRecogModuleLoss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CTCModuleLoss(BaseRecogModuleLoss):
|
||||
class CTCModuleLoss(BaseTextRecogModuleLoss):
|
||||
"""Implementation of loss module for CTC-loss based text recognition.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .attn_postprocessor import AttentionPostprocessor
|
||||
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
|
||||
from .base import BaseTextRecogPostprocessor
|
||||
from .ctc_postprocessor import CTCPostProcessor
|
||||
|
||||
__all__ = [
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
|
||||
from .base import BaseTextRecogPostprocessor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
|
||||
from .base import BaseTextRecogPostprocessor
|
||||
|
||||
|
||||
# TODO support beam search
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .abinet import ABINet
|
||||
from .base_recognizer import BaseRecognizer
|
||||
from .base import BaseRecognizer
|
||||
from .crnn import CRNN
|
||||
from .encoder_decoder_recognizer import EncoderDecoderRecognizer
|
||||
from .master import MASTER
|
||||
|
|
|
@ -8,7 +8,7 @@ from mmocr.registry import MODELS
|
|||
from mmocr.utils.typing import (ConfigType, InitConfigType, OptConfigType,
|
||||
OptRecSampleList, RecForwardResults,
|
||||
RecSampleList)
|
||||
from .base_recognizer import BaseRecognizer
|
||||
from .base import BaseRecognizer
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
|||
from mmengine.data import LabelData
|
||||
|
||||
from mmocr.models.common.dictionary import Dictionary
|
||||
from mmocr.models.textrecog.module_losses import BaseRecogModuleLoss
|
||||
from mmocr.models.textrecog.module_losses import BaseTextRecogModuleLoss
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from mmocr.testing import create_dummy_dict_file
|
||||
|
||||
|
@ -34,27 +34,30 @@ class TestBaseRecogModuleLoss(TestCase):
|
|||
same_start_end=False,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
base_recog_loss = BaseRecogModuleLoss(dict_cfg)
|
||||
base_recog_loss = BaseTextRecogModuleLoss(dict_cfg)
|
||||
self.assertIsInstance(base_recog_loss.dictionary, Dictionary)
|
||||
# test case mode
|
||||
with self.assertRaises(AssertionError):
|
||||
base_recog_loss = BaseRecogModuleLoss(dict_cfg, letter_case='no')
|
||||
base_recog_loss = BaseTextRecogModuleLoss(
|
||||
dict_cfg, letter_case='no')
|
||||
# test invalid pad_with
|
||||
with self.assertRaises(AssertionError):
|
||||
base_recog_loss = BaseRecogModuleLoss(dict_cfg, pad_with='test')
|
||||
base_recog_loss = BaseTextRecogModuleLoss(
|
||||
dict_cfg, pad_with='test')
|
||||
# test invalid combination of dictionary and pad_with
|
||||
dict_cfg = dict(type='Dictionary', dict_file=dict_file, with_end=False)
|
||||
for pad_with in ['end', 'padding']:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, f'pad_with="{pad_with}", but'
|
||||
f' dictionary.{pad_with}_idx is None'):
|
||||
base_recog_loss = BaseRecogModuleLoss(
|
||||
base_recog_loss = BaseTextRecogModuleLoss(
|
||||
dict_cfg, pad_with=pad_with)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'pad_with="auto", but'
|
||||
' dictionary.end_idx and dictionary.padding_idx are both'
|
||||
' None'):
|
||||
base_recog_loss = BaseRecogModuleLoss(dict_cfg, pad_with='auto')
|
||||
base_recog_loss = BaseTextRecogModuleLoss(
|
||||
dict_cfg, pad_with='auto')
|
||||
|
||||
# test dictionary is invalid type
|
||||
dict_cfg = ['tmp']
|
||||
|
@ -62,7 +65,7 @@ class TestBaseRecogModuleLoss(TestCase):
|
|||
TypeError, ('The type of dictionary should be `Dictionary`'
|
||||
' or dict, '
|
||||
f'but got {type(dict_cfg)}')):
|
||||
base_recog_loss = BaseRecogModuleLoss(dict_cfg)
|
||||
base_recog_loss = BaseTextRecogModuleLoss(dict_cfg)
|
||||
|
||||
tmp_dir.cleanup()
|
||||
|
||||
|
@ -81,7 +84,7 @@ class TestBaseRecogModuleLoss(TestCase):
|
|||
same_start_end=False,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
base_recog_loss = BaseRecogModuleLoss(dictionary, max_seq_len=10)
|
||||
base_recog_loss = BaseTextRecogModuleLoss(dictionary, max_seq_len=10)
|
||||
target_data_samples = base_recog_loss.get_targets([data_sample])
|
||||
assert self._equal(target_data_samples[0].gt_text.indexes,
|
||||
torch.LongTensor([0, 1, 2, 3]))
|
||||
|
@ -104,7 +107,7 @@ class TestBaseRecogModuleLoss(TestCase):
|
|||
same_start_end=False,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
base_recog_loss = BaseRecogModuleLoss(dictionary, max_seq_len=3)
|
||||
base_recog_loss = BaseTextRecogModuleLoss(dictionary, max_seq_len=3)
|
||||
data_sample.gt_text.item = '0123'
|
||||
target_data_samples = base_recog_loss.get_targets([data_sample])
|
||||
assert self._equal(target_data_samples[0].gt_text.indexes,
|
||||
|
@ -122,7 +125,7 @@ class TestBaseRecogModuleLoss(TestCase):
|
|||
same_start_end=False,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
base_recog_loss = BaseRecogModuleLoss(
|
||||
base_recog_loss = BaseTextRecogModuleLoss(
|
||||
dict_cfg, max_seq_len=10, letter_case='lower', pad_with='none')
|
||||
data_sample.gt_text.item = '0123'
|
||||
target_data_samples = base_recog_loss.get_targets([data_sample])
|
||||
|
|
Loading…
Reference in New Issue