mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Refactor] Rename base_xxx.py as base.py (#1322)
This commit is contained in:
parent
b81d58e70c
commit
1b5764b155
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .base_textdet_head import BaseTextDetHead
|
from .base import BaseTextDetHead
|
||||||
from .db_head import DBHead
|
from .db_head import DBHead
|
||||||
from .drrg_head import DRRGHead
|
from .drrg_head import DRRGHead
|
||||||
from .fce_head import FCEHead
|
from .fce_head import FCEHead
|
||||||
|
@ -7,7 +7,7 @@ import torch.nn as nn
|
|||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextDetDataSample
|
from mmocr.structures import TextDetDataSample
|
||||||
from mmocr.utils import check_argument
|
from mmocr.utils import check_argument
|
||||||
from .base_textdet_head import BaseTextDetHead
|
from .base import BaseTextDetHead
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .base_postprocessor import BaseTextDetPostProcessor
|
from .base import BaseTextDetPostProcessor
|
||||||
from .db_postprocessor import DBPostprocessor
|
from .db_postprocessor import DBPostprocessor
|
||||||
from .drrg_postprocessor import DRRGPostprocessor
|
from .drrg_postprocessor import DRRGPostprocessor
|
||||||
from .fce_postprocessor import FCEPostprocessor
|
from .fce_postprocessor import FCEPostprocessor
|
||||||
|
@ -11,7 +11,7 @@ from torch import Tensor
|
|||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextDetDataSample
|
from mmocr.structures import TextDetDataSample
|
||||||
from mmocr.utils import offset_polygon
|
from mmocr.utils import offset_polygon
|
||||||
from .base_postprocessor import BaseTextDetPostProcessor
|
from .base import BaseTextDetPostProcessor
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -11,7 +11,7 @@ from numpy import ndarray
|
|||||||
|
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextDetDataSample
|
from mmocr.structures import TextDetDataSample
|
||||||
from .base_postprocessor import BaseTextDetPostProcessor
|
from .base import BaseTextDetPostProcessor
|
||||||
|
|
||||||
|
|
||||||
class Node:
|
class Node:
|
||||||
|
@ -10,7 +10,7 @@ from numpy.fft import ifft
|
|||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextDetDataSample
|
from mmocr.structures import TextDetDataSample
|
||||||
from mmocr.utils import fill_hole
|
from mmocr.utils import fill_hole
|
||||||
from .base_postprocessor import BaseTextDetPostProcessor
|
from .base import BaseTextDetPostProcessor
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -9,7 +9,7 @@ from mmengine.data import InstanceData
|
|||||||
|
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextDetDataSample
|
from mmocr.structures import TextDetDataSample
|
||||||
from .base_postprocessor import BaseTextDetPostProcessor
|
from .base import BaseTextDetPostProcessor
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -12,7 +12,7 @@ from skimage.morphology import skeletonize
|
|||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextDetDataSample
|
from mmocr.structures import TextDetDataSample
|
||||||
from mmocr.utils import fill_hole
|
from mmocr.utils import fill_hole
|
||||||
from .base_postprocessor import BaseTextDetPostProcessor
|
from .base import BaseTextDetPostProcessor
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from .abi_fuser import ABIFuser
|
from .abi_fuser import ABIFuser
|
||||||
from .abi_language_decoder import ABILanguageDecoder
|
from .abi_language_decoder import ABILanguageDecoder
|
||||||
from .abi_vision_decoder import ABIVisionDecoder
|
from .abi_vision_decoder import ABIVisionDecoder
|
||||||
from .base_decoder import BaseDecoder
|
from .base import BaseDecoder
|
||||||
from .crnn_decoder import CRNNDecoder
|
from .crnn_decoder import CRNNDecoder
|
||||||
from .master_decoder import MasterDecoder
|
from .master_decoder import MasterDecoder
|
||||||
from .nrtr_decoder import NRTRDecoder
|
from .nrtr_decoder import NRTRDecoder
|
||||||
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||||||
from mmocr.models.common.dictionary import Dictionary
|
from mmocr.models.common.dictionary import Dictionary
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_decoder import BaseDecoder
|
from .base import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -11,7 +11,7 @@ from mmocr.models.common.dictionary import Dictionary
|
|||||||
from mmocr.models.common.modules import PositionalEncoding
|
from mmocr.models.common.modules import PositionalEncoding
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_decoder import BaseDecoder
|
from .base import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -9,7 +9,7 @@ from mmocr.models.common.dictionary import Dictionary
|
|||||||
from mmocr.models.common.modules import PositionalEncoding
|
from mmocr.models.common.modules import PositionalEncoding
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_decoder import BaseDecoder
|
from .base import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -9,7 +9,7 @@ from mmocr.models.common.dictionary import Dictionary
|
|||||||
from mmocr.models.textrecog.layers import BidirectionalLSTM
|
from mmocr.models.textrecog.layers import BidirectionalLSTM
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_decoder import BaseDecoder
|
from .base import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -12,7 +12,7 @@ from mmocr.models.common.dictionary import Dictionary
|
|||||||
from mmocr.models.common.modules import PositionalEncoding
|
from mmocr.models.common.modules import PositionalEncoding
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_decoder import BaseDecoder
|
from .base import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
def clones(module: nn.Module, N: int) -> nn.ModuleList:
|
def clones(module: nn.Module, N: int) -> nn.ModuleList:
|
||||||
|
@ -10,7 +10,7 @@ from mmocr.models.common import PositionalEncoding, TFDecoderLayer
|
|||||||
from mmocr.models.common.dictionary import Dictionary
|
from mmocr.models.common.dictionary import Dictionary
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_decoder import BaseDecoder
|
from .base import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -10,7 +10,7 @@ from mmocr.models.textrecog.layers import (DotProductAttentionLayer,
|
|||||||
PositionAwareLayer)
|
PositionAwareLayer)
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_decoder import BaseDecoder
|
from .base import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||||||
from mmocr.models.common.dictionary import Dictionary
|
from mmocr.models.common.dictionary import Dictionary
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_decoder import BaseDecoder
|
from .base import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -9,7 +9,7 @@ import torch.nn.functional as F
|
|||||||
from mmocr.models.common.dictionary import Dictionary
|
from mmocr.models.common.dictionary import Dictionary
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_decoder import BaseDecoder
|
from .base import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -9,7 +9,7 @@ from mmocr.models.common.dictionary import Dictionary
|
|||||||
from mmocr.models.textrecog.layers import DotProductAttentionLayer
|
from mmocr.models.textrecog.layers import DotProductAttentionLayer
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_decoder import BaseDecoder
|
from .base import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .abi_encoder import ABIEncoder
|
from .abi_encoder import ABIEncoder
|
||||||
from .base_encoder import BaseEncoder
|
from .base import BaseEncoder
|
||||||
from .channel_reduction_encoder import ChannelReductionEncoder
|
from .channel_reduction_encoder import ChannelReductionEncoder
|
||||||
from .nrtr_encoder import NRTREncoder
|
from .nrtr_encoder import NRTREncoder
|
||||||
from .sar_encoder import SAREncoder
|
from .sar_encoder import SAREncoder
|
||||||
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_encoder import BaseEncoder
|
from .base import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -9,7 +9,7 @@ from mmengine.model import ModuleList
|
|||||||
from mmocr.models.common import TFEncoderLayer
|
from mmocr.models.common import TFEncoderLayer
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_encoder import BaseEncoder
|
from .base import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -8,7 +8,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_encoder import BaseEncoder
|
from .base import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -11,7 +11,7 @@ from mmocr.models.textrecog.layers import (Adaptive2DPositionalEncoding,
|
|||||||
SATRNEncoderLayer)
|
SATRNEncoderLayer)
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_encoder import BaseEncoder
|
from .base import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .abi_module_loss import ABIModuleLoss
|
from .abi_module_loss import ABIModuleLoss
|
||||||
from .base_recog_module_loss import BaseRecogModuleLoss
|
from .base import BaseTextRecogModuleLoss
|
||||||
from .ce_module_loss import CEModuleLoss
|
from .ce_module_loss import CEModuleLoss
|
||||||
from .ctc_module_loss import CTCModuleLoss
|
from .ctc_module_loss import CTCModuleLoss
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BaseRecogModuleLoss', 'CEModuleLoss', 'CTCModuleLoss', 'ABIModuleLoss'
|
'BaseTextRecogModuleLoss', 'CEModuleLoss', 'CTCModuleLoss', 'ABIModuleLoss'
|
||||||
]
|
]
|
||||||
|
@ -6,12 +6,12 @@ import torch
|
|||||||
from mmocr.models.common.dictionary import Dictionary
|
from mmocr.models.common.dictionary import Dictionary
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_recog_module_loss import BaseRecogModuleLoss
|
from .base import BaseTextRecogModuleLoss
|
||||||
from .ce_module_loss import CEModuleLoss
|
from .ce_module_loss import CEModuleLoss
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class ABIModuleLoss(BaseRecogModuleLoss):
|
class ABIModuleLoss(BaseTextRecogModuleLoss):
|
||||||
"""Implementation of ABINet multiloss that allows mixing different types of
|
"""Implementation of ABINet multiloss that allows mixing different types of
|
||||||
losses with weights.
|
losses with weights.
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ from mmocr.registry import TASK_UTILS
|
|||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
|
|
||||||
|
|
||||||
class BaseRecogModuleLoss(nn.Module):
|
class BaseTextRecogModuleLoss(nn.Module):
|
||||||
"""Base recognition loss.
|
"""Base recognition loss.
|
||||||
|
|
||||||
Args:
|
Args:
|
@ -8,11 +8,11 @@ import torch.nn as nn
|
|||||||
from mmocr.models.common.dictionary import Dictionary
|
from mmocr.models.common.dictionary import Dictionary
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_recog_module_loss import BaseRecogModuleLoss
|
from .base import BaseTextRecogModuleLoss
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class CEModuleLoss(BaseRecogModuleLoss):
|
class CEModuleLoss(BaseTextRecogModuleLoss):
|
||||||
"""Implementation of loss module for encoder-decoder based text recognition
|
"""Implementation of loss module for encoder-decoder based text recognition
|
||||||
method with CrossEntropy loss.
|
method with CrossEntropy loss.
|
||||||
|
|
||||||
|
@ -8,11 +8,11 @@ import torch.nn as nn
|
|||||||
from mmocr.models.common.dictionary import Dictionary
|
from mmocr.models.common.dictionary import Dictionary
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_recog_module_loss import BaseRecogModuleLoss
|
from .base import BaseTextRecogModuleLoss
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class CTCModuleLoss(BaseRecogModuleLoss):
|
class CTCModuleLoss(BaseTextRecogModuleLoss):
|
||||||
"""Implementation of loss module for CTC-loss based text recognition.
|
"""Implementation of loss module for CTC-loss based text recognition.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .attn_postprocessor import AttentionPostprocessor
|
from .attn_postprocessor import AttentionPostprocessor
|
||||||
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
|
from .base import BaseTextRecogPostprocessor
|
||||||
from .ctc_postprocessor import CTCPostProcessor
|
from .ctc_postprocessor import CTCPostProcessor
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -5,7 +5,7 @@ import torch
|
|||||||
|
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
|
from .base import BaseTextRecogPostprocessor
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -6,7 +6,7 @@ import torch
|
|||||||
|
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from .base_textrecog_postprocessor import BaseTextRecogPostprocessor
|
from .base import BaseTextRecogPostprocessor
|
||||||
|
|
||||||
|
|
||||||
# TODO support beam search
|
# TODO support beam search
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .abinet import ABINet
|
from .abinet import ABINet
|
||||||
from .base_recognizer import BaseRecognizer
|
from .base import BaseRecognizer
|
||||||
from .crnn import CRNN
|
from .crnn import CRNN
|
||||||
from .encoder_decoder_recognizer import EncoderDecoderRecognizer
|
from .encoder_decoder_recognizer import EncoderDecoderRecognizer
|
||||||
from .master import MASTER
|
from .master import MASTER
|
||||||
|
@ -8,7 +8,7 @@ from mmocr.registry import MODELS
|
|||||||
from mmocr.utils.typing import (ConfigType, InitConfigType, OptConfigType,
|
from mmocr.utils.typing import (ConfigType, InitConfigType, OptConfigType,
|
||||||
OptRecSampleList, RecForwardResults,
|
OptRecSampleList, RecForwardResults,
|
||||||
RecSampleList)
|
RecSampleList)
|
||||||
from .base_recognizer import BaseRecognizer
|
from .base import BaseRecognizer
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -8,7 +8,7 @@ import torch
|
|||||||
from mmengine.data import LabelData
|
from mmengine.data import LabelData
|
||||||
|
|
||||||
from mmocr.models.common.dictionary import Dictionary
|
from mmocr.models.common.dictionary import Dictionary
|
||||||
from mmocr.models.textrecog.module_losses import BaseRecogModuleLoss
|
from mmocr.models.textrecog.module_losses import BaseTextRecogModuleLoss
|
||||||
from mmocr.structures import TextRecogDataSample
|
from mmocr.structures import TextRecogDataSample
|
||||||
from mmocr.testing import create_dummy_dict_file
|
from mmocr.testing import create_dummy_dict_file
|
||||||
|
|
||||||
@ -34,27 +34,30 @@ class TestBaseRecogModuleLoss(TestCase):
|
|||||||
same_start_end=False,
|
same_start_end=False,
|
||||||
with_padding=True,
|
with_padding=True,
|
||||||
with_unknown=True)
|
with_unknown=True)
|
||||||
base_recog_loss = BaseRecogModuleLoss(dict_cfg)
|
base_recog_loss = BaseTextRecogModuleLoss(dict_cfg)
|
||||||
self.assertIsInstance(base_recog_loss.dictionary, Dictionary)
|
self.assertIsInstance(base_recog_loss.dictionary, Dictionary)
|
||||||
# test case mode
|
# test case mode
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
base_recog_loss = BaseRecogModuleLoss(dict_cfg, letter_case='no')
|
base_recog_loss = BaseTextRecogModuleLoss(
|
||||||
|
dict_cfg, letter_case='no')
|
||||||
# test invalid pad_with
|
# test invalid pad_with
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
base_recog_loss = BaseRecogModuleLoss(dict_cfg, pad_with='test')
|
base_recog_loss = BaseTextRecogModuleLoss(
|
||||||
|
dict_cfg, pad_with='test')
|
||||||
# test invalid combination of dictionary and pad_with
|
# test invalid combination of dictionary and pad_with
|
||||||
dict_cfg = dict(type='Dictionary', dict_file=dict_file, with_end=False)
|
dict_cfg = dict(type='Dictionary', dict_file=dict_file, with_end=False)
|
||||||
for pad_with in ['end', 'padding']:
|
for pad_with in ['end', 'padding']:
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, f'pad_with="{pad_with}", but'
|
ValueError, f'pad_with="{pad_with}", but'
|
||||||
f' dictionary.{pad_with}_idx is None'):
|
f' dictionary.{pad_with}_idx is None'):
|
||||||
base_recog_loss = BaseRecogModuleLoss(
|
base_recog_loss = BaseTextRecogModuleLoss(
|
||||||
dict_cfg, pad_with=pad_with)
|
dict_cfg, pad_with=pad_with)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, 'pad_with="auto", but'
|
ValueError, 'pad_with="auto", but'
|
||||||
' dictionary.end_idx and dictionary.padding_idx are both'
|
' dictionary.end_idx and dictionary.padding_idx are both'
|
||||||
' None'):
|
' None'):
|
||||||
base_recog_loss = BaseRecogModuleLoss(dict_cfg, pad_with='auto')
|
base_recog_loss = BaseTextRecogModuleLoss(
|
||||||
|
dict_cfg, pad_with='auto')
|
||||||
|
|
||||||
# test dictionary is invalid type
|
# test dictionary is invalid type
|
||||||
dict_cfg = ['tmp']
|
dict_cfg = ['tmp']
|
||||||
@ -62,7 +65,7 @@ class TestBaseRecogModuleLoss(TestCase):
|
|||||||
TypeError, ('The type of dictionary should be `Dictionary`'
|
TypeError, ('The type of dictionary should be `Dictionary`'
|
||||||
' or dict, '
|
' or dict, '
|
||||||
f'but got {type(dict_cfg)}')):
|
f'but got {type(dict_cfg)}')):
|
||||||
base_recog_loss = BaseRecogModuleLoss(dict_cfg)
|
base_recog_loss = BaseTextRecogModuleLoss(dict_cfg)
|
||||||
|
|
||||||
tmp_dir.cleanup()
|
tmp_dir.cleanup()
|
||||||
|
|
||||||
@ -81,7 +84,7 @@ class TestBaseRecogModuleLoss(TestCase):
|
|||||||
same_start_end=False,
|
same_start_end=False,
|
||||||
with_padding=True,
|
with_padding=True,
|
||||||
with_unknown=True)
|
with_unknown=True)
|
||||||
base_recog_loss = BaseRecogModuleLoss(dictionary, max_seq_len=10)
|
base_recog_loss = BaseTextRecogModuleLoss(dictionary, max_seq_len=10)
|
||||||
target_data_samples = base_recog_loss.get_targets([data_sample])
|
target_data_samples = base_recog_loss.get_targets([data_sample])
|
||||||
assert self._equal(target_data_samples[0].gt_text.indexes,
|
assert self._equal(target_data_samples[0].gt_text.indexes,
|
||||||
torch.LongTensor([0, 1, 2, 3]))
|
torch.LongTensor([0, 1, 2, 3]))
|
||||||
@ -104,7 +107,7 @@ class TestBaseRecogModuleLoss(TestCase):
|
|||||||
same_start_end=False,
|
same_start_end=False,
|
||||||
with_padding=True,
|
with_padding=True,
|
||||||
with_unknown=True)
|
with_unknown=True)
|
||||||
base_recog_loss = BaseRecogModuleLoss(dictionary, max_seq_len=3)
|
base_recog_loss = BaseTextRecogModuleLoss(dictionary, max_seq_len=3)
|
||||||
data_sample.gt_text.item = '0123'
|
data_sample.gt_text.item = '0123'
|
||||||
target_data_samples = base_recog_loss.get_targets([data_sample])
|
target_data_samples = base_recog_loss.get_targets([data_sample])
|
||||||
assert self._equal(target_data_samples[0].gt_text.indexes,
|
assert self._equal(target_data_samples[0].gt_text.indexes,
|
||||||
@ -122,7 +125,7 @@ class TestBaseRecogModuleLoss(TestCase):
|
|||||||
same_start_end=False,
|
same_start_end=False,
|
||||||
with_padding=True,
|
with_padding=True,
|
||||||
with_unknown=True)
|
with_unknown=True)
|
||||||
base_recog_loss = BaseRecogModuleLoss(
|
base_recog_loss = BaseTextRecogModuleLoss(
|
||||||
dict_cfg, max_seq_len=10, letter_case='lower', pad_with='none')
|
dict_cfg, max_seq_len=10, letter_case='lower', pad_with='none')
|
||||||
data_sample.gt_text.item = '0123'
|
data_sample.gt_text.item = '0123'
|
||||||
target_data_samples = base_recog_loss.get_targets([data_sample])
|
target_data_samples = base_recog_loss.get_targets([data_sample])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user