[Refactor] union to MODELS

pull/1178/head
liukuikun 2022-05-12 03:01:34 +00:00 committed by gaotongxiao
parent 3f24e34a5d
commit 23458f8a47
99 changed files with 240 additions and 382 deletions
mmocr
core/deployment
models
common
backbones
kie
utils

View File

@ -3,7 +3,6 @@ from argparse import ArgumentParser
from mmocr.apis import init_detector
from mmocr.apis.inference import text_model_inference
from mmocr.models import build_detector # NOQA
from mmocr.registry import DATASETS # NOQA

View File

@ -5,7 +5,6 @@ import cv2
import torch
from mmocr.apis import init_detector, model_inference
from mmocr.models import build_detector # noqa: F401
from mmocr.registry import DATASETS # noqa: F401

View File

@ -11,7 +11,7 @@ from mmdet.core import get_classes
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
from mmocr.models import build_detector
from mmocr.registry import MODELS
from mmocr.utils import is_2dlist
from .utils import disable_text_recog_aug_test
@ -40,7 +40,7 @@ def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
if config.model.get('pretrained'):
config.model.pretrained = None
config.model.train_cfg = None
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
model = MODELS.build(config.model, test_cfg=config.get('test_cfg'))
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if 'CLASSES' in checkpoint.get('meta', {}):

View File

@ -5,7 +5,6 @@ from typing import Any, Iterable
import numpy as np
import torch
from mmdet.models.builder import DETECTORS
from mmocr.models.textdet.detectors.single_stage_text_detector import \
SingleStageTextDetector
@ -13,6 +12,7 @@ from mmocr.models.textdet.detectors.text_detector_mixin import \
TextDetectorMixin
from mmocr.models.textrecog.recognizer.encode_decode_recognizer import \
EncodeDecodeRecognizer
from mmocr.registry import MODELS
def inference_with_session(sess, io_binding, input_name, output_names,
@ -34,7 +34,7 @@ def inference_with_session(sess, io_binding, input_name, output_names,
return pred
@DETECTORS.register_module()
@MODELS.register_module()
class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector):
"""The class for evaluating onnx file of detection."""
@ -110,7 +110,7 @@ class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector):
return boundaries
@DETECTORS.register_module()
@MODELS.register_module()
class ONNXRuntimeRecognizer(EncodeDecodeRecognizer):
"""The class for evaluating onnx file of recognition."""
@ -201,7 +201,7 @@ class ONNXRuntimeRecognizer(EncodeDecodeRecognizer):
return results
@DETECTORS.register_module()
@MODELS.register_module()
class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector):
"""The class for evaluating TensorRT file of detection."""
@ -257,7 +257,7 @@ class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector):
return boundaries
@DETECTORS.register_module()
@MODELS.register_module()
class TensorRTRecognizer(EncodeDecodeRecognizer):
"""The class for evaluating TensorRT file of recognition."""

View File

@ -1,8 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmocr.models.builder import build_convertor
from mmocr.registry import TRANSFORMS
from mmocr.registry import MODELS, TRANSFORMS
@TRANSFORMS.register_module()
@ -18,7 +17,7 @@ class NerTransform:
"""
def __init__(self, label_convertor, max_len):
self.label_convertor = build_convertor(label_convertor)
self.label_convertor = MODELS.build(label_convertor)
self.max_len = max_len
def __call__(self, results):

View File

@ -4,8 +4,7 @@ import numpy as np
from mmdet.core import BitmapMasks
import mmocr.utils.check_argument as check_argument
from mmocr.models.builder import build_convertor
from mmocr.registry import TRANSFORMS
from mmocr.registry import MODELS, TRANSFORMS
@TRANSFORMS.register_module()
@ -41,7 +40,7 @@ class OCRSegTargets:
self.attn_shrink_ratio = attn_shrink_ratio
self.seg_shrink_ratio = seg_shrink_ratio
self.label_convertor = build_convertor(label_convertor)
self.label_convertor = MODELS.build(label_convertor)
self.box_type = box_type
self.pad_val = pad_val

View File

@ -1,19 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import common, kie, textdet, textrecog
from .builder import (BACKBONES, CONVERTORS, DECODERS, DETECTORS, ENCODERS,
HEADS, LOSSES, NECKS, PREPROCESSOR, build_backbone,
build_convertor, build_decoder, build_detector,
build_encoder, build_loss, build_preprocessor)
from .common import * # NOQA
from .kie import * # NOQA
from .ner import * # NOQA
from .textdet import * # NOQA
from .textrecog import * # NOQA
__all__ = [
'BACKBONES', 'DETECTORS', 'HEADS', 'LOSSES', 'NECKS', 'build_backbone',
'build_detector', 'build_loss', 'CONVERTORS', 'ENCODERS', 'DECODERS',
'PREPROCESSOR', 'build_convertor', 'build_encoder', 'build_decoder',
'build_preprocessor'
]
__all__ += common.__all__ + kie.__all__ + textdet.__all__ + textrecog.__all__
__all__ = common.__all__ + kie.__all__ + textdet.__all__ + textrecog.__all__

View File

@ -1,115 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch.nn as nn
from mmcv.cnn import ACTIVATION_LAYERS as MMCV_ACTIVATION_LAYERS
from mmcv.cnn import UPSAMPLE_LAYERS as MMCV_UPSAMPLE_LAYERS
from mmcv.utils import Registry, build_from_cfg
from mmocr.registry import MODELS
CONVERTORS = MODELS
ENCODERS = MODELS
DECODERS = MODELS
PREPROCESSOR = MODELS
POSTPROCESSOR = MODELS
UPSAMPLE_LAYERS = MODELS
BACKBONES = MODELS
LOSSES = MODELS
DETECTORS = MODELS
ROI_EXTRACTORS = MODELS
HEADS = MODELS
NECKS = MODELS
FUSERS = MODELS
RECOGNIZERS = MODELS
ACTIVATION_LAYERS = MODELS
def build_recognizer(cfg, train_cfg=None, test_cfg=None):
"""Build recognizer."""
warnings.warn('``build_recognizer`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return RECOGNIZERS(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_convertor(cfg):
"""Build label convertor for scene text recognizer."""
warnings.warn('``build_convertor`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return CONVERTORS.build(cfg)
def build_encoder(cfg):
"""Build encoder for scene text recognizer."""
warnings.warn('``build_encoder`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return ENCODERS.build(cfg)
def build_decoder(cfg):
"""Build decoder for scene text recognizer."""
warnings.warn('``build_decoder`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return DECODERS.build(cfg)
def build_preprocessor(cfg):
"""Build preprocessor for scene text recognizer."""
warnings.warn(
'``build_preprocessor`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return PREPROCESSOR(cfg)
def build_postprocessor(cfg):
"""Build postprocessor for scene text detector."""
warnings.warn(
'``build_postprocessor`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return POSTPROCESSOR.build(cfg)
def build_roi_extractor(cfg):
"""Build roi extractor."""
warnings.warn(
'``build_roi_extractor`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return ROI_EXTRACTORS.build(cfg)
def build_loss(cfg):
"""Build loss."""
warnings.warn('``build_loss`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return LOSSES.build(cfg)
def build_backbone(cfg):
"""Build backbone."""
warnings.warn('``build_backbone`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return BACKBONES.build(cfg)
def build_head(cfg):
"""Build head."""
warnings.warn('``build_head`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return HEADS.build(cfg)
def build_neck(cfg):
"""Build neck."""
warnings.warn('``build_neck`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return NECKS.build(cfg)
def build_fuser(cfg):
"""Build fuser."""
warnings.warn('``build_fuser`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return FUSERS.build(cfg)
UPSAMPLE_LAYERS = Registry('upsample layer', parent=MMCV_UPSAMPLE_LAYERS)
ACTIVATION_LAYERS = Registry('activation layer', parent=MMCV_ACTIVATION_LAYERS)
def build_upsample_layer(cfg, *args, **kwargs):
@ -160,21 +56,4 @@ def build_activation_layer(cfg):
Returns:
nn.Module: Created activation layer.
"""
warnings.warn(
'``build_activation_layer`` would be deprecated soon, please use '
'``mmocr.registry.MODELS.build()`` ')
return ACTIVATION_LAYERS.build(cfg)
def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
if train_cfg is not None or test_cfg is not None:
warnings.warn(
'train_cfg and test_cfg is deprecated, '
'please specify them in model', UserWarning)
assert cfg.get('train_cfg') is None or train_cfg is None, \
'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field '
return DETECTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
return build_from_cfg(cfg, ACTIVATION_LAYERS)

View File

@ -6,8 +6,9 @@ from mmcv.cnn import ConvModule, build_norm_layer
from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmocr.models.builder import (BACKBONES, UPSAMPLE_LAYERS,
build_activation_layer, build_upsample_layer)
from mmocr.models.builder import (UPSAMPLE_LAYERS, build_activation_layer,
build_upsample_layer)
from mmocr.registry import MODELS
class UpConvBlock(nn.Module):
@ -317,7 +318,7 @@ class InterpConv(nn.Module):
return out
@BACKBONES.register_module()
@MODELS.register_module()
class UNet(BaseModule):
"""UNet backbone.
U-Net: Convolutional Networks for Biomedical Image Segmentation.

View File

@ -4,11 +4,10 @@ import warnings
from mmdet.models.detectors import \
SingleStageDetector as MMDET_SingleStageDetector
from mmocr.models.builder import (DETECTORS, build_backbone, build_head,
build_neck)
from mmocr.registry import MODELS
@DETECTORS.register_module()
@MODELS.register_module()
class SingleStageDetector(MMDET_SingleStageDetector):
"""Base class for single-stage detectors.
@ -29,11 +28,11 @@ class SingleStageDetector(MMDET_SingleStageDetector):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
backbone.pretrained = pretrained
self.backbone = build_backbone(backbone)
self.backbone = MODELS.build(backbone)
if neck is not None:
self.neck = build_neck(neck)
self.neck = MODELS.build(neck)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.bbox_head = MODELS.build(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg

View File

@ -2,10 +2,10 @@
import torch
import torch.nn as nn
from mmocr.models.builder import LOSSES
from mmocr.registry import MODELS
@LOSSES.register_module()
@MODELS.register_module()
class DiceLoss(nn.Module):
def __init__(self, eps=1e-6):

View File

@ -7,12 +7,12 @@ from torch import nn
from torch.nn import functional as F
from mmocr.core import imshow_edge, imshow_node
from mmocr.models.builder import DETECTORS, build_roi_extractor
from mmocr.models.common.detectors import SingleStageDetector
from mmocr.registry import MODELS
from mmocr.utils import list_from_file
@DETECTORS.register_module()
@MODELS.register_module()
class SDMGR(SingleStageDetector):
"""The implementation of the paper: Spatial Dual-Modality Graph Reasoning
for Key Information Extraction. https://arxiv.org/abs/2103.14470.
@ -42,7 +42,7 @@ class SDMGR(SingleStageDetector):
backbone, neck, bbox_head, train_cfg, test_cfg, init_cfg=init_cfg)
self.visual_modality = visual_modality
if visual_modality:
self.extractor = build_roi_extractor({
self.extractor = MODELS.build({
**extractor, 'out_channels':
self.backbone.base_channels
})

View File

@ -4,10 +4,10 @@ from mmcv.runner import BaseModule
from torch import nn
from torch.nn import functional as F
from mmocr.models.builder import HEADS, build_loss
from mmocr.registry import MODELS
@HEADS.register_module()
@MODELS.register_module()
class SDMGRHead(BaseModule):
def __init__(self,
@ -45,7 +45,7 @@ class SDMGRHead(BaseModule):
[GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
self.node_cls = nn.Linear(node_embed, num_classes)
self.edge_cls = nn.Linear(edge_embed, 2)
self.loss = build_loss(loss)
self.loss = MODELS.build(loss)
def forward(self, relations, texts, x=None):
node_nums, char_nums = [], []

View File

@ -3,10 +3,10 @@ import torch
from mmdet.models.losses import accuracy
from torch import nn
from mmocr.models.builder import LOSSES
from mmocr.registry import MODELS
@LOSSES.register_module()
@MODELS.register_module()
class SDMGRLoss(nn.Module):
"""The implementation the loss of key information extraction proposed in
the paper: Spatial Dual-Modality Graph Reasoning for Key Information

View File

@ -1,10 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import (DETECTORS, build_convertor, build_decoder,
build_encoder, build_loss)
from mmocr.models.textrecog.recognizer.base import BaseRecognizer
from mmocr.registry import MODELS
@DETECTORS.register_module()
@MODELS.register_module()
class NerClassifier(BaseRecognizer):
"""Base class for NER classifier."""
@ -17,15 +16,15 @@ class NerClassifier(BaseRecognizer):
test_cfg=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.label_convertor = build_convertor(label_convertor)
self.label_convertor = MODELS.build(label_convertor)
self.encoder = build_encoder(encoder)
self.encoder = MODELS.build(encoder)
decoder.update(num_labels=self.label_convertor.num_labels)
self.decoder = build_decoder(decoder)
self.decoder = MODELS.build(decoder)
loss.update(num_labels=self.label_convertor.num_labels)
self.loss = build_loss(loss)
self.loss = MODELS.build(loss)
def extract_feat(self, imgs):
"""Extract features from images."""

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmocr.models.builder import CONVERTORS
from mmocr.registry import MODELS
from mmocr.utils import list_from_file
@CONVERTORS.register_module()
@MODELS.register_module()
class NerConvertor:
"""Convert between text, index and tensor for NER pipeline.

View File

@ -4,10 +4,10 @@ import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import BaseModule
from mmocr.models.builder import DECODERS
from mmocr.registry import MODELS
@DECODERS.register_module()
@MODELS.register_module()
class FCDecoder(BaseModule):
"""FC Decoder class for Ner.

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule
from mmocr.models.builder import ENCODERS
from mmocr.models.ner.utils.bert import BertModel
from mmocr.registry import MODELS
@ENCODERS.register_module()
@MODELS.register_module()
class BertEncoder(BaseModule):
"""Bert encoder
Args:

View File

@ -2,10 +2,10 @@
from torch import nn
from torch.nn import CrossEntropyLoss
from mmocr.models.builder import LOSSES
from mmocr.registry import MODELS
@LOSSES.register_module()
@MODELS.register_module()
class MaskedCrossEntropyLoss(nn.Module):
"""The implementation of masked cross entropy loss.

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch import nn
from mmocr.models.builder import LOSSES
from mmocr.models.common.losses.focal_loss import FocalLoss
from mmocr.registry import MODELS
@LOSSES.register_module()
@MODELS.register_module()
class MaskedFocalLoss(nn.Module):
"""The implementation of masked focal loss.

View File

@ -5,11 +5,11 @@ import torch
import torch.nn as nn
from mmcv.runner import BaseModule, Sequential
from mmocr.models.builder import HEADS
from mmocr.registry import MODELS
from .head_mixin import HeadMixin
@HEADS.register_module()
@MODELS.register_module()
class DBHead(HeadMixin, BaseModule):
"""The class for DBNet head.

View File

@ -7,13 +7,13 @@ import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import BaseModule
from mmocr.models.builder import HEADS, build_loss
from mmocr.models.textdet.modules import GCN, LocalGraphs, ProposalLocalGraphs
from mmocr.registry import MODELS
from mmocr.utils import check_argument
from .head_mixin import HeadMixin
@HEADS.register_module()
@MODELS.register_module()
class DRRGHead(HeadMixin, BaseModule):
"""The class for DRRG head: `Deep Relational Reasoning Graph Network for
Arbitrary Shape Text Detection <https://arxiv.org/abs/2003.07493>`_.
@ -118,7 +118,7 @@ class DRRGHead(HeadMixin, BaseModule):
self.center_region_thr = center_region_thr
self.center_region_area_thr = center_region_area_thr
self.local_graph_thr = local_graph_thr
self.loss_module = build_loss(loss)
self.loss_module = MODELS.build(loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg

View File

@ -5,12 +5,12 @@ import torch.nn as nn
from mmcv.runner import BaseModule
from mmdet.core import multi_apply
from mmocr.models.builder import HEADS
from mmocr.registry import MODELS
from ..postprocess.utils import poly_nms
from .head_mixin import HeadMixin
@HEADS.register_module()
@MODELS.register_module()
class FCEHead(HeadMixin, BaseModule):
"""The class for implementing FCENet head.

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmocr.models.builder import HEADS, build_loss, build_postprocessor
from mmocr.registry import MODELS
from mmocr.utils import check_argument
@HEADS.register_module()
@MODELS.register_module()
class HeadMixin:
"""Base head class for text detection, including loss calcalation and
postprocess.
@ -19,8 +19,8 @@ class HeadMixin:
assert isinstance(loss, dict)
assert isinstance(postprocessor, dict)
self.loss_module = build_loss(loss)
self.postprocessor = build_postprocessor(postprocessor)
self.loss_module = MODELS.build(loss)
self.postprocessor = MODELS.build(postprocessor)
def resize_boundary(self, boundaries, scale_factor):
"""Rescale boundaries via scale_factor.

View File

@ -6,12 +6,12 @@ import torch
import torch.nn as nn
from mmcv.runner import BaseModule
from mmocr.models.builder import HEADS
from mmocr.registry import MODELS
from mmocr.utils import check_argument
from .head_mixin import HeadMixin
@HEADS.register_module()
@MODELS.register_module()
class PANHead(HeadMixin, BaseModule):
"""The class for PANet head.

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import HEADS
from mmocr.registry import MODELS
from . import PANHead
@HEADS.register_module()
@MODELS.register_module()
class PSEHead(PANHead):
"""The class for PSENet head.

View File

@ -4,11 +4,11 @@ import warnings
import torch.nn as nn
from mmcv.runner import BaseModule
from mmocr.models.builder import HEADS
from mmocr.registry import MODELS
from .head_mixin import HeadMixin
@HEADS.register_module()
@MODELS.register_module()
class TextSnakeHead(HeadMixin, BaseModule):
"""The class for TextSnake head: TextSnake: A Flexible Representation for
Detecting Text of Arbitrary Shapes.

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import DETECTORS
from mmocr.registry import MODELS
from .single_stage_text_detector import SingleStageTextDetector
from .text_detector_mixin import TextDetectorMixin
@DETECTORS.register_module()
@MODELS.register_module()
class DBNet(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing DBNet text detector: Real-time Scene Text
Detection with Differentiable Binarization.

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import DETECTORS
from mmocr.registry import MODELS
from .single_stage_text_detector import SingleStageTextDetector
from .text_detector_mixin import TextDetectorMixin
@DETECTORS.register_module()
@MODELS.register_module()
class DRRG(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing DRRG text detector. Deep Relational Reasoning
Graph Network for Arbitrary Shape Text Detection.

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import DETECTORS
from mmocr.registry import MODELS
from .single_stage_text_detector import SingleStageTextDetector
from .text_detector_mixin import TextDetectorMixin
@DETECTORS.register_module()
@MODELS.register_module()
class FCENet(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing FCENet text detector
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text

View File

@ -2,11 +2,11 @@
from mmdet.models.detectors import MaskRCNN
from mmocr.core import seg2boundary
from mmocr.models.builder import DETECTORS
from mmocr.registry import MODELS
from .text_detector_mixin import TextDetectorMixin
@DETECTORS.register_module()
@MODELS.register_module()
class OCRMaskRCNN(TextDetectorMixin, MaskRCNN):
"""Mask RCNN tailored for OCR."""

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import DETECTORS
from mmocr.registry import MODELS
from .single_stage_text_detector import SingleStageTextDetector
from .text_detector_mixin import TextDetectorMixin
@DETECTORS.register_module()
@MODELS.register_module()
class PANet(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing PANet text detector:

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import DETECTORS
from mmocr.registry import MODELS
from .single_stage_text_detector import SingleStageTextDetector
from .text_detector_mixin import TextDetectorMixin
@DETECTORS.register_module()
@MODELS.register_module()
class PSENet(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing PSENet text detector: Shape Robust Text
Detection with Progressive Scale Expansion Network.

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmocr.models.builder import DETECTORS
from mmocr.models.common.detectors import SingleStageDetector
from mmocr.registry import MODELS
@DETECTORS.register_module()
@MODELS.register_module()
class SingleStageTextDetector(SingleStageDetector):
"""The class for implementing single stage text detector."""

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import DETECTORS
from mmocr.registry import MODELS
from .single_stage_text_detector import SingleStageTextDetector
from .text_detector_mixin import TextDetectorMixin
@DETECTORS.register_module()
@MODELS.register_module()
class TextSnake(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing TextSnake text detector: TextSnake: A
Flexible Representation for Detecting Text of Arbitrary Shapes.

View File

@ -3,11 +3,11 @@ import torch
import torch.nn.functional as F
from torch import nn
from mmocr.models.builder import LOSSES
from mmocr.models.common.losses.dice_loss import DiceLoss
from mmocr.registry import MODELS
@LOSSES.register_module()
@MODELS.register_module()
class DBLoss(nn.Module):
"""The class for implementing DBNet loss.

View File

@ -4,11 +4,11 @@ import torch.nn.functional as F
from mmdet.core import BitmapMasks
from torch import nn
from mmocr.models.builder import LOSSES
from mmocr.registry import MODELS
from mmocr.utils import check_argument
@LOSSES.register_module()
@MODELS.register_module()
class DRRGLoss(nn.Module):
"""The class for implementing DRRG loss. This is partially adapted from
https://github.com/GXYM/DRRG licensed under the MIT license.

View File

@ -5,10 +5,10 @@ import torch.nn.functional as F
from mmdet.core import multi_apply
from torch import nn
from mmocr.models.builder import LOSSES
from mmocr.registry import MODELS
@LOSSES.register_module()
@MODELS.register_module()
class FCELoss(nn.Module):
"""The class for implementing FCENet loss.

View File

@ -8,11 +8,11 @@ import torch.nn.functional as F
from mmdet.core import BitmapMasks
from torch import nn
from mmocr.models.builder import LOSSES
from mmocr.registry import MODELS
from mmocr.utils import check_argument
@LOSSES.register_module()
@MODELS.register_module()
class PANLoss(nn.Module):
"""The class for implementing PANet loss. This was partially adapted from
https://github.com/WenmuZhou/PAN.pytorch.

View File

@ -1,12 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.core import BitmapMasks
from mmocr.models.builder import LOSSES
from mmocr.registry import MODELS
from mmocr.utils import check_argument
from . import PANLoss
@LOSSES.register_module()
@MODELS.register_module()
class PSELoss(PANLoss):
r"""The class for implementing PSENet loss. This is partially adapted from
https://github.com/whai362/PSENet.

View File

@ -4,11 +4,11 @@ import torch.nn.functional as F
from mmdet.core import BitmapMasks
from torch import nn
from mmocr.models.builder import LOSSES
from mmocr.registry import MODELS
from mmocr.utils import check_argument
@LOSSES.register_module()
@MODELS.register_module()
class TextSnakeLoss(nn.Module):
"""The class for implementing TextSnake loss. This is partially adapted
from https://github.com/princewang1994/TextSnake.pytorch.

View File

@ -3,7 +3,7 @@ import torch.nn.functional as F
from mmcv.runner import BaseModule, ModuleList
from torch import nn
from mmocr.models.builder import NECKS
from mmocr.registry import MODELS
class FPEM(BaseModule):
@ -72,7 +72,7 @@ class SeparableConv2d(BaseModule):
return x
@NECKS.register_module()
@MODELS.register_module()
class FPEM_FFM(BaseModule):
"""This code is from https://github.com/WenmuZhou/PAN.pytorch.

View File

@ -5,10 +5,10 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, ModuleList, Sequential, auto_fp16
from mmocr.models.builder import NECKS
from mmocr.registry import MODELS
@NECKS.register_module()
@MODELS.register_module()
class FPNC(BaseModule):
"""FPN-like fusion module in Real-time Scene Text Detection with
Differentiable Binarization.

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from mmcv.runner import BaseModule
from torch import nn
from mmocr.models.builder import NECKS
from mmocr.registry import MODELS
class UpBlock(BaseModule):
@ -30,7 +30,7 @@ class UpBlock(BaseModule):
return x
@NECKS.register_module()
@MODELS.register_module()
class FPN_UNet(BaseModule):
"""The class for implementing DRRG and TextSnake U-Net-like FPN.

View File

@ -4,10 +4,10 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, ModuleList, auto_fp16
from mmocr.models.builder import NECKS
from mmocr.registry import MODELS
@NECKS.register_module()
@MODELS.register_module()
class FPNF(BaseModule):
"""FPN-like fusion module in Shape Robust Text Detection with Progressive
Scale Expansion Network.

View File

@ -3,12 +3,12 @@ import cv2
import numpy as np
from mmocr.core import points2boundary
from mmocr.models.builder import POSTPROCESSOR
from mmocr.registry import MODELS
from .base_postprocessor import BasePostprocessor
from .utils import box_score_fast, unclip
@POSTPROCESSOR.register_module()
@MODELS.register_module()
class DBPostprocessor(BasePostprocessor):
"""Decoding predictions of DbNet to instances. This is partially adapted
from https://github.com/MhLiao/DB.

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import POSTPROCESSOR
from mmocr.registry import MODELS
from .base_postprocessor import BasePostprocessor
from .utils import (clusters2labels, comps2boundaries, connected_components,
graph_propagation, remove_single)
@POSTPROCESSOR.register_module()
@MODELS.register_module()
class DRRGPostprocessor(BasePostprocessor):
"""Merge text components and construct boundaries of text instances.

View File

@ -2,12 +2,12 @@
import cv2
import numpy as np
from mmocr.models.builder import POSTPROCESSOR
from mmocr.registry import MODELS
from .base_postprocessor import BasePostprocessor
from .utils import fill_hole, fourier2poly, poly_nms
@POSTPROCESSOR.register_module()
@MODELS.register_module()
class FCEPostprocessor(BasePostprocessor):
"""Decoding predictions of FCENet to instances.

View File

@ -5,11 +5,11 @@ import torch
from mmcv.ops import pixel_group
from mmocr.core import points2boundary
from mmocr.models.builder import POSTPROCESSOR
from mmocr.registry import MODELS
from .base_postprocessor import BasePostprocessor
@POSTPROCESSOR.register_module()
@MODELS.register_module()
class PANPostprocessor(BasePostprocessor):
"""Convert scores to quadrangles via post processing in PANet. This is
partially adapted from https://github.com/WenmuZhou/PAN.pytorch.

View File

@ -6,11 +6,11 @@ import torch
from mmcv.ops import contour_expand
from mmocr.core import points2boundary
from mmocr.models.builder import POSTPROCESSOR
from mmocr.registry import MODELS
from .base_postprocessor import BasePostprocessor
@POSTPROCESSOR.register_module()
@MODELS.register_module()
class PSEPostprocessor(BasePostprocessor):
"""Decoding predictions of PSENet to instances. This is partially adapted
from https://github.com/whai362/PSENet.

View File

@ -5,12 +5,12 @@ import numpy as np
import torch
from skimage.morphology import skeletonize
from mmocr.models.builder import POSTPROCESSOR
from mmocr.registry import MODELS
from .base_postprocessor import BasePostprocessor
from .utils import centralize, fill_hole, merge_disks
@POSTPROCESSOR.register_module()
@MODELS.register_module()
class TextSnakePostprocessor(BasePostprocessor):
"""Decoding predictions of TextSnake to instances. This was partially
adapted from https://github.com/princewang1994/TextSnake.pytorch.

View File

@ -2,10 +2,10 @@
import torch.nn as nn
from mmcv.runner import BaseModule
from mmocr.models.builder import BACKBONES
from mmocr.registry import MODELS
@BACKBONES.register_module()
@MODELS.register_module()
class NRTRModalityTransform(BaseModule):
def __init__(self,

View File

@ -3,11 +3,11 @@ from mmcv.cnn import ConvModule, build_plugin_layer
from mmcv.runner import BaseModule, Sequential
import mmocr.utils as utils
from mmocr.models.builder import BACKBONES
from mmocr.models.textrecog.layers import BasicBlock
from mmocr.registry import MODELS
@BACKBONES.register_module()
@MODELS.register_module()
class ResNet(BaseModule):
"""
Args:

View File

@ -3,11 +3,11 @@ import torch.nn as nn
from mmcv.runner import BaseModule, Sequential
import mmocr.utils as utils
from mmocr.models.builder import BACKBONES
from mmocr.models.textrecog.layers import BasicBlock
from mmocr.registry import MODELS
@BACKBONES.register_module()
@MODELS.register_module()
class ResNet31OCR(BaseModule):
"""Implement ResNet backbone for text recognition, modified from
`ResNet <https://arxiv.org/pdf/1512.03385.pdf>`_

View File

@ -3,11 +3,11 @@ import torch.nn as nn
from mmcv.runner import BaseModule, Sequential
import mmocr.utils as utils
from mmocr.models.builder import BACKBONES
from mmocr.models.textrecog.layers import BasicBlock
from mmocr.registry import MODELS
@BACKBONES.register_module()
@MODELS.register_module()
class ResNetABI(BaseModule):
"""Implement ResNet backbone for text recognition, modified from `ResNet.

View File

@ -3,10 +3,10 @@ import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from mmocr.models.builder import BACKBONES
from mmocr.registry import MODELS
@BACKBONES.register_module()
@MODELS.register_module()
class ShallowCNN(BaseModule):
"""Implement Shallow CNN block for SATRN.

View File

@ -2,10 +2,10 @@
import torch.nn as nn
from mmcv.runner import BaseModule, Sequential
from mmocr.models.builder import BACKBONES
from mmocr.registry import MODELS
@BACKBONES.register_module()
@MODELS.register_module()
class VeryDeepVgg(BaseModule):
"""Implement VGG-VeryDeep backbone for text recognition, modified from
`VGG-VeryDeep <https://arxiv.org/pdf/1409.1556.pdf>`_

View File

@ -2,11 +2,11 @@
import torch
import mmocr.utils as utils
from mmocr.models.builder import CONVERTORS
from mmocr.registry import MODELS
from .attn import AttnConvertor
@CONVERTORS.register_module()
@MODELS.register_module()
class ABIConvertor(AttnConvertor):
"""Convert between text, index and tensor for encoder-decoder based
pipeline. Modified from AttnConvertor to get closer to ABINet's original

View File

@ -2,11 +2,11 @@
import torch
import mmocr.utils as utils
from mmocr.models.builder import CONVERTORS
from mmocr.registry import MODELS
from .base import BaseConvertor
@CONVERTORS.register_module()
@MODELS.register_module()
class AttnConvertor(BaseConvertor):
"""Convert between text, index and tensor for encoder-decoder based
pipeline.

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import CONVERTORS
from mmocr.registry import MODELS
from mmocr.utils import list_from_file
@CONVERTORS.register_module()
@MODELS.register_module()
class BaseConvertor:
"""Convert between text, index and tensor for text recognize pipeline.

View File

@ -5,11 +5,11 @@ import torch
import torch.nn.functional as F
import mmocr.utils as utils
from mmocr.models.builder import CONVERTORS
from mmocr.registry import MODELS
from .base import BaseConvertor
@CONVERTORS.register_module()
@MODELS.register_module()
class CTCConvertor(BaseConvertor):
"""Convert between text, index and tensor for CTC loss-based pipeline.

View File

@ -4,11 +4,11 @@ import numpy as np
import torch
import mmocr.utils as utils
from mmocr.models.builder import CONVERTORS
from mmocr.registry import MODELS
from .base import BaseConvertor
@CONVERTORS.register_module()
@MODELS.register_module()
class SegConvertor(BaseConvertor):
"""Convert between text, index and tensor for segmentation based pipeline.

View File

@ -6,12 +6,12 @@ import torch.nn as nn
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from mmcv.runner import ModuleList
from mmocr.models.builder import DECODERS
from mmocr.models.common.modules import PositionalEncoding
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder
@DECODERS.register_module()
@MODELS.register_module()
class ABILanguageDecoder(BaseDecoder):
r"""Transformer-based language model responsible for spell correction.
Implementation of language model of \

View File

@ -3,12 +3,12 @@ import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmocr.models.builder import DECODERS
from mmocr.models.common.modules import PositionalEncoding
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder
@DECODERS.register_module()
@MODELS.register_module()
class ABIVisionDecoder(BaseDecoder):
"""Converts visual features into text characters.

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule
from mmocr.models.builder import DECODERS
from mmocr.registry import MODELS
@DECODERS.register_module()
@MODELS.register_module()
class BaseDecoder(BaseModule):
"""Base decoder class for text recognition."""

View File

@ -2,12 +2,12 @@
import torch.nn as nn
from mmcv.runner import Sequential
from mmocr.models.builder import DECODERS
from mmocr.models.textrecog.layers import BidirectionalLSTM
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder
@DECODERS.register_module()
@MODELS.register_module()
class CRNNDecoder(BaseDecoder):
"""Decoder for CRNN.

View File

@ -8,8 +8,8 @@ import torch.nn.functional as F
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from mmcv.runner import ModuleList
from mmocr.models.builder import DECODERS
from mmocr.models.common.modules import PositionalEncoding
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder
@ -30,7 +30,7 @@ class Embeddings(nn.Module):
return self.lut(x) * math.sqrt(self.d_model)
@DECODERS.register_module()
@MODELS.register_module()
class MasterDecoder(BaseDecoder):
"""Decoder module in `MASTER <https://arxiv.org/abs/1910.02562>`_.

View File

@ -6,12 +6,12 @@ import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import ModuleList
from mmocr.models.builder import DECODERS
from mmocr.models.common import PositionalEncoding, TFDecoderLayer
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder
@DECODERS.register_module()
@MODELS.register_module()
class NRTRDecoder(BaseDecoder):
"""Transformer Decoder block with self attention mechanism.

View File

@ -4,13 +4,13 @@ import math
import torch
import torch.nn as nn
from mmocr.models.builder import DECODERS
from mmocr.models.textrecog.layers import (DotProductAttentionLayer,
PositionAwareLayer)
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder
@DECODERS.register_module()
@MODELS.register_module()
class PositionAttentionDecoder(BaseDecoder):
"""Position attention decoder for RobustScanner.

View File

@ -3,12 +3,12 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from mmocr.models.builder import DECODERS, build_decoder
from mmocr.models.textrecog.layers import RobustScannerFusionLayer
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder
@DECODERS.register_module()
@MODELS.register_module()
class RobustScannerDecoder(BaseDecoder):
"""Decoder for RobustScanner.
@ -72,7 +72,7 @@ class RobustScannerDecoder(BaseDecoder):
hybrid_decoder.update(encode_value=self.encode_value)
hybrid_decoder.update(return_feature=True)
self.hybrid_decoder = build_decoder(hybrid_decoder)
self.hybrid_decoder = MODELS.build(hybrid_decoder)
# init position decoder
position_decoder.update(num_classes=self.num_classes)
@ -83,7 +83,7 @@ class RobustScannerDecoder(BaseDecoder):
position_decoder.update(encode_value=self.encode_value)
position_decoder.update(return_feature=True)
self.position_decoder = build_decoder(position_decoder)
self.position_decoder = MODELS.build(position_decoder)
self.fusion_module = RobustScannerFusionLayer(
self.dim_model if encode_value else dim_input)

View File

@ -6,11 +6,11 @@ import torch.nn as nn
import torch.nn.functional as F
import mmocr.utils as utils
from mmocr.models.builder import DECODERS
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder
@DECODERS.register_module()
@MODELS.register_module()
class ParallelSARDecoder(BaseDecoder):
"""Implementation Parallel Decoder module in `SAR.
@ -255,7 +255,7 @@ class ParallelSARDecoder(BaseDecoder):
return outputs
@DECODERS.register_module()
@MODELS.register_module()
class SequentialSARDecoder(BaseDecoder):
"""Implementation Sequential Decoder module in `SAR.

View File

@ -5,7 +5,7 @@ import torch
import torch.nn.functional as F
import mmocr.utils as utils
from mmocr.models.builder import DECODERS
from mmocr.registry import MODELS
from . import ParallelSARDecoder
@ -31,7 +31,7 @@ class DecodeNode:
return accu_score
@DECODERS.register_module()
@MODELS.register_module()
class ParallelSARDecoderWithBS(ParallelSARDecoder):
"""Parallel Decoder module with beam-search in SAR.

View File

@ -5,12 +5,12 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from mmocr.models.builder import DECODERS
from mmocr.models.textrecog.layers import DotProductAttentionLayer
from mmocr.registry import MODELS
from .base_decoder import BaseDecoder
@DECODERS.register_module()
@MODELS.register_module()
class SequenceAttentionDecoder(BaseDecoder):
"""Sequence attention decoder for RobustScanner.

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import ENCODERS, build_decoder, build_encoder
from mmocr.registry import MODELS
from .base_encoder import BaseEncoder
@ENCODERS.register_module()
@MODELS.register_module()
class ABIVisionModel(BaseEncoder):
"""A wrapper of visual feature encoder and language token decoder that
converts visual features into text tokens.
@ -23,8 +23,8 @@ class ABIVisionModel(BaseEncoder):
init_cfg=dict(type='Xavier', layer='Conv2d'),
**kwargs):
super().__init__(init_cfg=init_cfg)
self.encoder = build_encoder(encoder)
self.decoder = build_decoder(decoder)
self.encoder = MODELS.build(encoder)
self.decoder = MODELS.build(decoder)
def forward(self, feat, img_metas=None):
"""

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule
from mmocr.models.builder import ENCODERS
from mmocr.registry import MODELS
@ENCODERS.register_module()
@MODELS.register_module()
class BaseEncoder(BaseModule):
"""Base Encoder class for text recognition."""

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmocr.models.builder import ENCODERS
from mmocr.registry import MODELS
from .base_encoder import BaseEncoder
@ENCODERS.register_module()
@MODELS.register_module()
class ChannelReductionEncoder(BaseEncoder):
"""Change the channel number with a one by one convoluational layer.

View File

@ -4,12 +4,12 @@ import math
import torch.nn as nn
from mmcv.runner import ModuleList
from mmocr.models.builder import ENCODERS
from mmocr.models.common import TFEncoderLayer
from mmocr.registry import MODELS
from .base_encoder import BaseEncoder
@ENCODERS.register_module()
@MODELS.register_module()
class NRTREncoder(BaseEncoder):
"""Transformer Encoder block with self attention mechanism.

View File

@ -6,11 +6,11 @@ import torch.nn as nn
import torch.nn.functional as F
import mmocr.utils as utils
from mmocr.models.builder import ENCODERS
from mmocr.registry import MODELS
from .base_encoder import BaseEncoder
@ENCODERS.register_module()
@MODELS.register_module()
class SAREncoder(BaseEncoder):
"""Implementation of encoder module in `SAR.

View File

@ -4,13 +4,13 @@ import math
import torch.nn as nn
from mmcv.runner import ModuleList
from mmocr.models.builder import ENCODERS
from mmocr.models.textrecog.layers import (Adaptive2DPositionalEncoding,
SatrnEncoderLayer)
from mmocr.registry import MODELS
from .base_encoder import BaseEncoder
@ENCODERS.register_module()
@MODELS.register_module()
class SatrnEncoder(BaseEncoder):
"""Implement encoder for SATRN, see `SATRN.

View File

@ -4,11 +4,11 @@ import copy
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from mmcv.runner import BaseModule, ModuleList
from mmocr.models.builder import ENCODERS
from mmocr.models.common.modules import PositionalEncoding
from mmocr.registry import MODELS
@ENCODERS.register_module()
@MODELS.register_module()
class TransformerEncoder(BaseModule):
"""Implement transformer encoder for text recognition, modified from
`<https://github.com/FangShancheng/ABINet>`.

View File

@ -3,10 +3,10 @@ import torch
import torch.nn as nn
from mmcv.runner import BaseModule
from mmocr.models.builder import FUSERS
from mmocr.registry import MODELS
@FUSERS.register_module()
@MODELS.register_module()
class ABIFuser(BaseModule):
"""Mix and align visual feature and linguistic feature Implementation of
language model of `ABINet <https://arxiv.org/abs/1910.04396>`_.

View File

@ -4,10 +4,10 @@ from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch import nn
from mmocr.models.builder import HEADS
from mmocr.registry import MODELS
@HEADS.register_module()
@MODELS.register_module()
class SegHead(BaseModule):
"""Head for segmentation based text recognition.

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmocr.models.builder import LOSSES
from mmocr.registry import MODELS
@LOSSES.register_module()
@MODELS.register_module()
class CELoss(nn.Module):
"""Implementation of loss module for encoder-decoder based text recognition
method with CrossEntropy loss.
@ -63,7 +63,7 @@ class CELoss(nn.Module):
return losses
@LOSSES.register_module()
@MODELS.register_module()
class SARLoss(CELoss):
"""Implementation of loss module in `SAR.
@ -95,7 +95,7 @@ class SARLoss(CELoss):
return outputs, targets
@LOSSES.register_module()
@MODELS.register_module()
class TFLoss(CELoss):
"""Implementation of loss module for transformer.

View File

@ -4,10 +4,10 @@ import math
import torch
import torch.nn as nn
from mmocr.models.builder import LOSSES
from mmocr.registry import MODELS
@LOSSES.register_module()
@MODELS.register_module()
class CTCLoss(nn.Module):
"""Implementation of loss module for CTC-loss based text recognition.

View File

@ -3,10 +3,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from mmocr.models.builder import LOSSES
from mmocr.registry import MODELS
@LOSSES.register_module()
@MODELS.register_module()
class ABILoss(nn.Module):
"""Implementation of ABINet multiloss that allows mixing different types of
losses with weights.

View File

@ -3,10 +3,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from mmocr.models.builder import LOSSES
from mmocr.registry import MODELS
@LOSSES.register_module()
@MODELS.register_module()
class SegLoss(nn.Module):
"""Implementation of loss module for segmentation based text recognition
method.

View File

@ -4,10 +4,10 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, ModuleList
from mmocr.models.builder import NECKS
from mmocr.registry import MODELS
@NECKS.register_module()
@MODELS.register_module()
class FPNOCR(BaseModule):
"""FPN-like Network for segmentation based text recognition.

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule
from mmocr.models.builder import PREPROCESSOR
from mmocr.registry import MODELS
@PREPROCESSOR.register_module()
@MODELS.register_module()
class BasePreprocessor(BaseModule):
"""Base Preprocessor class for text recognition."""

View File

@ -17,11 +17,11 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from mmocr.models.builder import PREPROCESSOR
from mmocr.registry import MODELS
from .base_preprocessor import BasePreprocessor
@PREPROCESSOR.register_module()
@MODELS.register_module()
class TPSPreprocessor(BasePreprocessor):
"""Rectification Network of RARE, namely TPS based STN in
https://arxiv.org/pdf/1603.03915.pdf.

View File

@ -3,13 +3,11 @@ import warnings
import torch
from mmocr.models.builder import (RECOGNIZERS, build_backbone, build_convertor,
build_decoder, build_encoder, build_fuser,
build_loss, build_preprocessor)
from mmocr.registry import MODELS
from .encode_decode_recognizer import EncodeDecodeRecognizer
@RECOGNIZERS.register_module()
@MODELS.register_module()
class ABINet(EncodeDecodeRecognizer):
"""Implementation of `Read Like Humans: Autonomous, Bidirectional and
Iterative LanguageModeling for Scene Text Recognition.
@ -36,21 +34,21 @@ class ABINet(EncodeDecodeRecognizer):
# Label convertor (str2tensor, tensor2str)
assert label_convertor is not None
label_convertor.update(max_seq_len=max_seq_len)
self.label_convertor = build_convertor(label_convertor)
self.label_convertor = MODELS.build(label_convertor)
# Preprocessor module, e.g., TPS
self.preprocessor = None
if preprocessor is not None:
self.preprocessor = build_preprocessor(preprocessor)
self.preprocessor = MODELS.build(preprocessor)
# Backbone
assert backbone is not None
self.backbone = build_backbone(backbone)
self.backbone = MODELS.build(backbone)
# Encoder module
self.encoder = None
if encoder is not None:
self.encoder = build_encoder(encoder)
self.encoder = MODELS.build(encoder)
# Decoder module
self.decoder = None
@ -59,11 +57,11 @@ class ABINet(EncodeDecodeRecognizer):
decoder.update(start_idx=self.label_convertor.start_idx)
decoder.update(padding_idx=self.label_convertor.padding_idx)
decoder.update(max_seq_len=max_seq_len)
self.decoder = build_decoder(decoder)
self.decoder = MODELS.build(decoder)
# Loss
assert loss is not None
self.loss = build_loss(loss)
self.loss = MODELS.build(loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
@ -78,7 +76,7 @@ class ABINet(EncodeDecodeRecognizer):
self.fuser = None
if fuser is not None:
self.fuser = build_fuser(fuser)
self.fuser = MODELS.build(fuser)
def forward_train(self, img, img_metas):
"""

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import RECOGNIZERS
from mmocr.registry import MODELS
from .encode_decode_recognizer import EncodeDecodeRecognizer
@RECOGNIZERS.register_module()
@MODELS.register_module()
class CRNNNet(EncodeDecodeRecognizer):
"""CTC-loss based recognizer."""

View File

@ -3,13 +3,11 @@ import warnings
import torch
from mmocr.models.builder import (RECOGNIZERS, build_backbone, build_convertor,
build_decoder, build_encoder, build_loss,
build_preprocessor)
from mmocr.registry import MODELS
from .base import BaseRecognizer
@RECOGNIZERS.register_module()
@MODELS.register_module()
class EncodeDecodeRecognizer(BaseRecognizer):
"""Base class for encode-decode recognizer."""
@ -31,21 +29,21 @@ class EncodeDecodeRecognizer(BaseRecognizer):
# Label convertor (str2tensor, tensor2str)
assert label_convertor is not None
label_convertor.update(max_seq_len=max_seq_len)
self.label_convertor = build_convertor(label_convertor)
self.label_convertor = MODELS.build(label_convertor)
# Preprocessor module, e.g., TPS
self.preprocessor = None
if preprocessor is not None:
self.preprocessor = build_preprocessor(preprocessor)
self.preprocessor = MODELS.build(preprocessor)
# Backbone
assert backbone is not None
self.backbone = build_backbone(backbone)
self.backbone = MODELS.build(backbone)
# Encoder module
self.encoder = None
if encoder is not None:
self.encoder = build_encoder(encoder)
self.encoder = MODELS.build(encoder)
# Decoder module
assert decoder is not None
@ -53,12 +51,12 @@ class EncodeDecodeRecognizer(BaseRecognizer):
decoder.update(start_idx=self.label_convertor.start_idx)
decoder.update(padding_idx=self.label_convertor.padding_idx)
decoder.update(max_seq_len=max_seq_len)
self.decoder = build_decoder(decoder)
self.decoder = MODELS.build(decoder)
# Loss
assert loss is not None
loss.update(ignore_index=self.label_convertor.padding_idx)
self.loss = build_loss(loss)
self.loss = MODELS.build(loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import DETECTORS
from mmocr.registry import MODELS
from .encode_decode_recognizer import EncodeDecodeRecognizer
@DETECTORS.register_module()
@MODELS.register_module()
class MASTER(EncodeDecodeRecognizer):
"""Implementation of `MASTER <https://arxiv.org/abs/1910.02562>`_"""

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import RECOGNIZERS
from mmocr.registry import MODELS
from .encode_decode_recognizer import EncodeDecodeRecognizer
@RECOGNIZERS.register_module()
@MODELS.register_module()
class NRTR(EncodeDecodeRecognizer):
"""Implementation of `NRTR <https://arxiv.org/pdf/1806.00926.pdf>`_"""

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import RECOGNIZERS
from mmocr.registry import MODELS
from .encode_decode_recognizer import EncodeDecodeRecognizer
@RECOGNIZERS.register_module()
@MODELS.register_module()
class RobustScanner(EncodeDecodeRecognizer):
"""Implementation of `RobustScanner.

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import RECOGNIZERS
from mmocr.registry import MODELS
from .encode_decode_recognizer import EncodeDecodeRecognizer
@RECOGNIZERS.register_module()
@MODELS.register_module()
class SARNet(EncodeDecodeRecognizer):
"""Implementation of `SAR <https://arxiv.org/abs/1811.00751>`_"""

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import RECOGNIZERS
from mmocr.registry import MODELS
from .encode_decode_recognizer import EncodeDecodeRecognizer
@RECOGNIZERS.register_module()
@MODELS.register_module()
class SATRN(EncodeDecodeRecognizer):
"""Implementation of `SATRN <https://arxiv.org/abs/1910.04396>`_"""

View File

@ -1,13 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from mmocr.models.builder import (RECOGNIZERS, build_backbone, build_convertor,
build_head, build_loss, build_neck,
build_preprocessor)
from mmocr.registry import MODELS
from .base import BaseRecognizer
@RECOGNIZERS.register_module()
@MODELS.register_module()
class SegRecognizer(BaseRecognizer):
"""Base class for segmentation based recognizer."""
@ -26,29 +24,29 @@ class SegRecognizer(BaseRecognizer):
# Label_convertor
assert label_convertor is not None
self.label_convertor = build_convertor(label_convertor)
self.label_convertor = MODELS.build(label_convertor)
# Preprocessor module, e.g., TPS
self.preprocessor = None
if preprocessor is not None:
self.preprocessor = build_preprocessor(preprocessor)
self.preprocessor = MODELS.build(preprocessor)
# Backbone
assert backbone is not None
self.backbone = build_backbone(backbone)
self.backbone = MODELS.build(backbone)
# Neck
assert neck is not None
self.neck = build_neck(neck)
self.neck = MODELS.build(neck)
# Head
assert head is not None
head.update(num_classes=self.label_convertor.num_classes())
self.head = build_head(head)
self.head = MODELS.build(head)
# Loss
assert loss is not None
self.loss = build_loss(loss)
self.loss = MODELS.build(loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg

View File

@ -24,9 +24,9 @@ from mmocr.apis.inference import model_inference
from mmocr.core.visualize import det_recog_show_result
from mmocr.datasets.kie_dataset import KIEDataset
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.models import build_detector
from mmocr.models.textdet.detectors import TextDetectorMixin
from mmocr.models.textrecog.recognizer import BaseRecognizer
from mmocr.registry import MODELS
from mmocr.utils import is_type_list
from mmocr.utils.box_util import stitch_boxes_into_lines
from mmocr.utils.fileio import list_from_file
@ -427,7 +427,7 @@ class MMOCR:
'kie/' + kie_models[self.kie]['ckpt']
kie_cfg = Config.fromfile(kie_config)
self.kie_model = build_detector(
self.kie_model = MODELS.build(
kie_cfg.model, test_cfg=kie_cfg.get('test_cfg'))
self.kie_model = revert_sync_batchnorm(self.kie_model)
self.kie_model.cfg = kie_cfg