mirror of https://github.com/open-mmlab/mmocr.git
[Refactor] union to MODELS
parent
3f24e34a5d
commit
23458f8a47
|
@ -3,7 +3,6 @@ from argparse import ArgumentParser
|
|||
|
||||
from mmocr.apis import init_detector
|
||||
from mmocr.apis.inference import text_model_inference
|
||||
from mmocr.models import build_detector # NOQA
|
||||
from mmocr.registry import DATASETS # NOQA
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ import cv2
|
|||
import torch
|
||||
|
||||
from mmocr.apis import init_detector, model_inference
|
||||
from mmocr.models import build_detector # noqa: F401
|
||||
from mmocr.registry import DATASETS # noqa: F401
|
||||
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from mmdet.core import get_classes
|
|||
from mmdet.datasets import replace_ImageToTensor
|
||||
from mmdet.datasets.pipelines import Compose
|
||||
|
||||
from mmocr.models import build_detector
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import is_2dlist
|
||||
from .utils import disable_text_recog_aug_test
|
||||
|
||||
|
@ -40,7 +40,7 @@ def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
|
|||
if config.model.get('pretrained'):
|
||||
config.model.pretrained = None
|
||||
config.model.train_cfg = None
|
||||
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
|
||||
model = MODELS.build(config.model, test_cfg=config.get('test_cfg'))
|
||||
if checkpoint is not None:
|
||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
if 'CLASSES' in checkpoint.get('meta', {}):
|
||||
|
|
|
@ -5,7 +5,6 @@ from typing import Any, Iterable
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmdet.models.builder import DETECTORS
|
||||
|
||||
from mmocr.models.textdet.detectors.single_stage_text_detector import \
|
||||
SingleStageTextDetector
|
||||
|
@ -13,6 +12,7 @@ from mmocr.models.textdet.detectors.text_detector_mixin import \
|
|||
TextDetectorMixin
|
||||
from mmocr.models.textrecog.recognizer.encode_decode_recognizer import \
|
||||
EncodeDecodeRecognizer
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
def inference_with_session(sess, io_binding, input_name, output_names,
|
||||
|
@ -34,7 +34,7 @@ def inference_with_session(sess, io_binding, input_name, output_names,
|
|||
return pred
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for evaluating onnx file of detection."""
|
||||
|
||||
|
@ -110,7 +110,7 @@ class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector):
|
|||
return boundaries
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class ONNXRuntimeRecognizer(EncodeDecodeRecognizer):
|
||||
"""The class for evaluating onnx file of recognition."""
|
||||
|
||||
|
@ -201,7 +201,7 @@ class ONNXRuntimeRecognizer(EncodeDecodeRecognizer):
|
|||
return results
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for evaluating TensorRT file of detection."""
|
||||
|
||||
|
@ -257,7 +257,7 @@ class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector):
|
|||
return boundaries
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class TensorRTRecognizer(EncodeDecodeRecognizer):
|
||||
"""The class for evaluating TensorRT file of recognition."""
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmocr.models.builder import build_convertor
|
||||
from mmocr.registry import TRANSFORMS
|
||||
from mmocr.registry import MODELS, TRANSFORMS
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
|
@ -18,7 +17,7 @@ class NerTransform:
|
|||
"""
|
||||
|
||||
def __init__(self, label_convertor, max_len):
|
||||
self.label_convertor = build_convertor(label_convertor)
|
||||
self.label_convertor = MODELS.build(label_convertor)
|
||||
self.max_len = max_len
|
||||
|
||||
def __call__(self, results):
|
||||
|
|
|
@ -4,8 +4,7 @@ import numpy as np
|
|||
from mmdet.core import BitmapMasks
|
||||
|
||||
import mmocr.utils.check_argument as check_argument
|
||||
from mmocr.models.builder import build_convertor
|
||||
from mmocr.registry import TRANSFORMS
|
||||
from mmocr.registry import MODELS, TRANSFORMS
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
|
@ -41,7 +40,7 @@ class OCRSegTargets:
|
|||
|
||||
self.attn_shrink_ratio = attn_shrink_ratio
|
||||
self.seg_shrink_ratio = seg_shrink_ratio
|
||||
self.label_convertor = build_convertor(label_convertor)
|
||||
self.label_convertor = MODELS.build(label_convertor)
|
||||
self.box_type = box_type
|
||||
self.pad_val = pad_val
|
||||
|
||||
|
|
|
@ -1,19 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import common, kie, textdet, textrecog
|
||||
from .builder import (BACKBONES, CONVERTORS, DECODERS, DETECTORS, ENCODERS,
|
||||
HEADS, LOSSES, NECKS, PREPROCESSOR, build_backbone,
|
||||
build_convertor, build_decoder, build_detector,
|
||||
build_encoder, build_loss, build_preprocessor)
|
||||
from .common import * # NOQA
|
||||
from .kie import * # NOQA
|
||||
from .ner import * # NOQA
|
||||
from .textdet import * # NOQA
|
||||
from .textrecog import * # NOQA
|
||||
|
||||
__all__ = [
|
||||
'BACKBONES', 'DETECTORS', 'HEADS', 'LOSSES', 'NECKS', 'build_backbone',
|
||||
'build_detector', 'build_loss', 'CONVERTORS', 'ENCODERS', 'DECODERS',
|
||||
'PREPROCESSOR', 'build_convertor', 'build_encoder', 'build_decoder',
|
||||
'build_preprocessor'
|
||||
]
|
||||
__all__ += common.__all__ + kie.__all__ + textdet.__all__ + textrecog.__all__
|
||||
__all__ = common.__all__ + kie.__all__ + textdet.__all__ + textrecog.__all__
|
||||
|
|
|
@ -1,115 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ACTIVATION_LAYERS as MMCV_ACTIVATION_LAYERS
|
||||
from mmcv.cnn import UPSAMPLE_LAYERS as MMCV_UPSAMPLE_LAYERS
|
||||
from mmcv.utils import Registry, build_from_cfg
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
CONVERTORS = MODELS
|
||||
ENCODERS = MODELS
|
||||
DECODERS = MODELS
|
||||
PREPROCESSOR = MODELS
|
||||
POSTPROCESSOR = MODELS
|
||||
|
||||
UPSAMPLE_LAYERS = MODELS
|
||||
BACKBONES = MODELS
|
||||
LOSSES = MODELS
|
||||
DETECTORS = MODELS
|
||||
ROI_EXTRACTORS = MODELS
|
||||
HEADS = MODELS
|
||||
NECKS = MODELS
|
||||
FUSERS = MODELS
|
||||
RECOGNIZERS = MODELS
|
||||
|
||||
ACTIVATION_LAYERS = MODELS
|
||||
|
||||
|
||||
def build_recognizer(cfg, train_cfg=None, test_cfg=None):
|
||||
"""Build recognizer."""
|
||||
warnings.warn('``build_recognizer`` would be deprecated soon, please use '
|
||||
'``mmocr.registry.MODELS.build()`` ')
|
||||
return RECOGNIZERS(
|
||||
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
|
||||
|
||||
|
||||
def build_convertor(cfg):
|
||||
"""Build label convertor for scene text recognizer."""
|
||||
warnings.warn('``build_convertor`` would be deprecated soon, please use '
|
||||
'``mmocr.registry.MODELS.build()`` ')
|
||||
return CONVERTORS.build(cfg)
|
||||
|
||||
|
||||
def build_encoder(cfg):
|
||||
"""Build encoder for scene text recognizer."""
|
||||
warnings.warn('``build_encoder`` would be deprecated soon, please use '
|
||||
'``mmocr.registry.MODELS.build()`` ')
|
||||
return ENCODERS.build(cfg)
|
||||
|
||||
|
||||
def build_decoder(cfg):
|
||||
"""Build decoder for scene text recognizer."""
|
||||
warnings.warn('``build_decoder`` would be deprecated soon, please use '
|
||||
'``mmocr.registry.MODELS.build()`` ')
|
||||
return DECODERS.build(cfg)
|
||||
|
||||
|
||||
def build_preprocessor(cfg):
|
||||
"""Build preprocessor for scene text recognizer."""
|
||||
warnings.warn(
|
||||
'``build_preprocessor`` would be deprecated soon, please use '
|
||||
'``mmocr.registry.MODELS.build()`` ')
|
||||
return PREPROCESSOR(cfg)
|
||||
|
||||
|
||||
def build_postprocessor(cfg):
|
||||
"""Build postprocessor for scene text detector."""
|
||||
warnings.warn(
|
||||
'``build_postprocessor`` would be deprecated soon, please use '
|
||||
'``mmocr.registry.MODELS.build()`` ')
|
||||
return POSTPROCESSOR.build(cfg)
|
||||
|
||||
|
||||
def build_roi_extractor(cfg):
|
||||
"""Build roi extractor."""
|
||||
warnings.warn(
|
||||
'``build_roi_extractor`` would be deprecated soon, please use '
|
||||
'``mmocr.registry.MODELS.build()`` ')
|
||||
return ROI_EXTRACTORS.build(cfg)
|
||||
|
||||
|
||||
def build_loss(cfg):
|
||||
"""Build loss."""
|
||||
warnings.warn('``build_loss`` would be deprecated soon, please use '
|
||||
'``mmocr.registry.MODELS.build()`` ')
|
||||
return LOSSES.build(cfg)
|
||||
|
||||
|
||||
def build_backbone(cfg):
|
||||
"""Build backbone."""
|
||||
warnings.warn('``build_backbone`` would be deprecated soon, please use '
|
||||
'``mmocr.registry.MODELS.build()`` ')
|
||||
return BACKBONES.build(cfg)
|
||||
|
||||
|
||||
def build_head(cfg):
|
||||
"""Build head."""
|
||||
warnings.warn('``build_head`` would be deprecated soon, please use '
|
||||
'``mmocr.registry.MODELS.build()`` ')
|
||||
return HEADS.build(cfg)
|
||||
|
||||
|
||||
def build_neck(cfg):
|
||||
"""Build neck."""
|
||||
warnings.warn('``build_neck`` would be deprecated soon, please use '
|
||||
'``mmocr.registry.MODELS.build()`` ')
|
||||
return NECKS.build(cfg)
|
||||
|
||||
|
||||
def build_fuser(cfg):
|
||||
"""Build fuser."""
|
||||
warnings.warn('``build_fuser`` would be deprecated soon, please use '
|
||||
'``mmocr.registry.MODELS.build()`` ')
|
||||
return FUSERS.build(cfg)
|
||||
UPSAMPLE_LAYERS = Registry('upsample layer', parent=MMCV_UPSAMPLE_LAYERS)
|
||||
ACTIVATION_LAYERS = Registry('activation layer', parent=MMCV_ACTIVATION_LAYERS)
|
||||
|
||||
|
||||
def build_upsample_layer(cfg, *args, **kwargs):
|
||||
|
@ -160,21 +56,4 @@ def build_activation_layer(cfg):
|
|||
Returns:
|
||||
nn.Module: Created activation layer.
|
||||
"""
|
||||
warnings.warn(
|
||||
'``build_activation_layer`` would be deprecated soon, please use '
|
||||
'``mmocr.registry.MODELS.build()`` ')
|
||||
return ACTIVATION_LAYERS.build(cfg)
|
||||
|
||||
|
||||
def build_detector(cfg, train_cfg=None, test_cfg=None):
|
||||
"""Build detector."""
|
||||
if train_cfg is not None or test_cfg is not None:
|
||||
warnings.warn(
|
||||
'train_cfg and test_cfg is deprecated, '
|
||||
'please specify them in model', UserWarning)
|
||||
assert cfg.get('train_cfg') is None or train_cfg is None, \
|
||||
'train_cfg specified in both outer field and model field '
|
||||
assert cfg.get('test_cfg') is None or test_cfg is None, \
|
||||
'test_cfg specified in both outer field and model field '
|
||||
return DETECTORS.build(
|
||||
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
|
||||
return build_from_cfg(cfg, ACTIVATION_LAYERS)
|
||||
|
|
|
@ -6,8 +6,9 @@ from mmcv.cnn import ConvModule, build_norm_layer
|
|||
from mmcv.runner import BaseModule
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmocr.models.builder import (BACKBONES, UPSAMPLE_LAYERS,
|
||||
build_activation_layer, build_upsample_layer)
|
||||
from mmocr.models.builder import (UPSAMPLE_LAYERS, build_activation_layer,
|
||||
build_upsample_layer)
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
class UpConvBlock(nn.Module):
|
||||
|
@ -317,7 +318,7 @@ class InterpConv(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class UNet(BaseModule):
|
||||
"""UNet backbone.
|
||||
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
||||
|
|
|
@ -4,11 +4,10 @@ import warnings
|
|||
from mmdet.models.detectors import \
|
||||
SingleStageDetector as MMDET_SingleStageDetector
|
||||
|
||||
from mmocr.models.builder import (DETECTORS, build_backbone, build_head,
|
||||
build_neck)
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SingleStageDetector(MMDET_SingleStageDetector):
|
||||
"""Base class for single-stage detectors.
|
||||
|
||||
|
@ -29,11 +28,11 @@ class SingleStageDetector(MMDET_SingleStageDetector):
|
|||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
backbone.pretrained = pretrained
|
||||
self.backbone = build_backbone(backbone)
|
||||
self.backbone = MODELS.build(backbone)
|
||||
if neck is not None:
|
||||
self.neck = build_neck(neck)
|
||||
self.neck = MODELS.build(neck)
|
||||
bbox_head.update(train_cfg=train_cfg)
|
||||
bbox_head.update(test_cfg=test_cfg)
|
||||
self.bbox_head = build_head(bbox_head)
|
||||
self.bbox_head = MODELS.build(bbox_head)
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmocr.models.builder import LOSSES
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class DiceLoss(nn.Module):
|
||||
|
||||
def __init__(self, eps=1e-6):
|
||||
|
|
|
@ -7,12 +7,12 @@ from torch import nn
|
|||
from torch.nn import functional as F
|
||||
|
||||
from mmocr.core import imshow_edge, imshow_node
|
||||
from mmocr.models.builder import DETECTORS, build_roi_extractor
|
||||
from mmocr.models.common.detectors import SingleStageDetector
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import list_from_file
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SDMGR(SingleStageDetector):
|
||||
"""The implementation of the paper: Spatial Dual-Modality Graph Reasoning
|
||||
for Key Information Extraction. https://arxiv.org/abs/2103.14470.
|
||||
|
@ -42,7 +42,7 @@ class SDMGR(SingleStageDetector):
|
|||
backbone, neck, bbox_head, train_cfg, test_cfg, init_cfg=init_cfg)
|
||||
self.visual_modality = visual_modality
|
||||
if visual_modality:
|
||||
self.extractor = build_roi_extractor({
|
||||
self.extractor = MODELS.build({
|
||||
**extractor, 'out_channels':
|
||||
self.backbone.base_channels
|
||||
})
|
||||
|
|
|
@ -4,10 +4,10 @@ from mmcv.runner import BaseModule
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmocr.models.builder import HEADS, build_loss
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SDMGRHead(BaseModule):
|
||||
|
||||
def __init__(self,
|
||||
|
@ -45,7 +45,7 @@ class SDMGRHead(BaseModule):
|
|||
[GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
|
||||
self.node_cls = nn.Linear(node_embed, num_classes)
|
||||
self.edge_cls = nn.Linear(edge_embed, 2)
|
||||
self.loss = build_loss(loss)
|
||||
self.loss = MODELS.build(loss)
|
||||
|
||||
def forward(self, relations, texts, x=None):
|
||||
node_nums, char_nums = [], []
|
||||
|
|
|
@ -3,10 +3,10 @@ import torch
|
|||
from mmdet.models.losses import accuracy
|
||||
from torch import nn
|
||||
|
||||
from mmocr.models.builder import LOSSES
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class SDMGRLoss(nn.Module):
|
||||
"""The implementation the loss of key information extraction proposed in
|
||||
the paper: Spatial Dual-Modality Graph Reasoning for Key Information
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import (DETECTORS, build_convertor, build_decoder,
|
||||
build_encoder, build_loss)
|
||||
from mmocr.models.textrecog.recognizer.base import BaseRecognizer
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class NerClassifier(BaseRecognizer):
|
||||
"""Base class for NER classifier."""
|
||||
|
||||
|
@ -17,15 +16,15 @@ class NerClassifier(BaseRecognizer):
|
|||
test_cfg=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.label_convertor = build_convertor(label_convertor)
|
||||
self.label_convertor = MODELS.build(label_convertor)
|
||||
|
||||
self.encoder = build_encoder(encoder)
|
||||
self.encoder = MODELS.build(encoder)
|
||||
|
||||
decoder.update(num_labels=self.label_convertor.num_labels)
|
||||
self.decoder = build_decoder(decoder)
|
||||
self.decoder = MODELS.build(decoder)
|
||||
|
||||
loss.update(num_labels=self.label_convertor.num_labels)
|
||||
self.loss = build_loss(loss)
|
||||
self.loss = MODELS.build(loss)
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
"""Extract features from images."""
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
|
||||
from mmocr.models.builder import CONVERTORS
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import list_from_file
|
||||
|
||||
|
||||
@CONVERTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class NerConvertor:
|
||||
"""Convert between text, index and tensor for NER pipeline.
|
||||
|
||||
|
|
|
@ -4,10 +4,10 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class FCDecoder(BaseModule):
|
||||
"""FC Decoder class for Ner.
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.models.builder import ENCODERS
|
||||
from mmocr.models.ner.utils.bert import BertModel
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class BertEncoder(BaseModule):
|
||||
"""Bert encoder
|
||||
Args:
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from mmocr.models.builder import LOSSES
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class MaskedCrossEntropyLoss(nn.Module):
|
||||
"""The implementation of masked cross entropy loss.
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from torch import nn
|
||||
|
||||
from mmocr.models.builder import LOSSES
|
||||
from mmocr.models.common.losses.focal_loss import FocalLoss
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class MaskedFocalLoss(nn.Module):
|
||||
"""The implementation of masked focal loss.
|
||||
|
||||
|
|
|
@ -5,11 +5,11 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule, Sequential
|
||||
|
||||
from mmocr.models.builder import HEADS
|
||||
from mmocr.registry import MODELS
|
||||
from .head_mixin import HeadMixin
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class DBHead(HeadMixin, BaseModule):
|
||||
"""The class for DBNet head.
|
||||
|
||||
|
|
|
@ -7,13 +7,13 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.models.builder import HEADS, build_loss
|
||||
from mmocr.models.textdet.modules import GCN, LocalGraphs, ProposalLocalGraphs
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import check_argument
|
||||
from .head_mixin import HeadMixin
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class DRRGHead(HeadMixin, BaseModule):
|
||||
"""The class for DRRG head: `Deep Relational Reasoning Graph Network for
|
||||
Arbitrary Shape Text Detection <https://arxiv.org/abs/2003.07493>`_.
|
||||
|
@ -118,7 +118,7 @@ class DRRGHead(HeadMixin, BaseModule):
|
|||
self.center_region_thr = center_region_thr
|
||||
self.center_region_area_thr = center_region_area_thr
|
||||
self.local_graph_thr = local_graph_thr
|
||||
self.loss_module = build_loss(loss)
|
||||
self.loss_module = MODELS.build(loss)
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
|
||||
|
|
|
@ -5,12 +5,12 @@ import torch.nn as nn
|
|||
from mmcv.runner import BaseModule
|
||||
from mmdet.core import multi_apply
|
||||
|
||||
from mmocr.models.builder import HEADS
|
||||
from mmocr.registry import MODELS
|
||||
from ..postprocess.utils import poly_nms
|
||||
from .head_mixin import HeadMixin
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class FCEHead(HeadMixin, BaseModule):
|
||||
"""The class for implementing FCENet head.
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
|
||||
from mmocr.models.builder import HEADS, build_loss, build_postprocessor
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import check_argument
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class HeadMixin:
|
||||
"""Base head class for text detection, including loss calcalation and
|
||||
postprocess.
|
||||
|
@ -19,8 +19,8 @@ class HeadMixin:
|
|||
assert isinstance(loss, dict)
|
||||
assert isinstance(postprocessor, dict)
|
||||
|
||||
self.loss_module = build_loss(loss)
|
||||
self.postprocessor = build_postprocessor(postprocessor)
|
||||
self.loss_module = MODELS.build(loss)
|
||||
self.postprocessor = MODELS.build(postprocessor)
|
||||
|
||||
def resize_boundary(self, boundaries, scale_factor):
|
||||
"""Rescale boundaries via scale_factor.
|
||||
|
|
|
@ -6,12 +6,12 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.models.builder import HEADS
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import check_argument
|
||||
from .head_mixin import HeadMixin
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class PANHead(HeadMixin, BaseModule):
|
||||
"""The class for PANet head.
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import HEADS
|
||||
from mmocr.registry import MODELS
|
||||
from . import PANHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class PSEHead(PANHead):
|
||||
"""The class for PSENet head.
|
||||
|
||||
|
|
|
@ -4,11 +4,11 @@ import warnings
|
|||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.models.builder import HEADS
|
||||
from mmocr.registry import MODELS
|
||||
from .head_mixin import HeadMixin
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class TextSnakeHead(HeadMixin, BaseModule):
|
||||
"""The class for TextSnake head: TextSnake: A Flexible Representation for
|
||||
Detecting Text of Arbitrary Shapes.
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import DETECTORS
|
||||
from mmocr.registry import MODELS
|
||||
from .single_stage_text_detector import SingleStageTextDetector
|
||||
from .text_detector_mixin import TextDetectorMixin
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class DBNet(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for implementing DBNet text detector: Real-time Scene Text
|
||||
Detection with Differentiable Binarization.
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import DETECTORS
|
||||
from mmocr.registry import MODELS
|
||||
from .single_stage_text_detector import SingleStageTextDetector
|
||||
from .text_detector_mixin import TextDetectorMixin
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class DRRG(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for implementing DRRG text detector. Deep Relational Reasoning
|
||||
Graph Network for Arbitrary Shape Text Detection.
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import DETECTORS
|
||||
from mmocr.registry import MODELS
|
||||
from .single_stage_text_detector import SingleStageTextDetector
|
||||
from .text_detector_mixin import TextDetectorMixin
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class FCENet(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for implementing FCENet text detector
|
||||
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text
|
||||
|
|
|
@ -2,11 +2,11 @@
|
|||
from mmdet.models.detectors import MaskRCNN
|
||||
|
||||
from mmocr.core import seg2boundary
|
||||
from mmocr.models.builder import DETECTORS
|
||||
from mmocr.registry import MODELS
|
||||
from .text_detector_mixin import TextDetectorMixin
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class OCRMaskRCNN(TextDetectorMixin, MaskRCNN):
|
||||
"""Mask RCNN tailored for OCR."""
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import DETECTORS
|
||||
from mmocr.registry import MODELS
|
||||
from .single_stage_text_detector import SingleStageTextDetector
|
||||
from .text_detector_mixin import TextDetectorMixin
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class PANet(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for implementing PANet text detector:
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import DETECTORS
|
||||
from mmocr.registry import MODELS
|
||||
from .single_stage_text_detector import SingleStageTextDetector
|
||||
from .text_detector_mixin import TextDetectorMixin
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class PSENet(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for implementing PSENet text detector: Shape Robust Text
|
||||
Detection with Progressive Scale Expansion Network.
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmocr.models.builder import DETECTORS
|
||||
from mmocr.models.common.detectors import SingleStageDetector
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SingleStageTextDetector(SingleStageDetector):
|
||||
"""The class for implementing single stage text detector."""
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import DETECTORS
|
||||
from mmocr.registry import MODELS
|
||||
from .single_stage_text_detector import SingleStageTextDetector
|
||||
from .text_detector_mixin import TextDetectorMixin
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class TextSnake(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for implementing TextSnake text detector: TextSnake: A
|
||||
Flexible Representation for Detecting Text of Arbitrary Shapes.
|
||||
|
|
|
@ -3,11 +3,11 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from mmocr.models.builder import LOSSES
|
||||
from mmocr.models.common.losses.dice_loss import DiceLoss
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class DBLoss(nn.Module):
|
||||
"""The class for implementing DBNet loss.
|
||||
|
||||
|
|
|
@ -4,11 +4,11 @@ import torch.nn.functional as F
|
|||
from mmdet.core import BitmapMasks
|
||||
from torch import nn
|
||||
|
||||
from mmocr.models.builder import LOSSES
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import check_argument
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class DRRGLoss(nn.Module):
|
||||
"""The class for implementing DRRG loss. This is partially adapted from
|
||||
https://github.com/GXYM/DRRG licensed under the MIT license.
|
||||
|
|
|
@ -5,10 +5,10 @@ import torch.nn.functional as F
|
|||
from mmdet.core import multi_apply
|
||||
from torch import nn
|
||||
|
||||
from mmocr.models.builder import LOSSES
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class FCELoss(nn.Module):
|
||||
"""The class for implementing FCENet loss.
|
||||
|
||||
|
|
|
@ -8,11 +8,11 @@ import torch.nn.functional as F
|
|||
from mmdet.core import BitmapMasks
|
||||
from torch import nn
|
||||
|
||||
from mmocr.models.builder import LOSSES
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import check_argument
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class PANLoss(nn.Module):
|
||||
"""The class for implementing PANet loss. This was partially adapted from
|
||||
https://github.com/WenmuZhou/PAN.pytorch.
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdet.core import BitmapMasks
|
||||
|
||||
from mmocr.models.builder import LOSSES
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import check_argument
|
||||
from . import PANLoss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class PSELoss(PANLoss):
|
||||
r"""The class for implementing PSENet loss. This is partially adapted from
|
||||
https://github.com/whai362/PSENet.
|
||||
|
|
|
@ -4,11 +4,11 @@ import torch.nn.functional as F
|
|||
from mmdet.core import BitmapMasks
|
||||
from torch import nn
|
||||
|
||||
from mmocr.models.builder import LOSSES
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import check_argument
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class TextSnakeLoss(nn.Module):
|
||||
"""The class for implementing TextSnake loss. This is partially adapted
|
||||
from https://github.com/princewang1994/TextSnake.pytorch.
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch.nn.functional as F
|
|||
from mmcv.runner import BaseModule, ModuleList
|
||||
from torch import nn
|
||||
|
||||
from mmocr.models.builder import NECKS
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
class FPEM(BaseModule):
|
||||
|
@ -72,7 +72,7 @@ class SeparableConv2d(BaseModule):
|
|||
return x
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
@MODELS.register_module()
|
||||
class FPEM_FFM(BaseModule):
|
||||
"""This code is from https://github.com/WenmuZhou/PAN.pytorch.
|
||||
|
||||
|
|
|
@ -5,10 +5,10 @@ import torch.nn.functional as F
|
|||
from mmcv.cnn import ConvModule
|
||||
from mmcv.runner import BaseModule, ModuleList, Sequential, auto_fp16
|
||||
|
||||
from mmocr.models.builder import NECKS
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
@MODELS.register_module()
|
||||
class FPNC(BaseModule):
|
||||
"""FPN-like fusion module in Real-time Scene Text Detection with
|
||||
Differentiable Binarization.
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn.functional as F
|
|||
from mmcv.runner import BaseModule
|
||||
from torch import nn
|
||||
|
||||
from mmocr.models.builder import NECKS
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
class UpBlock(BaseModule):
|
||||
|
@ -30,7 +30,7 @@ class UpBlock(BaseModule):
|
|||
return x
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
@MODELS.register_module()
|
||||
class FPN_UNet(BaseModule):
|
||||
"""The class for implementing DRRG and TextSnake U-Net-like FPN.
|
||||
|
||||
|
|
|
@ -4,10 +4,10 @@ import torch.nn.functional as F
|
|||
from mmcv.cnn import ConvModule
|
||||
from mmcv.runner import BaseModule, ModuleList, auto_fp16
|
||||
|
||||
from mmocr.models.builder import NECKS
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
@MODELS.register_module()
|
||||
class FPNF(BaseModule):
|
||||
"""FPN-like fusion module in Shape Robust Text Detection with Progressive
|
||||
Scale Expansion Network.
|
||||
|
|
|
@ -3,12 +3,12 @@ import cv2
|
|||
import numpy as np
|
||||
|
||||
from mmocr.core import points2boundary
|
||||
from mmocr.models.builder import POSTPROCESSOR
|
||||
from mmocr.registry import MODELS
|
||||
from .base_postprocessor import BasePostprocessor
|
||||
from .utils import box_score_fast, unclip
|
||||
|
||||
|
||||
@POSTPROCESSOR.register_module()
|
||||
@MODELS.register_module()
|
||||
class DBPostprocessor(BasePostprocessor):
|
||||
"""Decoding predictions of DbNet to instances. This is partially adapted
|
||||
from https://github.com/MhLiao/DB.
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import POSTPROCESSOR
|
||||
from mmocr.registry import MODELS
|
||||
from .base_postprocessor import BasePostprocessor
|
||||
from .utils import (clusters2labels, comps2boundaries, connected_components,
|
||||
graph_propagation, remove_single)
|
||||
|
||||
|
||||
@POSTPROCESSOR.register_module()
|
||||
@MODELS.register_module()
|
||||
class DRRGPostprocessor(BasePostprocessor):
|
||||
"""Merge text components and construct boundaries of text instances.
|
||||
|
||||
|
|
|
@ -2,12 +2,12 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from mmocr.models.builder import POSTPROCESSOR
|
||||
from mmocr.registry import MODELS
|
||||
from .base_postprocessor import BasePostprocessor
|
||||
from .utils import fill_hole, fourier2poly, poly_nms
|
||||
|
||||
|
||||
@POSTPROCESSOR.register_module()
|
||||
@MODELS.register_module()
|
||||
class FCEPostprocessor(BasePostprocessor):
|
||||
"""Decoding predictions of FCENet to instances.
|
||||
|
||||
|
|
|
@ -5,11 +5,11 @@ import torch
|
|||
from mmcv.ops import pixel_group
|
||||
|
||||
from mmocr.core import points2boundary
|
||||
from mmocr.models.builder import POSTPROCESSOR
|
||||
from mmocr.registry import MODELS
|
||||
from .base_postprocessor import BasePostprocessor
|
||||
|
||||
|
||||
@POSTPROCESSOR.register_module()
|
||||
@MODELS.register_module()
|
||||
class PANPostprocessor(BasePostprocessor):
|
||||
"""Convert scores to quadrangles via post processing in PANet. This is
|
||||
partially adapted from https://github.com/WenmuZhou/PAN.pytorch.
|
||||
|
|
|
@ -6,11 +6,11 @@ import torch
|
|||
from mmcv.ops import contour_expand
|
||||
|
||||
from mmocr.core import points2boundary
|
||||
from mmocr.models.builder import POSTPROCESSOR
|
||||
from mmocr.registry import MODELS
|
||||
from .base_postprocessor import BasePostprocessor
|
||||
|
||||
|
||||
@POSTPROCESSOR.register_module()
|
||||
@MODELS.register_module()
|
||||
class PSEPostprocessor(BasePostprocessor):
|
||||
"""Decoding predictions of PSENet to instances. This is partially adapted
|
||||
from https://github.com/whai362/PSENet.
|
||||
|
|
|
@ -5,12 +5,12 @@ import numpy as np
|
|||
import torch
|
||||
from skimage.morphology import skeletonize
|
||||
|
||||
from mmocr.models.builder import POSTPROCESSOR
|
||||
from mmocr.registry import MODELS
|
||||
from .base_postprocessor import BasePostprocessor
|
||||
from .utils import centralize, fill_hole, merge_disks
|
||||
|
||||
|
||||
@POSTPROCESSOR.register_module()
|
||||
@MODELS.register_module()
|
||||
class TextSnakePostprocessor(BasePostprocessor):
|
||||
"""Decoding predictions of TextSnake to instances. This was partially
|
||||
adapted from https://github.com/princewang1994/TextSnake.pytorch.
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.models.builder import BACKBONES
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class NRTRModalityTransform(BaseModule):
|
||||
|
||||
def __init__(self,
|
||||
|
|
|
@ -3,11 +3,11 @@ from mmcv.cnn import ConvModule, build_plugin_layer
|
|||
from mmcv.runner import BaseModule, Sequential
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import BACKBONES
|
||||
from mmocr.models.textrecog.layers import BasicBlock
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class ResNet(BaseModule):
|
||||
"""
|
||||
Args:
|
||||
|
|
|
@ -3,11 +3,11 @@ import torch.nn as nn
|
|||
from mmcv.runner import BaseModule, Sequential
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import BACKBONES
|
||||
from mmocr.models.textrecog.layers import BasicBlock
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class ResNet31OCR(BaseModule):
|
||||
"""Implement ResNet backbone for text recognition, modified from
|
||||
`ResNet <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
|
|
@ -3,11 +3,11 @@ import torch.nn as nn
|
|||
from mmcv.runner import BaseModule, Sequential
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import BACKBONES
|
||||
from mmocr.models.textrecog.layers import BasicBlock
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class ResNetABI(BaseModule):
|
||||
"""Implement ResNet backbone for text recognition, modified from `ResNet.
|
||||
|
||||
|
|
|
@ -3,10 +3,10 @@ import torch.nn as nn
|
|||
from mmcv.cnn import ConvModule
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.models.builder import BACKBONES
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class ShallowCNN(BaseModule):
|
||||
"""Implement Shallow CNN block for SATRN.
|
||||
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule, Sequential
|
||||
|
||||
from mmocr.models.builder import BACKBONES
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class VeryDeepVgg(BaseModule):
|
||||
"""Implement VGG-VeryDeep backbone for text recognition, modified from
|
||||
`VGG-VeryDeep <https://arxiv.org/pdf/1409.1556.pdf>`_
|
||||
|
|
|
@ -2,11 +2,11 @@
|
|||
import torch
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import CONVERTORS
|
||||
from mmocr.registry import MODELS
|
||||
from .attn import AttnConvertor
|
||||
|
||||
|
||||
@CONVERTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class ABIConvertor(AttnConvertor):
|
||||
"""Convert between text, index and tensor for encoder-decoder based
|
||||
pipeline. Modified from AttnConvertor to get closer to ABINet's original
|
||||
|
|
|
@ -2,11 +2,11 @@
|
|||
import torch
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import CONVERTORS
|
||||
from mmocr.registry import MODELS
|
||||
from .base import BaseConvertor
|
||||
|
||||
|
||||
@CONVERTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class AttnConvertor(BaseConvertor):
|
||||
"""Convert between text, index and tensor for encoder-decoder based
|
||||
pipeline.
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import CONVERTORS
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import list_from_file
|
||||
|
||||
|
||||
@CONVERTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class BaseConvertor:
|
||||
"""Convert between text, index and tensor for text recognize pipeline.
|
||||
|
||||
|
|
|
@ -5,11 +5,11 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import CONVERTORS
|
||||
from mmocr.registry import MODELS
|
||||
from .base import BaseConvertor
|
||||
|
||||
|
||||
@CONVERTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class CTCConvertor(BaseConvertor):
|
||||
"""Convert between text, index and tensor for CTC loss-based pipeline.
|
||||
|
||||
|
|
|
@ -4,11 +4,11 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import CONVERTORS
|
||||
from mmocr.registry import MODELS
|
||||
from .base import BaseConvertor
|
||||
|
||||
|
||||
@CONVERTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SegConvertor(BaseConvertor):
|
||||
"""Convert between text, index and tensor for segmentation based pipeline.
|
||||
|
||||
|
|
|
@ -6,12 +6,12 @@ import torch.nn as nn
|
|||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
||||
from mmcv.runner import ModuleList
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.models.common.modules import PositionalEncoding
|
||||
from mmocr.registry import MODELS
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class ABILanguageDecoder(BaseDecoder):
|
||||
r"""Transformer-based language model responsible for spell correction.
|
||||
Implementation of language model of \
|
||||
|
|
|
@ -3,12 +3,12 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.models.common.modules import PositionalEncoding
|
||||
from mmocr.registry import MODELS
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class ABIVisionDecoder(BaseDecoder):
|
||||
"""Converts visual features into text characters.
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class BaseDecoder(BaseModule):
|
||||
"""Base decoder class for text recognition."""
|
||||
|
||||
|
|
|
@ -2,12 +2,12 @@
|
|||
import torch.nn as nn
|
||||
from mmcv.runner import Sequential
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.models.textrecog.layers import BidirectionalLSTM
|
||||
from mmocr.registry import MODELS
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class CRNNDecoder(BaseDecoder):
|
||||
"""Decoder for CRNN.
|
||||
|
||||
|
|
|
@ -8,8 +8,8 @@ import torch.nn.functional as F
|
|||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
||||
from mmcv.runner import ModuleList
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.models.common.modules import PositionalEncoding
|
||||
from mmocr.registry import MODELS
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
|
@ -30,7 +30,7 @@ class Embeddings(nn.Module):
|
|||
return self.lut(x) * math.sqrt(self.d_model)
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class MasterDecoder(BaseDecoder):
|
||||
"""Decoder module in `MASTER <https://arxiv.org/abs/1910.02562>`_.
|
||||
|
||||
|
|
|
@ -6,12 +6,12 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from mmcv.runner import ModuleList
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.models.common import PositionalEncoding, TFDecoderLayer
|
||||
from mmocr.registry import MODELS
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class NRTRDecoder(BaseDecoder):
|
||||
"""Transformer Decoder block with self attention mechanism.
|
||||
|
||||
|
|
|
@ -4,13 +4,13 @@ import math
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.models.textrecog.layers import (DotProductAttentionLayer,
|
||||
PositionAwareLayer)
|
||||
from mmocr.registry import MODELS
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class PositionAttentionDecoder(BaseDecoder):
|
||||
"""Position attention decoder for RobustScanner.
|
||||
|
||||
|
|
|
@ -3,12 +3,12 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmocr.models.builder import DECODERS, build_decoder
|
||||
from mmocr.models.textrecog.layers import RobustScannerFusionLayer
|
||||
from mmocr.registry import MODELS
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class RobustScannerDecoder(BaseDecoder):
|
||||
"""Decoder for RobustScanner.
|
||||
|
||||
|
@ -72,7 +72,7 @@ class RobustScannerDecoder(BaseDecoder):
|
|||
hybrid_decoder.update(encode_value=self.encode_value)
|
||||
hybrid_decoder.update(return_feature=True)
|
||||
|
||||
self.hybrid_decoder = build_decoder(hybrid_decoder)
|
||||
self.hybrid_decoder = MODELS.build(hybrid_decoder)
|
||||
|
||||
# init position decoder
|
||||
position_decoder.update(num_classes=self.num_classes)
|
||||
|
@ -83,7 +83,7 @@ class RobustScannerDecoder(BaseDecoder):
|
|||
position_decoder.update(encode_value=self.encode_value)
|
||||
position_decoder.update(return_feature=True)
|
||||
|
||||
self.position_decoder = build_decoder(position_decoder)
|
||||
self.position_decoder = MODELS.build(position_decoder)
|
||||
|
||||
self.fusion_module = RobustScannerFusionLayer(
|
||||
self.dim_model if encode_value else dim_input)
|
||||
|
|
|
@ -6,11 +6,11 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.registry import MODELS
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class ParallelSARDecoder(BaseDecoder):
|
||||
"""Implementation Parallel Decoder module in `SAR.
|
||||
|
||||
|
@ -255,7 +255,7 @@ class ParallelSARDecoder(BaseDecoder):
|
|||
return outputs
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SequentialSARDecoder(BaseDecoder):
|
||||
"""Implementation Sequential Decoder module in `SAR.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.registry import MODELS
|
||||
from . import ParallelSARDecoder
|
||||
|
||||
|
||||
|
@ -31,7 +31,7 @@ class DecodeNode:
|
|||
return accu_score
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class ParallelSARDecoderWithBS(ParallelSARDecoder):
|
||||
"""Parallel Decoder module with beam-search in SAR.
|
||||
|
||||
|
|
|
@ -5,12 +5,12 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.models.textrecog.layers import DotProductAttentionLayer
|
||||
from mmocr.registry import MODELS
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SequenceAttentionDecoder(BaseDecoder):
|
||||
"""Sequence attention decoder for RobustScanner.
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import ENCODERS, build_decoder, build_encoder
|
||||
from mmocr.registry import MODELS
|
||||
from .base_encoder import BaseEncoder
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class ABIVisionModel(BaseEncoder):
|
||||
"""A wrapper of visual feature encoder and language token decoder that
|
||||
converts visual features into text tokens.
|
||||
|
@ -23,8 +23,8 @@ class ABIVisionModel(BaseEncoder):
|
|||
init_cfg=dict(type='Xavier', layer='Conv2d'),
|
||||
**kwargs):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.encoder = build_encoder(encoder)
|
||||
self.decoder = build_decoder(decoder)
|
||||
self.encoder = MODELS.build(encoder)
|
||||
self.decoder = MODELS.build(decoder)
|
||||
|
||||
def forward(self, feat, img_metas=None):
|
||||
"""
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.models.builder import ENCODERS
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class BaseEncoder(BaseModule):
|
||||
"""Base Encoder class for text recognition."""
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
|
||||
from mmocr.models.builder import ENCODERS
|
||||
from mmocr.registry import MODELS
|
||||
from .base_encoder import BaseEncoder
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class ChannelReductionEncoder(BaseEncoder):
|
||||
"""Change the channel number with a one by one convoluational layer.
|
||||
|
||||
|
|
|
@ -4,12 +4,12 @@ import math
|
|||
import torch.nn as nn
|
||||
from mmcv.runner import ModuleList
|
||||
|
||||
from mmocr.models.builder import ENCODERS
|
||||
from mmocr.models.common import TFEncoderLayer
|
||||
from mmocr.registry import MODELS
|
||||
from .base_encoder import BaseEncoder
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class NRTREncoder(BaseEncoder):
|
||||
"""Transformer Encoder block with self attention mechanism.
|
||||
|
||||
|
|
|
@ -6,11 +6,11 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import ENCODERS
|
||||
from mmocr.registry import MODELS
|
||||
from .base_encoder import BaseEncoder
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SAREncoder(BaseEncoder):
|
||||
"""Implementation of encoder module in `SAR.
|
||||
|
||||
|
|
|
@ -4,13 +4,13 @@ import math
|
|||
import torch.nn as nn
|
||||
from mmcv.runner import ModuleList
|
||||
|
||||
from mmocr.models.builder import ENCODERS
|
||||
from mmocr.models.textrecog.layers import (Adaptive2DPositionalEncoding,
|
||||
SatrnEncoderLayer)
|
||||
from mmocr.registry import MODELS
|
||||
from .base_encoder import BaseEncoder
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SatrnEncoder(BaseEncoder):
|
||||
"""Implement encoder for SATRN, see `SATRN.
|
||||
|
||||
|
|
|
@ -4,11 +4,11 @@ import copy
|
|||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
||||
from mmcv.runner import BaseModule, ModuleList
|
||||
|
||||
from mmocr.models.builder import ENCODERS
|
||||
from mmocr.models.common.modules import PositionalEncoding
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class TransformerEncoder(BaseModule):
|
||||
"""Implement transformer encoder for text recognition, modified from
|
||||
`<https://github.com/FangShancheng/ABINet>`.
|
||||
|
|
|
@ -3,10 +3,10 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.models.builder import FUSERS
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@FUSERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class ABIFuser(BaseModule):
|
||||
"""Mix and align visual feature and linguistic feature Implementation of
|
||||
language model of `ABINet <https://arxiv.org/abs/1910.04396>`_.
|
||||
|
|
|
@ -4,10 +4,10 @@ from mmcv.cnn import ConvModule
|
|||
from mmcv.runner import BaseModule
|
||||
from torch import nn
|
||||
|
||||
from mmocr.models.builder import HEADS
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SegHead(BaseModule):
|
||||
"""Head for segmentation based text recognition.
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
|
||||
from mmocr.models.builder import LOSSES
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class CELoss(nn.Module):
|
||||
"""Implementation of loss module for encoder-decoder based text recognition
|
||||
method with CrossEntropy loss.
|
||||
|
@ -63,7 +63,7 @@ class CELoss(nn.Module):
|
|||
return losses
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class SARLoss(CELoss):
|
||||
"""Implementation of loss module in `SAR.
|
||||
|
||||
|
@ -95,7 +95,7 @@ class SARLoss(CELoss):
|
|||
return outputs, targets
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class TFLoss(CELoss):
|
||||
"""Implementation of loss module for transformer.
|
||||
|
||||
|
|
|
@ -4,10 +4,10 @@ import math
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmocr.models.builder import LOSSES
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class CTCLoss(nn.Module):
|
||||
"""Implementation of loss module for CTC-loss based text recognition.
|
||||
|
||||
|
|
|
@ -3,10 +3,10 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmocr.models.builder import LOSSES
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class ABILoss(nn.Module):
|
||||
"""Implementation of ABINet multiloss that allows mixing different types of
|
||||
losses with weights.
|
||||
|
|
|
@ -3,10 +3,10 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmocr.models.builder import LOSSES
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class SegLoss(nn.Module):
|
||||
"""Implementation of loss module for segmentation based text recognition
|
||||
method.
|
||||
|
|
|
@ -4,10 +4,10 @@ import torch.nn.functional as F
|
|||
from mmcv.cnn import ConvModule
|
||||
from mmcv.runner import BaseModule, ModuleList
|
||||
|
||||
from mmocr.models.builder import NECKS
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
@MODELS.register_module()
|
||||
class FPNOCR(BaseModule):
|
||||
"""FPN-like Network for segmentation based text recognition.
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.models.builder import PREPROCESSOR
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@PREPROCESSOR.register_module()
|
||||
@MODELS.register_module()
|
||||
class BasePreprocessor(BaseModule):
|
||||
"""Base Preprocessor class for text recognition."""
|
||||
|
||||
|
|
|
@ -17,11 +17,11 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmocr.models.builder import PREPROCESSOR
|
||||
from mmocr.registry import MODELS
|
||||
from .base_preprocessor import BasePreprocessor
|
||||
|
||||
|
||||
@PREPROCESSOR.register_module()
|
||||
@MODELS.register_module()
|
||||
class TPSPreprocessor(BasePreprocessor):
|
||||
"""Rectification Network of RARE, namely TPS based STN in
|
||||
https://arxiv.org/pdf/1603.03915.pdf.
|
||||
|
|
|
@ -3,13 +3,11 @@ import warnings
|
|||
|
||||
import torch
|
||||
|
||||
from mmocr.models.builder import (RECOGNIZERS, build_backbone, build_convertor,
|
||||
build_decoder, build_encoder, build_fuser,
|
||||
build_loss, build_preprocessor)
|
||||
from mmocr.registry import MODELS
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
|
||||
|
||||
@RECOGNIZERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class ABINet(EncodeDecodeRecognizer):
|
||||
"""Implementation of `Read Like Humans: Autonomous, Bidirectional and
|
||||
Iterative LanguageModeling for Scene Text Recognition.
|
||||
|
@ -36,21 +34,21 @@ class ABINet(EncodeDecodeRecognizer):
|
|||
# Label convertor (str2tensor, tensor2str)
|
||||
assert label_convertor is not None
|
||||
label_convertor.update(max_seq_len=max_seq_len)
|
||||
self.label_convertor = build_convertor(label_convertor)
|
||||
self.label_convertor = MODELS.build(label_convertor)
|
||||
|
||||
# Preprocessor module, e.g., TPS
|
||||
self.preprocessor = None
|
||||
if preprocessor is not None:
|
||||
self.preprocessor = build_preprocessor(preprocessor)
|
||||
self.preprocessor = MODELS.build(preprocessor)
|
||||
|
||||
# Backbone
|
||||
assert backbone is not None
|
||||
self.backbone = build_backbone(backbone)
|
||||
self.backbone = MODELS.build(backbone)
|
||||
|
||||
# Encoder module
|
||||
self.encoder = None
|
||||
if encoder is not None:
|
||||
self.encoder = build_encoder(encoder)
|
||||
self.encoder = MODELS.build(encoder)
|
||||
|
||||
# Decoder module
|
||||
self.decoder = None
|
||||
|
@ -59,11 +57,11 @@ class ABINet(EncodeDecodeRecognizer):
|
|||
decoder.update(start_idx=self.label_convertor.start_idx)
|
||||
decoder.update(padding_idx=self.label_convertor.padding_idx)
|
||||
decoder.update(max_seq_len=max_seq_len)
|
||||
self.decoder = build_decoder(decoder)
|
||||
self.decoder = MODELS.build(decoder)
|
||||
|
||||
# Loss
|
||||
assert loss is not None
|
||||
self.loss = build_loss(loss)
|
||||
self.loss = MODELS.build(loss)
|
||||
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
|
@ -78,7 +76,7 @@ class ABINet(EncodeDecodeRecognizer):
|
|||
|
||||
self.fuser = None
|
||||
if fuser is not None:
|
||||
self.fuser = build_fuser(fuser)
|
||||
self.fuser = MODELS.build(fuser)
|
||||
|
||||
def forward_train(self, img, img_metas):
|
||||
"""
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import RECOGNIZERS
|
||||
from mmocr.registry import MODELS
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
|
||||
|
||||
@RECOGNIZERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class CRNNNet(EncodeDecodeRecognizer):
|
||||
"""CTC-loss based recognizer."""
|
||||
|
|
|
@ -3,13 +3,11 @@ import warnings
|
|||
|
||||
import torch
|
||||
|
||||
from mmocr.models.builder import (RECOGNIZERS, build_backbone, build_convertor,
|
||||
build_decoder, build_encoder, build_loss,
|
||||
build_preprocessor)
|
||||
from mmocr.registry import MODELS
|
||||
from .base import BaseRecognizer
|
||||
|
||||
|
||||
@RECOGNIZERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class EncodeDecodeRecognizer(BaseRecognizer):
|
||||
"""Base class for encode-decode recognizer."""
|
||||
|
||||
|
@ -31,21 +29,21 @@ class EncodeDecodeRecognizer(BaseRecognizer):
|
|||
# Label convertor (str2tensor, tensor2str)
|
||||
assert label_convertor is not None
|
||||
label_convertor.update(max_seq_len=max_seq_len)
|
||||
self.label_convertor = build_convertor(label_convertor)
|
||||
self.label_convertor = MODELS.build(label_convertor)
|
||||
|
||||
# Preprocessor module, e.g., TPS
|
||||
self.preprocessor = None
|
||||
if preprocessor is not None:
|
||||
self.preprocessor = build_preprocessor(preprocessor)
|
||||
self.preprocessor = MODELS.build(preprocessor)
|
||||
|
||||
# Backbone
|
||||
assert backbone is not None
|
||||
self.backbone = build_backbone(backbone)
|
||||
self.backbone = MODELS.build(backbone)
|
||||
|
||||
# Encoder module
|
||||
self.encoder = None
|
||||
if encoder is not None:
|
||||
self.encoder = build_encoder(encoder)
|
||||
self.encoder = MODELS.build(encoder)
|
||||
|
||||
# Decoder module
|
||||
assert decoder is not None
|
||||
|
@ -53,12 +51,12 @@ class EncodeDecodeRecognizer(BaseRecognizer):
|
|||
decoder.update(start_idx=self.label_convertor.start_idx)
|
||||
decoder.update(padding_idx=self.label_convertor.padding_idx)
|
||||
decoder.update(max_seq_len=max_seq_len)
|
||||
self.decoder = build_decoder(decoder)
|
||||
self.decoder = MODELS.build(decoder)
|
||||
|
||||
# Loss
|
||||
assert loss is not None
|
||||
loss.update(ignore_index=self.label_convertor.padding_idx)
|
||||
self.loss = build_loss(loss)
|
||||
self.loss = MODELS.build(loss)
|
||||
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import DETECTORS
|
||||
from mmocr.registry import MODELS
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class MASTER(EncodeDecodeRecognizer):
|
||||
"""Implementation of `MASTER <https://arxiv.org/abs/1910.02562>`_"""
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import RECOGNIZERS
|
||||
from mmocr.registry import MODELS
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
|
||||
|
||||
@RECOGNIZERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class NRTR(EncodeDecodeRecognizer):
|
||||
"""Implementation of `NRTR <https://arxiv.org/pdf/1806.00926.pdf>`_"""
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import RECOGNIZERS
|
||||
from mmocr.registry import MODELS
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
|
||||
|
||||
@RECOGNIZERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class RobustScanner(EncodeDecodeRecognizer):
|
||||
"""Implementation of `RobustScanner.
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import RECOGNIZERS
|
||||
from mmocr.registry import MODELS
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
|
||||
|
||||
@RECOGNIZERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SARNet(EncodeDecodeRecognizer):
|
||||
"""Implementation of `SAR <https://arxiv.org/abs/1811.00751>`_"""
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmocr.models.builder import RECOGNIZERS
|
||||
from mmocr.registry import MODELS
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
|
||||
|
||||
@RECOGNIZERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SATRN(EncodeDecodeRecognizer):
|
||||
"""Implementation of `SATRN <https://arxiv.org/abs/1910.04396>`_"""
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmocr.models.builder import (RECOGNIZERS, build_backbone, build_convertor,
|
||||
build_head, build_loss, build_neck,
|
||||
build_preprocessor)
|
||||
from mmocr.registry import MODELS
|
||||
from .base import BaseRecognizer
|
||||
|
||||
|
||||
@RECOGNIZERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SegRecognizer(BaseRecognizer):
|
||||
"""Base class for segmentation based recognizer."""
|
||||
|
||||
|
@ -26,29 +24,29 @@ class SegRecognizer(BaseRecognizer):
|
|||
|
||||
# Label_convertor
|
||||
assert label_convertor is not None
|
||||
self.label_convertor = build_convertor(label_convertor)
|
||||
self.label_convertor = MODELS.build(label_convertor)
|
||||
|
||||
# Preprocessor module, e.g., TPS
|
||||
self.preprocessor = None
|
||||
if preprocessor is not None:
|
||||
self.preprocessor = build_preprocessor(preprocessor)
|
||||
self.preprocessor = MODELS.build(preprocessor)
|
||||
|
||||
# Backbone
|
||||
assert backbone is not None
|
||||
self.backbone = build_backbone(backbone)
|
||||
self.backbone = MODELS.build(backbone)
|
||||
|
||||
# Neck
|
||||
assert neck is not None
|
||||
self.neck = build_neck(neck)
|
||||
self.neck = MODELS.build(neck)
|
||||
|
||||
# Head
|
||||
assert head is not None
|
||||
head.update(num_classes=self.label_convertor.num_classes())
|
||||
self.head = build_head(head)
|
||||
self.head = MODELS.build(head)
|
||||
|
||||
# Loss
|
||||
assert loss is not None
|
||||
self.loss = build_loss(loss)
|
||||
self.loss = MODELS.build(loss)
|
||||
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
|
|
|
@ -24,9 +24,9 @@ from mmocr.apis.inference import model_inference
|
|||
from mmocr.core.visualize import det_recog_show_result
|
||||
from mmocr.datasets.kie_dataset import KIEDataset
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.models import build_detector
|
||||
from mmocr.models.textdet.detectors import TextDetectorMixin
|
||||
from mmocr.models.textrecog.recognizer import BaseRecognizer
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import is_type_list
|
||||
from mmocr.utils.box_util import stitch_boxes_into_lines
|
||||
from mmocr.utils.fileio import list_from_file
|
||||
|
@ -427,7 +427,7 @@ class MMOCR:
|
|||
'kie/' + kie_models[self.kie]['ckpt']
|
||||
|
||||
kie_cfg = Config.fromfile(kie_config)
|
||||
self.kie_model = build_detector(
|
||||
self.kie_model = MODELS.build(
|
||||
kie_cfg.model, test_cfg=kie_cfg.get('test_cfg'))
|
||||
self.kie_model = revert_sync_batchnorm(self.kie_model)
|
||||
self.kie_model.cfg = kie_cfg
|
||||
|
|
Loading…
Reference in New Issue