[Refactor] Rename base_xxx.py as base.py (#1322)

pull/1317/head^2
Tong Gao 2022-08-25 11:20:42 +08:00 committed by GitHub
parent b81d58e70c
commit 1b5764b155
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 51 additions and 48 deletions

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_textdet_head import BaseTextDetHead
from .base import BaseTextDetHead
from .db_head import DBHead
from .drrg_head import DRRGHead
from .fce_head import FCEHead

View File

@ -7,7 +7,7 @@ import torch.nn as nn
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from mmocr.utils import check_argument
from .base_textdet_head import BaseTextDetHead
from .base import BaseTextDetHead
@MODELS.register_module()

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_postprocessor import BaseTextDetPostProcessor
from .base import BaseTextDetPostProcessor
from .db_postprocessor import DBPostprocessor
from .drrg_postprocessor import DRRGPostprocessor
from .fce_postprocessor import FCEPostprocessor

View File

@ -11,7 +11,7 @@ from torch import Tensor
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from mmocr.utils import offset_polygon
from .base_postprocessor import BaseTextDetPostProcessor
from .base import BaseTextDetPostProcessor
@MODELS.register_module()

View File

@ -11,7 +11,7 @@ from numpy import ndarray
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from .base_postprocessor import BaseTextDetPostProcessor
from .base import BaseTextDetPostProcessor
class Node:

View File

@ -10,7 +10,7 @@ from numpy.fft import ifft
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from mmocr.utils import fill_hole
from .base_postprocessor import BaseTextDetPostProcessor
from .base import BaseTextDetPostProcessor
@MODELS.register_module()

View File

@ -9,7 +9,7 @@ from mmengine.data import InstanceData
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from .base_postprocessor import BaseTextDetPostProcessor
from .base import BaseTextDetPostProcessor
@MODELS.register_module()

View File

@ -12,7 +12,7 @@ from skimage.morphology import skeletonize
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from mmocr.utils import fill_hole
from .base_postprocessor import BaseTextDetPostProcessor
from .base import BaseTextDetPostProcessor
@MODELS.register_module()

View File

@ -2,7 +2,7 @@
from .abi_fuser import ABIFuser
from .abi_language_decoder import ABILanguageDecoder
from .abi_vision_decoder import ABIVisionDecoder
from .base_decoder import BaseDecoder
from .base import BaseDecoder
from .crnn_decoder import CRNNDecoder
from .master_decoder import MasterDecoder
from .nrtr_decoder import NRTRDecoder

View File

@ -8,7 +8,7 @@ import torch.nn as nn
from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder
from .base import BaseDecoder
@MODELS.register_module()

View File

@ -11,7 +11,7 @@ from mmocr.models.common.dictionary import Dictionary
from mmocr.models.common.modules import PositionalEncoding
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder
from .base import BaseDecoder
@MODELS.register_module()

View File

@ -9,7 +9,7 @@ from mmocr.models.common.dictionary import Dictionary
from mmocr.models.common.modules import PositionalEncoding
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder
from .base import BaseDecoder
@MODELS.register_module()

View File

@ -9,7 +9,7 @@ from mmocr.models.common.dictionary import Dictionary
from mmocr.models.textrecog.layers import BidirectionalLSTM
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder
from .base import BaseDecoder
@MODELS.register_module()

View File

@ -12,7 +12,7 @@ from mmocr.models.common.dictionary import Dictionary
from mmocr.models.common.modules import PositionalEncoding
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder
from .base import BaseDecoder
def clones(module: nn.Module, N: int) -> nn.ModuleList:

View File

@ -10,7 +10,7 @@ from mmocr.models.common import PositionalEncoding, TFDecoderLayer
from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder
from .base import BaseDecoder
@MODELS.register_module()

View File

@ -10,7 +10,7 @@ from mmocr.models.textrecog.layers import (DotProductAttentionLayer,
PositionAwareLayer)
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder
from .base import BaseDecoder
@MODELS.register_module()

View File

@ -8,7 +8,7 @@ import torch.nn as nn
from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder
from .base import BaseDecoder
@MODELS.register_module()

View File

@ -9,7 +9,7 @@ import torch.nn.functional as F
from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder
from .base import BaseDecoder
@MODELS.register_module()

View File

@ -9,7 +9,7 @@ from mmocr.models.common.dictionary import Dictionary
from mmocr.models.textrecog.layers import DotProductAttentionLayer
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_decoder import BaseDecoder
from .base import BaseDecoder
@MODELS.register_module()

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .abi_encoder import ABIEncoder
from .base_encoder import BaseEncoder
from .base import BaseEncoder
from .channel_reduction_encoder import ChannelReductionEncoder
from .nrtr_encoder import NRTREncoder
from .sar_encoder import SAREncoder

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_encoder import BaseEncoder
from .base import BaseEncoder
@MODELS.register_module()

View File

@ -9,7 +9,7 @@ from mmengine.model import ModuleList
from mmocr.models.common import TFEncoderLayer
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_encoder import BaseEncoder
from .base import BaseEncoder
@MODELS.register_module()

View File

@ -8,7 +8,7 @@ import torch.nn.functional as F
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_encoder import BaseEncoder
from .base import BaseEncoder
@MODELS.register_module()

View File

@ -11,7 +11,7 @@ from mmocr.models.textrecog.layers import (Adaptive2DPositionalEncoding,
SATRNEncoderLayer)
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_encoder import BaseEncoder
from .base import BaseEncoder
@MODELS.register_module()

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .abi_module_loss import ABIModuleLoss
from .base_recog_module_loss import BaseRecogModuleLoss
from .base import BaseTextRecogModuleLoss
from .ce_module_loss import CEModuleLoss
from .ctc_module_loss import CTCModuleLoss
__all__ = [
'BaseRecogModuleLoss', 'CEModuleLoss', 'CTCModuleLoss', 'ABIModuleLoss'
'BaseTextRecogModuleLoss', 'CEModuleLoss', 'CTCModuleLoss', 'ABIModuleLoss'
]

View File

@ -6,12 +6,12 @@ import torch
from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_recog_module_loss import BaseRecogModuleLoss
from .base import BaseTextRecogModuleLoss
from .ce_module_loss import CEModuleLoss
@MODELS.register_module()
class ABIModuleLoss(BaseRecogModuleLoss):
class ABIModuleLoss(BaseTextRecogModuleLoss):
"""Implementation of ABINet multiloss that allows mixing different types of
losses with weights.

View File

@ -10,7 +10,7 @@ from mmocr.registry import TASK_UTILS
from mmocr.structures import TextRecogDataSample
class BaseRecogModuleLoss(nn.Module):
class BaseTextRecogModuleLoss(nn.Module):
"""Base recognition loss.
Args:

View File

@ -8,11 +8,11 @@ import torch.nn as nn
from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_recog_module_loss import BaseRecogModuleLoss
from .base import BaseTextRecogModuleLoss
@MODELS.register_module()
class CEModuleLoss(BaseRecogModuleLoss):
class CEModuleLoss(BaseTextRecogModuleLoss):
"""Implementation of loss module for encoder-decoder based text recognition
method with CrossEntropy loss.

View File

@ -8,11 +8,11 @@ import torch.nn as nn
from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_recog_module_loss import BaseRecogModuleLoss
from .base import BaseTextRecogModuleLoss
@MODELS.register_module()
class CTCModuleLoss(BaseRecogModuleLoss):
class CTCModuleLoss(BaseTextRecogModuleLoss):
"""Implementation of loss module for CTC-loss based text recognition.
Args:

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .attn_postprocessor import AttentionPostprocessor
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
from .base import BaseTextRecogPostprocessor
from .ctc_postprocessor import CTCPostProcessor
__all__ = [

View File

@ -5,7 +5,7 @@ import torch
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
from .base import BaseTextRecogPostprocessor
@MODELS.register_module()

View File

@ -6,7 +6,7 @@ import torch
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
from .base import BaseTextRecogPostprocessor
# TODO support beam search

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .abinet import ABINet
from .base_recognizer import BaseRecognizer
from .base import BaseRecognizer
from .crnn import CRNN
from .encoder_decoder_recognizer import EncoderDecoderRecognizer
from .master import MASTER

View File

@ -8,7 +8,7 @@ from mmocr.registry import MODELS
from mmocr.utils.typing import (ConfigType, InitConfigType, OptConfigType,
OptRecSampleList, RecForwardResults,
RecSampleList)
from .base_recognizer import BaseRecognizer
from .base import BaseRecognizer
@MODELS.register_module()

View File

@ -8,7 +8,7 @@ import torch
from mmengine.data import LabelData
from mmocr.models.common.dictionary import Dictionary
from mmocr.models.textrecog.module_losses import BaseRecogModuleLoss
from mmocr.models.textrecog.module_losses import BaseTextRecogModuleLoss
from mmocr.structures import TextRecogDataSample
from mmocr.testing import create_dummy_dict_file
@ -34,27 +34,30 @@ class TestBaseRecogModuleLoss(TestCase):
same_start_end=False,
with_padding=True,
with_unknown=True)
base_recog_loss = BaseRecogModuleLoss(dict_cfg)
base_recog_loss = BaseTextRecogModuleLoss(dict_cfg)
self.assertIsInstance(base_recog_loss.dictionary, Dictionary)
# test case mode
with self.assertRaises(AssertionError):
base_recog_loss = BaseRecogModuleLoss(dict_cfg, letter_case='no')
base_recog_loss = BaseTextRecogModuleLoss(
dict_cfg, letter_case='no')
# test invalid pad_with
with self.assertRaises(AssertionError):
base_recog_loss = BaseRecogModuleLoss(dict_cfg, pad_with='test')
base_recog_loss = BaseTextRecogModuleLoss(
dict_cfg, pad_with='test')
# test invalid combination of dictionary and pad_with
dict_cfg = dict(type='Dictionary', dict_file=dict_file, with_end=False)
for pad_with in ['end', 'padding']:
with self.assertRaisesRegex(
ValueError, f'pad_with="{pad_with}", but'
f' dictionary.{pad_with}_idx is None'):
base_recog_loss = BaseRecogModuleLoss(
base_recog_loss = BaseTextRecogModuleLoss(
dict_cfg, pad_with=pad_with)
with self.assertRaisesRegex(
ValueError, 'pad_with="auto", but'
' dictionary.end_idx and dictionary.padding_idx are both'
' None'):
base_recog_loss = BaseRecogModuleLoss(dict_cfg, pad_with='auto')
base_recog_loss = BaseTextRecogModuleLoss(
dict_cfg, pad_with='auto')
# test dictionary is invalid type
dict_cfg = ['tmp']
@ -62,7 +65,7 @@ class TestBaseRecogModuleLoss(TestCase):
TypeError, ('The type of dictionary should be `Dictionary`'
' or dict, '
f'but got {type(dict_cfg)}')):
base_recog_loss = BaseRecogModuleLoss(dict_cfg)
base_recog_loss = BaseTextRecogModuleLoss(dict_cfg)
tmp_dir.cleanup()
@ -81,7 +84,7 @@ class TestBaseRecogModuleLoss(TestCase):
same_start_end=False,
with_padding=True,
with_unknown=True)
base_recog_loss = BaseRecogModuleLoss(dictionary, max_seq_len=10)
base_recog_loss = BaseTextRecogModuleLoss(dictionary, max_seq_len=10)
target_data_samples = base_recog_loss.get_targets([data_sample])
assert self._equal(target_data_samples[0].gt_text.indexes,
torch.LongTensor([0, 1, 2, 3]))
@ -104,7 +107,7 @@ class TestBaseRecogModuleLoss(TestCase):
same_start_end=False,
with_padding=True,
with_unknown=True)
base_recog_loss = BaseRecogModuleLoss(dictionary, max_seq_len=3)
base_recog_loss = BaseTextRecogModuleLoss(dictionary, max_seq_len=3)
data_sample.gt_text.item = '0123'
target_data_samples = base_recog_loss.get_targets([data_sample])
assert self._equal(target_data_samples[0].gt_text.indexes,
@ -122,7 +125,7 @@ class TestBaseRecogModuleLoss(TestCase):
same_start_end=False,
with_padding=True,
with_unknown=True)
base_recog_loss = BaseRecogModuleLoss(
base_recog_loss = BaseTextRecogModuleLoss(
dict_cfg, max_seq_len=10, letter_case='lower', pad_with='none')
data_sample.gt_text.item = '0123'
target_data_samples = base_recog_loss.get_targets([data_sample])