[Refactor] Rename base_xxx.py as base.py (#1322)

This commit is contained in:
Tong Gao 2022-08-25 11:20:42 +08:00 committed by GitHub
parent b81d58e70c
commit 1b5764b155
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 51 additions and 48 deletions

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .base_textdet_head import BaseTextDetHead from .base import BaseTextDetHead
from .db_head import DBHead from .db_head import DBHead
from .drrg_head import DRRGHead from .drrg_head import DRRGHead
from .fce_head import FCEHead from .fce_head import FCEHead

View File

@ -7,7 +7,7 @@ import torch.nn as nn
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample from mmocr.structures import TextDetDataSample
from mmocr.utils import check_argument from mmocr.utils import check_argument
from .base_textdet_head import BaseTextDetHead from .base import BaseTextDetHead
@MODELS.register_module() @MODELS.register_module()

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .base_postprocessor import BaseTextDetPostProcessor from .base import BaseTextDetPostProcessor
from .db_postprocessor import DBPostprocessor from .db_postprocessor import DBPostprocessor
from .drrg_postprocessor import DRRGPostprocessor from .drrg_postprocessor import DRRGPostprocessor
from .fce_postprocessor import FCEPostprocessor from .fce_postprocessor import FCEPostprocessor

View File

@ -11,7 +11,7 @@ from torch import Tensor
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample from mmocr.structures import TextDetDataSample
from mmocr.utils import offset_polygon from mmocr.utils import offset_polygon
from .base_postprocessor import BaseTextDetPostProcessor from .base import BaseTextDetPostProcessor
@MODELS.register_module() @MODELS.register_module()

View File

@ -11,7 +11,7 @@ from numpy import ndarray
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample from mmocr.structures import TextDetDataSample
from .base_postprocessor import BaseTextDetPostProcessor from .base import BaseTextDetPostProcessor
class Node: class Node:

View File

@ -10,7 +10,7 @@ from numpy.fft import ifft
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample from mmocr.structures import TextDetDataSample
from mmocr.utils import fill_hole from mmocr.utils import fill_hole
from .base_postprocessor import BaseTextDetPostProcessor from .base import BaseTextDetPostProcessor
@MODELS.register_module() @MODELS.register_module()

View File

@ -9,7 +9,7 @@ from mmengine.data import InstanceData
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample from mmocr.structures import TextDetDataSample
from .base_postprocessor import BaseTextDetPostProcessor from .base import BaseTextDetPostProcessor
@MODELS.register_module() @MODELS.register_module()

View File

@ -12,7 +12,7 @@ from skimage.morphology import skeletonize
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample from mmocr.structures import TextDetDataSample
from mmocr.utils import fill_hole from mmocr.utils import fill_hole
from .base_postprocessor import BaseTextDetPostProcessor from .base import BaseTextDetPostProcessor
@MODELS.register_module() @MODELS.register_module()

View File

@ -2,7 +2,7 @@
from .abi_fuser import ABIFuser from .abi_fuser import ABIFuser
from .abi_language_decoder import ABILanguageDecoder from .abi_language_decoder import ABILanguageDecoder
from .abi_vision_decoder import ABIVisionDecoder from .abi_vision_decoder import ABIVisionDecoder
from .base_decoder import BaseDecoder from .base import BaseDecoder
from .crnn_decoder import CRNNDecoder from .crnn_decoder import CRNNDecoder
from .master_decoder import MasterDecoder from .master_decoder import MasterDecoder
from .nrtr_decoder import NRTRDecoder from .nrtr_decoder import NRTRDecoder

View File

@ -8,7 +8,7 @@ import torch.nn as nn
from mmocr.models.common.dictionary import Dictionary from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder from .base import BaseDecoder
@MODELS.register_module() @MODELS.register_module()

View File

@ -11,7 +11,7 @@ from mmocr.models.common.dictionary import Dictionary
from mmocr.models.common.modules import PositionalEncoding from mmocr.models.common.modules import PositionalEncoding
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder from .base import BaseDecoder
@MODELS.register_module() @MODELS.register_module()

View File

@ -9,7 +9,7 @@ from mmocr.models.common.dictionary import Dictionary
from mmocr.models.common.modules import PositionalEncoding from mmocr.models.common.modules import PositionalEncoding
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder from .base import BaseDecoder
@MODELS.register_module() @MODELS.register_module()

View File

@ -9,7 +9,7 @@ from mmocr.models.common.dictionary import Dictionary
from mmocr.models.textrecog.layers import BidirectionalLSTM from mmocr.models.textrecog.layers import BidirectionalLSTM
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder from .base import BaseDecoder
@MODELS.register_module() @MODELS.register_module()

View File

@ -12,7 +12,7 @@ from mmocr.models.common.dictionary import Dictionary
from mmocr.models.common.modules import PositionalEncoding from mmocr.models.common.modules import PositionalEncoding
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder from .base import BaseDecoder
def clones(module: nn.Module, N: int) -> nn.ModuleList: def clones(module: nn.Module, N: int) -> nn.ModuleList:

View File

@ -10,7 +10,7 @@ from mmocr.models.common import PositionalEncoding, TFDecoderLayer
from mmocr.models.common.dictionary import Dictionary from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder from .base import BaseDecoder
@MODELS.register_module() @MODELS.register_module()

View File

@ -10,7 +10,7 @@ from mmocr.models.textrecog.layers import (DotProductAttentionLayer,
PositionAwareLayer) PositionAwareLayer)
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder from .base import BaseDecoder
@MODELS.register_module() @MODELS.register_module()

View File

@ -8,7 +8,7 @@ import torch.nn as nn
from mmocr.models.common.dictionary import Dictionary from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder from .base import BaseDecoder
@MODELS.register_module() @MODELS.register_module()

View File

@ -9,7 +9,7 @@ import torch.nn.functional as F
from mmocr.models.common.dictionary import Dictionary from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder from .base import BaseDecoder
@MODELS.register_module() @MODELS.register_module()

View File

@ -9,7 +9,7 @@ from mmocr.models.common.dictionary import Dictionary
from mmocr.models.textrecog.layers import DotProductAttentionLayer from mmocr.models.textrecog.layers import DotProductAttentionLayer
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder from .base import BaseDecoder
@MODELS.register_module() @MODELS.register_module()

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .abi_encoder import ABIEncoder from .abi_encoder import ABIEncoder
from .base_encoder import BaseEncoder from .base import BaseEncoder
from .channel_reduction_encoder import ChannelReductionEncoder from .channel_reduction_encoder import ChannelReductionEncoder
from .nrtr_encoder import NRTREncoder from .nrtr_encoder import NRTREncoder
from .sar_encoder import SAREncoder from .sar_encoder import SAREncoder

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_encoder import BaseEncoder from .base import BaseEncoder
@MODELS.register_module() @MODELS.register_module()

View File

@ -9,7 +9,7 @@ from mmengine.model import ModuleList
from mmocr.models.common import TFEncoderLayer from mmocr.models.common import TFEncoderLayer
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_encoder import BaseEncoder from .base import BaseEncoder
@MODELS.register_module() @MODELS.register_module()

View File

@ -8,7 +8,7 @@ import torch.nn.functional as F
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_encoder import BaseEncoder from .base import BaseEncoder
@MODELS.register_module() @MODELS.register_module()

View File

@ -11,7 +11,7 @@ from mmocr.models.textrecog.layers import (Adaptive2DPositionalEncoding,
SATRNEncoderLayer) SATRNEncoderLayer)
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_encoder import BaseEncoder from .base import BaseEncoder
@MODELS.register_module() @MODELS.register_module()

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .abi_module_loss import ABIModuleLoss from .abi_module_loss import ABIModuleLoss
from .base_recog_module_loss import BaseRecogModuleLoss from .base import BaseTextRecogModuleLoss
from .ce_module_loss import CEModuleLoss from .ce_module_loss import CEModuleLoss
from .ctc_module_loss import CTCModuleLoss from .ctc_module_loss import CTCModuleLoss
__all__ = [ __all__ = [
'BaseRecogModuleLoss', 'CEModuleLoss', 'CTCModuleLoss', 'ABIModuleLoss' 'BaseTextRecogModuleLoss', 'CEModuleLoss', 'CTCModuleLoss', 'ABIModuleLoss'
] ]

View File

@ -6,12 +6,12 @@ import torch
from mmocr.models.common.dictionary import Dictionary from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_recog_module_loss import BaseRecogModuleLoss from .base import BaseTextRecogModuleLoss
from .ce_module_loss import CEModuleLoss from .ce_module_loss import CEModuleLoss
@MODELS.register_module() @MODELS.register_module()
class ABIModuleLoss(BaseRecogModuleLoss): class ABIModuleLoss(BaseTextRecogModuleLoss):
"""Implementation of ABINet multiloss that allows mixing different types of """Implementation of ABINet multiloss that allows mixing different types of
losses with weights. losses with weights.

View File

@ -10,7 +10,7 @@ from mmocr.registry import TASK_UTILS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
class BaseRecogModuleLoss(nn.Module): class BaseTextRecogModuleLoss(nn.Module):
"""Base recognition loss. """Base recognition loss.
Args: Args:

View File

@ -8,11 +8,11 @@ import torch.nn as nn
from mmocr.models.common.dictionary import Dictionary from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_recog_module_loss import BaseRecogModuleLoss from .base import BaseTextRecogModuleLoss
@MODELS.register_module() @MODELS.register_module()
class CEModuleLoss(BaseRecogModuleLoss): class CEModuleLoss(BaseTextRecogModuleLoss):
"""Implementation of loss module for encoder-decoder based text recognition """Implementation of loss module for encoder-decoder based text recognition
method with CrossEntropy loss. method with CrossEntropy loss.

View File

@ -8,11 +8,11 @@ import torch.nn as nn
from mmocr.models.common.dictionary import Dictionary from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_recog_module_loss import BaseRecogModuleLoss from .base import BaseTextRecogModuleLoss
@MODELS.register_module() @MODELS.register_module()
class CTCModuleLoss(BaseRecogModuleLoss): class CTCModuleLoss(BaseTextRecogModuleLoss):
"""Implementation of loss module for CTC-loss based text recognition. """Implementation of loss module for CTC-loss based text recognition.
Args: Args:

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .attn_postprocessor import AttentionPostprocessor from .attn_postprocessor import AttentionPostprocessor
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor from .base import BaseTextRecogPostprocessor
from .ctc_postprocessor import CTCPostProcessor from .ctc_postprocessor import CTCPostProcessor
__all__ = [ __all__ = [

View File

@ -5,7 +5,7 @@ import torch
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor from .base import BaseTextRecogPostprocessor
@MODELS.register_module() @MODELS.register_module()

View File

@ -6,7 +6,7 @@ import torch
from mmocr.registry import MODELS from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor from .base import BaseTextRecogPostprocessor
# TODO support beam search # TODO support beam search

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .abinet import ABINet from .abinet import ABINet
from .base_recognizer import BaseRecognizer from .base import BaseRecognizer
from .crnn import CRNN from .crnn import CRNN
from .encoder_decoder_recognizer import EncoderDecoderRecognizer from .encoder_decoder_recognizer import EncoderDecoderRecognizer
from .master import MASTER from .master import MASTER

View File

@ -8,7 +8,7 @@ from mmocr.registry import MODELS
from mmocr.utils.typing import (ConfigType, InitConfigType, OptConfigType, from mmocr.utils.typing import (ConfigType, InitConfigType, OptConfigType,
OptRecSampleList, RecForwardResults, OptRecSampleList, RecForwardResults,
RecSampleList) RecSampleList)
from .base_recognizer import BaseRecognizer from .base import BaseRecognizer
@MODELS.register_module() @MODELS.register_module()

View File

@ -8,7 +8,7 @@ import torch
from mmengine.data import LabelData from mmengine.data import LabelData
from mmocr.models.common.dictionary import Dictionary from mmocr.models.common.dictionary import Dictionary
from mmocr.models.textrecog.module_losses import BaseRecogModuleLoss from mmocr.models.textrecog.module_losses import BaseTextRecogModuleLoss
from mmocr.structures import TextRecogDataSample from mmocr.structures import TextRecogDataSample
from mmocr.testing import create_dummy_dict_file from mmocr.testing import create_dummy_dict_file
@ -34,27 +34,30 @@ class TestBaseRecogModuleLoss(TestCase):
same_start_end=False, same_start_end=False,
with_padding=True, with_padding=True,
with_unknown=True) with_unknown=True)
base_recog_loss = BaseRecogModuleLoss(dict_cfg) base_recog_loss = BaseTextRecogModuleLoss(dict_cfg)
self.assertIsInstance(base_recog_loss.dictionary, Dictionary) self.assertIsInstance(base_recog_loss.dictionary, Dictionary)
# test case mode # test case mode
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
base_recog_loss = BaseRecogModuleLoss(dict_cfg, letter_case='no') base_recog_loss = BaseTextRecogModuleLoss(
dict_cfg, letter_case='no')
# test invalid pad_with # test invalid pad_with
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
base_recog_loss = BaseRecogModuleLoss(dict_cfg, pad_with='test') base_recog_loss = BaseTextRecogModuleLoss(
dict_cfg, pad_with='test')
# test invalid combination of dictionary and pad_with # test invalid combination of dictionary and pad_with
dict_cfg = dict(type='Dictionary', dict_file=dict_file, with_end=False) dict_cfg = dict(type='Dictionary', dict_file=dict_file, with_end=False)
for pad_with in ['end', 'padding']: for pad_with in ['end', 'padding']:
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, f'pad_with="{pad_with}", but' ValueError, f'pad_with="{pad_with}", but'
f' dictionary.{pad_with}_idx is None'): f' dictionary.{pad_with}_idx is None'):
base_recog_loss = BaseRecogModuleLoss( base_recog_loss = BaseTextRecogModuleLoss(
dict_cfg, pad_with=pad_with) dict_cfg, pad_with=pad_with)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, 'pad_with="auto", but' ValueError, 'pad_with="auto", but'
' dictionary.end_idx and dictionary.padding_idx are both' ' dictionary.end_idx and dictionary.padding_idx are both'
' None'): ' None'):
base_recog_loss = BaseRecogModuleLoss(dict_cfg, pad_with='auto') base_recog_loss = BaseTextRecogModuleLoss(
dict_cfg, pad_with='auto')
# test dictionary is invalid type # test dictionary is invalid type
dict_cfg = ['tmp'] dict_cfg = ['tmp']
@ -62,7 +65,7 @@ class TestBaseRecogModuleLoss(TestCase):
TypeError, ('The type of dictionary should be `Dictionary`' TypeError, ('The type of dictionary should be `Dictionary`'
' or dict, ' ' or dict, '
f'but got {type(dict_cfg)}')): f'but got {type(dict_cfg)}')):
base_recog_loss = BaseRecogModuleLoss(dict_cfg) base_recog_loss = BaseTextRecogModuleLoss(dict_cfg)
tmp_dir.cleanup() tmp_dir.cleanup()
@ -81,7 +84,7 @@ class TestBaseRecogModuleLoss(TestCase):
same_start_end=False, same_start_end=False,
with_padding=True, with_padding=True,
with_unknown=True) with_unknown=True)
base_recog_loss = BaseRecogModuleLoss(dictionary, max_seq_len=10) base_recog_loss = BaseTextRecogModuleLoss(dictionary, max_seq_len=10)
target_data_samples = base_recog_loss.get_targets([data_sample]) target_data_samples = base_recog_loss.get_targets([data_sample])
assert self._equal(target_data_samples[0].gt_text.indexes, assert self._equal(target_data_samples[0].gt_text.indexes,
torch.LongTensor([0, 1, 2, 3])) torch.LongTensor([0, 1, 2, 3]))
@ -104,7 +107,7 @@ class TestBaseRecogModuleLoss(TestCase):
same_start_end=False, same_start_end=False,
with_padding=True, with_padding=True,
with_unknown=True) with_unknown=True)
base_recog_loss = BaseRecogModuleLoss(dictionary, max_seq_len=3) base_recog_loss = BaseTextRecogModuleLoss(dictionary, max_seq_len=3)
data_sample.gt_text.item = '0123' data_sample.gt_text.item = '0123'
target_data_samples = base_recog_loss.get_targets([data_sample]) target_data_samples = base_recog_loss.get_targets([data_sample])
assert self._equal(target_data_samples[0].gt_text.indexes, assert self._equal(target_data_samples[0].gt_text.indexes,
@ -122,7 +125,7 @@ class TestBaseRecogModuleLoss(TestCase):
same_start_end=False, same_start_end=False,
with_padding=True, with_padding=True,
with_unknown=True) with_unknown=True)
base_recog_loss = BaseRecogModuleLoss( base_recog_loss = BaseTextRecogModuleLoss(
dict_cfg, max_seq_len=10, letter_case='lower', pad_with='none') dict_cfg, max_seq_len=10, letter_case='lower', pad_with='none')
data_sample.gt_text.item = '0123' data_sample.gt_text.item = '0123'
target_data_samples = base_recog_loss.get_targets([data_sample]) target_data_samples = base_recog_loss.get_targets([data_sample])