[Refactor] union to MODELS

pull/1178/head
liukuikun 2022-05-12 03:01:34 +00:00 committed by gaotongxiao
parent 3f24e34a5d
commit 23458f8a47
99 changed files with 240 additions and 382 deletions

View File

@ -3,7 +3,6 @@ from argparse import ArgumentParser
from mmocr.apis import init_detector from mmocr.apis import init_detector
from mmocr.apis.inference import text_model_inference from mmocr.apis.inference import text_model_inference
from mmocr.models import build_detector # NOQA
from mmocr.registry import DATASETS # NOQA from mmocr.registry import DATASETS # NOQA

View File

@ -5,7 +5,6 @@ import cv2
import torch import torch
from mmocr.apis import init_detector, model_inference from mmocr.apis import init_detector, model_inference
from mmocr.models import build_detector # noqa: F401
from mmocr.registry import DATASETS # noqa: F401 from mmocr.registry import DATASETS # noqa: F401

View File

@ -11,7 +11,7 @@ from mmdet.core import get_classes
from mmdet.datasets import replace_ImageToTensor from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose from mmdet.datasets.pipelines import Compose
from mmocr.models import build_detector from mmocr.registry import MODELS
from mmocr.utils import is_2dlist from mmocr.utils import is_2dlist
from .utils import disable_text_recog_aug_test from .utils import disable_text_recog_aug_test
@ -40,7 +40,7 @@ def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
if config.model.get('pretrained'): if config.model.get('pretrained'):
config.model.pretrained = None config.model.pretrained = None
config.model.train_cfg = None config.model.train_cfg = None
model = build_detector(config.model, test_cfg=config.get('test_cfg')) model = MODELS.build(config.model, test_cfg=config.get('test_cfg'))
if checkpoint is not None: if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if 'CLASSES' in checkpoint.get('meta', {}): if 'CLASSES' in checkpoint.get('meta', {}):

View File

@ -5,7 +5,6 @@ from typing import Any, Iterable
import numpy as np import numpy as np
import torch import torch
from mmdet.models.builder import DETECTORS
from mmocr.models.textdet.detectors.single_stage_text_detector import \ from mmocr.models.textdet.detectors.single_stage_text_detector import \
SingleStageTextDetector SingleStageTextDetector
@ -13,6 +12,7 @@ from mmocr.models.textdet.detectors.text_detector_mixin import \
TextDetectorMixin TextDetectorMixin
from mmocr.models.textrecog.recognizer.encode_decode_recognizer import \ from mmocr.models.textrecog.recognizer.encode_decode_recognizer import \
EncodeDecodeRecognizer EncodeDecodeRecognizer
from mmocr.registry import MODELS
def inference_with_session(sess, io_binding, input_name, output_names, def inference_with_session(sess, io_binding, input_name, output_names,
@ -34,7 +34,7 @@ def inference_with_session(sess, io_binding, input_name, output_names,
return pred return pred
@DETECTORS.register_module() @MODELS.register_module()
class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector): class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector):
"""The class for evaluating onnx file of detection.""" """The class for evaluating onnx file of detection."""
@ -110,7 +110,7 @@ class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector):
return boundaries return boundaries
@DETECTORS.register_module() @MODELS.register_module()
class ONNXRuntimeRecognizer(EncodeDecodeRecognizer): class ONNXRuntimeRecognizer(EncodeDecodeRecognizer):
"""The class for evaluating onnx file of recognition.""" """The class for evaluating onnx file of recognition."""
@ -201,7 +201,7 @@ class ONNXRuntimeRecognizer(EncodeDecodeRecognizer):
return results return results
@DETECTORS.register_module() @MODELS.register_module()
class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector): class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector):
"""The class for evaluating TensorRT file of detection.""" """The class for evaluating TensorRT file of detection."""
@ -257,7 +257,7 @@ class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector):
return boundaries return boundaries
@DETECTORS.register_module() @MODELS.register_module()
class TensorRTRecognizer(EncodeDecodeRecognizer): class TensorRTRecognizer(EncodeDecodeRecognizer):
"""The class for evaluating TensorRT file of recognition.""" """The class for evaluating TensorRT file of recognition."""

View File

@ -1,8 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from mmocr.models.builder import build_convertor from mmocr.registry import MODELS, TRANSFORMS
from mmocr.registry import TRANSFORMS
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
@ -18,7 +17,7 @@ class NerTransform:
""" """
def __init__(self, label_convertor, max_len): def __init__(self, label_convertor, max_len):
self.label_convertor = build_convertor(label_convertor) self.label_convertor = MODELS.build(label_convertor)
self.max_len = max_len self.max_len = max_len
def __call__(self, results): def __call__(self, results):

View File

@ -4,8 +4,7 @@ import numpy as np
from mmdet.core import BitmapMasks from mmdet.core import BitmapMasks
import mmocr.utils.check_argument as check_argument import mmocr.utils.check_argument as check_argument
from mmocr.models.builder import build_convertor from mmocr.registry import MODELS, TRANSFORMS
from mmocr.registry import TRANSFORMS
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
@ -41,7 +40,7 @@ class OCRSegTargets:
self.attn_shrink_ratio = attn_shrink_ratio self.attn_shrink_ratio = attn_shrink_ratio
self.seg_shrink_ratio = seg_shrink_ratio self.seg_shrink_ratio = seg_shrink_ratio
self.label_convertor = build_convertor(label_convertor) self.label_convertor = MODELS.build(label_convertor)
self.box_type = box_type self.box_type = box_type
self.pad_val = pad_val self.pad_val = pad_val

View File

@ -1,19 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from . import common, kie, textdet, textrecog from . import common, kie, textdet, textrecog
from .builder import (BACKBONES, CONVERTORS, DECODERS, DETECTORS, ENCODERS,
HEADS, LOSSES, NECKS, PREPROCESSOR, build_backbone,
build_convertor, build_decoder, build_detector,
build_encoder, build_loss, build_preprocessor)
from .common import * # NOQA from .common import * # NOQA
from .kie import * # NOQA from .kie import * # NOQA
from .ner import * # NOQA from .ner import * # NOQA
from .textdet import * # NOQA from .textdet import * # NOQA
from .textrecog import * # NOQA from .textrecog import * # NOQA
__all__ = [ __all__ = common.__all__ + kie.__all__ + textdet.__all__ + textrecog.__all__
'BACKBONES', 'DETECTORS', 'HEADS', 'LOSSES', 'NECKS', 'build_backbone',
'build_detector', 'build_loss', 'CONVERTORS', 'ENCODERS', 'DECODERS',
'PREPROCESSOR', 'build_convertor', 'build_encoder', 'build_decoder',
'build_preprocessor'
]
__all__ += common.__all__ + kie.__all__ + textdet.__all__ + textrecog.__all__

View File

@ -1,115 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import ACTIVATION_LAYERS as MMCV_ACTIVATION_LAYERS
from mmcv.cnn import UPSAMPLE_LAYERS as MMCV_UPSAMPLE_LAYERS
from mmcv.utils import Registry, build_from_cfg
from mmocr.registry import MODELS UPSAMPLE_LAYERS = Registry('upsample layer', parent=MMCV_UPSAMPLE_LAYERS)
ACTIVATION_LAYERS = Registry('activation layer', parent=MMCV_ACTIVATION_LAYERS)
CONVERTORS = MODELS
ENCODERS = MODELS
DECODERS = MODELS
PREPROCESSOR = MODELS
POSTPROCESSOR = MODELS
UPSAMPLE_LAYERS = MODELS
BACKBONES = MODELS
LOSSES = MODELS
DETECTORS = MODELS
ROI_EXTRACTORS = MODELS
HEADS = MODELS
NECKS = MODELS
FUSERS = MODELS
RECOGNIZERS = MODELS
ACTIVATION_LAYERS = MODELS
def build_recognizer(cfg, train_cfg=None, test_cfg=None):
"""Build recognizer."""
warnings.warn('``build_recognizer`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return RECOGNIZERS(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_convertor(cfg):
"""Build label convertor for scene text recognizer."""
warnings.warn('``build_convertor`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return CONVERTORS.build(cfg)
def build_encoder(cfg):
"""Build encoder for scene text recognizer."""
warnings.warn('``build_encoder`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return ENCODERS.build(cfg)
def build_decoder(cfg):
"""Build decoder for scene text recognizer."""
warnings.warn('``build_decoder`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return DECODERS.build(cfg)
def build_preprocessor(cfg):
"""Build preprocessor for scene text recognizer."""
warnings.warn(
'``build_preprocessor`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return PREPROCESSOR(cfg)
def build_postprocessor(cfg):
"""Build postprocessor for scene text detector."""
warnings.warn(
'``build_postprocessor`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return POSTPROCESSOR.build(cfg)
def build_roi_extractor(cfg):
"""Build roi extractor."""
warnings.warn(
'``build_roi_extractor`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return ROI_EXTRACTORS.build(cfg)
def build_loss(cfg):
"""Build loss."""
warnings.warn('``build_loss`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return LOSSES.build(cfg)
def build_backbone(cfg):
"""Build backbone."""
warnings.warn('``build_backbone`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return BACKBONES.build(cfg)
def build_head(cfg):
"""Build head."""
warnings.warn('``build_head`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return HEADS.build(cfg)
def build_neck(cfg):
"""Build neck."""
warnings.warn('``build_neck`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return NECKS.build(cfg)
def build_fuser(cfg):
"""Build fuser."""
warnings.warn('``build_fuser`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return FUSERS.build(cfg)
def build_upsample_layer(cfg, *args, **kwargs): def build_upsample_layer(cfg, *args, **kwargs):
@ -160,21 +56,4 @@ def build_activation_layer(cfg):
Returns: Returns:
nn.Module: Created activation layer. nn.Module: Created activation layer.
""" """
warnings.warn( return build_from_cfg(cfg, ACTIVATION_LAYERS)
'``build_activation_layer`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return ACTIVATION_LAYERS.build(cfg)
def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
if train_cfg is not None or test_cfg is not None:
warnings.warn(
'train_cfg and test_cfg is deprecated, '
'please specify them in model', UserWarning)
assert cfg.get('train_cfg') is None or train_cfg is None, \
'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field '
return DETECTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))

View File

@ -6,8 +6,9 @@ from mmcv.cnn import ConvModule, build_norm_layer
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm from mmcv.utils.parrots_wrapper import _BatchNorm
from mmocr.models.builder import (BACKBONES, UPSAMPLE_LAYERS, from mmocr.models.builder import (UPSAMPLE_LAYERS, build_activation_layer,
build_activation_layer, build_upsample_layer) build_upsample_layer)
from mmocr.registry import MODELS
class UpConvBlock(nn.Module): class UpConvBlock(nn.Module):
@ -317,7 +318,7 @@ class InterpConv(nn.Module):
return out return out
@BACKBONES.register_module() @MODELS.register_module()
class UNet(BaseModule): class UNet(BaseModule):
"""UNet backbone. """UNet backbone.
U-Net: Convolutional Networks for Biomedical Image Segmentation. U-Net: Convolutional Networks for Biomedical Image Segmentation.

View File

@ -4,11 +4,10 @@ import warnings
from mmdet.models.detectors import \ from mmdet.models.detectors import \
SingleStageDetector as MMDET_SingleStageDetector SingleStageDetector as MMDET_SingleStageDetector
from mmocr.models.builder import (DETECTORS, build_backbone, build_head, from mmocr.registry import MODELS
build_neck)
@DETECTORS.register_module() @MODELS.register_module()
class SingleStageDetector(MMDET_SingleStageDetector): class SingleStageDetector(MMDET_SingleStageDetector):
"""Base class for single-stage detectors. """Base class for single-stage detectors.
@ -29,11 +28,11 @@ class SingleStageDetector(MMDET_SingleStageDetector):
warnings.warn('DeprecationWarning: pretrained is deprecated, ' warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead') 'please use "init_cfg" instead')
backbone.pretrained = pretrained backbone.pretrained = pretrained
self.backbone = build_backbone(backbone) self.backbone = MODELS.build(backbone)
if neck is not None: if neck is not None:
self.neck = build_neck(neck) self.neck = MODELS.build(neck)
bbox_head.update(train_cfg=train_cfg) bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg) bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head) self.bbox_head = MODELS.build(bbox_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg

View File

@ -2,10 +2,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmocr.models.builder import LOSSES from mmocr.registry import MODELS
@LOSSES.register_module() @MODELS.register_module()
class DiceLoss(nn.Module): class DiceLoss(nn.Module):
def __init__(self, eps=1e-6): def __init__(self, eps=1e-6):

View File

@ -7,12 +7,12 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from mmocr.core import imshow_edge, imshow_node from mmocr.core import imshow_edge, imshow_node
from mmocr.models.builder import DETECTORS, build_roi_extractor
from mmocr.models.common.detectors import SingleStageDetector from mmocr.models.common.detectors import SingleStageDetector
from mmocr.registry import MODELS
from mmocr.utils import list_from_file from mmocr.utils import list_from_file
@DETECTORS.register_module() @MODELS.register_module()
class SDMGR(SingleStageDetector): class SDMGR(SingleStageDetector):
"""The implementation of the paper: Spatial Dual-Modality Graph Reasoning """The implementation of the paper: Spatial Dual-Modality Graph Reasoning
for Key Information Extraction. https://arxiv.org/abs/2103.14470. for Key Information Extraction. https://arxiv.org/abs/2103.14470.
@ -42,7 +42,7 @@ class SDMGR(SingleStageDetector):
backbone, neck, bbox_head, train_cfg, test_cfg, init_cfg=init_cfg) backbone, neck, bbox_head, train_cfg, test_cfg, init_cfg=init_cfg)
self.visual_modality = visual_modality self.visual_modality = visual_modality
if visual_modality: if visual_modality:
self.extractor = build_roi_extractor({ self.extractor = MODELS.build({
**extractor, 'out_channels': **extractor, 'out_channels':
self.backbone.base_channels self.backbone.base_channels
}) })

View File

@ -4,10 +4,10 @@ from mmcv.runner import BaseModule
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from mmocr.models.builder import HEADS, build_loss from mmocr.registry import MODELS
@HEADS.register_module() @MODELS.register_module()
class SDMGRHead(BaseModule): class SDMGRHead(BaseModule):
def __init__(self, def __init__(self,
@ -45,7 +45,7 @@ class SDMGRHead(BaseModule):
[GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)]) [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
self.node_cls = nn.Linear(node_embed, num_classes) self.node_cls = nn.Linear(node_embed, num_classes)
self.edge_cls = nn.Linear(edge_embed, 2) self.edge_cls = nn.Linear(edge_embed, 2)
self.loss = build_loss(loss) self.loss = MODELS.build(loss)
def forward(self, relations, texts, x=None): def forward(self, relations, texts, x=None):
node_nums, char_nums = [], [] node_nums, char_nums = [], []

View File

@ -3,10 +3,10 @@ import torch
from mmdet.models.losses import accuracy from mmdet.models.losses import accuracy
from torch import nn from torch import nn
from mmocr.models.builder import LOSSES from mmocr.registry import MODELS
@LOSSES.register_module() @MODELS.register_module()
class SDMGRLoss(nn.Module): class SDMGRLoss(nn.Module):
"""The implementation the loss of key information extraction proposed in """The implementation the loss of key information extraction proposed in
the paper: Spatial Dual-Modality Graph Reasoning for Key Information the paper: Spatial Dual-Modality Graph Reasoning for Key Information

View File

@ -1,10 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import (DETECTORS, build_convertor, build_decoder,
build_encoder, build_loss)
from mmocr.models.textrecog.recognizer.base import BaseRecognizer from mmocr.models.textrecog.recognizer.base import BaseRecognizer
from mmocr.registry import MODELS
@DETECTORS.register_module() @MODELS.register_module()
class NerClassifier(BaseRecognizer): class NerClassifier(BaseRecognizer):
"""Base class for NER classifier.""" """Base class for NER classifier."""
@ -17,15 +16,15 @@ class NerClassifier(BaseRecognizer):
test_cfg=None, test_cfg=None,
init_cfg=None): init_cfg=None):
super().__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.label_convertor = build_convertor(label_convertor) self.label_convertor = MODELS.build(label_convertor)
self.encoder = build_encoder(encoder) self.encoder = MODELS.build(encoder)
decoder.update(num_labels=self.label_convertor.num_labels) decoder.update(num_labels=self.label_convertor.num_labels)
self.decoder = build_decoder(decoder) self.decoder = MODELS.build(decoder)
loss.update(num_labels=self.label_convertor.num_labels) loss.update(num_labels=self.label_convertor.num_labels)
self.loss = build_loss(loss) self.loss = MODELS.build(loss)
def extract_feat(self, imgs): def extract_feat(self, imgs):
"""Extract features from images.""" """Extract features from images."""

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import numpy as np import numpy as np
from mmocr.models.builder import CONVERTORS from mmocr.registry import MODELS
from mmocr.utils import list_from_file from mmocr.utils import list_from_file
@CONVERTORS.register_module() @MODELS.register_module()
class NerConvertor: class NerConvertor:
"""Convert between text, index and tensor for NER pipeline. """Convert between text, index and tensor for NER pipeline.

View File

@ -4,10 +4,10 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmocr.models.builder import DECODERS from mmocr.registry import MODELS
@DECODERS.register_module() @MODELS.register_module()
class FCDecoder(BaseModule): class FCDecoder(BaseModule):
"""FC Decoder class for Ner. """FC Decoder class for Ner.

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmocr.models.builder import ENCODERS
from mmocr.models.ner.utils.bert import BertModel from mmocr.models.ner.utils.bert import BertModel
from mmocr.registry import MODELS
@ENCODERS.register_module() @MODELS.register_module()
class BertEncoder(BaseModule): class BertEncoder(BaseModule):
"""Bert encoder """Bert encoder
Args: Args:

View File

@ -2,10 +2,10 @@
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from mmocr.models.builder import LOSSES from mmocr.registry import MODELS
@LOSSES.register_module() @MODELS.register_module()
class MaskedCrossEntropyLoss(nn.Module): class MaskedCrossEntropyLoss(nn.Module):
"""The implementation of masked cross entropy loss. """The implementation of masked cross entropy loss.

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from torch import nn from torch import nn
from mmocr.models.builder import LOSSES
from mmocr.models.common.losses.focal_loss import FocalLoss from mmocr.models.common.losses.focal_loss import FocalLoss
from mmocr.registry import MODELS
@LOSSES.register_module() @MODELS.register_module()
class MaskedFocalLoss(nn.Module): class MaskedFocalLoss(nn.Module):
"""The implementation of masked focal loss. """The implementation of masked focal loss.

View File

@ -5,11 +5,11 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmcv.runner import BaseModule, Sequential from mmcv.runner import BaseModule, Sequential
from mmocr.models.builder import HEADS from mmocr.registry import MODELS
from .head_mixin import HeadMixin from .head_mixin import HeadMixin
@HEADS.register_module() @MODELS.register_module()
class DBHead(HeadMixin, BaseModule): class DBHead(HeadMixin, BaseModule):
"""The class for DBNet head. """The class for DBNet head.

View File

@ -7,13 +7,13 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmocr.models.builder import HEADS, build_loss
from mmocr.models.textdet.modules import GCN, LocalGraphs, ProposalLocalGraphs from mmocr.models.textdet.modules import GCN, LocalGraphs, ProposalLocalGraphs
from mmocr.registry import MODELS
from mmocr.utils import check_argument from mmocr.utils import check_argument
from .head_mixin import HeadMixin from .head_mixin import HeadMixin
@HEADS.register_module() @MODELS.register_module()
class DRRGHead(HeadMixin, BaseModule): class DRRGHead(HeadMixin, BaseModule):
"""The class for DRRG head: `Deep Relational Reasoning Graph Network for """The class for DRRG head: `Deep Relational Reasoning Graph Network for
Arbitrary Shape Text Detection <https://arxiv.org/abs/2003.07493>`_. Arbitrary Shape Text Detection <https://arxiv.org/abs/2003.07493>`_.
@ -118,7 +118,7 @@ class DRRGHead(HeadMixin, BaseModule):
self.center_region_thr = center_region_thr self.center_region_thr = center_region_thr
self.center_region_area_thr = center_region_area_thr self.center_region_area_thr = center_region_area_thr
self.local_graph_thr = local_graph_thr self.local_graph_thr = local_graph_thr
self.loss_module = build_loss(loss) self.loss_module = MODELS.build(loss)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg

View File

@ -5,12 +5,12 @@ import torch.nn as nn
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmdet.core import multi_apply from mmdet.core import multi_apply
from mmocr.models.builder import HEADS from mmocr.registry import MODELS
from ..postprocess.utils import poly_nms from ..postprocess.utils import poly_nms
from .head_mixin import HeadMixin from .head_mixin import HeadMixin
@HEADS.register_module() @MODELS.register_module()
class FCEHead(HeadMixin, BaseModule): class FCEHead(HeadMixin, BaseModule):
"""The class for implementing FCENet head. """The class for implementing FCENet head.

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import numpy as np import numpy as np
from mmocr.models.builder import HEADS, build_loss, build_postprocessor from mmocr.registry import MODELS
from mmocr.utils import check_argument from mmocr.utils import check_argument
@HEADS.register_module() @MODELS.register_module()
class HeadMixin: class HeadMixin:
"""Base head class for text detection, including loss calcalation and """Base head class for text detection, including loss calcalation and
postprocess. postprocess.
@ -19,8 +19,8 @@ class HeadMixin:
assert isinstance(loss, dict) assert isinstance(loss, dict)
assert isinstance(postprocessor, dict) assert isinstance(postprocessor, dict)
self.loss_module = build_loss(loss) self.loss_module = MODELS.build(loss)
self.postprocessor = build_postprocessor(postprocessor) self.postprocessor = MODELS.build(postprocessor)
def resize_boundary(self, boundaries, scale_factor): def resize_boundary(self, boundaries, scale_factor):
"""Rescale boundaries via scale_factor. """Rescale boundaries via scale_factor.

View File

@ -6,12 +6,12 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmocr.models.builder import HEADS from mmocr.registry import MODELS
from mmocr.utils import check_argument from mmocr.utils import check_argument
from .head_mixin import HeadMixin from .head_mixin import HeadMixin
@HEADS.register_module() @MODELS.register_module()
class PANHead(HeadMixin, BaseModule): class PANHead(HeadMixin, BaseModule):
"""The class for PANet head. """The class for PANet head.

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import HEADS from mmocr.registry import MODELS
from . import PANHead from . import PANHead
@HEADS.register_module() @MODELS.register_module()
class PSEHead(PANHead): class PSEHead(PANHead):
"""The class for PSENet head. """The class for PSENet head.

View File

@ -4,11 +4,11 @@ import warnings
import torch.nn as nn import torch.nn as nn
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmocr.models.builder import HEADS from mmocr.registry import MODELS
from .head_mixin import HeadMixin from .head_mixin import HeadMixin
@HEADS.register_module() @MODELS.register_module()
class TextSnakeHead(HeadMixin, BaseModule): class TextSnakeHead(HeadMixin, BaseModule):
"""The class for TextSnake head: TextSnake: A Flexible Representation for """The class for TextSnake head: TextSnake: A Flexible Representation for
Detecting Text of Arbitrary Shapes. Detecting Text of Arbitrary Shapes.

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import DETECTORS from mmocr.registry import MODELS
from .single_stage_text_detector import SingleStageTextDetector from .single_stage_text_detector import SingleStageTextDetector
from .text_detector_mixin import TextDetectorMixin from .text_detector_mixin import TextDetectorMixin
@DETECTORS.register_module() @MODELS.register_module()
class DBNet(TextDetectorMixin, SingleStageTextDetector): class DBNet(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing DBNet text detector: Real-time Scene Text """The class for implementing DBNet text detector: Real-time Scene Text
Detection with Differentiable Binarization. Detection with Differentiable Binarization.

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import DETECTORS from mmocr.registry import MODELS
from .single_stage_text_detector import SingleStageTextDetector from .single_stage_text_detector import SingleStageTextDetector
from .text_detector_mixin import TextDetectorMixin from .text_detector_mixin import TextDetectorMixin
@DETECTORS.register_module() @MODELS.register_module()
class DRRG(TextDetectorMixin, SingleStageTextDetector): class DRRG(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing DRRG text detector. Deep Relational Reasoning """The class for implementing DRRG text detector. Deep Relational Reasoning
Graph Network for Arbitrary Shape Text Detection. Graph Network for Arbitrary Shape Text Detection.

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import DETECTORS from mmocr.registry import MODELS
from .single_stage_text_detector import SingleStageTextDetector from .single_stage_text_detector import SingleStageTextDetector
from .text_detector_mixin import TextDetectorMixin from .text_detector_mixin import TextDetectorMixin
@DETECTORS.register_module() @MODELS.register_module()
class FCENet(TextDetectorMixin, SingleStageTextDetector): class FCENet(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing FCENet text detector """The class for implementing FCENet text detector
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text

View File

@ -2,11 +2,11 @@
from mmdet.models.detectors import MaskRCNN from mmdet.models.detectors import MaskRCNN
from mmocr.core import seg2boundary from mmocr.core import seg2boundary
from mmocr.models.builder import DETECTORS from mmocr.registry import MODELS
from .text_detector_mixin import TextDetectorMixin from .text_detector_mixin import TextDetectorMixin
@DETECTORS.register_module() @MODELS.register_module()
class OCRMaskRCNN(TextDetectorMixin, MaskRCNN): class OCRMaskRCNN(TextDetectorMixin, MaskRCNN):
"""Mask RCNN tailored for OCR.""" """Mask RCNN tailored for OCR."""

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import DETECTORS from mmocr.registry import MODELS
from .single_stage_text_detector import SingleStageTextDetector from .single_stage_text_detector import SingleStageTextDetector
from .text_detector_mixin import TextDetectorMixin from .text_detector_mixin import TextDetectorMixin
@DETECTORS.register_module() @MODELS.register_module()
class PANet(TextDetectorMixin, SingleStageTextDetector): class PANet(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing PANet text detector: """The class for implementing PANet text detector:

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import DETECTORS from mmocr.registry import MODELS
from .single_stage_text_detector import SingleStageTextDetector from .single_stage_text_detector import SingleStageTextDetector
from .text_detector_mixin import TextDetectorMixin from .text_detector_mixin import TextDetectorMixin
@DETECTORS.register_module() @MODELS.register_module()
class PSENet(TextDetectorMixin, SingleStageTextDetector): class PSENet(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing PSENet text detector: Shape Robust Text """The class for implementing PSENet text detector: Shape Robust Text
Detection with Progressive Scale Expansion Network. Detection with Progressive Scale Expansion Network.

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from mmocr.models.builder import DETECTORS
from mmocr.models.common.detectors import SingleStageDetector from mmocr.models.common.detectors import SingleStageDetector
from mmocr.registry import MODELS
@DETECTORS.register_module() @MODELS.register_module()
class SingleStageTextDetector(SingleStageDetector): class SingleStageTextDetector(SingleStageDetector):
"""The class for implementing single stage text detector.""" """The class for implementing single stage text detector."""

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import DETECTORS from mmocr.registry import MODELS
from .single_stage_text_detector import SingleStageTextDetector from .single_stage_text_detector import SingleStageTextDetector
from .text_detector_mixin import TextDetectorMixin from .text_detector_mixin import TextDetectorMixin
@DETECTORS.register_module() @MODELS.register_module()
class TextSnake(TextDetectorMixin, SingleStageTextDetector): class TextSnake(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing TextSnake text detector: TextSnake: A """The class for implementing TextSnake text detector: TextSnake: A
Flexible Representation for Detecting Text of Arbitrary Shapes. Flexible Representation for Detecting Text of Arbitrary Shapes.

View File

@ -3,11 +3,11 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from mmocr.models.builder import LOSSES
from mmocr.models.common.losses.dice_loss import DiceLoss from mmocr.models.common.losses.dice_loss import DiceLoss
from mmocr.registry import MODELS
@LOSSES.register_module() @MODELS.register_module()
class DBLoss(nn.Module): class DBLoss(nn.Module):
"""The class for implementing DBNet loss. """The class for implementing DBNet loss.

View File

@ -4,11 +4,11 @@ import torch.nn.functional as F
from mmdet.core import BitmapMasks from mmdet.core import BitmapMasks
from torch import nn from torch import nn
from mmocr.models.builder import LOSSES from mmocr.registry import MODELS
from mmocr.utils import check_argument from mmocr.utils import check_argument
@LOSSES.register_module() @MODELS.register_module()
class DRRGLoss(nn.Module): class DRRGLoss(nn.Module):
"""The class for implementing DRRG loss. This is partially adapted from """The class for implementing DRRG loss. This is partially adapted from
https://github.com/GXYM/DRRG licensed under the MIT license. https://github.com/GXYM/DRRG licensed under the MIT license.

View File

@ -5,10 +5,10 @@ import torch.nn.functional as F
from mmdet.core import multi_apply from mmdet.core import multi_apply
from torch import nn from torch import nn
from mmocr.models.builder import LOSSES from mmocr.registry import MODELS
@LOSSES.register_module() @MODELS.register_module()
class FCELoss(nn.Module): class FCELoss(nn.Module):
"""The class for implementing FCENet loss. """The class for implementing FCENet loss.

View File

@ -8,11 +8,11 @@ import torch.nn.functional as F
from mmdet.core import BitmapMasks from mmdet.core import BitmapMasks
from torch import nn from torch import nn
from mmocr.models.builder import LOSSES from mmocr.registry import MODELS
from mmocr.utils import check_argument from mmocr.utils import check_argument
@LOSSES.register_module() @MODELS.register_module()
class PANLoss(nn.Module): class PANLoss(nn.Module):
"""The class for implementing PANet loss. This was partially adapted from """The class for implementing PANet loss. This was partially adapted from
https://github.com/WenmuZhou/PAN.pytorch. https://github.com/WenmuZhou/PAN.pytorch.

View File

@ -1,12 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdet.core import BitmapMasks from mmdet.core import BitmapMasks
from mmocr.models.builder import LOSSES from mmocr.registry import MODELS
from mmocr.utils import check_argument from mmocr.utils import check_argument
from . import PANLoss from . import PANLoss
@LOSSES.register_module() @MODELS.register_module()
class PSELoss(PANLoss): class PSELoss(PANLoss):
r"""The class for implementing PSENet loss. This is partially adapted from r"""The class for implementing PSENet loss. This is partially adapted from
https://github.com/whai362/PSENet. https://github.com/whai362/PSENet.

View File

@ -4,11 +4,11 @@ import torch.nn.functional as F
from mmdet.core import BitmapMasks from mmdet.core import BitmapMasks
from torch import nn from torch import nn
from mmocr.models.builder import LOSSES from mmocr.registry import MODELS
from mmocr.utils import check_argument from mmocr.utils import check_argument
@LOSSES.register_module() @MODELS.register_module()
class TextSnakeLoss(nn.Module): class TextSnakeLoss(nn.Module):
"""The class for implementing TextSnake loss. This is partially adapted """The class for implementing TextSnake loss. This is partially adapted
from https://github.com/princewang1994/TextSnake.pytorch. from https://github.com/princewang1994/TextSnake.pytorch.

View File

@ -3,7 +3,7 @@ import torch.nn.functional as F
from mmcv.runner import BaseModule, ModuleList from mmcv.runner import BaseModule, ModuleList
from torch import nn from torch import nn
from mmocr.models.builder import NECKS from mmocr.registry import MODELS
class FPEM(BaseModule): class FPEM(BaseModule):
@ -72,7 +72,7 @@ class SeparableConv2d(BaseModule):
return x return x
@NECKS.register_module() @MODELS.register_module()
class FPEM_FFM(BaseModule): class FPEM_FFM(BaseModule):
"""This code is from https://github.com/WenmuZhou/PAN.pytorch. """This code is from https://github.com/WenmuZhou/PAN.pytorch.

View File

@ -5,10 +5,10 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, ModuleList, Sequential, auto_fp16 from mmcv.runner import BaseModule, ModuleList, Sequential, auto_fp16
from mmocr.models.builder import NECKS from mmocr.registry import MODELS
@NECKS.register_module() @MODELS.register_module()
class FPNC(BaseModule): class FPNC(BaseModule):
"""FPN-like fusion module in Real-time Scene Text Detection with """FPN-like fusion module in Real-time Scene Text Detection with
Differentiable Binarization. Differentiable Binarization.

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch import nn from torch import nn
from mmocr.models.builder import NECKS from mmocr.registry import MODELS
class UpBlock(BaseModule): class UpBlock(BaseModule):
@ -30,7 +30,7 @@ class UpBlock(BaseModule):
return x return x
@NECKS.register_module() @MODELS.register_module()
class FPN_UNet(BaseModule): class FPN_UNet(BaseModule):
"""The class for implementing DRRG and TextSnake U-Net-like FPN. """The class for implementing DRRG and TextSnake U-Net-like FPN.

View File

@ -4,10 +4,10 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, ModuleList, auto_fp16 from mmcv.runner import BaseModule, ModuleList, auto_fp16
from mmocr.models.builder import NECKS from mmocr.registry import MODELS
@NECKS.register_module() @MODELS.register_module()
class FPNF(BaseModule): class FPNF(BaseModule):
"""FPN-like fusion module in Shape Robust Text Detection with Progressive """FPN-like fusion module in Shape Robust Text Detection with Progressive
Scale Expansion Network. Scale Expansion Network.

View File

@ -3,12 +3,12 @@ import cv2
import numpy as np import numpy as np
from mmocr.core import points2boundary from mmocr.core import points2boundary
from mmocr.models.builder import POSTPROCESSOR from mmocr.registry import MODELS
from .base_postprocessor import BasePostprocessor from .base_postprocessor import BasePostprocessor
from .utils import box_score_fast, unclip from .utils import box_score_fast, unclip
@POSTPROCESSOR.register_module() @MODELS.register_module()
class DBPostprocessor(BasePostprocessor): class DBPostprocessor(BasePostprocessor):
"""Decoding predictions of DbNet to instances. This is partially adapted """Decoding predictions of DbNet to instances. This is partially adapted
from https://github.com/MhLiao/DB. from https://github.com/MhLiao/DB.

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import POSTPROCESSOR from mmocr.registry import MODELS
from .base_postprocessor import BasePostprocessor from .base_postprocessor import BasePostprocessor
from .utils import (clusters2labels, comps2boundaries, connected_components, from .utils import (clusters2labels, comps2boundaries, connected_components,
graph_propagation, remove_single) graph_propagation, remove_single)
@POSTPROCESSOR.register_module() @MODELS.register_module()
class DRRGPostprocessor(BasePostprocessor): class DRRGPostprocessor(BasePostprocessor):
"""Merge text components and construct boundaries of text instances. """Merge text components and construct boundaries of text instances.

View File

@ -2,12 +2,12 @@
import cv2 import cv2
import numpy as np import numpy as np
from mmocr.models.builder import POSTPROCESSOR from mmocr.registry import MODELS
from .base_postprocessor import BasePostprocessor from .base_postprocessor import BasePostprocessor
from .utils import fill_hole, fourier2poly, poly_nms from .utils import fill_hole, fourier2poly, poly_nms
@POSTPROCESSOR.register_module() @MODELS.register_module()
class FCEPostprocessor(BasePostprocessor): class FCEPostprocessor(BasePostprocessor):
"""Decoding predictions of FCENet to instances. """Decoding predictions of FCENet to instances.

View File

@ -5,11 +5,11 @@ import torch
from mmcv.ops import pixel_group from mmcv.ops import pixel_group
from mmocr.core import points2boundary from mmocr.core import points2boundary
from mmocr.models.builder import POSTPROCESSOR from mmocr.registry import MODELS
from .base_postprocessor import BasePostprocessor from .base_postprocessor import BasePostprocessor
@POSTPROCESSOR.register_module() @MODELS.register_module()
class PANPostprocessor(BasePostprocessor): class PANPostprocessor(BasePostprocessor):
"""Convert scores to quadrangles via post processing in PANet. This is """Convert scores to quadrangles via post processing in PANet. This is
partially adapted from https://github.com/WenmuZhou/PAN.pytorch. partially adapted from https://github.com/WenmuZhou/PAN.pytorch.

View File

@ -6,11 +6,11 @@ import torch
from mmcv.ops import contour_expand from mmcv.ops import contour_expand
from mmocr.core import points2boundary from mmocr.core import points2boundary
from mmocr.models.builder import POSTPROCESSOR from mmocr.registry import MODELS
from .base_postprocessor import BasePostprocessor from .base_postprocessor import BasePostprocessor
@POSTPROCESSOR.register_module() @MODELS.register_module()
class PSEPostprocessor(BasePostprocessor): class PSEPostprocessor(BasePostprocessor):
"""Decoding predictions of PSENet to instances. This is partially adapted """Decoding predictions of PSENet to instances. This is partially adapted
from https://github.com/whai362/PSENet. from https://github.com/whai362/PSENet.

View File

@ -5,12 +5,12 @@ import numpy as np
import torch import torch
from skimage.morphology import skeletonize from skimage.morphology import skeletonize
from mmocr.models.builder import POSTPROCESSOR from mmocr.registry import MODELS
from .base_postprocessor import BasePostprocessor from .base_postprocessor import BasePostprocessor
from .utils import centralize, fill_hole, merge_disks from .utils import centralize, fill_hole, merge_disks
@POSTPROCESSOR.register_module() @MODELS.register_module()
class TextSnakePostprocessor(BasePostprocessor): class TextSnakePostprocessor(BasePostprocessor):
"""Decoding predictions of TextSnake to instances. This was partially """Decoding predictions of TextSnake to instances. This was partially
adapted from https://github.com/princewang1994/TextSnake.pytorch. adapted from https://github.com/princewang1994/TextSnake.pytorch.

View File

@ -2,10 +2,10 @@
import torch.nn as nn import torch.nn as nn
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmocr.models.builder import BACKBONES from mmocr.registry import MODELS
@BACKBONES.register_module() @MODELS.register_module()
class NRTRModalityTransform(BaseModule): class NRTRModalityTransform(BaseModule):
def __init__(self, def __init__(self,

View File

@ -3,11 +3,11 @@ from mmcv.cnn import ConvModule, build_plugin_layer
from mmcv.runner import BaseModule, Sequential from mmcv.runner import BaseModule, Sequential
import mmocr.utils as utils import mmocr.utils as utils
from mmocr.models.builder import BACKBONES
from mmocr.models.textrecog.layers import BasicBlock from mmocr.models.textrecog.layers import BasicBlock
from mmocr.registry import MODELS
@BACKBONES.register_module() @MODELS.register_module()
class ResNet(BaseModule): class ResNet(BaseModule):
""" """
Args: Args:

View File

@ -3,11 +3,11 @@ import torch.nn as nn
from mmcv.runner import BaseModule, Sequential from mmcv.runner import BaseModule, Sequential
import mmocr.utils as utils import mmocr.utils as utils
from mmocr.models.builder import BACKBONES
from mmocr.models.textrecog.layers import BasicBlock from mmocr.models.textrecog.layers import BasicBlock
from mmocr.registry import MODELS
@BACKBONES.register_module() @MODELS.register_module()
class ResNet31OCR(BaseModule): class ResNet31OCR(BaseModule):
"""Implement ResNet backbone for text recognition, modified from """Implement ResNet backbone for text recognition, modified from
`ResNet <https://arxiv.org/pdf/1512.03385.pdf>`_ `ResNet <https://arxiv.org/pdf/1512.03385.pdf>`_

View File

@ -3,11 +3,11 @@ import torch.nn as nn
from mmcv.runner import BaseModule, Sequential from mmcv.runner import BaseModule, Sequential
import mmocr.utils as utils import mmocr.utils as utils
from mmocr.models.builder import BACKBONES
from mmocr.models.textrecog.layers import BasicBlock from mmocr.models.textrecog.layers import BasicBlock
from mmocr.registry import MODELS
@BACKBONES.register_module() @MODELS.register_module()
class ResNetABI(BaseModule): class ResNetABI(BaseModule):
"""Implement ResNet backbone for text recognition, modified from `ResNet. """Implement ResNet backbone for text recognition, modified from `ResNet.

View File

@ -3,10 +3,10 @@ import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmocr.models.builder import BACKBONES from mmocr.registry import MODELS
@BACKBONES.register_module() @MODELS.register_module()
class ShallowCNN(BaseModule): class ShallowCNN(BaseModule):
"""Implement Shallow CNN block for SATRN. """Implement Shallow CNN block for SATRN.

View File

@ -2,10 +2,10 @@
import torch.nn as nn import torch.nn as nn
from mmcv.runner import BaseModule, Sequential from mmcv.runner import BaseModule, Sequential
from mmocr.models.builder import BACKBONES from mmocr.registry import MODELS
@BACKBONES.register_module() @MODELS.register_module()
class VeryDeepVgg(BaseModule): class VeryDeepVgg(BaseModule):
"""Implement VGG-VeryDeep backbone for text recognition, modified from """Implement VGG-VeryDeep backbone for text recognition, modified from
`VGG-VeryDeep <https://arxiv.org/pdf/1409.1556.pdf>`_ `VGG-VeryDeep <https://arxiv.org/pdf/1409.1556.pdf>`_

View File

@ -2,11 +2,11 @@
import torch import torch
import mmocr.utils as utils import mmocr.utils as utils
from mmocr.models.builder import CONVERTORS from mmocr.registry import MODELS
from .attn import AttnConvertor from .attn import AttnConvertor
@CONVERTORS.register_module() @MODELS.register_module()
class ABIConvertor(AttnConvertor): class ABIConvertor(AttnConvertor):
"""Convert between text, index and tensor for encoder-decoder based """Convert between text, index and tensor for encoder-decoder based
pipeline. Modified from AttnConvertor to get closer to ABINet's original pipeline. Modified from AttnConvertor to get closer to ABINet's original

View File

@ -2,11 +2,11 @@
import torch import torch
import mmocr.utils as utils import mmocr.utils as utils
from mmocr.models.builder import CONVERTORS from mmocr.registry import MODELS
from .base import BaseConvertor from .base import BaseConvertor
@CONVERTORS.register_module() @MODELS.register_module()
class AttnConvertor(BaseConvertor): class AttnConvertor(BaseConvertor):
"""Convert between text, index and tensor for encoder-decoder based """Convert between text, index and tensor for encoder-decoder based
pipeline. pipeline.

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import CONVERTORS from mmocr.registry import MODELS
from mmocr.utils import list_from_file from mmocr.utils import list_from_file
@CONVERTORS.register_module() @MODELS.register_module()
class BaseConvertor: class BaseConvertor:
"""Convert between text, index and tensor for text recognize pipeline. """Convert between text, index and tensor for text recognize pipeline.

View File

@ -5,11 +5,11 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import mmocr.utils as utils import mmocr.utils as utils
from mmocr.models.builder import CONVERTORS from mmocr.registry import MODELS
from .base import BaseConvertor from .base import BaseConvertor
@CONVERTORS.register_module() @MODELS.register_module()
class CTCConvertor(BaseConvertor): class CTCConvertor(BaseConvertor):
"""Convert between text, index and tensor for CTC loss-based pipeline. """Convert between text, index and tensor for CTC loss-based pipeline.

View File

@ -4,11 +4,11 @@ import numpy as np
import torch import torch
import mmocr.utils as utils import mmocr.utils as utils
from mmocr.models.builder import CONVERTORS from mmocr.registry import MODELS
from .base import BaseConvertor from .base import BaseConvertor
@CONVERTORS.register_module() @MODELS.register_module()
class SegConvertor(BaseConvertor): class SegConvertor(BaseConvertor):
"""Convert between text, index and tensor for segmentation based pipeline. """Convert between text, index and tensor for segmentation based pipeline.

View File

@ -6,12 +6,12 @@ import torch.nn as nn
from mmcv.cnn.bricks.transformer import BaseTransformerLayer from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from mmcv.runner import ModuleList from mmcv.runner import ModuleList
from mmocr.models.builder import DECODERS
from mmocr.models.common.modules import PositionalEncoding from mmocr.models.common.modules import PositionalEncoding
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder from .base_decoder import BaseDecoder
@DECODERS.register_module() @MODELS.register_module()
class ABILanguageDecoder(BaseDecoder): class ABILanguageDecoder(BaseDecoder):
r"""Transformer-based language model responsible for spell correction. r"""Transformer-based language model responsible for spell correction.
Implementation of language model of \ Implementation of language model of \

View File

@ -3,12 +3,12 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmocr.models.builder import DECODERS
from mmocr.models.common.modules import PositionalEncoding from mmocr.models.common.modules import PositionalEncoding
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder from .base_decoder import BaseDecoder
@DECODERS.register_module() @MODELS.register_module()
class ABIVisionDecoder(BaseDecoder): class ABIVisionDecoder(BaseDecoder):
"""Converts visual features into text characters. """Converts visual features into text characters.

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmocr.models.builder import DECODERS from mmocr.registry import MODELS
@DECODERS.register_module() @MODELS.register_module()
class BaseDecoder(BaseModule): class BaseDecoder(BaseModule):
"""Base decoder class for text recognition.""" """Base decoder class for text recognition."""

View File

@ -2,12 +2,12 @@
import torch.nn as nn import torch.nn as nn
from mmcv.runner import Sequential from mmcv.runner import Sequential
from mmocr.models.builder import DECODERS
from mmocr.models.textrecog.layers import BidirectionalLSTM from mmocr.models.textrecog.layers import BidirectionalLSTM
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder from .base_decoder import BaseDecoder
@DECODERS.register_module() @MODELS.register_module()
class CRNNDecoder(BaseDecoder): class CRNNDecoder(BaseDecoder):
"""Decoder for CRNN. """Decoder for CRNN.

View File

@ -8,8 +8,8 @@ import torch.nn.functional as F
from mmcv.cnn.bricks.transformer import BaseTransformerLayer from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from mmcv.runner import ModuleList from mmcv.runner import ModuleList
from mmocr.models.builder import DECODERS
from mmocr.models.common.modules import PositionalEncoding from mmocr.models.common.modules import PositionalEncoding
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder from .base_decoder import BaseDecoder
@ -30,7 +30,7 @@ class Embeddings(nn.Module):
return self.lut(x) * math.sqrt(self.d_model) return self.lut(x) * math.sqrt(self.d_model)
@DECODERS.register_module() @MODELS.register_module()
class MasterDecoder(BaseDecoder): class MasterDecoder(BaseDecoder):
"""Decoder module in `MASTER <https://arxiv.org/abs/1910.02562>`_. """Decoder module in `MASTER <https://arxiv.org/abs/1910.02562>`_.

View File

@ -6,12 +6,12 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.runner import ModuleList from mmcv.runner import ModuleList
from mmocr.models.builder import DECODERS
from mmocr.models.common import PositionalEncoding, TFDecoderLayer from mmocr.models.common import PositionalEncoding, TFDecoderLayer
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder from .base_decoder import BaseDecoder
@DECODERS.register_module() @MODELS.register_module()
class NRTRDecoder(BaseDecoder): class NRTRDecoder(BaseDecoder):
"""Transformer Decoder block with self attention mechanism. """Transformer Decoder block with self attention mechanism.

View File

@ -4,13 +4,13 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmocr.models.builder import DECODERS
from mmocr.models.textrecog.layers import (DotProductAttentionLayer, from mmocr.models.textrecog.layers import (DotProductAttentionLayer,
PositionAwareLayer) PositionAwareLayer)
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder from .base_decoder import BaseDecoder
@DECODERS.register_module() @MODELS.register_module()
class PositionAttentionDecoder(BaseDecoder): class PositionAttentionDecoder(BaseDecoder):
"""Position attention decoder for RobustScanner. """Position attention decoder for RobustScanner.

View File

@ -3,12 +3,12 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmocr.models.builder import DECODERS, build_decoder
from mmocr.models.textrecog.layers import RobustScannerFusionLayer from mmocr.models.textrecog.layers import RobustScannerFusionLayer
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder from .base_decoder import BaseDecoder
@DECODERS.register_module() @MODELS.register_module()
class RobustScannerDecoder(BaseDecoder): class RobustScannerDecoder(BaseDecoder):
"""Decoder for RobustScanner. """Decoder for RobustScanner.
@ -72,7 +72,7 @@ class RobustScannerDecoder(BaseDecoder):
hybrid_decoder.update(encode_value=self.encode_value) hybrid_decoder.update(encode_value=self.encode_value)
hybrid_decoder.update(return_feature=True) hybrid_decoder.update(return_feature=True)
self.hybrid_decoder = build_decoder(hybrid_decoder) self.hybrid_decoder = MODELS.build(hybrid_decoder)
# init position decoder # init position decoder
position_decoder.update(num_classes=self.num_classes) position_decoder.update(num_classes=self.num_classes)
@ -83,7 +83,7 @@ class RobustScannerDecoder(BaseDecoder):
position_decoder.update(encode_value=self.encode_value) position_decoder.update(encode_value=self.encode_value)
position_decoder.update(return_feature=True) position_decoder.update(return_feature=True)
self.position_decoder = build_decoder(position_decoder) self.position_decoder = MODELS.build(position_decoder)
self.fusion_module = RobustScannerFusionLayer( self.fusion_module = RobustScannerFusionLayer(
self.dim_model if encode_value else dim_input) self.dim_model if encode_value else dim_input)

View File

@ -6,11 +6,11 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import mmocr.utils as utils import mmocr.utils as utils
from mmocr.models.builder import DECODERS from mmocr.registry import MODELS
from .base_decoder import BaseDecoder from .base_decoder import BaseDecoder
@DECODERS.register_module() @MODELS.register_module()
class ParallelSARDecoder(BaseDecoder): class ParallelSARDecoder(BaseDecoder):
"""Implementation Parallel Decoder module in `SAR. """Implementation Parallel Decoder module in `SAR.
@ -255,7 +255,7 @@ class ParallelSARDecoder(BaseDecoder):
return outputs return outputs
@DECODERS.register_module() @MODELS.register_module()
class SequentialSARDecoder(BaseDecoder): class SequentialSARDecoder(BaseDecoder):
"""Implementation Sequential Decoder module in `SAR. """Implementation Sequential Decoder module in `SAR.

View File

@ -5,7 +5,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import mmocr.utils as utils import mmocr.utils as utils
from mmocr.models.builder import DECODERS from mmocr.registry import MODELS
from . import ParallelSARDecoder from . import ParallelSARDecoder
@ -31,7 +31,7 @@ class DecodeNode:
return accu_score return accu_score
@DECODERS.register_module() @MODELS.register_module()
class ParallelSARDecoderWithBS(ParallelSARDecoder): class ParallelSARDecoderWithBS(ParallelSARDecoder):
"""Parallel Decoder module with beam-search in SAR. """Parallel Decoder module with beam-search in SAR.

View File

@ -5,12 +5,12 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmocr.models.builder import DECODERS
from mmocr.models.textrecog.layers import DotProductAttentionLayer from mmocr.models.textrecog.layers import DotProductAttentionLayer
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder from .base_decoder import BaseDecoder
@DECODERS.register_module() @MODELS.register_module()
class SequenceAttentionDecoder(BaseDecoder): class SequenceAttentionDecoder(BaseDecoder):
"""Sequence attention decoder for RobustScanner. """Sequence attention decoder for RobustScanner.

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import ENCODERS, build_decoder, build_encoder from mmocr.registry import MODELS
from .base_encoder import BaseEncoder from .base_encoder import BaseEncoder
@ENCODERS.register_module() @MODELS.register_module()
class ABIVisionModel(BaseEncoder): class ABIVisionModel(BaseEncoder):
"""A wrapper of visual feature encoder and language token decoder that """A wrapper of visual feature encoder and language token decoder that
converts visual features into text tokens. converts visual features into text tokens.
@ -23,8 +23,8 @@ class ABIVisionModel(BaseEncoder):
init_cfg=dict(type='Xavier', layer='Conv2d'), init_cfg=dict(type='Xavier', layer='Conv2d'),
**kwargs): **kwargs):
super().__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.encoder = build_encoder(encoder) self.encoder = MODELS.build(encoder)
self.decoder = build_decoder(decoder) self.decoder = MODELS.build(decoder)
def forward(self, feat, img_metas=None): def forward(self, feat, img_metas=None):
""" """

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmocr.models.builder import ENCODERS from mmocr.registry import MODELS
@ENCODERS.register_module() @MODELS.register_module()
class BaseEncoder(BaseModule): class BaseEncoder(BaseModule):
"""Base Encoder class for text recognition.""" """Base Encoder class for text recognition."""

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn import torch.nn as nn
from mmocr.models.builder import ENCODERS from mmocr.registry import MODELS
from .base_encoder import BaseEncoder from .base_encoder import BaseEncoder
@ENCODERS.register_module() @MODELS.register_module()
class ChannelReductionEncoder(BaseEncoder): class ChannelReductionEncoder(BaseEncoder):
"""Change the channel number with a one by one convoluational layer. """Change the channel number with a one by one convoluational layer.

View File

@ -4,12 +4,12 @@ import math
import torch.nn as nn import torch.nn as nn
from mmcv.runner import ModuleList from mmcv.runner import ModuleList
from mmocr.models.builder import ENCODERS
from mmocr.models.common import TFEncoderLayer from mmocr.models.common import TFEncoderLayer
from mmocr.registry import MODELS
from .base_encoder import BaseEncoder from .base_encoder import BaseEncoder
@ENCODERS.register_module() @MODELS.register_module()
class NRTREncoder(BaseEncoder): class NRTREncoder(BaseEncoder):
"""Transformer Encoder block with self attention mechanism. """Transformer Encoder block with self attention mechanism.

View File

@ -6,11 +6,11 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import mmocr.utils as utils import mmocr.utils as utils
from mmocr.models.builder import ENCODERS from mmocr.registry import MODELS
from .base_encoder import BaseEncoder from .base_encoder import BaseEncoder
@ENCODERS.register_module() @MODELS.register_module()
class SAREncoder(BaseEncoder): class SAREncoder(BaseEncoder):
"""Implementation of encoder module in `SAR. """Implementation of encoder module in `SAR.

View File

@ -4,13 +4,13 @@ import math
import torch.nn as nn import torch.nn as nn
from mmcv.runner import ModuleList from mmcv.runner import ModuleList
from mmocr.models.builder import ENCODERS
from mmocr.models.textrecog.layers import (Adaptive2DPositionalEncoding, from mmocr.models.textrecog.layers import (Adaptive2DPositionalEncoding,
SatrnEncoderLayer) SatrnEncoderLayer)
from mmocr.registry import MODELS
from .base_encoder import BaseEncoder from .base_encoder import BaseEncoder
@ENCODERS.register_module() @MODELS.register_module()
class SatrnEncoder(BaseEncoder): class SatrnEncoder(BaseEncoder):
"""Implement encoder for SATRN, see `SATRN. """Implement encoder for SATRN, see `SATRN.

View File

@ -4,11 +4,11 @@ import copy
from mmcv.cnn.bricks.transformer import BaseTransformerLayer from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from mmcv.runner import BaseModule, ModuleList from mmcv.runner import BaseModule, ModuleList
from mmocr.models.builder import ENCODERS
from mmocr.models.common.modules import PositionalEncoding from mmocr.models.common.modules import PositionalEncoding
from mmocr.registry import MODELS
@ENCODERS.register_module() @MODELS.register_module()
class TransformerEncoder(BaseModule): class TransformerEncoder(BaseModule):
"""Implement transformer encoder for text recognition, modified from """Implement transformer encoder for text recognition, modified from
`<https://github.com/FangShancheng/ABINet>`. `<https://github.com/FangShancheng/ABINet>`.

View File

@ -3,10 +3,10 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmocr.models.builder import FUSERS from mmocr.registry import MODELS
@FUSERS.register_module() @MODELS.register_module()
class ABIFuser(BaseModule): class ABIFuser(BaseModule):
"""Mix and align visual feature and linguistic feature Implementation of """Mix and align visual feature and linguistic feature Implementation of
language model of `ABINet <https://arxiv.org/abs/1910.04396>`_. language model of `ABINet <https://arxiv.org/abs/1910.04396>`_.

View File

@ -4,10 +4,10 @@ from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch import nn from torch import nn
from mmocr.models.builder import HEADS from mmocr.registry import MODELS
@HEADS.register_module() @MODELS.register_module()
class SegHead(BaseModule): class SegHead(BaseModule):
"""Head for segmentation based text recognition. """Head for segmentation based text recognition.

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn import torch.nn as nn
from mmocr.models.builder import LOSSES from mmocr.registry import MODELS
@LOSSES.register_module() @MODELS.register_module()
class CELoss(nn.Module): class CELoss(nn.Module):
"""Implementation of loss module for encoder-decoder based text recognition """Implementation of loss module for encoder-decoder based text recognition
method with CrossEntropy loss. method with CrossEntropy loss.
@ -63,7 +63,7 @@ class CELoss(nn.Module):
return losses return losses
@LOSSES.register_module() @MODELS.register_module()
class SARLoss(CELoss): class SARLoss(CELoss):
"""Implementation of loss module in `SAR. """Implementation of loss module in `SAR.
@ -95,7 +95,7 @@ class SARLoss(CELoss):
return outputs, targets return outputs, targets
@LOSSES.register_module() @MODELS.register_module()
class TFLoss(CELoss): class TFLoss(CELoss):
"""Implementation of loss module for transformer. """Implementation of loss module for transformer.

View File

@ -4,10 +4,10 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmocr.models.builder import LOSSES from mmocr.registry import MODELS
@LOSSES.register_module() @MODELS.register_module()
class CTCLoss(nn.Module): class CTCLoss(nn.Module):
"""Implementation of loss module for CTC-loss based text recognition. """Implementation of loss module for CTC-loss based text recognition.

View File

@ -3,10 +3,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmocr.models.builder import LOSSES from mmocr.registry import MODELS
@LOSSES.register_module() @MODELS.register_module()
class ABILoss(nn.Module): class ABILoss(nn.Module):
"""Implementation of ABINet multiloss that allows mixing different types of """Implementation of ABINet multiloss that allows mixing different types of
losses with weights. losses with weights.

View File

@ -3,10 +3,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmocr.models.builder import LOSSES from mmocr.registry import MODELS
@LOSSES.register_module() @MODELS.register_module()
class SegLoss(nn.Module): class SegLoss(nn.Module):
"""Implementation of loss module for segmentation based text recognition """Implementation of loss module for segmentation based text recognition
method. method.

View File

@ -4,10 +4,10 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, ModuleList from mmcv.runner import BaseModule, ModuleList
from mmocr.models.builder import NECKS from mmocr.registry import MODELS
@NECKS.register_module() @MODELS.register_module()
class FPNOCR(BaseModule): class FPNOCR(BaseModule):
"""FPN-like Network for segmentation based text recognition. """FPN-like Network for segmentation based text recognition.

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmocr.models.builder import PREPROCESSOR from mmocr.registry import MODELS
@PREPROCESSOR.register_module() @MODELS.register_module()
class BasePreprocessor(BaseModule): class BasePreprocessor(BaseModule):
"""Base Preprocessor class for text recognition.""" """Base Preprocessor class for text recognition."""

View File

@ -17,11 +17,11 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmocr.models.builder import PREPROCESSOR from mmocr.registry import MODELS
from .base_preprocessor import BasePreprocessor from .base_preprocessor import BasePreprocessor
@PREPROCESSOR.register_module() @MODELS.register_module()
class TPSPreprocessor(BasePreprocessor): class TPSPreprocessor(BasePreprocessor):
"""Rectification Network of RARE, namely TPS based STN in """Rectification Network of RARE, namely TPS based STN in
https://arxiv.org/pdf/1603.03915.pdf. https://arxiv.org/pdf/1603.03915.pdf.

View File

@ -3,13 +3,11 @@ import warnings
import torch import torch
from mmocr.models.builder import (RECOGNIZERS, build_backbone, build_convertor, from mmocr.registry import MODELS
build_decoder, build_encoder, build_fuser,
build_loss, build_preprocessor)
from .encode_decode_recognizer import EncodeDecodeRecognizer from .encode_decode_recognizer import EncodeDecodeRecognizer
@RECOGNIZERS.register_module() @MODELS.register_module()
class ABINet(EncodeDecodeRecognizer): class ABINet(EncodeDecodeRecognizer):
"""Implementation of `Read Like Humans: Autonomous, Bidirectional and """Implementation of `Read Like Humans: Autonomous, Bidirectional and
Iterative LanguageModeling for Scene Text Recognition. Iterative LanguageModeling for Scene Text Recognition.
@ -36,21 +34,21 @@ class ABINet(EncodeDecodeRecognizer):
# Label convertor (str2tensor, tensor2str) # Label convertor (str2tensor, tensor2str)
assert label_convertor is not None assert label_convertor is not None
label_convertor.update(max_seq_len=max_seq_len) label_convertor.update(max_seq_len=max_seq_len)
self.label_convertor = build_convertor(label_convertor) self.label_convertor = MODELS.build(label_convertor)
# Preprocessor module, e.g., TPS # Preprocessor module, e.g., TPS
self.preprocessor = None self.preprocessor = None
if preprocessor is not None: if preprocessor is not None:
self.preprocessor = build_preprocessor(preprocessor) self.preprocessor = MODELS.build(preprocessor)
# Backbone # Backbone
assert backbone is not None assert backbone is not None
self.backbone = build_backbone(backbone) self.backbone = MODELS.build(backbone)
# Encoder module # Encoder module
self.encoder = None self.encoder = None
if encoder is not None: if encoder is not None:
self.encoder = build_encoder(encoder) self.encoder = MODELS.build(encoder)
# Decoder module # Decoder module
self.decoder = None self.decoder = None
@ -59,11 +57,11 @@ class ABINet(EncodeDecodeRecognizer):
decoder.update(start_idx=self.label_convertor.start_idx) decoder.update(start_idx=self.label_convertor.start_idx)
decoder.update(padding_idx=self.label_convertor.padding_idx) decoder.update(padding_idx=self.label_convertor.padding_idx)
decoder.update(max_seq_len=max_seq_len) decoder.update(max_seq_len=max_seq_len)
self.decoder = build_decoder(decoder) self.decoder = MODELS.build(decoder)
# Loss # Loss
assert loss is not None assert loss is not None
self.loss = build_loss(loss) self.loss = MODELS.build(loss)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
@ -78,7 +76,7 @@ class ABINet(EncodeDecodeRecognizer):
self.fuser = None self.fuser = None
if fuser is not None: if fuser is not None:
self.fuser = build_fuser(fuser) self.fuser = MODELS.build(fuser)
def forward_train(self, img, img_metas): def forward_train(self, img, img_metas):
""" """

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import RECOGNIZERS from mmocr.registry import MODELS
from .encode_decode_recognizer import EncodeDecodeRecognizer from .encode_decode_recognizer import EncodeDecodeRecognizer
@RECOGNIZERS.register_module() @MODELS.register_module()
class CRNNNet(EncodeDecodeRecognizer): class CRNNNet(EncodeDecodeRecognizer):
"""CTC-loss based recognizer.""" """CTC-loss based recognizer."""

View File

@ -3,13 +3,11 @@ import warnings
import torch import torch
from mmocr.models.builder import (RECOGNIZERS, build_backbone, build_convertor, from mmocr.registry import MODELS
build_decoder, build_encoder, build_loss,
build_preprocessor)
from .base import BaseRecognizer from .base import BaseRecognizer
@RECOGNIZERS.register_module() @MODELS.register_module()
class EncodeDecodeRecognizer(BaseRecognizer): class EncodeDecodeRecognizer(BaseRecognizer):
"""Base class for encode-decode recognizer.""" """Base class for encode-decode recognizer."""
@ -31,21 +29,21 @@ class EncodeDecodeRecognizer(BaseRecognizer):
# Label convertor (str2tensor, tensor2str) # Label convertor (str2tensor, tensor2str)
assert label_convertor is not None assert label_convertor is not None
label_convertor.update(max_seq_len=max_seq_len) label_convertor.update(max_seq_len=max_seq_len)
self.label_convertor = build_convertor(label_convertor) self.label_convertor = MODELS.build(label_convertor)
# Preprocessor module, e.g., TPS # Preprocessor module, e.g., TPS
self.preprocessor = None self.preprocessor = None
if preprocessor is not None: if preprocessor is not None:
self.preprocessor = build_preprocessor(preprocessor) self.preprocessor = MODELS.build(preprocessor)
# Backbone # Backbone
assert backbone is not None assert backbone is not None
self.backbone = build_backbone(backbone) self.backbone = MODELS.build(backbone)
# Encoder module # Encoder module
self.encoder = None self.encoder = None
if encoder is not None: if encoder is not None:
self.encoder = build_encoder(encoder) self.encoder = MODELS.build(encoder)
# Decoder module # Decoder module
assert decoder is not None assert decoder is not None
@ -53,12 +51,12 @@ class EncodeDecodeRecognizer(BaseRecognizer):
decoder.update(start_idx=self.label_convertor.start_idx) decoder.update(start_idx=self.label_convertor.start_idx)
decoder.update(padding_idx=self.label_convertor.padding_idx) decoder.update(padding_idx=self.label_convertor.padding_idx)
decoder.update(max_seq_len=max_seq_len) decoder.update(max_seq_len=max_seq_len)
self.decoder = build_decoder(decoder) self.decoder = MODELS.build(decoder)
# Loss # Loss
assert loss is not None assert loss is not None
loss.update(ignore_index=self.label_convertor.padding_idx) loss.update(ignore_index=self.label_convertor.padding_idx)
self.loss = build_loss(loss) self.loss = MODELS.build(loss)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import DETECTORS from mmocr.registry import MODELS
from .encode_decode_recognizer import EncodeDecodeRecognizer from .encode_decode_recognizer import EncodeDecodeRecognizer
@DETECTORS.register_module() @MODELS.register_module()
class MASTER(EncodeDecodeRecognizer): class MASTER(EncodeDecodeRecognizer):
"""Implementation of `MASTER <https://arxiv.org/abs/1910.02562>`_""" """Implementation of `MASTER <https://arxiv.org/abs/1910.02562>`_"""

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import RECOGNIZERS from mmocr.registry import MODELS
from .encode_decode_recognizer import EncodeDecodeRecognizer from .encode_decode_recognizer import EncodeDecodeRecognizer
@RECOGNIZERS.register_module() @MODELS.register_module()
class NRTR(EncodeDecodeRecognizer): class NRTR(EncodeDecodeRecognizer):
"""Implementation of `NRTR <https://arxiv.org/pdf/1806.00926.pdf>`_""" """Implementation of `NRTR <https://arxiv.org/pdf/1806.00926.pdf>`_"""

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import RECOGNIZERS from mmocr.registry import MODELS
from .encode_decode_recognizer import EncodeDecodeRecognizer from .encode_decode_recognizer import EncodeDecodeRecognizer
@RECOGNIZERS.register_module() @MODELS.register_module()
class RobustScanner(EncodeDecodeRecognizer): class RobustScanner(EncodeDecodeRecognizer):
"""Implementation of `RobustScanner. """Implementation of `RobustScanner.

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import RECOGNIZERS from mmocr.registry import MODELS
from .encode_decode_recognizer import EncodeDecodeRecognizer from .encode_decode_recognizer import EncodeDecodeRecognizer
@RECOGNIZERS.register_module() @MODELS.register_module()
class SARNet(EncodeDecodeRecognizer): class SARNet(EncodeDecodeRecognizer):
"""Implementation of `SAR <https://arxiv.org/abs/1811.00751>`_""" """Implementation of `SAR <https://arxiv.org/abs/1811.00751>`_"""

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import RECOGNIZERS from mmocr.registry import MODELS
from .encode_decode_recognizer import EncodeDecodeRecognizer from .encode_decode_recognizer import EncodeDecodeRecognizer
@RECOGNIZERS.register_module() @MODELS.register_module()
class SATRN(EncodeDecodeRecognizer): class SATRN(EncodeDecodeRecognizer):
"""Implementation of `SATRN <https://arxiv.org/abs/1910.04396>`_""" """Implementation of `SATRN <https://arxiv.org/abs/1910.04396>`_"""

View File

@ -1,13 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from mmocr.models.builder import (RECOGNIZERS, build_backbone, build_convertor, from mmocr.registry import MODELS
build_head, build_loss, build_neck,
build_preprocessor)
from .base import BaseRecognizer from .base import BaseRecognizer
@RECOGNIZERS.register_module() @MODELS.register_module()
class SegRecognizer(BaseRecognizer): class SegRecognizer(BaseRecognizer):
"""Base class for segmentation based recognizer.""" """Base class for segmentation based recognizer."""
@ -26,29 +24,29 @@ class SegRecognizer(BaseRecognizer):
# Label_convertor # Label_convertor
assert label_convertor is not None assert label_convertor is not None
self.label_convertor = build_convertor(label_convertor) self.label_convertor = MODELS.build(label_convertor)
# Preprocessor module, e.g., TPS # Preprocessor module, e.g., TPS
self.preprocessor = None self.preprocessor = None
if preprocessor is not None: if preprocessor is not None:
self.preprocessor = build_preprocessor(preprocessor) self.preprocessor = MODELS.build(preprocessor)
# Backbone # Backbone
assert backbone is not None assert backbone is not None
self.backbone = build_backbone(backbone) self.backbone = MODELS.build(backbone)
# Neck # Neck
assert neck is not None assert neck is not None
self.neck = build_neck(neck) self.neck = MODELS.build(neck)
# Head # Head
assert head is not None assert head is not None
head.update(num_classes=self.label_convertor.num_classes()) head.update(num_classes=self.label_convertor.num_classes())
self.head = build_head(head) self.head = MODELS.build(head)
# Loss # Loss
assert loss is not None assert loss is not None
self.loss = build_loss(loss) self.loss = MODELS.build(loss)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg

View File

@ -24,9 +24,9 @@ from mmocr.apis.inference import model_inference
from mmocr.core.visualize import det_recog_show_result from mmocr.core.visualize import det_recog_show_result
from mmocr.datasets.kie_dataset import KIEDataset from mmocr.datasets.kie_dataset import KIEDataset
from mmocr.datasets.pipelines.crop import crop_img from mmocr.datasets.pipelines.crop import crop_img
from mmocr.models import build_detector
from mmocr.models.textdet.detectors import TextDetectorMixin from mmocr.models.textdet.detectors import TextDetectorMixin
from mmocr.models.textrecog.recognizer import BaseRecognizer from mmocr.models.textrecog.recognizer import BaseRecognizer
from mmocr.registry import MODELS
from mmocr.utils import is_type_list from mmocr.utils import is_type_list
from mmocr.utils.box_util import stitch_boxes_into_lines from mmocr.utils.box_util import stitch_boxes_into_lines
from mmocr.utils.fileio import list_from_file from mmocr.utils.fileio import list_from_file
@ -427,7 +427,7 @@ class MMOCR:
'kie/' + kie_models[self.kie]['ckpt'] 'kie/' + kie_models[self.kie]['ckpt']
kie_cfg = Config.fromfile(kie_config) kie_cfg = Config.fromfile(kie_config)
self.kie_model = build_detector( self.kie_model = MODELS.build(
kie_cfg.model, test_cfg=kie_cfg.get('test_cfg')) kie_cfg.model, test_cfg=kie_cfg.get('test_cfg'))
self.kie_model = revert_sync_batchnorm(self.kie_model) self.kie_model = revert_sync_batchnorm(self.kie_model)
self.kie_model.cfg = kie_cfg self.kie_model.cfg = kie_cfg