From 23458f8a47415d7f5a0beaad341af6de3c4455e2 Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Thu, 12 May 2022 03:01:34 +0000 Subject: [PATCH] [Refactor] union to MODELS --- demo/ner_demo.py | 1 - demo/webcam_demo.py | 1 - mmocr/apis/inference.py | 4 +- mmocr/core/deployment/deploy_utils.py | 10 +- mmocr/datasets/pipelines/ner_transforms.py | 5 +- mmocr/datasets/pipelines/ocr_seg_targets.py | 5 +- mmocr/models/__init__.py | 12 +- mmocr/models/builder.py | 133 +----------------- mmocr/models/common/backbones/unet.py | 7 +- mmocr/models/common/detectors/single_stage.py | 11 +- mmocr/models/common/losses/dice_loss.py | 4 +- mmocr/models/kie/extractors/sdmgr.py | 6 +- mmocr/models/kie/heads/sdmgr_head.py | 6 +- mmocr/models/kie/losses/sdmgr_loss.py | 4 +- .../models/ner/classifiers/ner_classifier.py | 13 +- mmocr/models/ner/convertors/ner_convertor.py | 4 +- mmocr/models/ner/decoders/fc_decoder.py | 4 +- mmocr/models/ner/encoders/bert_encoder.py | 4 +- .../ner/losses/masked_cross_entropy_loss.py | 4 +- mmocr/models/ner/losses/masked_focal_loss.py | 4 +- mmocr/models/textdet/dense_heads/db_head.py | 4 +- mmocr/models/textdet/dense_heads/drrg_head.py | 6 +- mmocr/models/textdet/dense_heads/fce_head.py | 4 +- .../models/textdet/dense_heads/head_mixin.py | 8 +- mmocr/models/textdet/dense_heads/pan_head.py | 4 +- mmocr/models/textdet/dense_heads/pse_head.py | 4 +- .../textdet/dense_heads/textsnake_head.py | 4 +- mmocr/models/textdet/detectors/dbnet.py | 4 +- mmocr/models/textdet/detectors/drrg.py | 4 +- mmocr/models/textdet/detectors/fcenet.py | 4 +- .../models/textdet/detectors/ocr_mask_rcnn.py | 4 +- mmocr/models/textdet/detectors/panet.py | 4 +- mmocr/models/textdet/detectors/psenet.py | 4 +- .../detectors/single_stage_text_detector.py | 4 +- mmocr/models/textdet/detectors/textsnake.py | 4 +- mmocr/models/textdet/losses/db_loss.py | 4 +- mmocr/models/textdet/losses/drrg_loss.py | 4 +- mmocr/models/textdet/losses/fce_loss.py | 4 +- mmocr/models/textdet/losses/pan_loss.py | 4 +- mmocr/models/textdet/losses/pse_loss.py | 4 +- mmocr/models/textdet/losses/textsnake_loss.py | 4 +- mmocr/models/textdet/necks/fpem_ffm.py | 4 +- mmocr/models/textdet/necks/fpn_cat.py | 4 +- mmocr/models/textdet/necks/fpn_unet.py | 4 +- mmocr/models/textdet/necks/fpnf.py | 4 +- .../textdet/postprocess/db_postprocessor.py | 4 +- .../textdet/postprocess/drrg_postprocessor.py | 4 +- .../textdet/postprocess/fce_postprocessor.py | 4 +- .../textdet/postprocess/pan_postprocessor.py | 4 +- .../textdet/postprocess/pse_postprocessor.py | 4 +- .../postprocess/textsnake_postprocessor.py | 4 +- .../backbones/nrtr_modality_transformer.py | 4 +- mmocr/models/textrecog/backbones/resnet.py | 4 +- .../textrecog/backbones/resnet31_ocr.py | 4 +- .../models/textrecog/backbones/resnet_abi.py | 4 +- .../models/textrecog/backbones/shallow_cnn.py | 4 +- .../textrecog/backbones/very_deep_vgg.py | 4 +- mmocr/models/textrecog/convertors/abi.py | 4 +- mmocr/models/textrecog/convertors/attn.py | 4 +- mmocr/models/textrecog/convertors/base.py | 4 +- mmocr/models/textrecog/convertors/ctc.py | 4 +- mmocr/models/textrecog/convertors/seg.py | 4 +- .../decoders/abinet_language_decoder.py | 4 +- .../decoders/abinet_vision_decoder.py | 4 +- .../models/textrecog/decoders/base_decoder.py | 4 +- .../models/textrecog/decoders/crnn_decoder.py | 4 +- .../textrecog/decoders/master_decoder.py | 4 +- .../models/textrecog/decoders/nrtr_decoder.py | 4 +- .../decoders/position_attention_decoder.py | 4 +- .../decoders/robust_scanner_decoder.py | 8 +- .../models/textrecog/decoders/sar_decoder.py | 6 +- .../textrecog/decoders/sar_decoder_with_bs.py | 4 +- .../decoders/sequence_attention_decoder.py | 4 +- .../textrecog/encoders/abinet_vision_model.py | 8 +- .../models/textrecog/encoders/base_encoder.py | 4 +- .../encoders/channel_reduction_encoder.py | 4 +- .../models/textrecog/encoders/nrtr_encoder.py | 4 +- .../models/textrecog/encoders/sar_encoder.py | 4 +- .../textrecog/encoders/satrn_encoder.py | 4 +- .../models/textrecog/encoders/transformer.py | 4 +- mmocr/models/textrecog/fusers/abi_fuser.py | 4 +- mmocr/models/textrecog/heads/seg_head.py | 4 +- mmocr/models/textrecog/losses/ce_loss.py | 8 +- mmocr/models/textrecog/losses/ctc_loss.py | 4 +- mmocr/models/textrecog/losses/mix_loss.py | 4 +- mmocr/models/textrecog/losses/seg_loss.py | 4 +- mmocr/models/textrecog/necks/fpn_ocr.py | 4 +- .../preprocessor/base_preprocessor.py | 4 +- .../preprocessor/tps_preprocessor.py | 4 +- mmocr/models/textrecog/recognizer/abinet.py | 20 ++- mmocr/models/textrecog/recognizer/crnn.py | 4 +- .../recognizer/encode_decode_recognizer.py | 18 ++- mmocr/models/textrecog/recognizer/master.py | 4 +- mmocr/models/textrecog/recognizer/nrtr.py | 4 +- .../textrecog/recognizer/robust_scanner.py | 4 +- mmocr/models/textrecog/recognizer/sar.py | 4 +- mmocr/models/textrecog/recognizer/satrn.py | 4 +- .../textrecog/recognizer/seg_recognizer.py | 18 ++- mmocr/utils/ocr.py | 4 +- 99 files changed, 240 insertions(+), 382 deletions(-) diff --git a/demo/ner_demo.py b/demo/ner_demo.py index f61ee390..003f2e89 100755 --- a/demo/ner_demo.py +++ b/demo/ner_demo.py @@ -3,7 +3,6 @@ from argparse import ArgumentParser from mmocr.apis import init_detector from mmocr.apis.inference import text_model_inference -from mmocr.models import build_detector # NOQA from mmocr.registry import DATASETS # NOQA diff --git a/demo/webcam_demo.py b/demo/webcam_demo.py index 9d2d6859..87db08e6 100644 --- a/demo/webcam_demo.py +++ b/demo/webcam_demo.py @@ -5,7 +5,6 @@ import cv2 import torch from mmocr.apis import init_detector, model_inference -from mmocr.models import build_detector # noqa: F401 from mmocr.registry import DATASETS # noqa: F401 diff --git a/mmocr/apis/inference.py b/mmocr/apis/inference.py index 1a8d5eec..a912ce31 100644 --- a/mmocr/apis/inference.py +++ b/mmocr/apis/inference.py @@ -11,7 +11,7 @@ from mmdet.core import get_classes from mmdet.datasets import replace_ImageToTensor from mmdet.datasets.pipelines import Compose -from mmocr.models import build_detector +from mmocr.registry import MODELS from mmocr.utils import is_2dlist from .utils import disable_text_recog_aug_test @@ -40,7 +40,7 @@ def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None): if config.model.get('pretrained'): config.model.pretrained = None config.model.train_cfg = None - model = build_detector(config.model, test_cfg=config.get('test_cfg')) + model = MODELS.build(config.model, test_cfg=config.get('test_cfg')) if checkpoint is not None: checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') if 'CLASSES' in checkpoint.get('meta', {}): diff --git a/mmocr/core/deployment/deploy_utils.py b/mmocr/core/deployment/deploy_utils.py index 9f5b31bb..7500a2c9 100644 --- a/mmocr/core/deployment/deploy_utils.py +++ b/mmocr/core/deployment/deploy_utils.py @@ -5,7 +5,6 @@ from typing import Any, Iterable import numpy as np import torch -from mmdet.models.builder import DETECTORS from mmocr.models.textdet.detectors.single_stage_text_detector import \ SingleStageTextDetector @@ -13,6 +12,7 @@ from mmocr.models.textdet.detectors.text_detector_mixin import \ TextDetectorMixin from mmocr.models.textrecog.recognizer.encode_decode_recognizer import \ EncodeDecodeRecognizer +from mmocr.registry import MODELS def inference_with_session(sess, io_binding, input_name, output_names, @@ -34,7 +34,7 @@ def inference_with_session(sess, io_binding, input_name, output_names, return pred -@DETECTORS.register_module() +@MODELS.register_module() class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector): """The class for evaluating onnx file of detection.""" @@ -110,7 +110,7 @@ class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector): return boundaries -@DETECTORS.register_module() +@MODELS.register_module() class ONNXRuntimeRecognizer(EncodeDecodeRecognizer): """The class for evaluating onnx file of recognition.""" @@ -201,7 +201,7 @@ class ONNXRuntimeRecognizer(EncodeDecodeRecognizer): return results -@DETECTORS.register_module() +@MODELS.register_module() class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector): """The class for evaluating TensorRT file of detection.""" @@ -257,7 +257,7 @@ class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector): return boundaries -@DETECTORS.register_module() +@MODELS.register_module() class TensorRTRecognizer(EncodeDecodeRecognizer): """The class for evaluating TensorRT file of recognition.""" diff --git a/mmocr/datasets/pipelines/ner_transforms.py b/mmocr/datasets/pipelines/ner_transforms.py index d230fe48..6be0c828 100644 --- a/mmocr/datasets/pipelines/ner_transforms.py +++ b/mmocr/datasets/pipelines/ner_transforms.py @@ -1,8 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from mmocr.models.builder import build_convertor -from mmocr.registry import TRANSFORMS +from mmocr.registry import MODELS, TRANSFORMS @TRANSFORMS.register_module() @@ -18,7 +17,7 @@ class NerTransform: """ def __init__(self, label_convertor, max_len): - self.label_convertor = build_convertor(label_convertor) + self.label_convertor = MODELS.build(label_convertor) self.max_len = max_len def __call__(self, results): diff --git a/mmocr/datasets/pipelines/ocr_seg_targets.py b/mmocr/datasets/pipelines/ocr_seg_targets.py index 9a4258fa..04cbc3ea 100644 --- a/mmocr/datasets/pipelines/ocr_seg_targets.py +++ b/mmocr/datasets/pipelines/ocr_seg_targets.py @@ -4,8 +4,7 @@ import numpy as np from mmdet.core import BitmapMasks import mmocr.utils.check_argument as check_argument -from mmocr.models.builder import build_convertor -from mmocr.registry import TRANSFORMS +from mmocr.registry import MODELS, TRANSFORMS @TRANSFORMS.register_module() @@ -41,7 +40,7 @@ class OCRSegTargets: self.attn_shrink_ratio = attn_shrink_ratio self.seg_shrink_ratio = seg_shrink_ratio - self.label_convertor = build_convertor(label_convertor) + self.label_convertor = MODELS.build(label_convertor) self.box_type = box_type self.pad_val = pad_val diff --git a/mmocr/models/__init__.py b/mmocr/models/__init__.py index e0c7bb89..8e0fcc60 100644 --- a/mmocr/models/__init__.py +++ b/mmocr/models/__init__.py @@ -1,19 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from . import common, kie, textdet, textrecog -from .builder import (BACKBONES, CONVERTORS, DECODERS, DETECTORS, ENCODERS, - HEADS, LOSSES, NECKS, PREPROCESSOR, build_backbone, - build_convertor, build_decoder, build_detector, - build_encoder, build_loss, build_preprocessor) from .common import * # NOQA from .kie import * # NOQA from .ner import * # NOQA from .textdet import * # NOQA from .textrecog import * # NOQA -__all__ = [ - 'BACKBONES', 'DETECTORS', 'HEADS', 'LOSSES', 'NECKS', 'build_backbone', - 'build_detector', 'build_loss', 'CONVERTORS', 'ENCODERS', 'DECODERS', - 'PREPROCESSOR', 'build_convertor', 'build_encoder', 'build_decoder', - 'build_preprocessor' -] -__all__ += common.__all__ + kie.__all__ + textdet.__all__ + textrecog.__all__ +__all__ = common.__all__ + kie.__all__ + textdet.__all__ + textrecog.__all__ diff --git a/mmocr/models/builder.py b/mmocr/models/builder.py index 30d2d657..5fc8296a 100644 --- a/mmocr/models/builder.py +++ b/mmocr/models/builder.py @@ -1,115 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings - import torch.nn as nn +from mmcv.cnn import ACTIVATION_LAYERS as MMCV_ACTIVATION_LAYERS +from mmcv.cnn import UPSAMPLE_LAYERS as MMCV_UPSAMPLE_LAYERS +from mmcv.utils import Registry, build_from_cfg -from mmocr.registry import MODELS - -CONVERTORS = MODELS -ENCODERS = MODELS -DECODERS = MODELS -PREPROCESSOR = MODELS -POSTPROCESSOR = MODELS - -UPSAMPLE_LAYERS = MODELS -BACKBONES = MODELS -LOSSES = MODELS -DETECTORS = MODELS -ROI_EXTRACTORS = MODELS -HEADS = MODELS -NECKS = MODELS -FUSERS = MODELS -RECOGNIZERS = MODELS - -ACTIVATION_LAYERS = MODELS - - -def build_recognizer(cfg, train_cfg=None, test_cfg=None): - """Build recognizer.""" - warnings.warn('``build_recognizer`` would be deprecated soon, please use ' - '``mmocr.registry.MODELS.build()`` ') - return RECOGNIZERS( - cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) - - -def build_convertor(cfg): - """Build label convertor for scene text recognizer.""" - warnings.warn('``build_convertor`` would be deprecated soon, please use ' - '``mmocr.registry.MODELS.build()`` ') - return CONVERTORS.build(cfg) - - -def build_encoder(cfg): - """Build encoder for scene text recognizer.""" - warnings.warn('``build_encoder`` would be deprecated soon, please use ' - '``mmocr.registry.MODELS.build()`` ') - return ENCODERS.build(cfg) - - -def build_decoder(cfg): - """Build decoder for scene text recognizer.""" - warnings.warn('``build_decoder`` would be deprecated soon, please use ' - '``mmocr.registry.MODELS.build()`` ') - return DECODERS.build(cfg) - - -def build_preprocessor(cfg): - """Build preprocessor for scene text recognizer.""" - warnings.warn( - '``build_preprocessor`` would be deprecated soon, please use ' - '``mmocr.registry.MODELS.build()`` ') - return PREPROCESSOR(cfg) - - -def build_postprocessor(cfg): - """Build postprocessor for scene text detector.""" - warnings.warn( - '``build_postprocessor`` would be deprecated soon, please use ' - '``mmocr.registry.MODELS.build()`` ') - return POSTPROCESSOR.build(cfg) - - -def build_roi_extractor(cfg): - """Build roi extractor.""" - warnings.warn( - '``build_roi_extractor`` would be deprecated soon, please use ' - '``mmocr.registry.MODELS.build()`` ') - return ROI_EXTRACTORS.build(cfg) - - -def build_loss(cfg): - """Build loss.""" - warnings.warn('``build_loss`` would be deprecated soon, please use ' - '``mmocr.registry.MODELS.build()`` ') - return LOSSES.build(cfg) - - -def build_backbone(cfg): - """Build backbone.""" - warnings.warn('``build_backbone`` would be deprecated soon, please use ' - '``mmocr.registry.MODELS.build()`` ') - return BACKBONES.build(cfg) - - -def build_head(cfg): - """Build head.""" - warnings.warn('``build_head`` would be deprecated soon, please use ' - '``mmocr.registry.MODELS.build()`` ') - return HEADS.build(cfg) - - -def build_neck(cfg): - """Build neck.""" - warnings.warn('``build_neck`` would be deprecated soon, please use ' - '``mmocr.registry.MODELS.build()`` ') - return NECKS.build(cfg) - - -def build_fuser(cfg): - """Build fuser.""" - warnings.warn('``build_fuser`` would be deprecated soon, please use ' - '``mmocr.registry.MODELS.build()`` ') - return FUSERS.build(cfg) +UPSAMPLE_LAYERS = Registry('upsample layer', parent=MMCV_UPSAMPLE_LAYERS) +ACTIVATION_LAYERS = Registry('activation layer', parent=MMCV_ACTIVATION_LAYERS) def build_upsample_layer(cfg, *args, **kwargs): @@ -160,21 +56,4 @@ def build_activation_layer(cfg): Returns: nn.Module: Created activation layer. """ - warnings.warn( - '``build_activation_layer`` would be deprecated soon, please use ' - '``mmocr.registry.MODELS.build()`` ') - return ACTIVATION_LAYERS.build(cfg) - - -def build_detector(cfg, train_cfg=None, test_cfg=None): - """Build detector.""" - if train_cfg is not None or test_cfg is not None: - warnings.warn( - 'train_cfg and test_cfg is deprecated, ' - 'please specify them in model', UserWarning) - assert cfg.get('train_cfg') is None or train_cfg is None, \ - 'train_cfg specified in both outer field and model field ' - assert cfg.get('test_cfg') is None or test_cfg is None, \ - 'test_cfg specified in both outer field and model field ' - return DETECTORS.build( - cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) + return build_from_cfg(cfg, ACTIVATION_LAYERS) diff --git a/mmocr/models/common/backbones/unet.py b/mmocr/models/common/backbones/unet.py index a69e9f72..5624c68e 100644 --- a/mmocr/models/common/backbones/unet.py +++ b/mmocr/models/common/backbones/unet.py @@ -6,8 +6,9 @@ from mmcv.cnn import ConvModule, build_norm_layer from mmcv.runner import BaseModule from mmcv.utils.parrots_wrapper import _BatchNorm -from mmocr.models.builder import (BACKBONES, UPSAMPLE_LAYERS, - build_activation_layer, build_upsample_layer) +from mmocr.models.builder import (UPSAMPLE_LAYERS, build_activation_layer, + build_upsample_layer) +from mmocr.registry import MODELS class UpConvBlock(nn.Module): @@ -317,7 +318,7 @@ class InterpConv(nn.Module): return out -@BACKBONES.register_module() +@MODELS.register_module() class UNet(BaseModule): """UNet backbone. U-Net: Convolutional Networks for Biomedical Image Segmentation. diff --git a/mmocr/models/common/detectors/single_stage.py b/mmocr/models/common/detectors/single_stage.py index d3a8aebb..b5336523 100644 --- a/mmocr/models/common/detectors/single_stage.py +++ b/mmocr/models/common/detectors/single_stage.py @@ -4,11 +4,10 @@ import warnings from mmdet.models.detectors import \ SingleStageDetector as MMDET_SingleStageDetector -from mmocr.models.builder import (DETECTORS, build_backbone, build_head, - build_neck) +from mmocr.registry import MODELS -@DETECTORS.register_module() +@MODELS.register_module() class SingleStageDetector(MMDET_SingleStageDetector): """Base class for single-stage detectors. @@ -29,11 +28,11 @@ class SingleStageDetector(MMDET_SingleStageDetector): warnings.warn('DeprecationWarning: pretrained is deprecated, ' 'please use "init_cfg" instead') backbone.pretrained = pretrained - self.backbone = build_backbone(backbone) + self.backbone = MODELS.build(backbone) if neck is not None: - self.neck = build_neck(neck) + self.neck = MODELS.build(neck) bbox_head.update(train_cfg=train_cfg) bbox_head.update(test_cfg=test_cfg) - self.bbox_head = build_head(bbox_head) + self.bbox_head = MODELS.build(bbox_head) self.train_cfg = train_cfg self.test_cfg = test_cfg diff --git a/mmocr/models/common/losses/dice_loss.py b/mmocr/models/common/losses/dice_loss.py index 0777200b..6a3973d9 100644 --- a/mmocr/models/common/losses/dice_loss.py +++ b/mmocr/models/common/losses/dice_loss.py @@ -2,10 +2,10 @@ import torch import torch.nn as nn -from mmocr.models.builder import LOSSES +from mmocr.registry import MODELS -@LOSSES.register_module() +@MODELS.register_module() class DiceLoss(nn.Module): def __init__(self, eps=1e-6): diff --git a/mmocr/models/kie/extractors/sdmgr.py b/mmocr/models/kie/extractors/sdmgr.py index 9fa08ccc..726a5f61 100644 --- a/mmocr/models/kie/extractors/sdmgr.py +++ b/mmocr/models/kie/extractors/sdmgr.py @@ -7,12 +7,12 @@ from torch import nn from torch.nn import functional as F from mmocr.core import imshow_edge, imshow_node -from mmocr.models.builder import DETECTORS, build_roi_extractor from mmocr.models.common.detectors import SingleStageDetector +from mmocr.registry import MODELS from mmocr.utils import list_from_file -@DETECTORS.register_module() +@MODELS.register_module() class SDMGR(SingleStageDetector): """The implementation of the paper: Spatial Dual-Modality Graph Reasoning for Key Information Extraction. https://arxiv.org/abs/2103.14470. @@ -42,7 +42,7 @@ class SDMGR(SingleStageDetector): backbone, neck, bbox_head, train_cfg, test_cfg, init_cfg=init_cfg) self.visual_modality = visual_modality if visual_modality: - self.extractor = build_roi_extractor({ + self.extractor = MODELS.build({ **extractor, 'out_channels': self.backbone.base_channels }) diff --git a/mmocr/models/kie/heads/sdmgr_head.py b/mmocr/models/kie/heads/sdmgr_head.py index 9898518d..cd69b857 100644 --- a/mmocr/models/kie/heads/sdmgr_head.py +++ b/mmocr/models/kie/heads/sdmgr_head.py @@ -4,10 +4,10 @@ from mmcv.runner import BaseModule from torch import nn from torch.nn import functional as F -from mmocr.models.builder import HEADS, build_loss +from mmocr.registry import MODELS -@HEADS.register_module() +@MODELS.register_module() class SDMGRHead(BaseModule): def __init__(self, @@ -45,7 +45,7 @@ class SDMGRHead(BaseModule): [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)]) self.node_cls = nn.Linear(node_embed, num_classes) self.edge_cls = nn.Linear(edge_embed, 2) - self.loss = build_loss(loss) + self.loss = MODELS.build(loss) def forward(self, relations, texts, x=None): node_nums, char_nums = [], [] diff --git a/mmocr/models/kie/losses/sdmgr_loss.py b/mmocr/models/kie/losses/sdmgr_loss.py index dba2d12d..bf80702c 100644 --- a/mmocr/models/kie/losses/sdmgr_loss.py +++ b/mmocr/models/kie/losses/sdmgr_loss.py @@ -3,10 +3,10 @@ import torch from mmdet.models.losses import accuracy from torch import nn -from mmocr.models.builder import LOSSES +from mmocr.registry import MODELS -@LOSSES.register_module() +@MODELS.register_module() class SDMGRLoss(nn.Module): """The implementation the loss of key information extraction proposed in the paper: Spatial Dual-Modality Graph Reasoning for Key Information diff --git a/mmocr/models/ner/classifiers/ner_classifier.py b/mmocr/models/ner/classifiers/ner_classifier.py index 7fefef60..1dc8eccb 100644 --- a/mmocr/models/ner/classifiers/ner_classifier.py +++ b/mmocr/models/ner/classifiers/ner_classifier.py @@ -1,10 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import (DETECTORS, build_convertor, build_decoder, - build_encoder, build_loss) from mmocr.models.textrecog.recognizer.base import BaseRecognizer +from mmocr.registry import MODELS -@DETECTORS.register_module() +@MODELS.register_module() class NerClassifier(BaseRecognizer): """Base class for NER classifier.""" @@ -17,15 +16,15 @@ class NerClassifier(BaseRecognizer): test_cfg=None, init_cfg=None): super().__init__(init_cfg=init_cfg) - self.label_convertor = build_convertor(label_convertor) + self.label_convertor = MODELS.build(label_convertor) - self.encoder = build_encoder(encoder) + self.encoder = MODELS.build(encoder) decoder.update(num_labels=self.label_convertor.num_labels) - self.decoder = build_decoder(decoder) + self.decoder = MODELS.build(decoder) loss.update(num_labels=self.label_convertor.num_labels) - self.loss = build_loss(loss) + self.loss = MODELS.build(loss) def extract_feat(self, imgs): """Extract features from images.""" diff --git a/mmocr/models/ner/convertors/ner_convertor.py b/mmocr/models/ner/convertors/ner_convertor.py index ca7288bc..26f006c6 100644 --- a/mmocr/models/ner/convertors/ner_convertor.py +++ b/mmocr/models/ner/convertors/ner_convertor.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np -from mmocr.models.builder import CONVERTORS +from mmocr.registry import MODELS from mmocr.utils import list_from_file -@CONVERTORS.register_module() +@MODELS.register_module() class NerConvertor: """Convert between text, index and tensor for NER pipeline. diff --git a/mmocr/models/ner/decoders/fc_decoder.py b/mmocr/models/ner/decoders/fc_decoder.py index b88302f1..1eb7018e 100644 --- a/mmocr/models/ner/decoders/fc_decoder.py +++ b/mmocr/models/ner/decoders/fc_decoder.py @@ -4,10 +4,10 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.runner import BaseModule -from mmocr.models.builder import DECODERS +from mmocr.registry import MODELS -@DECODERS.register_module() +@MODELS.register_module() class FCDecoder(BaseModule): """FC Decoder class for Ner. diff --git a/mmocr/models/ner/encoders/bert_encoder.py b/mmocr/models/ner/encoders/bert_encoder.py index 24c60aae..d464c19f 100644 --- a/mmocr/models/ner/encoders/bert_encoder.py +++ b/mmocr/models/ner/encoders/bert_encoder.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmcv.runner import BaseModule -from mmocr.models.builder import ENCODERS from mmocr.models.ner.utils.bert import BertModel +from mmocr.registry import MODELS -@ENCODERS.register_module() +@MODELS.register_module() class BertEncoder(BaseModule): """Bert encoder Args: diff --git a/mmocr/models/ner/losses/masked_cross_entropy_loss.py b/mmocr/models/ner/losses/masked_cross_entropy_loss.py index 034fb295..e98d90cc 100644 --- a/mmocr/models/ner/losses/masked_cross_entropy_loss.py +++ b/mmocr/models/ner/losses/masked_cross_entropy_loss.py @@ -2,10 +2,10 @@ from torch import nn from torch.nn import CrossEntropyLoss -from mmocr.models.builder import LOSSES +from mmocr.registry import MODELS -@LOSSES.register_module() +@MODELS.register_module() class MaskedCrossEntropyLoss(nn.Module): """The implementation of masked cross entropy loss. diff --git a/mmocr/models/ner/losses/masked_focal_loss.py b/mmocr/models/ner/losses/masked_focal_loss.py index 065dc781..11850d1f 100644 --- a/mmocr/models/ner/losses/masked_focal_loss.py +++ b/mmocr/models/ner/losses/masked_focal_loss.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from torch import nn -from mmocr.models.builder import LOSSES from mmocr.models.common.losses.focal_loss import FocalLoss +from mmocr.registry import MODELS -@LOSSES.register_module() +@MODELS.register_module() class MaskedFocalLoss(nn.Module): """The implementation of masked focal loss. diff --git a/mmocr/models/textdet/dense_heads/db_head.py b/mmocr/models/textdet/dense_heads/db_head.py index b843c29f..2aa2990c 100644 --- a/mmocr/models/textdet/dense_heads/db_head.py +++ b/mmocr/models/textdet/dense_heads/db_head.py @@ -5,11 +5,11 @@ import torch import torch.nn as nn from mmcv.runner import BaseModule, Sequential -from mmocr.models.builder import HEADS +from mmocr.registry import MODELS from .head_mixin import HeadMixin -@HEADS.register_module() +@MODELS.register_module() class DBHead(HeadMixin, BaseModule): """The class for DBNet head. diff --git a/mmocr/models/textdet/dense_heads/drrg_head.py b/mmocr/models/textdet/dense_heads/drrg_head.py index e3135ee0..ecbf74e2 100644 --- a/mmocr/models/textdet/dense_heads/drrg_head.py +++ b/mmocr/models/textdet/dense_heads/drrg_head.py @@ -7,13 +7,13 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.runner import BaseModule -from mmocr.models.builder import HEADS, build_loss from mmocr.models.textdet.modules import GCN, LocalGraphs, ProposalLocalGraphs +from mmocr.registry import MODELS from mmocr.utils import check_argument from .head_mixin import HeadMixin -@HEADS.register_module() +@MODELS.register_module() class DRRGHead(HeadMixin, BaseModule): """The class for DRRG head: `Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection `_. @@ -118,7 +118,7 @@ class DRRGHead(HeadMixin, BaseModule): self.center_region_thr = center_region_thr self.center_region_area_thr = center_region_area_thr self.local_graph_thr = local_graph_thr - self.loss_module = build_loss(loss) + self.loss_module = MODELS.build(loss) self.train_cfg = train_cfg self.test_cfg = test_cfg diff --git a/mmocr/models/textdet/dense_heads/fce_head.py b/mmocr/models/textdet/dense_heads/fce_head.py index 07855578..33f9a1c0 100644 --- a/mmocr/models/textdet/dense_heads/fce_head.py +++ b/mmocr/models/textdet/dense_heads/fce_head.py @@ -5,12 +5,12 @@ import torch.nn as nn from mmcv.runner import BaseModule from mmdet.core import multi_apply -from mmocr.models.builder import HEADS +from mmocr.registry import MODELS from ..postprocess.utils import poly_nms from .head_mixin import HeadMixin -@HEADS.register_module() +@MODELS.register_module() class FCEHead(HeadMixin, BaseModule): """The class for implementing FCENet head. diff --git a/mmocr/models/textdet/dense_heads/head_mixin.py b/mmocr/models/textdet/dense_heads/head_mixin.py index c232e3be..e35a2e74 100644 --- a/mmocr/models/textdet/dense_heads/head_mixin.py +++ b/mmocr/models/textdet/dense_heads/head_mixin.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np -from mmocr.models.builder import HEADS, build_loss, build_postprocessor +from mmocr.registry import MODELS from mmocr.utils import check_argument -@HEADS.register_module() +@MODELS.register_module() class HeadMixin: """Base head class for text detection, including loss calcalation and postprocess. @@ -19,8 +19,8 @@ class HeadMixin: assert isinstance(loss, dict) assert isinstance(postprocessor, dict) - self.loss_module = build_loss(loss) - self.postprocessor = build_postprocessor(postprocessor) + self.loss_module = MODELS.build(loss) + self.postprocessor = MODELS.build(postprocessor) def resize_boundary(self, boundaries, scale_factor): """Rescale boundaries via scale_factor. diff --git a/mmocr/models/textdet/dense_heads/pan_head.py b/mmocr/models/textdet/dense_heads/pan_head.py index cd696aa3..9f75a0ec 100644 --- a/mmocr/models/textdet/dense_heads/pan_head.py +++ b/mmocr/models/textdet/dense_heads/pan_head.py @@ -6,12 +6,12 @@ import torch import torch.nn as nn from mmcv.runner import BaseModule -from mmocr.models.builder import HEADS +from mmocr.registry import MODELS from mmocr.utils import check_argument from .head_mixin import HeadMixin -@HEADS.register_module() +@MODELS.register_module() class PANHead(HeadMixin, BaseModule): """The class for PANet head. diff --git a/mmocr/models/textdet/dense_heads/pse_head.py b/mmocr/models/textdet/dense_heads/pse_head.py index 4952e0a1..63738c8d 100644 --- a/mmocr/models/textdet/dense_heads/pse_head.py +++ b/mmocr/models/textdet/dense_heads/pse_head.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import HEADS +from mmocr.registry import MODELS from . import PANHead -@HEADS.register_module() +@MODELS.register_module() class PSEHead(PANHead): """The class for PSENet head. diff --git a/mmocr/models/textdet/dense_heads/textsnake_head.py b/mmocr/models/textdet/dense_heads/textsnake_head.py index 777bd703..eb6f047c 100644 --- a/mmocr/models/textdet/dense_heads/textsnake_head.py +++ b/mmocr/models/textdet/dense_heads/textsnake_head.py @@ -4,11 +4,11 @@ import warnings import torch.nn as nn from mmcv.runner import BaseModule -from mmocr.models.builder import HEADS +from mmocr.registry import MODELS from .head_mixin import HeadMixin -@HEADS.register_module() +@MODELS.register_module() class TextSnakeHead(HeadMixin, BaseModule): """The class for TextSnake head: TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes. diff --git a/mmocr/models/textdet/detectors/dbnet.py b/mmocr/models/textdet/detectors/dbnet.py index 643e3213..83ba9e05 100644 --- a/mmocr/models/textdet/detectors/dbnet.py +++ b/mmocr/models/textdet/detectors/dbnet.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import DETECTORS +from mmocr.registry import MODELS from .single_stage_text_detector import SingleStageTextDetector from .text_detector_mixin import TextDetectorMixin -@DETECTORS.register_module() +@MODELS.register_module() class DBNet(TextDetectorMixin, SingleStageTextDetector): """The class for implementing DBNet text detector: Real-time Scene Text Detection with Differentiable Binarization. diff --git a/mmocr/models/textdet/detectors/drrg.py b/mmocr/models/textdet/detectors/drrg.py index a5bbc2b8..5e592aef 100644 --- a/mmocr/models/textdet/detectors/drrg.py +++ b/mmocr/models/textdet/detectors/drrg.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import DETECTORS +from mmocr.registry import MODELS from .single_stage_text_detector import SingleStageTextDetector from .text_detector_mixin import TextDetectorMixin -@DETECTORS.register_module() +@MODELS.register_module() class DRRG(TextDetectorMixin, SingleStageTextDetector): """The class for implementing DRRG text detector. Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection. diff --git a/mmocr/models/textdet/detectors/fcenet.py b/mmocr/models/textdet/detectors/fcenet.py index da9bcb7c..f08355b6 100644 --- a/mmocr/models/textdet/detectors/fcenet.py +++ b/mmocr/models/textdet/detectors/fcenet.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import DETECTORS +from mmocr.registry import MODELS from .single_stage_text_detector import SingleStageTextDetector from .text_detector_mixin import TextDetectorMixin -@DETECTORS.register_module() +@MODELS.register_module() class FCENet(TextDetectorMixin, SingleStageTextDetector): """The class for implementing FCENet text detector FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text diff --git a/mmocr/models/textdet/detectors/ocr_mask_rcnn.py b/mmocr/models/textdet/detectors/ocr_mask_rcnn.py index 3cfbff57..444699eb 100644 --- a/mmocr/models/textdet/detectors/ocr_mask_rcnn.py +++ b/mmocr/models/textdet/detectors/ocr_mask_rcnn.py @@ -2,11 +2,11 @@ from mmdet.models.detectors import MaskRCNN from mmocr.core import seg2boundary -from mmocr.models.builder import DETECTORS +from mmocr.registry import MODELS from .text_detector_mixin import TextDetectorMixin -@DETECTORS.register_module() +@MODELS.register_module() class OCRMaskRCNN(TextDetectorMixin, MaskRCNN): """Mask RCNN tailored for OCR.""" diff --git a/mmocr/models/textdet/detectors/panet.py b/mmocr/models/textdet/detectors/panet.py index 1c952513..3f6cba39 100644 --- a/mmocr/models/textdet/detectors/panet.py +++ b/mmocr/models/textdet/detectors/panet.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import DETECTORS +from mmocr.registry import MODELS from .single_stage_text_detector import SingleStageTextDetector from .text_detector_mixin import TextDetectorMixin -@DETECTORS.register_module() +@MODELS.register_module() class PANet(TextDetectorMixin, SingleStageTextDetector): """The class for implementing PANet text detector: diff --git a/mmocr/models/textdet/detectors/psenet.py b/mmocr/models/textdet/detectors/psenet.py index 58dabccb..8ab38073 100644 --- a/mmocr/models/textdet/detectors/psenet.py +++ b/mmocr/models/textdet/detectors/psenet.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import DETECTORS +from mmocr.registry import MODELS from .single_stage_text_detector import SingleStageTextDetector from .text_detector_mixin import TextDetectorMixin -@DETECTORS.register_module() +@MODELS.register_module() class PSENet(TextDetectorMixin, SingleStageTextDetector): """The class for implementing PSENet text detector: Shape Robust Text Detection with Progressive Scale Expansion Network. diff --git a/mmocr/models/textdet/detectors/single_stage_text_detector.py b/mmocr/models/textdet/detectors/single_stage_text_detector.py index d6d27ba2..0b0e1a33 100644 --- a/mmocr/models/textdet/detectors/single_stage_text_detector.py +++ b/mmocr/models/textdet/detectors/single_stage_text_detector.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from mmocr.models.builder import DETECTORS from mmocr.models.common.detectors import SingleStageDetector +from mmocr.registry import MODELS -@DETECTORS.register_module() +@MODELS.register_module() class SingleStageTextDetector(SingleStageDetector): """The class for implementing single stage text detector.""" diff --git a/mmocr/models/textdet/detectors/textsnake.py b/mmocr/models/textdet/detectors/textsnake.py index 1b9bc3e2..69be2139 100644 --- a/mmocr/models/textdet/detectors/textsnake.py +++ b/mmocr/models/textdet/detectors/textsnake.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import DETECTORS +from mmocr.registry import MODELS from .single_stage_text_detector import SingleStageTextDetector from .text_detector_mixin import TextDetectorMixin -@DETECTORS.register_module() +@MODELS.register_module() class TextSnake(TextDetectorMixin, SingleStageTextDetector): """The class for implementing TextSnake text detector: TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes. diff --git a/mmocr/models/textdet/losses/db_loss.py b/mmocr/models/textdet/losses/db_loss.py index 20ca2259..cc497b03 100644 --- a/mmocr/models/textdet/losses/db_loss.py +++ b/mmocr/models/textdet/losses/db_loss.py @@ -3,11 +3,11 @@ import torch import torch.nn.functional as F from torch import nn -from mmocr.models.builder import LOSSES from mmocr.models.common.losses.dice_loss import DiceLoss +from mmocr.registry import MODELS -@LOSSES.register_module() +@MODELS.register_module() class DBLoss(nn.Module): """The class for implementing DBNet loss. diff --git a/mmocr/models/textdet/losses/drrg_loss.py b/mmocr/models/textdet/losses/drrg_loss.py index a59868d9..f9388e6c 100644 --- a/mmocr/models/textdet/losses/drrg_loss.py +++ b/mmocr/models/textdet/losses/drrg_loss.py @@ -4,11 +4,11 @@ import torch.nn.functional as F from mmdet.core import BitmapMasks from torch import nn -from mmocr.models.builder import LOSSES +from mmocr.registry import MODELS from mmocr.utils import check_argument -@LOSSES.register_module() +@MODELS.register_module() class DRRGLoss(nn.Module): """The class for implementing DRRG loss. This is partially adapted from https://github.com/GXYM/DRRG licensed under the MIT license. diff --git a/mmocr/models/textdet/losses/fce_loss.py b/mmocr/models/textdet/losses/fce_loss.py index e956f10e..5272dfdc 100644 --- a/mmocr/models/textdet/losses/fce_loss.py +++ b/mmocr/models/textdet/losses/fce_loss.py @@ -5,10 +5,10 @@ import torch.nn.functional as F from mmdet.core import multi_apply from torch import nn -from mmocr.models.builder import LOSSES +from mmocr.registry import MODELS -@LOSSES.register_module() +@MODELS.register_module() class FCELoss(nn.Module): """The class for implementing FCENet loss. diff --git a/mmocr/models/textdet/losses/pan_loss.py b/mmocr/models/textdet/losses/pan_loss.py index 04f691eb..1776cc33 100644 --- a/mmocr/models/textdet/losses/pan_loss.py +++ b/mmocr/models/textdet/losses/pan_loss.py @@ -8,11 +8,11 @@ import torch.nn.functional as F from mmdet.core import BitmapMasks from torch import nn -from mmocr.models.builder import LOSSES +from mmocr.registry import MODELS from mmocr.utils import check_argument -@LOSSES.register_module() +@MODELS.register_module() class PANLoss(nn.Module): """The class for implementing PANet loss. This was partially adapted from https://github.com/WenmuZhou/PAN.pytorch. diff --git a/mmocr/models/textdet/losses/pse_loss.py b/mmocr/models/textdet/losses/pse_loss.py index 8ab1c0e1..eebaf48a 100644 --- a/mmocr/models/textdet/losses/pse_loss.py +++ b/mmocr/models/textdet/losses/pse_loss.py @@ -1,12 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmdet.core import BitmapMasks -from mmocr.models.builder import LOSSES +from mmocr.registry import MODELS from mmocr.utils import check_argument from . import PANLoss -@LOSSES.register_module() +@MODELS.register_module() class PSELoss(PANLoss): r"""The class for implementing PSENet loss. This is partially adapted from https://github.com/whai362/PSENet. diff --git a/mmocr/models/textdet/losses/textsnake_loss.py b/mmocr/models/textdet/losses/textsnake_loss.py index d36abb56..43b67b6f 100644 --- a/mmocr/models/textdet/losses/textsnake_loss.py +++ b/mmocr/models/textdet/losses/textsnake_loss.py @@ -4,11 +4,11 @@ import torch.nn.functional as F from mmdet.core import BitmapMasks from torch import nn -from mmocr.models.builder import LOSSES +from mmocr.registry import MODELS from mmocr.utils import check_argument -@LOSSES.register_module() +@MODELS.register_module() class TextSnakeLoss(nn.Module): """The class for implementing TextSnake loss. This is partially adapted from https://github.com/princewang1994/TextSnake.pytorch. diff --git a/mmocr/models/textdet/necks/fpem_ffm.py b/mmocr/models/textdet/necks/fpem_ffm.py index c98b43f1..b588dbed 100644 --- a/mmocr/models/textdet/necks/fpem_ffm.py +++ b/mmocr/models/textdet/necks/fpem_ffm.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from mmcv.runner import BaseModule, ModuleList from torch import nn -from mmocr.models.builder import NECKS +from mmocr.registry import MODELS class FPEM(BaseModule): @@ -72,7 +72,7 @@ class SeparableConv2d(BaseModule): return x -@NECKS.register_module() +@MODELS.register_module() class FPEM_FFM(BaseModule): """This code is from https://github.com/WenmuZhou/PAN.pytorch. diff --git a/mmocr/models/textdet/necks/fpn_cat.py b/mmocr/models/textdet/necks/fpn_cat.py index fa85b538..c7479247 100644 --- a/mmocr/models/textdet/necks/fpn_cat.py +++ b/mmocr/models/textdet/necks/fpn_cat.py @@ -5,10 +5,10 @@ import torch.nn.functional as F from mmcv.cnn import ConvModule from mmcv.runner import BaseModule, ModuleList, Sequential, auto_fp16 -from mmocr.models.builder import NECKS +from mmocr.registry import MODELS -@NECKS.register_module() +@MODELS.register_module() class FPNC(BaseModule): """FPN-like fusion module in Real-time Scene Text Detection with Differentiable Binarization. diff --git a/mmocr/models/textdet/necks/fpn_unet.py b/mmocr/models/textdet/necks/fpn_unet.py index c5c48604..609cace3 100644 --- a/mmocr/models/textdet/necks/fpn_unet.py +++ b/mmocr/models/textdet/necks/fpn_unet.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from mmcv.runner import BaseModule from torch import nn -from mmocr.models.builder import NECKS +from mmocr.registry import MODELS class UpBlock(BaseModule): @@ -30,7 +30,7 @@ class UpBlock(BaseModule): return x -@NECKS.register_module() +@MODELS.register_module() class FPN_UNet(BaseModule): """The class for implementing DRRG and TextSnake U-Net-like FPN. diff --git a/mmocr/models/textdet/necks/fpnf.py b/mmocr/models/textdet/necks/fpnf.py index 6138d5bf..dd623559 100644 --- a/mmocr/models/textdet/necks/fpnf.py +++ b/mmocr/models/textdet/necks/fpnf.py @@ -4,10 +4,10 @@ import torch.nn.functional as F from mmcv.cnn import ConvModule from mmcv.runner import BaseModule, ModuleList, auto_fp16 -from mmocr.models.builder import NECKS +from mmocr.registry import MODELS -@NECKS.register_module() +@MODELS.register_module() class FPNF(BaseModule): """FPN-like fusion module in Shape Robust Text Detection with Progressive Scale Expansion Network. diff --git a/mmocr/models/textdet/postprocess/db_postprocessor.py b/mmocr/models/textdet/postprocess/db_postprocessor.py index f185b633..0d760411 100644 --- a/mmocr/models/textdet/postprocess/db_postprocessor.py +++ b/mmocr/models/textdet/postprocess/db_postprocessor.py @@ -3,12 +3,12 @@ import cv2 import numpy as np from mmocr.core import points2boundary -from mmocr.models.builder import POSTPROCESSOR +from mmocr.registry import MODELS from .base_postprocessor import BasePostprocessor from .utils import box_score_fast, unclip -@POSTPROCESSOR.register_module() +@MODELS.register_module() class DBPostprocessor(BasePostprocessor): """Decoding predictions of DbNet to instances. This is partially adapted from https://github.com/MhLiao/DB. diff --git a/mmocr/models/textdet/postprocess/drrg_postprocessor.py b/mmocr/models/textdet/postprocess/drrg_postprocessor.py index ebfb17b9..b583e71f 100644 --- a/mmocr/models/textdet/postprocess/drrg_postprocessor.py +++ b/mmocr/models/textdet/postprocess/drrg_postprocessor.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import POSTPROCESSOR +from mmocr.registry import MODELS from .base_postprocessor import BasePostprocessor from .utils import (clusters2labels, comps2boundaries, connected_components, graph_propagation, remove_single) -@POSTPROCESSOR.register_module() +@MODELS.register_module() class DRRGPostprocessor(BasePostprocessor): """Merge text components and construct boundaries of text instances. diff --git a/mmocr/models/textdet/postprocess/fce_postprocessor.py b/mmocr/models/textdet/postprocess/fce_postprocessor.py index 226e3bd7..c9f075aa 100644 --- a/mmocr/models/textdet/postprocess/fce_postprocessor.py +++ b/mmocr/models/textdet/postprocess/fce_postprocessor.py @@ -2,12 +2,12 @@ import cv2 import numpy as np -from mmocr.models.builder import POSTPROCESSOR +from mmocr.registry import MODELS from .base_postprocessor import BasePostprocessor from .utils import fill_hole, fourier2poly, poly_nms -@POSTPROCESSOR.register_module() +@MODELS.register_module() class FCEPostprocessor(BasePostprocessor): """Decoding predictions of FCENet to instances. diff --git a/mmocr/models/textdet/postprocess/pan_postprocessor.py b/mmocr/models/textdet/postprocess/pan_postprocessor.py index 11271418..f1d5e01e 100644 --- a/mmocr/models/textdet/postprocess/pan_postprocessor.py +++ b/mmocr/models/textdet/postprocess/pan_postprocessor.py @@ -5,11 +5,11 @@ import torch from mmcv.ops import pixel_group from mmocr.core import points2boundary -from mmocr.models.builder import POSTPROCESSOR +from mmocr.registry import MODELS from .base_postprocessor import BasePostprocessor -@POSTPROCESSOR.register_module() +@MODELS.register_module() class PANPostprocessor(BasePostprocessor): """Convert scores to quadrangles via post processing in PANet. This is partially adapted from https://github.com/WenmuZhou/PAN.pytorch. diff --git a/mmocr/models/textdet/postprocess/pse_postprocessor.py b/mmocr/models/textdet/postprocess/pse_postprocessor.py index 4cf53661..3c47e9bd 100644 --- a/mmocr/models/textdet/postprocess/pse_postprocessor.py +++ b/mmocr/models/textdet/postprocess/pse_postprocessor.py @@ -6,11 +6,11 @@ import torch from mmcv.ops import contour_expand from mmocr.core import points2boundary -from mmocr.models.builder import POSTPROCESSOR +from mmocr.registry import MODELS from .base_postprocessor import BasePostprocessor -@POSTPROCESSOR.register_module() +@MODELS.register_module() class PSEPostprocessor(BasePostprocessor): """Decoding predictions of PSENet to instances. This is partially adapted from https://github.com/whai362/PSENet. diff --git a/mmocr/models/textdet/postprocess/textsnake_postprocessor.py b/mmocr/models/textdet/postprocess/textsnake_postprocessor.py index 3e37154c..ba875fc2 100644 --- a/mmocr/models/textdet/postprocess/textsnake_postprocessor.py +++ b/mmocr/models/textdet/postprocess/textsnake_postprocessor.py @@ -5,12 +5,12 @@ import numpy as np import torch from skimage.morphology import skeletonize -from mmocr.models.builder import POSTPROCESSOR +from mmocr.registry import MODELS from .base_postprocessor import BasePostprocessor from .utils import centralize, fill_hole, merge_disks -@POSTPROCESSOR.register_module() +@MODELS.register_module() class TextSnakePostprocessor(BasePostprocessor): """Decoding predictions of TextSnake to instances. This was partially adapted from https://github.com/princewang1994/TextSnake.pytorch. diff --git a/mmocr/models/textrecog/backbones/nrtr_modality_transformer.py b/mmocr/models/textrecog/backbones/nrtr_modality_transformer.py index a514ffdf..5ef99132 100644 --- a/mmocr/models/textrecog/backbones/nrtr_modality_transformer.py +++ b/mmocr/models/textrecog/backbones/nrtr_modality_transformer.py @@ -2,10 +2,10 @@ import torch.nn as nn from mmcv.runner import BaseModule -from mmocr.models.builder import BACKBONES +from mmocr.registry import MODELS -@BACKBONES.register_module() +@MODELS.register_module() class NRTRModalityTransform(BaseModule): def __init__(self, diff --git a/mmocr/models/textrecog/backbones/resnet.py b/mmocr/models/textrecog/backbones/resnet.py index 170e8b7d..445be51f 100644 --- a/mmocr/models/textrecog/backbones/resnet.py +++ b/mmocr/models/textrecog/backbones/resnet.py @@ -3,11 +3,11 @@ from mmcv.cnn import ConvModule, build_plugin_layer from mmcv.runner import BaseModule, Sequential import mmocr.utils as utils -from mmocr.models.builder import BACKBONES from mmocr.models.textrecog.layers import BasicBlock +from mmocr.registry import MODELS -@BACKBONES.register_module() +@MODELS.register_module() class ResNet(BaseModule): """ Args: diff --git a/mmocr/models/textrecog/backbones/resnet31_ocr.py b/mmocr/models/textrecog/backbones/resnet31_ocr.py index bf83546f..3ef7d08d 100644 --- a/mmocr/models/textrecog/backbones/resnet31_ocr.py +++ b/mmocr/models/textrecog/backbones/resnet31_ocr.py @@ -3,11 +3,11 @@ import torch.nn as nn from mmcv.runner import BaseModule, Sequential import mmocr.utils as utils -from mmocr.models.builder import BACKBONES from mmocr.models.textrecog.layers import BasicBlock +from mmocr.registry import MODELS -@BACKBONES.register_module() +@MODELS.register_module() class ResNet31OCR(BaseModule): """Implement ResNet backbone for text recognition, modified from `ResNet `_ diff --git a/mmocr/models/textrecog/backbones/resnet_abi.py b/mmocr/models/textrecog/backbones/resnet_abi.py index 026a786f..08a6788b 100644 --- a/mmocr/models/textrecog/backbones/resnet_abi.py +++ b/mmocr/models/textrecog/backbones/resnet_abi.py @@ -3,11 +3,11 @@ import torch.nn as nn from mmcv.runner import BaseModule, Sequential import mmocr.utils as utils -from mmocr.models.builder import BACKBONES from mmocr.models.textrecog.layers import BasicBlock +from mmocr.registry import MODELS -@BACKBONES.register_module() +@MODELS.register_module() class ResNetABI(BaseModule): """Implement ResNet backbone for text recognition, modified from `ResNet. diff --git a/mmocr/models/textrecog/backbones/shallow_cnn.py b/mmocr/models/textrecog/backbones/shallow_cnn.py index f2cd89a6..c32e3615 100644 --- a/mmocr/models/textrecog/backbones/shallow_cnn.py +++ b/mmocr/models/textrecog/backbones/shallow_cnn.py @@ -3,10 +3,10 @@ import torch.nn as nn from mmcv.cnn import ConvModule from mmcv.runner import BaseModule -from mmocr.models.builder import BACKBONES +from mmocr.registry import MODELS -@BACKBONES.register_module() +@MODELS.register_module() class ShallowCNN(BaseModule): """Implement Shallow CNN block for SATRN. diff --git a/mmocr/models/textrecog/backbones/very_deep_vgg.py b/mmocr/models/textrecog/backbones/very_deep_vgg.py index 2831f2b3..a75610f4 100644 --- a/mmocr/models/textrecog/backbones/very_deep_vgg.py +++ b/mmocr/models/textrecog/backbones/very_deep_vgg.py @@ -2,10 +2,10 @@ import torch.nn as nn from mmcv.runner import BaseModule, Sequential -from mmocr.models.builder import BACKBONES +from mmocr.registry import MODELS -@BACKBONES.register_module() +@MODELS.register_module() class VeryDeepVgg(BaseModule): """Implement VGG-VeryDeep backbone for text recognition, modified from `VGG-VeryDeep `_ diff --git a/mmocr/models/textrecog/convertors/abi.py b/mmocr/models/textrecog/convertors/abi.py index e9243992..591c9c92 100644 --- a/mmocr/models/textrecog/convertors/abi.py +++ b/mmocr/models/textrecog/convertors/abi.py @@ -2,11 +2,11 @@ import torch import mmocr.utils as utils -from mmocr.models.builder import CONVERTORS +from mmocr.registry import MODELS from .attn import AttnConvertor -@CONVERTORS.register_module() +@MODELS.register_module() class ABIConvertor(AttnConvertor): """Convert between text, index and tensor for encoder-decoder based pipeline. Modified from AttnConvertor to get closer to ABINet's original diff --git a/mmocr/models/textrecog/convertors/attn.py b/mmocr/models/textrecog/convertors/attn.py index e17411c4..a545d651 100644 --- a/mmocr/models/textrecog/convertors/attn.py +++ b/mmocr/models/textrecog/convertors/attn.py @@ -2,11 +2,11 @@ import torch import mmocr.utils as utils -from mmocr.models.builder import CONVERTORS +from mmocr.registry import MODELS from .base import BaseConvertor -@CONVERTORS.register_module() +@MODELS.register_module() class AttnConvertor(BaseConvertor): """Convert between text, index and tensor for encoder-decoder based pipeline. diff --git a/mmocr/models/textrecog/convertors/base.py b/mmocr/models/textrecog/convertors/base.py index 83b1ab76..71d65eab 100644 --- a/mmocr/models/textrecog/convertors/base.py +++ b/mmocr/models/textrecog/convertors/base.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import CONVERTORS +from mmocr.registry import MODELS from mmocr.utils import list_from_file -@CONVERTORS.register_module() +@MODELS.register_module() class BaseConvertor: """Convert between text, index and tensor for text recognize pipeline. diff --git a/mmocr/models/textrecog/convertors/ctc.py b/mmocr/models/textrecog/convertors/ctc.py index ec4d037d..84576489 100644 --- a/mmocr/models/textrecog/convertors/ctc.py +++ b/mmocr/models/textrecog/convertors/ctc.py @@ -5,11 +5,11 @@ import torch import torch.nn.functional as F import mmocr.utils as utils -from mmocr.models.builder import CONVERTORS +from mmocr.registry import MODELS from .base import BaseConvertor -@CONVERTORS.register_module() +@MODELS.register_module() class CTCConvertor(BaseConvertor): """Convert between text, index and tensor for CTC loss-based pipeline. diff --git a/mmocr/models/textrecog/convertors/seg.py b/mmocr/models/textrecog/convertors/seg.py index 5bc115d1..c86c3e48 100644 --- a/mmocr/models/textrecog/convertors/seg.py +++ b/mmocr/models/textrecog/convertors/seg.py @@ -4,11 +4,11 @@ import numpy as np import torch import mmocr.utils as utils -from mmocr.models.builder import CONVERTORS +from mmocr.registry import MODELS from .base import BaseConvertor -@CONVERTORS.register_module() +@MODELS.register_module() class SegConvertor(BaseConvertor): """Convert between text, index and tensor for segmentation based pipeline. diff --git a/mmocr/models/textrecog/decoders/abinet_language_decoder.py b/mmocr/models/textrecog/decoders/abinet_language_decoder.py index 4c4ce96e..7ec9564b 100644 --- a/mmocr/models/textrecog/decoders/abinet_language_decoder.py +++ b/mmocr/models/textrecog/decoders/abinet_language_decoder.py @@ -6,12 +6,12 @@ import torch.nn as nn from mmcv.cnn.bricks.transformer import BaseTransformerLayer from mmcv.runner import ModuleList -from mmocr.models.builder import DECODERS from mmocr.models.common.modules import PositionalEncoding +from mmocr.registry import MODELS from .base_decoder import BaseDecoder -@DECODERS.register_module() +@MODELS.register_module() class ABILanguageDecoder(BaseDecoder): r"""Transformer-based language model responsible for spell correction. Implementation of language model of \ diff --git a/mmocr/models/textrecog/decoders/abinet_vision_decoder.py b/mmocr/models/textrecog/decoders/abinet_vision_decoder.py index 7c565bd9..ca8638bd 100644 --- a/mmocr/models/textrecog/decoders/abinet_vision_decoder.py +++ b/mmocr/models/textrecog/decoders/abinet_vision_decoder.py @@ -3,12 +3,12 @@ import torch import torch.nn as nn from mmcv.cnn import ConvModule -from mmocr.models.builder import DECODERS from mmocr.models.common.modules import PositionalEncoding +from mmocr.registry import MODELS from .base_decoder import BaseDecoder -@DECODERS.register_module() +@MODELS.register_module() class ABIVisionDecoder(BaseDecoder): """Converts visual features into text characters. diff --git a/mmocr/models/textrecog/decoders/base_decoder.py b/mmocr/models/textrecog/decoders/base_decoder.py index 09e2db88..ba640ddd 100644 --- a/mmocr/models/textrecog/decoders/base_decoder.py +++ b/mmocr/models/textrecog/decoders/base_decoder.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmcv.runner import BaseModule -from mmocr.models.builder import DECODERS +from mmocr.registry import MODELS -@DECODERS.register_module() +@MODELS.register_module() class BaseDecoder(BaseModule): """Base decoder class for text recognition.""" diff --git a/mmocr/models/textrecog/decoders/crnn_decoder.py b/mmocr/models/textrecog/decoders/crnn_decoder.py index 9f40f4e2..82322133 100644 --- a/mmocr/models/textrecog/decoders/crnn_decoder.py +++ b/mmocr/models/textrecog/decoders/crnn_decoder.py @@ -2,12 +2,12 @@ import torch.nn as nn from mmcv.runner import Sequential -from mmocr.models.builder import DECODERS from mmocr.models.textrecog.layers import BidirectionalLSTM +from mmocr.registry import MODELS from .base_decoder import BaseDecoder -@DECODERS.register_module() +@MODELS.register_module() class CRNNDecoder(BaseDecoder): """Decoder for CRNN. diff --git a/mmocr/models/textrecog/decoders/master_decoder.py b/mmocr/models/textrecog/decoders/master_decoder.py index 297d52f6..dc2e1051 100644 --- a/mmocr/models/textrecog/decoders/master_decoder.py +++ b/mmocr/models/textrecog/decoders/master_decoder.py @@ -8,8 +8,8 @@ import torch.nn.functional as F from mmcv.cnn.bricks.transformer import BaseTransformerLayer from mmcv.runner import ModuleList -from mmocr.models.builder import DECODERS from mmocr.models.common.modules import PositionalEncoding +from mmocr.registry import MODELS from .base_decoder import BaseDecoder @@ -30,7 +30,7 @@ class Embeddings(nn.Module): return self.lut(x) * math.sqrt(self.d_model) -@DECODERS.register_module() +@MODELS.register_module() class MasterDecoder(BaseDecoder): """Decoder module in `MASTER `_. diff --git a/mmocr/models/textrecog/decoders/nrtr_decoder.py b/mmocr/models/textrecog/decoders/nrtr_decoder.py index c21c0248..9d19844d 100644 --- a/mmocr/models/textrecog/decoders/nrtr_decoder.py +++ b/mmocr/models/textrecog/decoders/nrtr_decoder.py @@ -6,12 +6,12 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.runner import ModuleList -from mmocr.models.builder import DECODERS from mmocr.models.common import PositionalEncoding, TFDecoderLayer +from mmocr.registry import MODELS from .base_decoder import BaseDecoder -@DECODERS.register_module() +@MODELS.register_module() class NRTRDecoder(BaseDecoder): """Transformer Decoder block with self attention mechanism. diff --git a/mmocr/models/textrecog/decoders/position_attention_decoder.py b/mmocr/models/textrecog/decoders/position_attention_decoder.py index 37ab7389..80ff6c55 100644 --- a/mmocr/models/textrecog/decoders/position_attention_decoder.py +++ b/mmocr/models/textrecog/decoders/position_attention_decoder.py @@ -4,13 +4,13 @@ import math import torch import torch.nn as nn -from mmocr.models.builder import DECODERS from mmocr.models.textrecog.layers import (DotProductAttentionLayer, PositionAwareLayer) +from mmocr.registry import MODELS from .base_decoder import BaseDecoder -@DECODERS.register_module() +@MODELS.register_module() class PositionAttentionDecoder(BaseDecoder): """Position attention decoder for RobustScanner. diff --git a/mmocr/models/textrecog/decoders/robust_scanner_decoder.py b/mmocr/models/textrecog/decoders/robust_scanner_decoder.py index 0e2bbd47..451d01d7 100644 --- a/mmocr/models/textrecog/decoders/robust_scanner_decoder.py +++ b/mmocr/models/textrecog/decoders/robust_scanner_decoder.py @@ -3,12 +3,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from mmocr.models.builder import DECODERS, build_decoder from mmocr.models.textrecog.layers import RobustScannerFusionLayer +from mmocr.registry import MODELS from .base_decoder import BaseDecoder -@DECODERS.register_module() +@MODELS.register_module() class RobustScannerDecoder(BaseDecoder): """Decoder for RobustScanner. @@ -72,7 +72,7 @@ class RobustScannerDecoder(BaseDecoder): hybrid_decoder.update(encode_value=self.encode_value) hybrid_decoder.update(return_feature=True) - self.hybrid_decoder = build_decoder(hybrid_decoder) + self.hybrid_decoder = MODELS.build(hybrid_decoder) # init position decoder position_decoder.update(num_classes=self.num_classes) @@ -83,7 +83,7 @@ class RobustScannerDecoder(BaseDecoder): position_decoder.update(encode_value=self.encode_value) position_decoder.update(return_feature=True) - self.position_decoder = build_decoder(position_decoder) + self.position_decoder = MODELS.build(position_decoder) self.fusion_module = RobustScannerFusionLayer( self.dim_model if encode_value else dim_input) diff --git a/mmocr/models/textrecog/decoders/sar_decoder.py b/mmocr/models/textrecog/decoders/sar_decoder.py index ee79e8c0..597b2caf 100755 --- a/mmocr/models/textrecog/decoders/sar_decoder.py +++ b/mmocr/models/textrecog/decoders/sar_decoder.py @@ -6,11 +6,11 @@ import torch.nn as nn import torch.nn.functional as F import mmocr.utils as utils -from mmocr.models.builder import DECODERS +from mmocr.registry import MODELS from .base_decoder import BaseDecoder -@DECODERS.register_module() +@MODELS.register_module() class ParallelSARDecoder(BaseDecoder): """Implementation Parallel Decoder module in `SAR. @@ -255,7 +255,7 @@ class ParallelSARDecoder(BaseDecoder): return outputs -@DECODERS.register_module() +@MODELS.register_module() class SequentialSARDecoder(BaseDecoder): """Implementation Sequential Decoder module in `SAR. diff --git a/mmocr/models/textrecog/decoders/sar_decoder_with_bs.py b/mmocr/models/textrecog/decoders/sar_decoder_with_bs.py index d00e385d..495b72fb 100755 --- a/mmocr/models/textrecog/decoders/sar_decoder_with_bs.py +++ b/mmocr/models/textrecog/decoders/sar_decoder_with_bs.py @@ -5,7 +5,7 @@ import torch import torch.nn.functional as F import mmocr.utils as utils -from mmocr.models.builder import DECODERS +from mmocr.registry import MODELS from . import ParallelSARDecoder @@ -31,7 +31,7 @@ class DecodeNode: return accu_score -@DECODERS.register_module() +@MODELS.register_module() class ParallelSARDecoderWithBS(ParallelSARDecoder): """Parallel Decoder module with beam-search in SAR. diff --git a/mmocr/models/textrecog/decoders/sequence_attention_decoder.py b/mmocr/models/textrecog/decoders/sequence_attention_decoder.py index a6a10f72..9075a653 100644 --- a/mmocr/models/textrecog/decoders/sequence_attention_decoder.py +++ b/mmocr/models/textrecog/decoders/sequence_attention_decoder.py @@ -5,12 +5,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from mmocr.models.builder import DECODERS from mmocr.models.textrecog.layers import DotProductAttentionLayer +from mmocr.registry import MODELS from .base_decoder import BaseDecoder -@DECODERS.register_module() +@MODELS.register_module() class SequenceAttentionDecoder(BaseDecoder): """Sequence attention decoder for RobustScanner. diff --git a/mmocr/models/textrecog/encoders/abinet_vision_model.py b/mmocr/models/textrecog/encoders/abinet_vision_model.py index 5c19c8ef..188063d0 100644 --- a/mmocr/models/textrecog/encoders/abinet_vision_model.py +++ b/mmocr/models/textrecog/encoders/abinet_vision_model.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import ENCODERS, build_decoder, build_encoder +from mmocr.registry import MODELS from .base_encoder import BaseEncoder -@ENCODERS.register_module() +@MODELS.register_module() class ABIVisionModel(BaseEncoder): """A wrapper of visual feature encoder and language token decoder that converts visual features into text tokens. @@ -23,8 +23,8 @@ class ABIVisionModel(BaseEncoder): init_cfg=dict(type='Xavier', layer='Conv2d'), **kwargs): super().__init__(init_cfg=init_cfg) - self.encoder = build_encoder(encoder) - self.decoder = build_decoder(decoder) + self.encoder = MODELS.build(encoder) + self.decoder = MODELS.build(decoder) def forward(self, feat, img_metas=None): """ diff --git a/mmocr/models/textrecog/encoders/base_encoder.py b/mmocr/models/textrecog/encoders/base_encoder.py index 726c78a8..3078fd5f 100644 --- a/mmocr/models/textrecog/encoders/base_encoder.py +++ b/mmocr/models/textrecog/encoders/base_encoder.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmcv.runner import BaseModule -from mmocr.models.builder import ENCODERS +from mmocr.registry import MODELS -@ENCODERS.register_module() +@MODELS.register_module() class BaseEncoder(BaseModule): """Base Encoder class for text recognition.""" diff --git a/mmocr/models/textrecog/encoders/channel_reduction_encoder.py b/mmocr/models/textrecog/encoders/channel_reduction_encoder.py index 0e957f85..790c57f2 100644 --- a/mmocr/models/textrecog/encoders/channel_reduction_encoder.py +++ b/mmocr/models/textrecog/encoders/channel_reduction_encoder.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn -from mmocr.models.builder import ENCODERS +from mmocr.registry import MODELS from .base_encoder import BaseEncoder -@ENCODERS.register_module() +@MODELS.register_module() class ChannelReductionEncoder(BaseEncoder): """Change the channel number with a one by one convoluational layer. diff --git a/mmocr/models/textrecog/encoders/nrtr_encoder.py b/mmocr/models/textrecog/encoders/nrtr_encoder.py index 72b229f0..81214a2c 100644 --- a/mmocr/models/textrecog/encoders/nrtr_encoder.py +++ b/mmocr/models/textrecog/encoders/nrtr_encoder.py @@ -4,12 +4,12 @@ import math import torch.nn as nn from mmcv.runner import ModuleList -from mmocr.models.builder import ENCODERS from mmocr.models.common import TFEncoderLayer +from mmocr.registry import MODELS from .base_encoder import BaseEncoder -@ENCODERS.register_module() +@MODELS.register_module() class NRTREncoder(BaseEncoder): """Transformer Encoder block with self attention mechanism. diff --git a/mmocr/models/textrecog/encoders/sar_encoder.py b/mmocr/models/textrecog/encoders/sar_encoder.py index d2f0a8e1..1e48c8ad 100644 --- a/mmocr/models/textrecog/encoders/sar_encoder.py +++ b/mmocr/models/textrecog/encoders/sar_encoder.py @@ -6,11 +6,11 @@ import torch.nn as nn import torch.nn.functional as F import mmocr.utils as utils -from mmocr.models.builder import ENCODERS +from mmocr.registry import MODELS from .base_encoder import BaseEncoder -@ENCODERS.register_module() +@MODELS.register_module() class SAREncoder(BaseEncoder): """Implementation of encoder module in `SAR. diff --git a/mmocr/models/textrecog/encoders/satrn_encoder.py b/mmocr/models/textrecog/encoders/satrn_encoder.py index 00af0826..59944659 100644 --- a/mmocr/models/textrecog/encoders/satrn_encoder.py +++ b/mmocr/models/textrecog/encoders/satrn_encoder.py @@ -4,13 +4,13 @@ import math import torch.nn as nn from mmcv.runner import ModuleList -from mmocr.models.builder import ENCODERS from mmocr.models.textrecog.layers import (Adaptive2DPositionalEncoding, SatrnEncoderLayer) +from mmocr.registry import MODELS from .base_encoder import BaseEncoder -@ENCODERS.register_module() +@MODELS.register_module() class SatrnEncoder(BaseEncoder): """Implement encoder for SATRN, see `SATRN. diff --git a/mmocr/models/textrecog/encoders/transformer.py b/mmocr/models/textrecog/encoders/transformer.py index 887b4ef8..08b29771 100644 --- a/mmocr/models/textrecog/encoders/transformer.py +++ b/mmocr/models/textrecog/encoders/transformer.py @@ -4,11 +4,11 @@ import copy from mmcv.cnn.bricks.transformer import BaseTransformerLayer from mmcv.runner import BaseModule, ModuleList -from mmocr.models.builder import ENCODERS from mmocr.models.common.modules import PositionalEncoding +from mmocr.registry import MODELS -@ENCODERS.register_module() +@MODELS.register_module() class TransformerEncoder(BaseModule): """Implement transformer encoder for text recognition, modified from ``. diff --git a/mmocr/models/textrecog/fusers/abi_fuser.py b/mmocr/models/textrecog/fusers/abi_fuser.py index 310cf6f0..8277a9aa 100644 --- a/mmocr/models/textrecog/fusers/abi_fuser.py +++ b/mmocr/models/textrecog/fusers/abi_fuser.py @@ -3,10 +3,10 @@ import torch import torch.nn as nn from mmcv.runner import BaseModule -from mmocr.models.builder import FUSERS +from mmocr.registry import MODELS -@FUSERS.register_module() +@MODELS.register_module() class ABIFuser(BaseModule): """Mix and align visual feature and linguistic feature Implementation of language model of `ABINet `_. diff --git a/mmocr/models/textrecog/heads/seg_head.py b/mmocr/models/textrecog/heads/seg_head.py index e8686db8..3ee4790f 100644 --- a/mmocr/models/textrecog/heads/seg_head.py +++ b/mmocr/models/textrecog/heads/seg_head.py @@ -4,10 +4,10 @@ from mmcv.cnn import ConvModule from mmcv.runner import BaseModule from torch import nn -from mmocr.models.builder import HEADS +from mmocr.registry import MODELS -@HEADS.register_module() +@MODELS.register_module() class SegHead(BaseModule): """Head for segmentation based text recognition. diff --git a/mmocr/models/textrecog/losses/ce_loss.py b/mmocr/models/textrecog/losses/ce_loss.py index 38883bbd..58b1c87a 100644 --- a/mmocr/models/textrecog/losses/ce_loss.py +++ b/mmocr/models/textrecog/losses/ce_loss.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn -from mmocr.models.builder import LOSSES +from mmocr.registry import MODELS -@LOSSES.register_module() +@MODELS.register_module() class CELoss(nn.Module): """Implementation of loss module for encoder-decoder based text recognition method with CrossEntropy loss. @@ -63,7 +63,7 @@ class CELoss(nn.Module): return losses -@LOSSES.register_module() +@MODELS.register_module() class SARLoss(CELoss): """Implementation of loss module in `SAR. @@ -95,7 +95,7 @@ class SARLoss(CELoss): return outputs, targets -@LOSSES.register_module() +@MODELS.register_module() class TFLoss(CELoss): """Implementation of loss module for transformer. diff --git a/mmocr/models/textrecog/losses/ctc_loss.py b/mmocr/models/textrecog/losses/ctc_loss.py index 24c6390b..dbf8a4ab 100644 --- a/mmocr/models/textrecog/losses/ctc_loss.py +++ b/mmocr/models/textrecog/losses/ctc_loss.py @@ -4,10 +4,10 @@ import math import torch import torch.nn as nn -from mmocr.models.builder import LOSSES +from mmocr.registry import MODELS -@LOSSES.register_module() +@MODELS.register_module() class CTCLoss(nn.Module): """Implementation of loss module for CTC-loss based text recognition. diff --git a/mmocr/models/textrecog/losses/mix_loss.py b/mmocr/models/textrecog/losses/mix_loss.py index e7f05f45..64de3788 100644 --- a/mmocr/models/textrecog/losses/mix_loss.py +++ b/mmocr/models/textrecog/losses/mix_loss.py @@ -3,10 +3,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from mmocr.models.builder import LOSSES +from mmocr.registry import MODELS -@LOSSES.register_module() +@MODELS.register_module() class ABILoss(nn.Module): """Implementation of ABINet multiloss that allows mixing different types of losses with weights. diff --git a/mmocr/models/textrecog/losses/seg_loss.py b/mmocr/models/textrecog/losses/seg_loss.py index 5adc2873..a79725d4 100644 --- a/mmocr/models/textrecog/losses/seg_loss.py +++ b/mmocr/models/textrecog/losses/seg_loss.py @@ -3,10 +3,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from mmocr.models.builder import LOSSES +from mmocr.registry import MODELS -@LOSSES.register_module() +@MODELS.register_module() class SegLoss(nn.Module): """Implementation of loss module for segmentation based text recognition method. diff --git a/mmocr/models/textrecog/necks/fpn_ocr.py b/mmocr/models/textrecog/necks/fpn_ocr.py index e1a6aae1..e6874e46 100644 --- a/mmocr/models/textrecog/necks/fpn_ocr.py +++ b/mmocr/models/textrecog/necks/fpn_ocr.py @@ -4,10 +4,10 @@ import torch.nn.functional as F from mmcv.cnn import ConvModule from mmcv.runner import BaseModule, ModuleList -from mmocr.models.builder import NECKS +from mmocr.registry import MODELS -@NECKS.register_module() +@MODELS.register_module() class FPNOCR(BaseModule): """FPN-like Network for segmentation based text recognition. diff --git a/mmocr/models/textrecog/preprocessor/base_preprocessor.py b/mmocr/models/textrecog/preprocessor/base_preprocessor.py index ddd4a8f7..bf6a6520 100644 --- a/mmocr/models/textrecog/preprocessor/base_preprocessor.py +++ b/mmocr/models/textrecog/preprocessor/base_preprocessor.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmcv.runner import BaseModule -from mmocr.models.builder import PREPROCESSOR +from mmocr.registry import MODELS -@PREPROCESSOR.register_module() +@MODELS.register_module() class BasePreprocessor(BaseModule): """Base Preprocessor class for text recognition.""" diff --git a/mmocr/models/textrecog/preprocessor/tps_preprocessor.py b/mmocr/models/textrecog/preprocessor/tps_preprocessor.py index 44c332db..e34c28cc 100644 --- a/mmocr/models/textrecog/preprocessor/tps_preprocessor.py +++ b/mmocr/models/textrecog/preprocessor/tps_preprocessor.py @@ -17,11 +17,11 @@ import torch import torch.nn as nn import torch.nn.functional as F -from mmocr.models.builder import PREPROCESSOR +from mmocr.registry import MODELS from .base_preprocessor import BasePreprocessor -@PREPROCESSOR.register_module() +@MODELS.register_module() class TPSPreprocessor(BasePreprocessor): """Rectification Network of RARE, namely TPS based STN in https://arxiv.org/pdf/1603.03915.pdf. diff --git a/mmocr/models/textrecog/recognizer/abinet.py b/mmocr/models/textrecog/recognizer/abinet.py index 43cd9d8c..40084f7f 100644 --- a/mmocr/models/textrecog/recognizer/abinet.py +++ b/mmocr/models/textrecog/recognizer/abinet.py @@ -3,13 +3,11 @@ import warnings import torch -from mmocr.models.builder import (RECOGNIZERS, build_backbone, build_convertor, - build_decoder, build_encoder, build_fuser, - build_loss, build_preprocessor) +from mmocr.registry import MODELS from .encode_decode_recognizer import EncodeDecodeRecognizer -@RECOGNIZERS.register_module() +@MODELS.register_module() class ABINet(EncodeDecodeRecognizer): """Implementation of `Read Like Humans: Autonomous, Bidirectional and Iterative LanguageModeling for Scene Text Recognition. @@ -36,21 +34,21 @@ class ABINet(EncodeDecodeRecognizer): # Label convertor (str2tensor, tensor2str) assert label_convertor is not None label_convertor.update(max_seq_len=max_seq_len) - self.label_convertor = build_convertor(label_convertor) + self.label_convertor = MODELS.build(label_convertor) # Preprocessor module, e.g., TPS self.preprocessor = None if preprocessor is not None: - self.preprocessor = build_preprocessor(preprocessor) + self.preprocessor = MODELS.build(preprocessor) # Backbone assert backbone is not None - self.backbone = build_backbone(backbone) + self.backbone = MODELS.build(backbone) # Encoder module self.encoder = None if encoder is not None: - self.encoder = build_encoder(encoder) + self.encoder = MODELS.build(encoder) # Decoder module self.decoder = None @@ -59,11 +57,11 @@ class ABINet(EncodeDecodeRecognizer): decoder.update(start_idx=self.label_convertor.start_idx) decoder.update(padding_idx=self.label_convertor.padding_idx) decoder.update(max_seq_len=max_seq_len) - self.decoder = build_decoder(decoder) + self.decoder = MODELS.build(decoder) # Loss assert loss is not None - self.loss = build_loss(loss) + self.loss = MODELS.build(loss) self.train_cfg = train_cfg self.test_cfg = test_cfg @@ -78,7 +76,7 @@ class ABINet(EncodeDecodeRecognizer): self.fuser = None if fuser is not None: - self.fuser = build_fuser(fuser) + self.fuser = MODELS.build(fuser) def forward_train(self, img, img_metas): """ diff --git a/mmocr/models/textrecog/recognizer/crnn.py b/mmocr/models/textrecog/recognizer/crnn.py index d4ab90b9..f27375f2 100644 --- a/mmocr/models/textrecog/recognizer/crnn.py +++ b/mmocr/models/textrecog/recognizer/crnn.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import RECOGNIZERS +from mmocr.registry import MODELS from .encode_decode_recognizer import EncodeDecodeRecognizer -@RECOGNIZERS.register_module() +@MODELS.register_module() class CRNNNet(EncodeDecodeRecognizer): """CTC-loss based recognizer.""" diff --git a/mmocr/models/textrecog/recognizer/encode_decode_recognizer.py b/mmocr/models/textrecog/recognizer/encode_decode_recognizer.py index f219a857..52614444 100644 --- a/mmocr/models/textrecog/recognizer/encode_decode_recognizer.py +++ b/mmocr/models/textrecog/recognizer/encode_decode_recognizer.py @@ -3,13 +3,11 @@ import warnings import torch -from mmocr.models.builder import (RECOGNIZERS, build_backbone, build_convertor, - build_decoder, build_encoder, build_loss, - build_preprocessor) +from mmocr.registry import MODELS from .base import BaseRecognizer -@RECOGNIZERS.register_module() +@MODELS.register_module() class EncodeDecodeRecognizer(BaseRecognizer): """Base class for encode-decode recognizer.""" @@ -31,21 +29,21 @@ class EncodeDecodeRecognizer(BaseRecognizer): # Label convertor (str2tensor, tensor2str) assert label_convertor is not None label_convertor.update(max_seq_len=max_seq_len) - self.label_convertor = build_convertor(label_convertor) + self.label_convertor = MODELS.build(label_convertor) # Preprocessor module, e.g., TPS self.preprocessor = None if preprocessor is not None: - self.preprocessor = build_preprocessor(preprocessor) + self.preprocessor = MODELS.build(preprocessor) # Backbone assert backbone is not None - self.backbone = build_backbone(backbone) + self.backbone = MODELS.build(backbone) # Encoder module self.encoder = None if encoder is not None: - self.encoder = build_encoder(encoder) + self.encoder = MODELS.build(encoder) # Decoder module assert decoder is not None @@ -53,12 +51,12 @@ class EncodeDecodeRecognizer(BaseRecognizer): decoder.update(start_idx=self.label_convertor.start_idx) decoder.update(padding_idx=self.label_convertor.padding_idx) decoder.update(max_seq_len=max_seq_len) - self.decoder = build_decoder(decoder) + self.decoder = MODELS.build(decoder) # Loss assert loss is not None loss.update(ignore_index=self.label_convertor.padding_idx) - self.loss = build_loss(loss) + self.loss = MODELS.build(loss) self.train_cfg = train_cfg self.test_cfg = test_cfg diff --git a/mmocr/models/textrecog/recognizer/master.py b/mmocr/models/textrecog/recognizer/master.py index ff616e7a..22049b3c 100644 --- a/mmocr/models/textrecog/recognizer/master.py +++ b/mmocr/models/textrecog/recognizer/master.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import DETECTORS +from mmocr.registry import MODELS from .encode_decode_recognizer import EncodeDecodeRecognizer -@DETECTORS.register_module() +@MODELS.register_module() class MASTER(EncodeDecodeRecognizer): """Implementation of `MASTER `_""" diff --git a/mmocr/models/textrecog/recognizer/nrtr.py b/mmocr/models/textrecog/recognizer/nrtr.py index 36096bed..de499b05 100644 --- a/mmocr/models/textrecog/recognizer/nrtr.py +++ b/mmocr/models/textrecog/recognizer/nrtr.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import RECOGNIZERS +from mmocr.registry import MODELS from .encode_decode_recognizer import EncodeDecodeRecognizer -@RECOGNIZERS.register_module() +@MODELS.register_module() class NRTR(EncodeDecodeRecognizer): """Implementation of `NRTR `_""" diff --git a/mmocr/models/textrecog/recognizer/robust_scanner.py b/mmocr/models/textrecog/recognizer/robust_scanner.py index 666be91e..0b2a404e 100644 --- a/mmocr/models/textrecog/recognizer/robust_scanner.py +++ b/mmocr/models/textrecog/recognizer/robust_scanner.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import RECOGNIZERS +from mmocr.registry import MODELS from .encode_decode_recognizer import EncodeDecodeRecognizer -@RECOGNIZERS.register_module() +@MODELS.register_module() class RobustScanner(EncodeDecodeRecognizer): """Implementation of `RobustScanner. diff --git a/mmocr/models/textrecog/recognizer/sar.py b/mmocr/models/textrecog/recognizer/sar.py index 3f84cd00..867f2ac7 100644 --- a/mmocr/models/textrecog/recognizer/sar.py +++ b/mmocr/models/textrecog/recognizer/sar.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import RECOGNIZERS +from mmocr.registry import MODELS from .encode_decode_recognizer import EncodeDecodeRecognizer -@RECOGNIZERS.register_module() +@MODELS.register_module() class SARNet(EncodeDecodeRecognizer): """Implementation of `SAR `_""" diff --git a/mmocr/models/textrecog/recognizer/satrn.py b/mmocr/models/textrecog/recognizer/satrn.py index c2d3121b..eb3aa4c5 100644 --- a/mmocr/models/textrecog/recognizer/satrn.py +++ b/mmocr/models/textrecog/recognizer/satrn.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmocr.models.builder import RECOGNIZERS +from mmocr.registry import MODELS from .encode_decode_recognizer import EncodeDecodeRecognizer -@RECOGNIZERS.register_module() +@MODELS.register_module() class SATRN(EncodeDecodeRecognizer): """Implementation of `SATRN `_""" diff --git a/mmocr/models/textrecog/recognizer/seg_recognizer.py b/mmocr/models/textrecog/recognizer/seg_recognizer.py index 1746dbf9..8b7bd63f 100644 --- a/mmocr/models/textrecog/recognizer/seg_recognizer.py +++ b/mmocr/models/textrecog/recognizer/seg_recognizer.py @@ -1,13 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings -from mmocr.models.builder import (RECOGNIZERS, build_backbone, build_convertor, - build_head, build_loss, build_neck, - build_preprocessor) +from mmocr.registry import MODELS from .base import BaseRecognizer -@RECOGNIZERS.register_module() +@MODELS.register_module() class SegRecognizer(BaseRecognizer): """Base class for segmentation based recognizer.""" @@ -26,29 +24,29 @@ class SegRecognizer(BaseRecognizer): # Label_convertor assert label_convertor is not None - self.label_convertor = build_convertor(label_convertor) + self.label_convertor = MODELS.build(label_convertor) # Preprocessor module, e.g., TPS self.preprocessor = None if preprocessor is not None: - self.preprocessor = build_preprocessor(preprocessor) + self.preprocessor = MODELS.build(preprocessor) # Backbone assert backbone is not None - self.backbone = build_backbone(backbone) + self.backbone = MODELS.build(backbone) # Neck assert neck is not None - self.neck = build_neck(neck) + self.neck = MODELS.build(neck) # Head assert head is not None head.update(num_classes=self.label_convertor.num_classes()) - self.head = build_head(head) + self.head = MODELS.build(head) # Loss assert loss is not None - self.loss = build_loss(loss) + self.loss = MODELS.build(loss) self.train_cfg = train_cfg self.test_cfg = test_cfg diff --git a/mmocr/utils/ocr.py b/mmocr/utils/ocr.py index 18bc454c..8f16b88b 100755 --- a/mmocr/utils/ocr.py +++ b/mmocr/utils/ocr.py @@ -24,9 +24,9 @@ from mmocr.apis.inference import model_inference from mmocr.core.visualize import det_recog_show_result from mmocr.datasets.kie_dataset import KIEDataset from mmocr.datasets.pipelines.crop import crop_img -from mmocr.models import build_detector from mmocr.models.textdet.detectors import TextDetectorMixin from mmocr.models.textrecog.recognizer import BaseRecognizer +from mmocr.registry import MODELS from mmocr.utils import is_type_list from mmocr.utils.box_util import stitch_boxes_into_lines from mmocr.utils.fileio import list_from_file @@ -427,7 +427,7 @@ class MMOCR: 'kie/' + kie_models[self.kie]['ckpt'] kie_cfg = Config.fromfile(kie_config) - self.kie_model = build_detector( + self.kie_model = MODELS.build( kie_cfg.model, test_cfg=kie_cfg.get('test_cfg')) self.kie_model = revert_sync_batchnorm(self.kie_model) self.kie_model.cfg = kie_cfg