mirror of https://github.com/open-mmlab/mmocr.git
[Refactor] union to MODELS
parent
3f24e34a5d
commit
23458f8a47
|
@ -3,7 +3,6 @@ from argparse import ArgumentParser
|
||||||
|
|
||||||
from mmocr.apis import init_detector
|
from mmocr.apis import init_detector
|
||||||
from mmocr.apis.inference import text_model_inference
|
from mmocr.apis.inference import text_model_inference
|
||||||
from mmocr.models import build_detector # NOQA
|
|
||||||
from mmocr.registry import DATASETS # NOQA
|
from mmocr.registry import DATASETS # NOQA
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ import cv2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmocr.apis import init_detector, model_inference
|
from mmocr.apis import init_detector, model_inference
|
||||||
from mmocr.models import build_detector # noqa: F401
|
|
||||||
from mmocr.registry import DATASETS # noqa: F401
|
from mmocr.registry import DATASETS # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ from mmdet.core import get_classes
|
||||||
from mmdet.datasets import replace_ImageToTensor
|
from mmdet.datasets import replace_ImageToTensor
|
||||||
from mmdet.datasets.pipelines import Compose
|
from mmdet.datasets.pipelines import Compose
|
||||||
|
|
||||||
from mmocr.models import build_detector
|
from mmocr.registry import MODELS
|
||||||
from mmocr.utils import is_2dlist
|
from mmocr.utils import is_2dlist
|
||||||
from .utils import disable_text_recog_aug_test
|
from .utils import disable_text_recog_aug_test
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
|
||||||
if config.model.get('pretrained'):
|
if config.model.get('pretrained'):
|
||||||
config.model.pretrained = None
|
config.model.pretrained = None
|
||||||
config.model.train_cfg = None
|
config.model.train_cfg = None
|
||||||
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
|
model = MODELS.build(config.model, test_cfg=config.get('test_cfg'))
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||||
if 'CLASSES' in checkpoint.get('meta', {}):
|
if 'CLASSES' in checkpoint.get('meta', {}):
|
||||||
|
|
|
@ -5,7 +5,6 @@ from typing import Any, Iterable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mmdet.models.builder import DETECTORS
|
|
||||||
|
|
||||||
from mmocr.models.textdet.detectors.single_stage_text_detector import \
|
from mmocr.models.textdet.detectors.single_stage_text_detector import \
|
||||||
SingleStageTextDetector
|
SingleStageTextDetector
|
||||||
|
@ -13,6 +12,7 @@ from mmocr.models.textdet.detectors.text_detector_mixin import \
|
||||||
TextDetectorMixin
|
TextDetectorMixin
|
||||||
from mmocr.models.textrecog.recognizer.encode_decode_recognizer import \
|
from mmocr.models.textrecog.recognizer.encode_decode_recognizer import \
|
||||||
EncodeDecodeRecognizer
|
EncodeDecodeRecognizer
|
||||||
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
def inference_with_session(sess, io_binding, input_name, output_names,
|
def inference_with_session(sess, io_binding, input_name, output_names,
|
||||||
|
@ -34,7 +34,7 @@ def inference_with_session(sess, io_binding, input_name, output_names,
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
|
|
||||||
@DETECTORS.register_module()
|
@MODELS.register_module()
|
||||||
class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector):
|
class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector):
|
||||||
"""The class for evaluating onnx file of detection."""
|
"""The class for evaluating onnx file of detection."""
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector):
|
||||||
return boundaries
|
return boundaries
|
||||||
|
|
||||||
|
|
||||||
@DETECTORS.register_module()
|
@MODELS.register_module()
|
||||||
class ONNXRuntimeRecognizer(EncodeDecodeRecognizer):
|
class ONNXRuntimeRecognizer(EncodeDecodeRecognizer):
|
||||||
"""The class for evaluating onnx file of recognition."""
|
"""The class for evaluating onnx file of recognition."""
|
||||||
|
|
||||||
|
@ -201,7 +201,7 @@ class ONNXRuntimeRecognizer(EncodeDecodeRecognizer):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@DETECTORS.register_module()
|
@MODELS.register_module()
|
||||||
class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector):
|
class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector):
|
||||||
"""The class for evaluating TensorRT file of detection."""
|
"""The class for evaluating TensorRT file of detection."""
|
||||||
|
|
||||||
|
@ -257,7 +257,7 @@ class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector):
|
||||||
return boundaries
|
return boundaries
|
||||||
|
|
||||||
|
|
||||||
@DETECTORS.register_module()
|
@MODELS.register_module()
|
||||||
class TensorRTRecognizer(EncodeDecodeRecognizer):
|
class TensorRTRecognizer(EncodeDecodeRecognizer):
|
||||||
"""The class for evaluating TensorRT file of recognition."""
|
"""The class for evaluating TensorRT file of recognition."""
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmocr.models.builder import build_convertor
|
from mmocr.registry import MODELS, TRANSFORMS
|
||||||
from mmocr.registry import TRANSFORMS
|
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
|
@ -18,7 +17,7 @@ class NerTransform:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, label_convertor, max_len):
|
def __init__(self, label_convertor, max_len):
|
||||||
self.label_convertor = build_convertor(label_convertor)
|
self.label_convertor = MODELS.build(label_convertor)
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
|
|
||||||
def __call__(self, results):
|
def __call__(self, results):
|
||||||
|
|
|
@ -4,8 +4,7 @@ import numpy as np
|
||||||
from mmdet.core import BitmapMasks
|
from mmdet.core import BitmapMasks
|
||||||
|
|
||||||
import mmocr.utils.check_argument as check_argument
|
import mmocr.utils.check_argument as check_argument
|
||||||
from mmocr.models.builder import build_convertor
|
from mmocr.registry import MODELS, TRANSFORMS
|
||||||
from mmocr.registry import TRANSFORMS
|
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
|
@ -41,7 +40,7 @@ class OCRSegTargets:
|
||||||
|
|
||||||
self.attn_shrink_ratio = attn_shrink_ratio
|
self.attn_shrink_ratio = attn_shrink_ratio
|
||||||
self.seg_shrink_ratio = seg_shrink_ratio
|
self.seg_shrink_ratio = seg_shrink_ratio
|
||||||
self.label_convertor = build_convertor(label_convertor)
|
self.label_convertor = MODELS.build(label_convertor)
|
||||||
self.box_type = box_type
|
self.box_type = box_type
|
||||||
self.pad_val = pad_val
|
self.pad_val = pad_val
|
||||||
|
|
||||||
|
|
|
@ -1,19 +1,9 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from . import common, kie, textdet, textrecog
|
from . import common, kie, textdet, textrecog
|
||||||
from .builder import (BACKBONES, CONVERTORS, DECODERS, DETECTORS, ENCODERS,
|
|
||||||
HEADS, LOSSES, NECKS, PREPROCESSOR, build_backbone,
|
|
||||||
build_convertor, build_decoder, build_detector,
|
|
||||||
build_encoder, build_loss, build_preprocessor)
|
|
||||||
from .common import * # NOQA
|
from .common import * # NOQA
|
||||||
from .kie import * # NOQA
|
from .kie import * # NOQA
|
||||||
from .ner import * # NOQA
|
from .ner import * # NOQA
|
||||||
from .textdet import * # NOQA
|
from .textdet import * # NOQA
|
||||||
from .textrecog import * # NOQA
|
from .textrecog import * # NOQA
|
||||||
|
|
||||||
__all__ = [
|
__all__ = common.__all__ + kie.__all__ + textdet.__all__ + textrecog.__all__
|
||||||
'BACKBONES', 'DETECTORS', 'HEADS', 'LOSSES', 'NECKS', 'build_backbone',
|
|
||||||
'build_detector', 'build_loss', 'CONVERTORS', 'ENCODERS', 'DECODERS',
|
|
||||||
'PREPROCESSOR', 'build_convertor', 'build_encoder', 'build_decoder',
|
|
||||||
'build_preprocessor'
|
|
||||||
]
|
|
||||||
__all__ += common.__all__ + kie.__all__ + textdet.__all__ + textrecog.__all__
|
|
||||||
|
|
|
@ -1,115 +1,11 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import warnings
|
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from mmcv.cnn import ACTIVATION_LAYERS as MMCV_ACTIVATION_LAYERS
|
||||||
|
from mmcv.cnn import UPSAMPLE_LAYERS as MMCV_UPSAMPLE_LAYERS
|
||||||
|
from mmcv.utils import Registry, build_from_cfg
|
||||||
|
|
||||||
from mmocr.registry import MODELS
|
UPSAMPLE_LAYERS = Registry('upsample layer', parent=MMCV_UPSAMPLE_LAYERS)
|
||||||
|
ACTIVATION_LAYERS = Registry('activation layer', parent=MMCV_ACTIVATION_LAYERS)
|
||||||
CONVERTORS = MODELS
|
|
||||||
ENCODERS = MODELS
|
|
||||||
DECODERS = MODELS
|
|
||||||
PREPROCESSOR = MODELS
|
|
||||||
POSTPROCESSOR = MODELS
|
|
||||||
|
|
||||||
UPSAMPLE_LAYERS = MODELS
|
|
||||||
BACKBONES = MODELS
|
|
||||||
LOSSES = MODELS
|
|
||||||
DETECTORS = MODELS
|
|
||||||
ROI_EXTRACTORS = MODELS
|
|
||||||
HEADS = MODELS
|
|
||||||
NECKS = MODELS
|
|
||||||
FUSERS = MODELS
|
|
||||||
RECOGNIZERS = MODELS
|
|
||||||
|
|
||||||
ACTIVATION_LAYERS = MODELS
|
|
||||||
|
|
||||||
|
|
||||||
def build_recognizer(cfg, train_cfg=None, test_cfg=None):
|
|
||||||
"""Build recognizer."""
|
|
||||||
warnings.warn('``build_recognizer`` would be deprecated soon, please use '
|
|
||||||
'``mmocr.registry.MODELS.build()`` ')
|
|
||||||
return RECOGNIZERS(
|
|
||||||
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
|
|
||||||
|
|
||||||
|
|
||||||
def build_convertor(cfg):
|
|
||||||
"""Build label convertor for scene text recognizer."""
|
|
||||||
warnings.warn('``build_convertor`` would be deprecated soon, please use '
|
|
||||||
'``mmocr.registry.MODELS.build()`` ')
|
|
||||||
return CONVERTORS.build(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def build_encoder(cfg):
|
|
||||||
"""Build encoder for scene text recognizer."""
|
|
||||||
warnings.warn('``build_encoder`` would be deprecated soon, please use '
|
|
||||||
'``mmocr.registry.MODELS.build()`` ')
|
|
||||||
return ENCODERS.build(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def build_decoder(cfg):
|
|
||||||
"""Build decoder for scene text recognizer."""
|
|
||||||
warnings.warn('``build_decoder`` would be deprecated soon, please use '
|
|
||||||
'``mmocr.registry.MODELS.build()`` ')
|
|
||||||
return DECODERS.build(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def build_preprocessor(cfg):
|
|
||||||
"""Build preprocessor for scene text recognizer."""
|
|
||||||
warnings.warn(
|
|
||||||
'``build_preprocessor`` would be deprecated soon, please use '
|
|
||||||
'``mmocr.registry.MODELS.build()`` ')
|
|
||||||
return PREPROCESSOR(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def build_postprocessor(cfg):
|
|
||||||
"""Build postprocessor for scene text detector."""
|
|
||||||
warnings.warn(
|
|
||||||
'``build_postprocessor`` would be deprecated soon, please use '
|
|
||||||
'``mmocr.registry.MODELS.build()`` ')
|
|
||||||
return POSTPROCESSOR.build(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def build_roi_extractor(cfg):
|
|
||||||
"""Build roi extractor."""
|
|
||||||
warnings.warn(
|
|
||||||
'``build_roi_extractor`` would be deprecated soon, please use '
|
|
||||||
'``mmocr.registry.MODELS.build()`` ')
|
|
||||||
return ROI_EXTRACTORS.build(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def build_loss(cfg):
|
|
||||||
"""Build loss."""
|
|
||||||
warnings.warn('``build_loss`` would be deprecated soon, please use '
|
|
||||||
'``mmocr.registry.MODELS.build()`` ')
|
|
||||||
return LOSSES.build(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def build_backbone(cfg):
|
|
||||||
"""Build backbone."""
|
|
||||||
warnings.warn('``build_backbone`` would be deprecated soon, please use '
|
|
||||||
'``mmocr.registry.MODELS.build()`` ')
|
|
||||||
return BACKBONES.build(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def build_head(cfg):
|
|
||||||
"""Build head."""
|
|
||||||
warnings.warn('``build_head`` would be deprecated soon, please use '
|
|
||||||
'``mmocr.registry.MODELS.build()`` ')
|
|
||||||
return HEADS.build(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def build_neck(cfg):
|
|
||||||
"""Build neck."""
|
|
||||||
warnings.warn('``build_neck`` would be deprecated soon, please use '
|
|
||||||
'``mmocr.registry.MODELS.build()`` ')
|
|
||||||
return NECKS.build(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def build_fuser(cfg):
|
|
||||||
"""Build fuser."""
|
|
||||||
warnings.warn('``build_fuser`` would be deprecated soon, please use '
|
|
||||||
'``mmocr.registry.MODELS.build()`` ')
|
|
||||||
return FUSERS.build(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def build_upsample_layer(cfg, *args, **kwargs):
|
def build_upsample_layer(cfg, *args, **kwargs):
|
||||||
|
@ -160,21 +56,4 @@ def build_activation_layer(cfg):
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: Created activation layer.
|
nn.Module: Created activation layer.
|
||||||
"""
|
"""
|
||||||
warnings.warn(
|
return build_from_cfg(cfg, ACTIVATION_LAYERS)
|
||||||
'``build_activation_layer`` would be deprecated soon, please use '
|
|
||||||
'``mmocr.registry.MODELS.build()`` ')
|
|
||||||
return ACTIVATION_LAYERS.build(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def build_detector(cfg, train_cfg=None, test_cfg=None):
|
|
||||||
"""Build detector."""
|
|
||||||
if train_cfg is not None or test_cfg is not None:
|
|
||||||
warnings.warn(
|
|
||||||
'train_cfg and test_cfg is deprecated, '
|
|
||||||
'please specify them in model', UserWarning)
|
|
||||||
assert cfg.get('train_cfg') is None or train_cfg is None, \
|
|
||||||
'train_cfg specified in both outer field and model field '
|
|
||||||
assert cfg.get('test_cfg') is None or test_cfg is None, \
|
|
||||||
'test_cfg specified in both outer field and model field '
|
|
||||||
return DETECTORS.build(
|
|
||||||
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
|
|
||||||
|
|
|
@ -6,8 +6,9 @@ from mmcv.cnn import ConvModule, build_norm_layer
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
|
|
||||||
from mmocr.models.builder import (BACKBONES, UPSAMPLE_LAYERS,
|
from mmocr.models.builder import (UPSAMPLE_LAYERS, build_activation_layer,
|
||||||
build_activation_layer, build_upsample_layer)
|
build_upsample_layer)
|
||||||
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
class UpConvBlock(nn.Module):
|
class UpConvBlock(nn.Module):
|
||||||
|
@ -317,7 +318,7 @@ class InterpConv(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@MODELS.register_module()
|
||||||
class UNet(BaseModule):
|
class UNet(BaseModule):
|
||||||
"""UNet backbone.
|
"""UNet backbone.
|
||||||
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
||||||
|
|
|
@ -4,11 +4,10 @@ import warnings
|
||||||
from mmdet.models.detectors import \
|
from mmdet.models.detectors import \
|
||||||
SingleStageDetector as MMDET_SingleStageDetector
|
SingleStageDetector as MMDET_SingleStageDetector
|
||||||
|
|
||||||
from mmocr.models.builder import (DETECTORS, build_backbone, build_head,
|
from mmocr.registry import MODELS
|
||||||
build_neck)
|
|
||||||
|
|
||||||
|
|
||||||
@DETECTORS.register_module()
|
@MODELS.register_module()
|
||||||
class SingleStageDetector(MMDET_SingleStageDetector):
|
class SingleStageDetector(MMDET_SingleStageDetector):
|
||||||
"""Base class for single-stage detectors.
|
"""Base class for single-stage detectors.
|
||||||
|
|
||||||
|
@ -29,11 +28,11 @@ class SingleStageDetector(MMDET_SingleStageDetector):
|
||||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||||
'please use "init_cfg" instead')
|
'please use "init_cfg" instead')
|
||||||
backbone.pretrained = pretrained
|
backbone.pretrained = pretrained
|
||||||
self.backbone = build_backbone(backbone)
|
self.backbone = MODELS.build(backbone)
|
||||||
if neck is not None:
|
if neck is not None:
|
||||||
self.neck = build_neck(neck)
|
self.neck = MODELS.build(neck)
|
||||||
bbox_head.update(train_cfg=train_cfg)
|
bbox_head.update(train_cfg=train_cfg)
|
||||||
bbox_head.update(test_cfg=test_cfg)
|
bbox_head.update(test_cfg=test_cfg)
|
||||||
self.bbox_head = build_head(bbox_head)
|
self.bbox_head = MODELS.build(bbox_head)
|
||||||
self.train_cfg = train_cfg
|
self.train_cfg = train_cfg
|
||||||
self.test_cfg = test_cfg
|
self.test_cfg = test_cfg
|
||||||
|
|
|
@ -2,10 +2,10 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from mmocr.models.builder import LOSSES
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class DiceLoss(nn.Module):
|
class DiceLoss(nn.Module):
|
||||||
|
|
||||||
def __init__(self, eps=1e-6):
|
def __init__(self, eps=1e-6):
|
||||||
|
|
|
@ -7,12 +7,12 @@ from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from mmocr.core import imshow_edge, imshow_node
|
from mmocr.core import imshow_edge, imshow_node
|
||||||
from mmocr.models.builder import DETECTORS, build_roi_extractor
|
|
||||||
from mmocr.models.common.detectors import SingleStageDetector
|
from mmocr.models.common.detectors import SingleStageDetector
|
||||||
|
from mmocr.registry import MODELS
|
||||||
from mmocr.utils import list_from_file
|
from mmocr.utils import list_from_file
|
||||||
|
|
||||||
|
|
||||||
@DETECTORS.register_module()
|
@MODELS.register_module()
|
||||||
class SDMGR(SingleStageDetector):
|
class SDMGR(SingleStageDetector):
|
||||||
"""The implementation of the paper: Spatial Dual-Modality Graph Reasoning
|
"""The implementation of the paper: Spatial Dual-Modality Graph Reasoning
|
||||||
for Key Information Extraction. https://arxiv.org/abs/2103.14470.
|
for Key Information Extraction. https://arxiv.org/abs/2103.14470.
|
||||||
|
@ -42,7 +42,7 @@ class SDMGR(SingleStageDetector):
|
||||||
backbone, neck, bbox_head, train_cfg, test_cfg, init_cfg=init_cfg)
|
backbone, neck, bbox_head, train_cfg, test_cfg, init_cfg=init_cfg)
|
||||||
self.visual_modality = visual_modality
|
self.visual_modality = visual_modality
|
||||||
if visual_modality:
|
if visual_modality:
|
||||||
self.extractor = build_roi_extractor({
|
self.extractor = MODELS.build({
|
||||||
**extractor, 'out_channels':
|
**extractor, 'out_channels':
|
||||||
self.backbone.base_channels
|
self.backbone.base_channels
|
||||||
})
|
})
|
||||||
|
|
|
@ -4,10 +4,10 @@ from mmcv.runner import BaseModule
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from mmocr.models.builder import HEADS, build_loss
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@HEADS.register_module()
|
@MODELS.register_module()
|
||||||
class SDMGRHead(BaseModule):
|
class SDMGRHead(BaseModule):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -45,7 +45,7 @@ class SDMGRHead(BaseModule):
|
||||||
[GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
|
[GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
|
||||||
self.node_cls = nn.Linear(node_embed, num_classes)
|
self.node_cls = nn.Linear(node_embed, num_classes)
|
||||||
self.edge_cls = nn.Linear(edge_embed, 2)
|
self.edge_cls = nn.Linear(edge_embed, 2)
|
||||||
self.loss = build_loss(loss)
|
self.loss = MODELS.build(loss)
|
||||||
|
|
||||||
def forward(self, relations, texts, x=None):
|
def forward(self, relations, texts, x=None):
|
||||||
node_nums, char_nums = [], []
|
node_nums, char_nums = [], []
|
||||||
|
|
|
@ -3,10 +3,10 @@ import torch
|
||||||
from mmdet.models.losses import accuracy
|
from mmdet.models.losses import accuracy
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mmocr.models.builder import LOSSES
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class SDMGRLoss(nn.Module):
|
class SDMGRLoss(nn.Module):
|
||||||
"""The implementation the loss of key information extraction proposed in
|
"""The implementation the loss of key information extraction proposed in
|
||||||
the paper: Spatial Dual-Modality Graph Reasoning for Key Information
|
the paper: Spatial Dual-Modality Graph Reasoning for Key Information
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import (DETECTORS, build_convertor, build_decoder,
|
|
||||||
build_encoder, build_loss)
|
|
||||||
from mmocr.models.textrecog.recognizer.base import BaseRecognizer
|
from mmocr.models.textrecog.recognizer.base import BaseRecognizer
|
||||||
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@DETECTORS.register_module()
|
@MODELS.register_module()
|
||||||
class NerClassifier(BaseRecognizer):
|
class NerClassifier(BaseRecognizer):
|
||||||
"""Base class for NER classifier."""
|
"""Base class for NER classifier."""
|
||||||
|
|
||||||
|
@ -17,15 +16,15 @@ class NerClassifier(BaseRecognizer):
|
||||||
test_cfg=None,
|
test_cfg=None,
|
||||||
init_cfg=None):
|
init_cfg=None):
|
||||||
super().__init__(init_cfg=init_cfg)
|
super().__init__(init_cfg=init_cfg)
|
||||||
self.label_convertor = build_convertor(label_convertor)
|
self.label_convertor = MODELS.build(label_convertor)
|
||||||
|
|
||||||
self.encoder = build_encoder(encoder)
|
self.encoder = MODELS.build(encoder)
|
||||||
|
|
||||||
decoder.update(num_labels=self.label_convertor.num_labels)
|
decoder.update(num_labels=self.label_convertor.num_labels)
|
||||||
self.decoder = build_decoder(decoder)
|
self.decoder = MODELS.build(decoder)
|
||||||
|
|
||||||
loss.update(num_labels=self.label_convertor.num_labels)
|
loss.update(num_labels=self.label_convertor.num_labels)
|
||||||
self.loss = build_loss(loss)
|
self.loss = MODELS.build(loss)
|
||||||
|
|
||||||
def extract_feat(self, imgs):
|
def extract_feat(self, imgs):
|
||||||
"""Extract features from images."""
|
"""Extract features from images."""
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mmocr.models.builder import CONVERTORS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.utils import list_from_file
|
from mmocr.utils import list_from_file
|
||||||
|
|
||||||
|
|
||||||
@CONVERTORS.register_module()
|
@MODELS.register_module()
|
||||||
class NerConvertor:
|
class NerConvertor:
|
||||||
"""Convert between text, index and tensor for NER pipeline.
|
"""Convert between text, index and tensor for NER pipeline.
|
||||||
|
|
||||||
|
|
|
@ -4,10 +4,10 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
from mmocr.models.builder import DECODERS
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@DECODERS.register_module()
|
@MODELS.register_module()
|
||||||
class FCDecoder(BaseModule):
|
class FCDecoder(BaseModule):
|
||||||
"""FC Decoder class for Ner.
|
"""FC Decoder class for Ner.
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
from mmocr.models.builder import ENCODERS
|
|
||||||
from mmocr.models.ner.utils.bert import BertModel
|
from mmocr.models.ner.utils.bert import BertModel
|
||||||
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@ENCODERS.register_module()
|
@MODELS.register_module()
|
||||||
class BertEncoder(BaseModule):
|
class BertEncoder(BaseModule):
|
||||||
"""Bert encoder
|
"""Bert encoder
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -2,10 +2,10 @@
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from mmocr.models.builder import LOSSES
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class MaskedCrossEntropyLoss(nn.Module):
|
class MaskedCrossEntropyLoss(nn.Module):
|
||||||
"""The implementation of masked cross entropy loss.
|
"""The implementation of masked cross entropy loss.
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mmocr.models.builder import LOSSES
|
|
||||||
from mmocr.models.common.losses.focal_loss import FocalLoss
|
from mmocr.models.common.losses.focal_loss import FocalLoss
|
||||||
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class MaskedFocalLoss(nn.Module):
|
class MaskedFocalLoss(nn.Module):
|
||||||
"""The implementation of masked focal loss.
|
"""The implementation of masked focal loss.
|
||||||
|
|
||||||
|
|
|
@ -5,11 +5,11 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.runner import BaseModule, Sequential
|
from mmcv.runner import BaseModule, Sequential
|
||||||
|
|
||||||
from mmocr.models.builder import HEADS
|
from mmocr.registry import MODELS
|
||||||
from .head_mixin import HeadMixin
|
from .head_mixin import HeadMixin
|
||||||
|
|
||||||
|
|
||||||
@HEADS.register_module()
|
@MODELS.register_module()
|
||||||
class DBHead(HeadMixin, BaseModule):
|
class DBHead(HeadMixin, BaseModule):
|
||||||
"""The class for DBNet head.
|
"""The class for DBNet head.
|
||||||
|
|
||||||
|
|
|
@ -7,13 +7,13 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
from mmocr.models.builder import HEADS, build_loss
|
|
||||||
from mmocr.models.textdet.modules import GCN, LocalGraphs, ProposalLocalGraphs
|
from mmocr.models.textdet.modules import GCN, LocalGraphs, ProposalLocalGraphs
|
||||||
|
from mmocr.registry import MODELS
|
||||||
from mmocr.utils import check_argument
|
from mmocr.utils import check_argument
|
||||||
from .head_mixin import HeadMixin
|
from .head_mixin import HeadMixin
|
||||||
|
|
||||||
|
|
||||||
@HEADS.register_module()
|
@MODELS.register_module()
|
||||||
class DRRGHead(HeadMixin, BaseModule):
|
class DRRGHead(HeadMixin, BaseModule):
|
||||||
"""The class for DRRG head: `Deep Relational Reasoning Graph Network for
|
"""The class for DRRG head: `Deep Relational Reasoning Graph Network for
|
||||||
Arbitrary Shape Text Detection <https://arxiv.org/abs/2003.07493>`_.
|
Arbitrary Shape Text Detection <https://arxiv.org/abs/2003.07493>`_.
|
||||||
|
@ -118,7 +118,7 @@ class DRRGHead(HeadMixin, BaseModule):
|
||||||
self.center_region_thr = center_region_thr
|
self.center_region_thr = center_region_thr
|
||||||
self.center_region_area_thr = center_region_area_thr
|
self.center_region_area_thr = center_region_area_thr
|
||||||
self.local_graph_thr = local_graph_thr
|
self.local_graph_thr = local_graph_thr
|
||||||
self.loss_module = build_loss(loss)
|
self.loss_module = MODELS.build(loss)
|
||||||
self.train_cfg = train_cfg
|
self.train_cfg = train_cfg
|
||||||
self.test_cfg = test_cfg
|
self.test_cfg = test_cfg
|
||||||
|
|
||||||
|
|
|
@ -5,12 +5,12 @@ import torch.nn as nn
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
from mmdet.core import multi_apply
|
from mmdet.core import multi_apply
|
||||||
|
|
||||||
from mmocr.models.builder import HEADS
|
from mmocr.registry import MODELS
|
||||||
from ..postprocess.utils import poly_nms
|
from ..postprocess.utils import poly_nms
|
||||||
from .head_mixin import HeadMixin
|
from .head_mixin import HeadMixin
|
||||||
|
|
||||||
|
|
||||||
@HEADS.register_module()
|
@MODELS.register_module()
|
||||||
class FCEHead(HeadMixin, BaseModule):
|
class FCEHead(HeadMixin, BaseModule):
|
||||||
"""The class for implementing FCENet head.
|
"""The class for implementing FCENet head.
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mmocr.models.builder import HEADS, build_loss, build_postprocessor
|
from mmocr.registry import MODELS
|
||||||
from mmocr.utils import check_argument
|
from mmocr.utils import check_argument
|
||||||
|
|
||||||
|
|
||||||
@HEADS.register_module()
|
@MODELS.register_module()
|
||||||
class HeadMixin:
|
class HeadMixin:
|
||||||
"""Base head class for text detection, including loss calcalation and
|
"""Base head class for text detection, including loss calcalation and
|
||||||
postprocess.
|
postprocess.
|
||||||
|
@ -19,8 +19,8 @@ class HeadMixin:
|
||||||
assert isinstance(loss, dict)
|
assert isinstance(loss, dict)
|
||||||
assert isinstance(postprocessor, dict)
|
assert isinstance(postprocessor, dict)
|
||||||
|
|
||||||
self.loss_module = build_loss(loss)
|
self.loss_module = MODELS.build(loss)
|
||||||
self.postprocessor = build_postprocessor(postprocessor)
|
self.postprocessor = MODELS.build(postprocessor)
|
||||||
|
|
||||||
def resize_boundary(self, boundaries, scale_factor):
|
def resize_boundary(self, boundaries, scale_factor):
|
||||||
"""Rescale boundaries via scale_factor.
|
"""Rescale boundaries via scale_factor.
|
||||||
|
|
|
@ -6,12 +6,12 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
from mmocr.models.builder import HEADS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.utils import check_argument
|
from mmocr.utils import check_argument
|
||||||
from .head_mixin import HeadMixin
|
from .head_mixin import HeadMixin
|
||||||
|
|
||||||
|
|
||||||
@HEADS.register_module()
|
@MODELS.register_module()
|
||||||
class PANHead(HeadMixin, BaseModule):
|
class PANHead(HeadMixin, BaseModule):
|
||||||
"""The class for PANet head.
|
"""The class for PANet head.
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import HEADS
|
from mmocr.registry import MODELS
|
||||||
from . import PANHead
|
from . import PANHead
|
||||||
|
|
||||||
|
|
||||||
@HEADS.register_module()
|
@MODELS.register_module()
|
||||||
class PSEHead(PANHead):
|
class PSEHead(PANHead):
|
||||||
"""The class for PSENet head.
|
"""The class for PSENet head.
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,11 @@ import warnings
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
from mmocr.models.builder import HEADS
|
from mmocr.registry import MODELS
|
||||||
from .head_mixin import HeadMixin
|
from .head_mixin import HeadMixin
|
||||||
|
|
||||||
|
|
||||||
@HEADS.register_module()
|
@MODELS.register_module()
|
||||||
class TextSnakeHead(HeadMixin, BaseModule):
|
class TextSnakeHead(HeadMixin, BaseModule):
|
||||||
"""The class for TextSnake head: TextSnake: A Flexible Representation for
|
"""The class for TextSnake head: TextSnake: A Flexible Representation for
|
||||||
Detecting Text of Arbitrary Shapes.
|
Detecting Text of Arbitrary Shapes.
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import DETECTORS
|
from mmocr.registry import MODELS
|
||||||
from .single_stage_text_detector import SingleStageTextDetector
|
from .single_stage_text_detector import SingleStageTextDetector
|
||||||
from .text_detector_mixin import TextDetectorMixin
|
from .text_detector_mixin import TextDetectorMixin
|
||||||
|
|
||||||
|
|
||||||
@DETECTORS.register_module()
|
@MODELS.register_module()
|
||||||
class DBNet(TextDetectorMixin, SingleStageTextDetector):
|
class DBNet(TextDetectorMixin, SingleStageTextDetector):
|
||||||
"""The class for implementing DBNet text detector: Real-time Scene Text
|
"""The class for implementing DBNet text detector: Real-time Scene Text
|
||||||
Detection with Differentiable Binarization.
|
Detection with Differentiable Binarization.
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import DETECTORS
|
from mmocr.registry import MODELS
|
||||||
from .single_stage_text_detector import SingleStageTextDetector
|
from .single_stage_text_detector import SingleStageTextDetector
|
||||||
from .text_detector_mixin import TextDetectorMixin
|
from .text_detector_mixin import TextDetectorMixin
|
||||||
|
|
||||||
|
|
||||||
@DETECTORS.register_module()
|
@MODELS.register_module()
|
||||||
class DRRG(TextDetectorMixin, SingleStageTextDetector):
|
class DRRG(TextDetectorMixin, SingleStageTextDetector):
|
||||||
"""The class for implementing DRRG text detector. Deep Relational Reasoning
|
"""The class for implementing DRRG text detector. Deep Relational Reasoning
|
||||||
Graph Network for Arbitrary Shape Text Detection.
|
Graph Network for Arbitrary Shape Text Detection.
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import DETECTORS
|
from mmocr.registry import MODELS
|
||||||
from .single_stage_text_detector import SingleStageTextDetector
|
from .single_stage_text_detector import SingleStageTextDetector
|
||||||
from .text_detector_mixin import TextDetectorMixin
|
from .text_detector_mixin import TextDetectorMixin
|
||||||
|
|
||||||
|
|
||||||
@DETECTORS.register_module()
|
@MODELS.register_module()
|
||||||
class FCENet(TextDetectorMixin, SingleStageTextDetector):
|
class FCENet(TextDetectorMixin, SingleStageTextDetector):
|
||||||
"""The class for implementing FCENet text detector
|
"""The class for implementing FCENet text detector
|
||||||
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text
|
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text
|
||||||
|
|
|
@ -2,11 +2,11 @@
|
||||||
from mmdet.models.detectors import MaskRCNN
|
from mmdet.models.detectors import MaskRCNN
|
||||||
|
|
||||||
from mmocr.core import seg2boundary
|
from mmocr.core import seg2boundary
|
||||||
from mmocr.models.builder import DETECTORS
|
from mmocr.registry import MODELS
|
||||||
from .text_detector_mixin import TextDetectorMixin
|
from .text_detector_mixin import TextDetectorMixin
|
||||||
|
|
||||||
|
|
||||||
@DETECTORS.register_module()
|
@MODELS.register_module()
|
||||||
class OCRMaskRCNN(TextDetectorMixin, MaskRCNN):
|
class OCRMaskRCNN(TextDetectorMixin, MaskRCNN):
|
||||||
"""Mask RCNN tailored for OCR."""
|
"""Mask RCNN tailored for OCR."""
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import DETECTORS
|
from mmocr.registry import MODELS
|
||||||
from .single_stage_text_detector import SingleStageTextDetector
|
from .single_stage_text_detector import SingleStageTextDetector
|
||||||
from .text_detector_mixin import TextDetectorMixin
|
from .text_detector_mixin import TextDetectorMixin
|
||||||
|
|
||||||
|
|
||||||
@DETECTORS.register_module()
|
@MODELS.register_module()
|
||||||
class PANet(TextDetectorMixin, SingleStageTextDetector):
|
class PANet(TextDetectorMixin, SingleStageTextDetector):
|
||||||
"""The class for implementing PANet text detector:
|
"""The class for implementing PANet text detector:
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import DETECTORS
|
from mmocr.registry import MODELS
|
||||||
from .single_stage_text_detector import SingleStageTextDetector
|
from .single_stage_text_detector import SingleStageTextDetector
|
||||||
from .text_detector_mixin import TextDetectorMixin
|
from .text_detector_mixin import TextDetectorMixin
|
||||||
|
|
||||||
|
|
||||||
@DETECTORS.register_module()
|
@MODELS.register_module()
|
||||||
class PSENet(TextDetectorMixin, SingleStageTextDetector):
|
class PSENet(TextDetectorMixin, SingleStageTextDetector):
|
||||||
"""The class for implementing PSENet text detector: Shape Robust Text
|
"""The class for implementing PSENet text detector: Shape Robust Text
|
||||||
Detection with Progressive Scale Expansion Network.
|
Detection with Progressive Scale Expansion Network.
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmocr.models.builder import DETECTORS
|
|
||||||
from mmocr.models.common.detectors import SingleStageDetector
|
from mmocr.models.common.detectors import SingleStageDetector
|
||||||
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@DETECTORS.register_module()
|
@MODELS.register_module()
|
||||||
class SingleStageTextDetector(SingleStageDetector):
|
class SingleStageTextDetector(SingleStageDetector):
|
||||||
"""The class for implementing single stage text detector."""
|
"""The class for implementing single stage text detector."""
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import DETECTORS
|
from mmocr.registry import MODELS
|
||||||
from .single_stage_text_detector import SingleStageTextDetector
|
from .single_stage_text_detector import SingleStageTextDetector
|
||||||
from .text_detector_mixin import TextDetectorMixin
|
from .text_detector_mixin import TextDetectorMixin
|
||||||
|
|
||||||
|
|
||||||
@DETECTORS.register_module()
|
@MODELS.register_module()
|
||||||
class TextSnake(TextDetectorMixin, SingleStageTextDetector):
|
class TextSnake(TextDetectorMixin, SingleStageTextDetector):
|
||||||
"""The class for implementing TextSnake text detector: TextSnake: A
|
"""The class for implementing TextSnake text detector: TextSnake: A
|
||||||
Flexible Representation for Detecting Text of Arbitrary Shapes.
|
Flexible Representation for Detecting Text of Arbitrary Shapes.
|
||||||
|
|
|
@ -3,11 +3,11 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mmocr.models.builder import LOSSES
|
|
||||||
from mmocr.models.common.losses.dice_loss import DiceLoss
|
from mmocr.models.common.losses.dice_loss import DiceLoss
|
||||||
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class DBLoss(nn.Module):
|
class DBLoss(nn.Module):
|
||||||
"""The class for implementing DBNet loss.
|
"""The class for implementing DBNet loss.
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,11 @@ import torch.nn.functional as F
|
||||||
from mmdet.core import BitmapMasks
|
from mmdet.core import BitmapMasks
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mmocr.models.builder import LOSSES
|
from mmocr.registry import MODELS
|
||||||
from mmocr.utils import check_argument
|
from mmocr.utils import check_argument
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class DRRGLoss(nn.Module):
|
class DRRGLoss(nn.Module):
|
||||||
"""The class for implementing DRRG loss. This is partially adapted from
|
"""The class for implementing DRRG loss. This is partially adapted from
|
||||||
https://github.com/GXYM/DRRG licensed under the MIT license.
|
https://github.com/GXYM/DRRG licensed under the MIT license.
|
||||||
|
|
|
@ -5,10 +5,10 @@ import torch.nn.functional as F
|
||||||
from mmdet.core import multi_apply
|
from mmdet.core import multi_apply
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mmocr.models.builder import LOSSES
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class FCELoss(nn.Module):
|
class FCELoss(nn.Module):
|
||||||
"""The class for implementing FCENet loss.
|
"""The class for implementing FCENet loss.
|
||||||
|
|
||||||
|
|
|
@ -8,11 +8,11 @@ import torch.nn.functional as F
|
||||||
from mmdet.core import BitmapMasks
|
from mmdet.core import BitmapMasks
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mmocr.models.builder import LOSSES
|
from mmocr.registry import MODELS
|
||||||
from mmocr.utils import check_argument
|
from mmocr.utils import check_argument
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class PANLoss(nn.Module):
|
class PANLoss(nn.Module):
|
||||||
"""The class for implementing PANet loss. This was partially adapted from
|
"""The class for implementing PANet loss. This was partially adapted from
|
||||||
https://github.com/WenmuZhou/PAN.pytorch.
|
https://github.com/WenmuZhou/PAN.pytorch.
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmdet.core import BitmapMasks
|
from mmdet.core import BitmapMasks
|
||||||
|
|
||||||
from mmocr.models.builder import LOSSES
|
from mmocr.registry import MODELS
|
||||||
from mmocr.utils import check_argument
|
from mmocr.utils import check_argument
|
||||||
from . import PANLoss
|
from . import PANLoss
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class PSELoss(PANLoss):
|
class PSELoss(PANLoss):
|
||||||
r"""The class for implementing PSENet loss. This is partially adapted from
|
r"""The class for implementing PSENet loss. This is partially adapted from
|
||||||
https://github.com/whai362/PSENet.
|
https://github.com/whai362/PSENet.
|
||||||
|
|
|
@ -4,11 +4,11 @@ import torch.nn.functional as F
|
||||||
from mmdet.core import BitmapMasks
|
from mmdet.core import BitmapMasks
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mmocr.models.builder import LOSSES
|
from mmocr.registry import MODELS
|
||||||
from mmocr.utils import check_argument
|
from mmocr.utils import check_argument
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class TextSnakeLoss(nn.Module):
|
class TextSnakeLoss(nn.Module):
|
||||||
"""The class for implementing TextSnake loss. This is partially adapted
|
"""The class for implementing TextSnake loss. This is partially adapted
|
||||||
from https://github.com/princewang1994/TextSnake.pytorch.
|
from https://github.com/princewang1994/TextSnake.pytorch.
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch.nn.functional as F
|
||||||
from mmcv.runner import BaseModule, ModuleList
|
from mmcv.runner import BaseModule, ModuleList
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mmocr.models.builder import NECKS
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
class FPEM(BaseModule):
|
class FPEM(BaseModule):
|
||||||
|
@ -72,7 +72,7 @@ class SeparableConv2d(BaseModule):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@NECKS.register_module()
|
@MODELS.register_module()
|
||||||
class FPEM_FFM(BaseModule):
|
class FPEM_FFM(BaseModule):
|
||||||
"""This code is from https://github.com/WenmuZhou/PAN.pytorch.
|
"""This code is from https://github.com/WenmuZhou/PAN.pytorch.
|
||||||
|
|
||||||
|
|
|
@ -5,10 +5,10 @@ import torch.nn.functional as F
|
||||||
from mmcv.cnn import ConvModule
|
from mmcv.cnn import ConvModule
|
||||||
from mmcv.runner import BaseModule, ModuleList, Sequential, auto_fp16
|
from mmcv.runner import BaseModule, ModuleList, Sequential, auto_fp16
|
||||||
|
|
||||||
from mmocr.models.builder import NECKS
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@NECKS.register_module()
|
@MODELS.register_module()
|
||||||
class FPNC(BaseModule):
|
class FPNC(BaseModule):
|
||||||
"""FPN-like fusion module in Real-time Scene Text Detection with
|
"""FPN-like fusion module in Real-time Scene Text Detection with
|
||||||
Differentiable Binarization.
|
Differentiable Binarization.
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn.functional as F
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mmocr.models.builder import NECKS
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
class UpBlock(BaseModule):
|
class UpBlock(BaseModule):
|
||||||
|
@ -30,7 +30,7 @@ class UpBlock(BaseModule):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@NECKS.register_module()
|
@MODELS.register_module()
|
||||||
class FPN_UNet(BaseModule):
|
class FPN_UNet(BaseModule):
|
||||||
"""The class for implementing DRRG and TextSnake U-Net-like FPN.
|
"""The class for implementing DRRG and TextSnake U-Net-like FPN.
|
||||||
|
|
||||||
|
|
|
@ -4,10 +4,10 @@ import torch.nn.functional as F
|
||||||
from mmcv.cnn import ConvModule
|
from mmcv.cnn import ConvModule
|
||||||
from mmcv.runner import BaseModule, ModuleList, auto_fp16
|
from mmcv.runner import BaseModule, ModuleList, auto_fp16
|
||||||
|
|
||||||
from mmocr.models.builder import NECKS
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@NECKS.register_module()
|
@MODELS.register_module()
|
||||||
class FPNF(BaseModule):
|
class FPNF(BaseModule):
|
||||||
"""FPN-like fusion module in Shape Robust Text Detection with Progressive
|
"""FPN-like fusion module in Shape Robust Text Detection with Progressive
|
||||||
Scale Expansion Network.
|
Scale Expansion Network.
|
||||||
|
|
|
@ -3,12 +3,12 @@ import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mmocr.core import points2boundary
|
from mmocr.core import points2boundary
|
||||||
from mmocr.models.builder import POSTPROCESSOR
|
from mmocr.registry import MODELS
|
||||||
from .base_postprocessor import BasePostprocessor
|
from .base_postprocessor import BasePostprocessor
|
||||||
from .utils import box_score_fast, unclip
|
from .utils import box_score_fast, unclip
|
||||||
|
|
||||||
|
|
||||||
@POSTPROCESSOR.register_module()
|
@MODELS.register_module()
|
||||||
class DBPostprocessor(BasePostprocessor):
|
class DBPostprocessor(BasePostprocessor):
|
||||||
"""Decoding predictions of DbNet to instances. This is partially adapted
|
"""Decoding predictions of DbNet to instances. This is partially adapted
|
||||||
from https://github.com/MhLiao/DB.
|
from https://github.com/MhLiao/DB.
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import POSTPROCESSOR
|
from mmocr.registry import MODELS
|
||||||
from .base_postprocessor import BasePostprocessor
|
from .base_postprocessor import BasePostprocessor
|
||||||
from .utils import (clusters2labels, comps2boundaries, connected_components,
|
from .utils import (clusters2labels, comps2boundaries, connected_components,
|
||||||
graph_propagation, remove_single)
|
graph_propagation, remove_single)
|
||||||
|
|
||||||
|
|
||||||
@POSTPROCESSOR.register_module()
|
@MODELS.register_module()
|
||||||
class DRRGPostprocessor(BasePostprocessor):
|
class DRRGPostprocessor(BasePostprocessor):
|
||||||
"""Merge text components and construct boundaries of text instances.
|
"""Merge text components and construct boundaries of text instances.
|
||||||
|
|
||||||
|
|
|
@ -2,12 +2,12 @@
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mmocr.models.builder import POSTPROCESSOR
|
from mmocr.registry import MODELS
|
||||||
from .base_postprocessor import BasePostprocessor
|
from .base_postprocessor import BasePostprocessor
|
||||||
from .utils import fill_hole, fourier2poly, poly_nms
|
from .utils import fill_hole, fourier2poly, poly_nms
|
||||||
|
|
||||||
|
|
||||||
@POSTPROCESSOR.register_module()
|
@MODELS.register_module()
|
||||||
class FCEPostprocessor(BasePostprocessor):
|
class FCEPostprocessor(BasePostprocessor):
|
||||||
"""Decoding predictions of FCENet to instances.
|
"""Decoding predictions of FCENet to instances.
|
||||||
|
|
||||||
|
|
|
@ -5,11 +5,11 @@ import torch
|
||||||
from mmcv.ops import pixel_group
|
from mmcv.ops import pixel_group
|
||||||
|
|
||||||
from mmocr.core import points2boundary
|
from mmocr.core import points2boundary
|
||||||
from mmocr.models.builder import POSTPROCESSOR
|
from mmocr.registry import MODELS
|
||||||
from .base_postprocessor import BasePostprocessor
|
from .base_postprocessor import BasePostprocessor
|
||||||
|
|
||||||
|
|
||||||
@POSTPROCESSOR.register_module()
|
@MODELS.register_module()
|
||||||
class PANPostprocessor(BasePostprocessor):
|
class PANPostprocessor(BasePostprocessor):
|
||||||
"""Convert scores to quadrangles via post processing in PANet. This is
|
"""Convert scores to quadrangles via post processing in PANet. This is
|
||||||
partially adapted from https://github.com/WenmuZhou/PAN.pytorch.
|
partially adapted from https://github.com/WenmuZhou/PAN.pytorch.
|
||||||
|
|
|
@ -6,11 +6,11 @@ import torch
|
||||||
from mmcv.ops import contour_expand
|
from mmcv.ops import contour_expand
|
||||||
|
|
||||||
from mmocr.core import points2boundary
|
from mmocr.core import points2boundary
|
||||||
from mmocr.models.builder import POSTPROCESSOR
|
from mmocr.registry import MODELS
|
||||||
from .base_postprocessor import BasePostprocessor
|
from .base_postprocessor import BasePostprocessor
|
||||||
|
|
||||||
|
|
||||||
@POSTPROCESSOR.register_module()
|
@MODELS.register_module()
|
||||||
class PSEPostprocessor(BasePostprocessor):
|
class PSEPostprocessor(BasePostprocessor):
|
||||||
"""Decoding predictions of PSENet to instances. This is partially adapted
|
"""Decoding predictions of PSENet to instances. This is partially adapted
|
||||||
from https://github.com/whai362/PSENet.
|
from https://github.com/whai362/PSENet.
|
||||||
|
|
|
@ -5,12 +5,12 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from skimage.morphology import skeletonize
|
from skimage.morphology import skeletonize
|
||||||
|
|
||||||
from mmocr.models.builder import POSTPROCESSOR
|
from mmocr.registry import MODELS
|
||||||
from .base_postprocessor import BasePostprocessor
|
from .base_postprocessor import BasePostprocessor
|
||||||
from .utils import centralize, fill_hole, merge_disks
|
from .utils import centralize, fill_hole, merge_disks
|
||||||
|
|
||||||
|
|
||||||
@POSTPROCESSOR.register_module()
|
@MODELS.register_module()
|
||||||
class TextSnakePostprocessor(BasePostprocessor):
|
class TextSnakePostprocessor(BasePostprocessor):
|
||||||
"""Decoding predictions of TextSnake to instances. This was partially
|
"""Decoding predictions of TextSnake to instances. This was partially
|
||||||
adapted from https://github.com/princewang1994/TextSnake.pytorch.
|
adapted from https://github.com/princewang1994/TextSnake.pytorch.
|
||||||
|
|
|
@ -2,10 +2,10 @@
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
from mmocr.models.builder import BACKBONES
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@MODELS.register_module()
|
||||||
class NRTRModalityTransform(BaseModule):
|
class NRTRModalityTransform(BaseModule):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|
|
@ -3,11 +3,11 @@ from mmcv.cnn import ConvModule, build_plugin_layer
|
||||||
from mmcv.runner import BaseModule, Sequential
|
from mmcv.runner import BaseModule, Sequential
|
||||||
|
|
||||||
import mmocr.utils as utils
|
import mmocr.utils as utils
|
||||||
from mmocr.models.builder import BACKBONES
|
|
||||||
from mmocr.models.textrecog.layers import BasicBlock
|
from mmocr.models.textrecog.layers import BasicBlock
|
||||||
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@MODELS.register_module()
|
||||||
class ResNet(BaseModule):
|
class ResNet(BaseModule):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -3,11 +3,11 @@ import torch.nn as nn
|
||||||
from mmcv.runner import BaseModule, Sequential
|
from mmcv.runner import BaseModule, Sequential
|
||||||
|
|
||||||
import mmocr.utils as utils
|
import mmocr.utils as utils
|
||||||
from mmocr.models.builder import BACKBONES
|
|
||||||
from mmocr.models.textrecog.layers import BasicBlock
|
from mmocr.models.textrecog.layers import BasicBlock
|
||||||
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@MODELS.register_module()
|
||||||
class ResNet31OCR(BaseModule):
|
class ResNet31OCR(BaseModule):
|
||||||
"""Implement ResNet backbone for text recognition, modified from
|
"""Implement ResNet backbone for text recognition, modified from
|
||||||
`ResNet <https://arxiv.org/pdf/1512.03385.pdf>`_
|
`ResNet <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||||
|
|
|
@ -3,11 +3,11 @@ import torch.nn as nn
|
||||||
from mmcv.runner import BaseModule, Sequential
|
from mmcv.runner import BaseModule, Sequential
|
||||||
|
|
||||||
import mmocr.utils as utils
|
import mmocr.utils as utils
|
||||||
from mmocr.models.builder import BACKBONES
|
|
||||||
from mmocr.models.textrecog.layers import BasicBlock
|
from mmocr.models.textrecog.layers import BasicBlock
|
||||||
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@MODELS.register_module()
|
||||||
class ResNetABI(BaseModule):
|
class ResNetABI(BaseModule):
|
||||||
"""Implement ResNet backbone for text recognition, modified from `ResNet.
|
"""Implement ResNet backbone for text recognition, modified from `ResNet.
|
||||||
|
|
||||||
|
|
|
@ -3,10 +3,10 @@ import torch.nn as nn
|
||||||
from mmcv.cnn import ConvModule
|
from mmcv.cnn import ConvModule
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
from mmocr.models.builder import BACKBONES
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@MODELS.register_module()
|
||||||
class ShallowCNN(BaseModule):
|
class ShallowCNN(BaseModule):
|
||||||
"""Implement Shallow CNN block for SATRN.
|
"""Implement Shallow CNN block for SATRN.
|
||||||
|
|
||||||
|
|
|
@ -2,10 +2,10 @@
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.runner import BaseModule, Sequential
|
from mmcv.runner import BaseModule, Sequential
|
||||||
|
|
||||||
from mmocr.models.builder import BACKBONES
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@MODELS.register_module()
|
||||||
class VeryDeepVgg(BaseModule):
|
class VeryDeepVgg(BaseModule):
|
||||||
"""Implement VGG-VeryDeep backbone for text recognition, modified from
|
"""Implement VGG-VeryDeep backbone for text recognition, modified from
|
||||||
`VGG-VeryDeep <https://arxiv.org/pdf/1409.1556.pdf>`_
|
`VGG-VeryDeep <https://arxiv.org/pdf/1409.1556.pdf>`_
|
||||||
|
|
|
@ -2,11 +2,11 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import mmocr.utils as utils
|
import mmocr.utils as utils
|
||||||
from mmocr.models.builder import CONVERTORS
|
from mmocr.registry import MODELS
|
||||||
from .attn import AttnConvertor
|
from .attn import AttnConvertor
|
||||||
|
|
||||||
|
|
||||||
@CONVERTORS.register_module()
|
@MODELS.register_module()
|
||||||
class ABIConvertor(AttnConvertor):
|
class ABIConvertor(AttnConvertor):
|
||||||
"""Convert between text, index and tensor for encoder-decoder based
|
"""Convert between text, index and tensor for encoder-decoder based
|
||||||
pipeline. Modified from AttnConvertor to get closer to ABINet's original
|
pipeline. Modified from AttnConvertor to get closer to ABINet's original
|
||||||
|
|
|
@ -2,11 +2,11 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import mmocr.utils as utils
|
import mmocr.utils as utils
|
||||||
from mmocr.models.builder import CONVERTORS
|
from mmocr.registry import MODELS
|
||||||
from .base import BaseConvertor
|
from .base import BaseConvertor
|
||||||
|
|
||||||
|
|
||||||
@CONVERTORS.register_module()
|
@MODELS.register_module()
|
||||||
class AttnConvertor(BaseConvertor):
|
class AttnConvertor(BaseConvertor):
|
||||||
"""Convert between text, index and tensor for encoder-decoder based
|
"""Convert between text, index and tensor for encoder-decoder based
|
||||||
pipeline.
|
pipeline.
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import CONVERTORS
|
from mmocr.registry import MODELS
|
||||||
from mmocr.utils import list_from_file
|
from mmocr.utils import list_from_file
|
||||||
|
|
||||||
|
|
||||||
@CONVERTORS.register_module()
|
@MODELS.register_module()
|
||||||
class BaseConvertor:
|
class BaseConvertor:
|
||||||
"""Convert between text, index and tensor for text recognize pipeline.
|
"""Convert between text, index and tensor for text recognize pipeline.
|
||||||
|
|
||||||
|
|
|
@ -5,11 +5,11 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
import mmocr.utils as utils
|
import mmocr.utils as utils
|
||||||
from mmocr.models.builder import CONVERTORS
|
from mmocr.registry import MODELS
|
||||||
from .base import BaseConvertor
|
from .base import BaseConvertor
|
||||||
|
|
||||||
|
|
||||||
@CONVERTORS.register_module()
|
@MODELS.register_module()
|
||||||
class CTCConvertor(BaseConvertor):
|
class CTCConvertor(BaseConvertor):
|
||||||
"""Convert between text, index and tensor for CTC loss-based pipeline.
|
"""Convert between text, index and tensor for CTC loss-based pipeline.
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,11 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import mmocr.utils as utils
|
import mmocr.utils as utils
|
||||||
from mmocr.models.builder import CONVERTORS
|
from mmocr.registry import MODELS
|
||||||
from .base import BaseConvertor
|
from .base import BaseConvertor
|
||||||
|
|
||||||
|
|
||||||
@CONVERTORS.register_module()
|
@MODELS.register_module()
|
||||||
class SegConvertor(BaseConvertor):
|
class SegConvertor(BaseConvertor):
|
||||||
"""Convert between text, index and tensor for segmentation based pipeline.
|
"""Convert between text, index and tensor for segmentation based pipeline.
|
||||||
|
|
||||||
|
|
|
@ -6,12 +6,12 @@ import torch.nn as nn
|
||||||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
||||||
from mmcv.runner import ModuleList
|
from mmcv.runner import ModuleList
|
||||||
|
|
||||||
from mmocr.models.builder import DECODERS
|
|
||||||
from mmocr.models.common.modules import PositionalEncoding
|
from mmocr.models.common.modules import PositionalEncoding
|
||||||
|
from mmocr.registry import MODELS
|
||||||
from .base_decoder import BaseDecoder
|
from .base_decoder import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@DECODERS.register_module()
|
@MODELS.register_module()
|
||||||
class ABILanguageDecoder(BaseDecoder):
|
class ABILanguageDecoder(BaseDecoder):
|
||||||
r"""Transformer-based language model responsible for spell correction.
|
r"""Transformer-based language model responsible for spell correction.
|
||||||
Implementation of language model of \
|
Implementation of language model of \
|
||||||
|
|
|
@ -3,12 +3,12 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.cnn import ConvModule
|
from mmcv.cnn import ConvModule
|
||||||
|
|
||||||
from mmocr.models.builder import DECODERS
|
|
||||||
from mmocr.models.common.modules import PositionalEncoding
|
from mmocr.models.common.modules import PositionalEncoding
|
||||||
|
from mmocr.registry import MODELS
|
||||||
from .base_decoder import BaseDecoder
|
from .base_decoder import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@DECODERS.register_module()
|
@MODELS.register_module()
|
||||||
class ABIVisionDecoder(BaseDecoder):
|
class ABIVisionDecoder(BaseDecoder):
|
||||||
"""Converts visual features into text characters.
|
"""Converts visual features into text characters.
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
from mmocr.models.builder import DECODERS
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@DECODERS.register_module()
|
@MODELS.register_module()
|
||||||
class BaseDecoder(BaseModule):
|
class BaseDecoder(BaseModule):
|
||||||
"""Base decoder class for text recognition."""
|
"""Base decoder class for text recognition."""
|
||||||
|
|
||||||
|
|
|
@ -2,12 +2,12 @@
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.runner import Sequential
|
from mmcv.runner import Sequential
|
||||||
|
|
||||||
from mmocr.models.builder import DECODERS
|
|
||||||
from mmocr.models.textrecog.layers import BidirectionalLSTM
|
from mmocr.models.textrecog.layers import BidirectionalLSTM
|
||||||
|
from mmocr.registry import MODELS
|
||||||
from .base_decoder import BaseDecoder
|
from .base_decoder import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@DECODERS.register_module()
|
@MODELS.register_module()
|
||||||
class CRNNDecoder(BaseDecoder):
|
class CRNNDecoder(BaseDecoder):
|
||||||
"""Decoder for CRNN.
|
"""Decoder for CRNN.
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,8 @@ import torch.nn.functional as F
|
||||||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
||||||
from mmcv.runner import ModuleList
|
from mmcv.runner import ModuleList
|
||||||
|
|
||||||
from mmocr.models.builder import DECODERS
|
|
||||||
from mmocr.models.common.modules import PositionalEncoding
|
from mmocr.models.common.modules import PositionalEncoding
|
||||||
|
from mmocr.registry import MODELS
|
||||||
from .base_decoder import BaseDecoder
|
from .base_decoder import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ class Embeddings(nn.Module):
|
||||||
return self.lut(x) * math.sqrt(self.d_model)
|
return self.lut(x) * math.sqrt(self.d_model)
|
||||||
|
|
||||||
|
|
||||||
@DECODERS.register_module()
|
@MODELS.register_module()
|
||||||
class MasterDecoder(BaseDecoder):
|
class MasterDecoder(BaseDecoder):
|
||||||
"""Decoder module in `MASTER <https://arxiv.org/abs/1910.02562>`_.
|
"""Decoder module in `MASTER <https://arxiv.org/abs/1910.02562>`_.
|
||||||
|
|
||||||
|
|
|
@ -6,12 +6,12 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from mmcv.runner import ModuleList
|
from mmcv.runner import ModuleList
|
||||||
|
|
||||||
from mmocr.models.builder import DECODERS
|
|
||||||
from mmocr.models.common import PositionalEncoding, TFDecoderLayer
|
from mmocr.models.common import PositionalEncoding, TFDecoderLayer
|
||||||
|
from mmocr.registry import MODELS
|
||||||
from .base_decoder import BaseDecoder
|
from .base_decoder import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@DECODERS.register_module()
|
@MODELS.register_module()
|
||||||
class NRTRDecoder(BaseDecoder):
|
class NRTRDecoder(BaseDecoder):
|
||||||
"""Transformer Decoder block with self attention mechanism.
|
"""Transformer Decoder block with self attention mechanism.
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,13 @@ import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from mmocr.models.builder import DECODERS
|
|
||||||
from mmocr.models.textrecog.layers import (DotProductAttentionLayer,
|
from mmocr.models.textrecog.layers import (DotProductAttentionLayer,
|
||||||
PositionAwareLayer)
|
PositionAwareLayer)
|
||||||
|
from mmocr.registry import MODELS
|
||||||
from .base_decoder import BaseDecoder
|
from .base_decoder import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@DECODERS.register_module()
|
@MODELS.register_module()
|
||||||
class PositionAttentionDecoder(BaseDecoder):
|
class PositionAttentionDecoder(BaseDecoder):
|
||||||
"""Position attention decoder for RobustScanner.
|
"""Position attention decoder for RobustScanner.
|
||||||
|
|
||||||
|
|
|
@ -3,12 +3,12 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from mmocr.models.builder import DECODERS, build_decoder
|
|
||||||
from mmocr.models.textrecog.layers import RobustScannerFusionLayer
|
from mmocr.models.textrecog.layers import RobustScannerFusionLayer
|
||||||
|
from mmocr.registry import MODELS
|
||||||
from .base_decoder import BaseDecoder
|
from .base_decoder import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@DECODERS.register_module()
|
@MODELS.register_module()
|
||||||
class RobustScannerDecoder(BaseDecoder):
|
class RobustScannerDecoder(BaseDecoder):
|
||||||
"""Decoder for RobustScanner.
|
"""Decoder for RobustScanner.
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ class RobustScannerDecoder(BaseDecoder):
|
||||||
hybrid_decoder.update(encode_value=self.encode_value)
|
hybrid_decoder.update(encode_value=self.encode_value)
|
||||||
hybrid_decoder.update(return_feature=True)
|
hybrid_decoder.update(return_feature=True)
|
||||||
|
|
||||||
self.hybrid_decoder = build_decoder(hybrid_decoder)
|
self.hybrid_decoder = MODELS.build(hybrid_decoder)
|
||||||
|
|
||||||
# init position decoder
|
# init position decoder
|
||||||
position_decoder.update(num_classes=self.num_classes)
|
position_decoder.update(num_classes=self.num_classes)
|
||||||
|
@ -83,7 +83,7 @@ class RobustScannerDecoder(BaseDecoder):
|
||||||
position_decoder.update(encode_value=self.encode_value)
|
position_decoder.update(encode_value=self.encode_value)
|
||||||
position_decoder.update(return_feature=True)
|
position_decoder.update(return_feature=True)
|
||||||
|
|
||||||
self.position_decoder = build_decoder(position_decoder)
|
self.position_decoder = MODELS.build(position_decoder)
|
||||||
|
|
||||||
self.fusion_module = RobustScannerFusionLayer(
|
self.fusion_module = RobustScannerFusionLayer(
|
||||||
self.dim_model if encode_value else dim_input)
|
self.dim_model if encode_value else dim_input)
|
||||||
|
|
|
@ -6,11 +6,11 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
import mmocr.utils as utils
|
import mmocr.utils as utils
|
||||||
from mmocr.models.builder import DECODERS
|
from mmocr.registry import MODELS
|
||||||
from .base_decoder import BaseDecoder
|
from .base_decoder import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@DECODERS.register_module()
|
@MODELS.register_module()
|
||||||
class ParallelSARDecoder(BaseDecoder):
|
class ParallelSARDecoder(BaseDecoder):
|
||||||
"""Implementation Parallel Decoder module in `SAR.
|
"""Implementation Parallel Decoder module in `SAR.
|
||||||
|
|
||||||
|
@ -255,7 +255,7 @@ class ParallelSARDecoder(BaseDecoder):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@DECODERS.register_module()
|
@MODELS.register_module()
|
||||||
class SequentialSARDecoder(BaseDecoder):
|
class SequentialSARDecoder(BaseDecoder):
|
||||||
"""Implementation Sequential Decoder module in `SAR.
|
"""Implementation Sequential Decoder module in `SAR.
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
import mmocr.utils as utils
|
import mmocr.utils as utils
|
||||||
from mmocr.models.builder import DECODERS
|
from mmocr.registry import MODELS
|
||||||
from . import ParallelSARDecoder
|
from . import ParallelSARDecoder
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ class DecodeNode:
|
||||||
return accu_score
|
return accu_score
|
||||||
|
|
||||||
|
|
||||||
@DECODERS.register_module()
|
@MODELS.register_module()
|
||||||
class ParallelSARDecoderWithBS(ParallelSARDecoder):
|
class ParallelSARDecoderWithBS(ParallelSARDecoder):
|
||||||
"""Parallel Decoder module with beam-search in SAR.
|
"""Parallel Decoder module with beam-search in SAR.
|
||||||
|
|
||||||
|
|
|
@ -5,12 +5,12 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from mmocr.models.builder import DECODERS
|
|
||||||
from mmocr.models.textrecog.layers import DotProductAttentionLayer
|
from mmocr.models.textrecog.layers import DotProductAttentionLayer
|
||||||
|
from mmocr.registry import MODELS
|
||||||
from .base_decoder import BaseDecoder
|
from .base_decoder import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
@DECODERS.register_module()
|
@MODELS.register_module()
|
||||||
class SequenceAttentionDecoder(BaseDecoder):
|
class SequenceAttentionDecoder(BaseDecoder):
|
||||||
"""Sequence attention decoder for RobustScanner.
|
"""Sequence attention decoder for RobustScanner.
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import ENCODERS, build_decoder, build_encoder
|
from mmocr.registry import MODELS
|
||||||
from .base_encoder import BaseEncoder
|
from .base_encoder import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
@ENCODERS.register_module()
|
@MODELS.register_module()
|
||||||
class ABIVisionModel(BaseEncoder):
|
class ABIVisionModel(BaseEncoder):
|
||||||
"""A wrapper of visual feature encoder and language token decoder that
|
"""A wrapper of visual feature encoder and language token decoder that
|
||||||
converts visual features into text tokens.
|
converts visual features into text tokens.
|
||||||
|
@ -23,8 +23,8 @@ class ABIVisionModel(BaseEncoder):
|
||||||
init_cfg=dict(type='Xavier', layer='Conv2d'),
|
init_cfg=dict(type='Xavier', layer='Conv2d'),
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(init_cfg=init_cfg)
|
super().__init__(init_cfg=init_cfg)
|
||||||
self.encoder = build_encoder(encoder)
|
self.encoder = MODELS.build(encoder)
|
||||||
self.decoder = build_decoder(decoder)
|
self.decoder = MODELS.build(decoder)
|
||||||
|
|
||||||
def forward(self, feat, img_metas=None):
|
def forward(self, feat, img_metas=None):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
from mmocr.models.builder import ENCODERS
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@ENCODERS.register_module()
|
@MODELS.register_module()
|
||||||
class BaseEncoder(BaseModule):
|
class BaseEncoder(BaseModule):
|
||||||
"""Base Encoder class for text recognition."""
|
"""Base Encoder class for text recognition."""
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from mmocr.models.builder import ENCODERS
|
from mmocr.registry import MODELS
|
||||||
from .base_encoder import BaseEncoder
|
from .base_encoder import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
@ENCODERS.register_module()
|
@MODELS.register_module()
|
||||||
class ChannelReductionEncoder(BaseEncoder):
|
class ChannelReductionEncoder(BaseEncoder):
|
||||||
"""Change the channel number with a one by one convoluational layer.
|
"""Change the channel number with a one by one convoluational layer.
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,12 @@ import math
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.runner import ModuleList
|
from mmcv.runner import ModuleList
|
||||||
|
|
||||||
from mmocr.models.builder import ENCODERS
|
|
||||||
from mmocr.models.common import TFEncoderLayer
|
from mmocr.models.common import TFEncoderLayer
|
||||||
|
from mmocr.registry import MODELS
|
||||||
from .base_encoder import BaseEncoder
|
from .base_encoder import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
@ENCODERS.register_module()
|
@MODELS.register_module()
|
||||||
class NRTREncoder(BaseEncoder):
|
class NRTREncoder(BaseEncoder):
|
||||||
"""Transformer Encoder block with self attention mechanism.
|
"""Transformer Encoder block with self attention mechanism.
|
||||||
|
|
||||||
|
|
|
@ -6,11 +6,11 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
import mmocr.utils as utils
|
import mmocr.utils as utils
|
||||||
from mmocr.models.builder import ENCODERS
|
from mmocr.registry import MODELS
|
||||||
from .base_encoder import BaseEncoder
|
from .base_encoder import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
@ENCODERS.register_module()
|
@MODELS.register_module()
|
||||||
class SAREncoder(BaseEncoder):
|
class SAREncoder(BaseEncoder):
|
||||||
"""Implementation of encoder module in `SAR.
|
"""Implementation of encoder module in `SAR.
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,13 @@ import math
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.runner import ModuleList
|
from mmcv.runner import ModuleList
|
||||||
|
|
||||||
from mmocr.models.builder import ENCODERS
|
|
||||||
from mmocr.models.textrecog.layers import (Adaptive2DPositionalEncoding,
|
from mmocr.models.textrecog.layers import (Adaptive2DPositionalEncoding,
|
||||||
SatrnEncoderLayer)
|
SatrnEncoderLayer)
|
||||||
|
from mmocr.registry import MODELS
|
||||||
from .base_encoder import BaseEncoder
|
from .base_encoder import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
@ENCODERS.register_module()
|
@MODELS.register_module()
|
||||||
class SatrnEncoder(BaseEncoder):
|
class SatrnEncoder(BaseEncoder):
|
||||||
"""Implement encoder for SATRN, see `SATRN.
|
"""Implement encoder for SATRN, see `SATRN.
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,11 @@ import copy
|
||||||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
|
||||||
from mmcv.runner import BaseModule, ModuleList
|
from mmcv.runner import BaseModule, ModuleList
|
||||||
|
|
||||||
from mmocr.models.builder import ENCODERS
|
|
||||||
from mmocr.models.common.modules import PositionalEncoding
|
from mmocr.models.common.modules import PositionalEncoding
|
||||||
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@ENCODERS.register_module()
|
@MODELS.register_module()
|
||||||
class TransformerEncoder(BaseModule):
|
class TransformerEncoder(BaseModule):
|
||||||
"""Implement transformer encoder for text recognition, modified from
|
"""Implement transformer encoder for text recognition, modified from
|
||||||
`<https://github.com/FangShancheng/ABINet>`.
|
`<https://github.com/FangShancheng/ABINet>`.
|
||||||
|
|
|
@ -3,10 +3,10 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
from mmocr.models.builder import FUSERS
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@FUSERS.register_module()
|
@MODELS.register_module()
|
||||||
class ABIFuser(BaseModule):
|
class ABIFuser(BaseModule):
|
||||||
"""Mix and align visual feature and linguistic feature Implementation of
|
"""Mix and align visual feature and linguistic feature Implementation of
|
||||||
language model of `ABINet <https://arxiv.org/abs/1910.04396>`_.
|
language model of `ABINet <https://arxiv.org/abs/1910.04396>`_.
|
||||||
|
|
|
@ -4,10 +4,10 @@ from mmcv.cnn import ConvModule
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mmocr.models.builder import HEADS
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@HEADS.register_module()
|
@MODELS.register_module()
|
||||||
class SegHead(BaseModule):
|
class SegHead(BaseModule):
|
||||||
"""Head for segmentation based text recognition.
|
"""Head for segmentation based text recognition.
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from mmocr.models.builder import LOSSES
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class CELoss(nn.Module):
|
class CELoss(nn.Module):
|
||||||
"""Implementation of loss module for encoder-decoder based text recognition
|
"""Implementation of loss module for encoder-decoder based text recognition
|
||||||
method with CrossEntropy loss.
|
method with CrossEntropy loss.
|
||||||
|
@ -63,7 +63,7 @@ class CELoss(nn.Module):
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class SARLoss(CELoss):
|
class SARLoss(CELoss):
|
||||||
"""Implementation of loss module in `SAR.
|
"""Implementation of loss module in `SAR.
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ class SARLoss(CELoss):
|
||||||
return outputs, targets
|
return outputs, targets
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class TFLoss(CELoss):
|
class TFLoss(CELoss):
|
||||||
"""Implementation of loss module for transformer.
|
"""Implementation of loss module for transformer.
|
||||||
|
|
||||||
|
|
|
@ -4,10 +4,10 @@ import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from mmocr.models.builder import LOSSES
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class CTCLoss(nn.Module):
|
class CTCLoss(nn.Module):
|
||||||
"""Implementation of loss module for CTC-loss based text recognition.
|
"""Implementation of loss module for CTC-loss based text recognition.
|
||||||
|
|
||||||
|
|
|
@ -3,10 +3,10 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from mmocr.models.builder import LOSSES
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class ABILoss(nn.Module):
|
class ABILoss(nn.Module):
|
||||||
"""Implementation of ABINet multiloss that allows mixing different types of
|
"""Implementation of ABINet multiloss that allows mixing different types of
|
||||||
losses with weights.
|
losses with weights.
|
||||||
|
|
|
@ -3,10 +3,10 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from mmocr.models.builder import LOSSES
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class SegLoss(nn.Module):
|
class SegLoss(nn.Module):
|
||||||
"""Implementation of loss module for segmentation based text recognition
|
"""Implementation of loss module for segmentation based text recognition
|
||||||
method.
|
method.
|
||||||
|
|
|
@ -4,10 +4,10 @@ import torch.nn.functional as F
|
||||||
from mmcv.cnn import ConvModule
|
from mmcv.cnn import ConvModule
|
||||||
from mmcv.runner import BaseModule, ModuleList
|
from mmcv.runner import BaseModule, ModuleList
|
||||||
|
|
||||||
from mmocr.models.builder import NECKS
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@NECKS.register_module()
|
@MODELS.register_module()
|
||||||
class FPNOCR(BaseModule):
|
class FPNOCR(BaseModule):
|
||||||
"""FPN-like Network for segmentation based text recognition.
|
"""FPN-like Network for segmentation based text recognition.
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
from mmocr.models.builder import PREPROCESSOR
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@PREPROCESSOR.register_module()
|
@MODELS.register_module()
|
||||||
class BasePreprocessor(BaseModule):
|
class BasePreprocessor(BaseModule):
|
||||||
"""Base Preprocessor class for text recognition."""
|
"""Base Preprocessor class for text recognition."""
|
||||||
|
|
||||||
|
|
|
@ -17,11 +17,11 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from mmocr.models.builder import PREPROCESSOR
|
from mmocr.registry import MODELS
|
||||||
from .base_preprocessor import BasePreprocessor
|
from .base_preprocessor import BasePreprocessor
|
||||||
|
|
||||||
|
|
||||||
@PREPROCESSOR.register_module()
|
@MODELS.register_module()
|
||||||
class TPSPreprocessor(BasePreprocessor):
|
class TPSPreprocessor(BasePreprocessor):
|
||||||
"""Rectification Network of RARE, namely TPS based STN in
|
"""Rectification Network of RARE, namely TPS based STN in
|
||||||
https://arxiv.org/pdf/1603.03915.pdf.
|
https://arxiv.org/pdf/1603.03915.pdf.
|
||||||
|
|
|
@ -3,13 +3,11 @@ import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmocr.models.builder import (RECOGNIZERS, build_backbone, build_convertor,
|
from mmocr.registry import MODELS
|
||||||
build_decoder, build_encoder, build_fuser,
|
|
||||||
build_loss, build_preprocessor)
|
|
||||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||||
|
|
||||||
|
|
||||||
@RECOGNIZERS.register_module()
|
@MODELS.register_module()
|
||||||
class ABINet(EncodeDecodeRecognizer):
|
class ABINet(EncodeDecodeRecognizer):
|
||||||
"""Implementation of `Read Like Humans: Autonomous, Bidirectional and
|
"""Implementation of `Read Like Humans: Autonomous, Bidirectional and
|
||||||
Iterative LanguageModeling for Scene Text Recognition.
|
Iterative LanguageModeling for Scene Text Recognition.
|
||||||
|
@ -36,21 +34,21 @@ class ABINet(EncodeDecodeRecognizer):
|
||||||
# Label convertor (str2tensor, tensor2str)
|
# Label convertor (str2tensor, tensor2str)
|
||||||
assert label_convertor is not None
|
assert label_convertor is not None
|
||||||
label_convertor.update(max_seq_len=max_seq_len)
|
label_convertor.update(max_seq_len=max_seq_len)
|
||||||
self.label_convertor = build_convertor(label_convertor)
|
self.label_convertor = MODELS.build(label_convertor)
|
||||||
|
|
||||||
# Preprocessor module, e.g., TPS
|
# Preprocessor module, e.g., TPS
|
||||||
self.preprocessor = None
|
self.preprocessor = None
|
||||||
if preprocessor is not None:
|
if preprocessor is not None:
|
||||||
self.preprocessor = build_preprocessor(preprocessor)
|
self.preprocessor = MODELS.build(preprocessor)
|
||||||
|
|
||||||
# Backbone
|
# Backbone
|
||||||
assert backbone is not None
|
assert backbone is not None
|
||||||
self.backbone = build_backbone(backbone)
|
self.backbone = MODELS.build(backbone)
|
||||||
|
|
||||||
# Encoder module
|
# Encoder module
|
||||||
self.encoder = None
|
self.encoder = None
|
||||||
if encoder is not None:
|
if encoder is not None:
|
||||||
self.encoder = build_encoder(encoder)
|
self.encoder = MODELS.build(encoder)
|
||||||
|
|
||||||
# Decoder module
|
# Decoder module
|
||||||
self.decoder = None
|
self.decoder = None
|
||||||
|
@ -59,11 +57,11 @@ class ABINet(EncodeDecodeRecognizer):
|
||||||
decoder.update(start_idx=self.label_convertor.start_idx)
|
decoder.update(start_idx=self.label_convertor.start_idx)
|
||||||
decoder.update(padding_idx=self.label_convertor.padding_idx)
|
decoder.update(padding_idx=self.label_convertor.padding_idx)
|
||||||
decoder.update(max_seq_len=max_seq_len)
|
decoder.update(max_seq_len=max_seq_len)
|
||||||
self.decoder = build_decoder(decoder)
|
self.decoder = MODELS.build(decoder)
|
||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
assert loss is not None
|
assert loss is not None
|
||||||
self.loss = build_loss(loss)
|
self.loss = MODELS.build(loss)
|
||||||
|
|
||||||
self.train_cfg = train_cfg
|
self.train_cfg = train_cfg
|
||||||
self.test_cfg = test_cfg
|
self.test_cfg = test_cfg
|
||||||
|
@ -78,7 +76,7 @@ class ABINet(EncodeDecodeRecognizer):
|
||||||
|
|
||||||
self.fuser = None
|
self.fuser = None
|
||||||
if fuser is not None:
|
if fuser is not None:
|
||||||
self.fuser = build_fuser(fuser)
|
self.fuser = MODELS.build(fuser)
|
||||||
|
|
||||||
def forward_train(self, img, img_metas):
|
def forward_train(self, img, img_metas):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import RECOGNIZERS
|
from mmocr.registry import MODELS
|
||||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||||
|
|
||||||
|
|
||||||
@RECOGNIZERS.register_module()
|
@MODELS.register_module()
|
||||||
class CRNNNet(EncodeDecodeRecognizer):
|
class CRNNNet(EncodeDecodeRecognizer):
|
||||||
"""CTC-loss based recognizer."""
|
"""CTC-loss based recognizer."""
|
||||||
|
|
|
@ -3,13 +3,11 @@ import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmocr.models.builder import (RECOGNIZERS, build_backbone, build_convertor,
|
from mmocr.registry import MODELS
|
||||||
build_decoder, build_encoder, build_loss,
|
|
||||||
build_preprocessor)
|
|
||||||
from .base import BaseRecognizer
|
from .base import BaseRecognizer
|
||||||
|
|
||||||
|
|
||||||
@RECOGNIZERS.register_module()
|
@MODELS.register_module()
|
||||||
class EncodeDecodeRecognizer(BaseRecognizer):
|
class EncodeDecodeRecognizer(BaseRecognizer):
|
||||||
"""Base class for encode-decode recognizer."""
|
"""Base class for encode-decode recognizer."""
|
||||||
|
|
||||||
|
@ -31,21 +29,21 @@ class EncodeDecodeRecognizer(BaseRecognizer):
|
||||||
# Label convertor (str2tensor, tensor2str)
|
# Label convertor (str2tensor, tensor2str)
|
||||||
assert label_convertor is not None
|
assert label_convertor is not None
|
||||||
label_convertor.update(max_seq_len=max_seq_len)
|
label_convertor.update(max_seq_len=max_seq_len)
|
||||||
self.label_convertor = build_convertor(label_convertor)
|
self.label_convertor = MODELS.build(label_convertor)
|
||||||
|
|
||||||
# Preprocessor module, e.g., TPS
|
# Preprocessor module, e.g., TPS
|
||||||
self.preprocessor = None
|
self.preprocessor = None
|
||||||
if preprocessor is not None:
|
if preprocessor is not None:
|
||||||
self.preprocessor = build_preprocessor(preprocessor)
|
self.preprocessor = MODELS.build(preprocessor)
|
||||||
|
|
||||||
# Backbone
|
# Backbone
|
||||||
assert backbone is not None
|
assert backbone is not None
|
||||||
self.backbone = build_backbone(backbone)
|
self.backbone = MODELS.build(backbone)
|
||||||
|
|
||||||
# Encoder module
|
# Encoder module
|
||||||
self.encoder = None
|
self.encoder = None
|
||||||
if encoder is not None:
|
if encoder is not None:
|
||||||
self.encoder = build_encoder(encoder)
|
self.encoder = MODELS.build(encoder)
|
||||||
|
|
||||||
# Decoder module
|
# Decoder module
|
||||||
assert decoder is not None
|
assert decoder is not None
|
||||||
|
@ -53,12 +51,12 @@ class EncodeDecodeRecognizer(BaseRecognizer):
|
||||||
decoder.update(start_idx=self.label_convertor.start_idx)
|
decoder.update(start_idx=self.label_convertor.start_idx)
|
||||||
decoder.update(padding_idx=self.label_convertor.padding_idx)
|
decoder.update(padding_idx=self.label_convertor.padding_idx)
|
||||||
decoder.update(max_seq_len=max_seq_len)
|
decoder.update(max_seq_len=max_seq_len)
|
||||||
self.decoder = build_decoder(decoder)
|
self.decoder = MODELS.build(decoder)
|
||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
assert loss is not None
|
assert loss is not None
|
||||||
loss.update(ignore_index=self.label_convertor.padding_idx)
|
loss.update(ignore_index=self.label_convertor.padding_idx)
|
||||||
self.loss = build_loss(loss)
|
self.loss = MODELS.build(loss)
|
||||||
|
|
||||||
self.train_cfg = train_cfg
|
self.train_cfg = train_cfg
|
||||||
self.test_cfg = test_cfg
|
self.test_cfg = test_cfg
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import DETECTORS
|
from mmocr.registry import MODELS
|
||||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||||
|
|
||||||
|
|
||||||
@DETECTORS.register_module()
|
@MODELS.register_module()
|
||||||
class MASTER(EncodeDecodeRecognizer):
|
class MASTER(EncodeDecodeRecognizer):
|
||||||
"""Implementation of `MASTER <https://arxiv.org/abs/1910.02562>`_"""
|
"""Implementation of `MASTER <https://arxiv.org/abs/1910.02562>`_"""
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import RECOGNIZERS
|
from mmocr.registry import MODELS
|
||||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||||
|
|
||||||
|
|
||||||
@RECOGNIZERS.register_module()
|
@MODELS.register_module()
|
||||||
class NRTR(EncodeDecodeRecognizer):
|
class NRTR(EncodeDecodeRecognizer):
|
||||||
"""Implementation of `NRTR <https://arxiv.org/pdf/1806.00926.pdf>`_"""
|
"""Implementation of `NRTR <https://arxiv.org/pdf/1806.00926.pdf>`_"""
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import RECOGNIZERS
|
from mmocr.registry import MODELS
|
||||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||||
|
|
||||||
|
|
||||||
@RECOGNIZERS.register_module()
|
@MODELS.register_module()
|
||||||
class RobustScanner(EncodeDecodeRecognizer):
|
class RobustScanner(EncodeDecodeRecognizer):
|
||||||
"""Implementation of `RobustScanner.
|
"""Implementation of `RobustScanner.
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import RECOGNIZERS
|
from mmocr.registry import MODELS
|
||||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||||
|
|
||||||
|
|
||||||
@RECOGNIZERS.register_module()
|
@MODELS.register_module()
|
||||||
class SARNet(EncodeDecodeRecognizer):
|
class SARNet(EncodeDecodeRecognizer):
|
||||||
"""Implementation of `SAR <https://arxiv.org/abs/1811.00751>`_"""
|
"""Implementation of `SAR <https://arxiv.org/abs/1811.00751>`_"""
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmocr.models.builder import RECOGNIZERS
|
from mmocr.registry import MODELS
|
||||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||||
|
|
||||||
|
|
||||||
@RECOGNIZERS.register_module()
|
@MODELS.register_module()
|
||||||
class SATRN(EncodeDecodeRecognizer):
|
class SATRN(EncodeDecodeRecognizer):
|
||||||
"""Implementation of `SATRN <https://arxiv.org/abs/1910.04396>`_"""
|
"""Implementation of `SATRN <https://arxiv.org/abs/1910.04396>`_"""
|
||||||
|
|
|
@ -1,13 +1,11 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from mmocr.models.builder import (RECOGNIZERS, build_backbone, build_convertor,
|
from mmocr.registry import MODELS
|
||||||
build_head, build_loss, build_neck,
|
|
||||||
build_preprocessor)
|
|
||||||
from .base import BaseRecognizer
|
from .base import BaseRecognizer
|
||||||
|
|
||||||
|
|
||||||
@RECOGNIZERS.register_module()
|
@MODELS.register_module()
|
||||||
class SegRecognizer(BaseRecognizer):
|
class SegRecognizer(BaseRecognizer):
|
||||||
"""Base class for segmentation based recognizer."""
|
"""Base class for segmentation based recognizer."""
|
||||||
|
|
||||||
|
@ -26,29 +24,29 @@ class SegRecognizer(BaseRecognizer):
|
||||||
|
|
||||||
# Label_convertor
|
# Label_convertor
|
||||||
assert label_convertor is not None
|
assert label_convertor is not None
|
||||||
self.label_convertor = build_convertor(label_convertor)
|
self.label_convertor = MODELS.build(label_convertor)
|
||||||
|
|
||||||
# Preprocessor module, e.g., TPS
|
# Preprocessor module, e.g., TPS
|
||||||
self.preprocessor = None
|
self.preprocessor = None
|
||||||
if preprocessor is not None:
|
if preprocessor is not None:
|
||||||
self.preprocessor = build_preprocessor(preprocessor)
|
self.preprocessor = MODELS.build(preprocessor)
|
||||||
|
|
||||||
# Backbone
|
# Backbone
|
||||||
assert backbone is not None
|
assert backbone is not None
|
||||||
self.backbone = build_backbone(backbone)
|
self.backbone = MODELS.build(backbone)
|
||||||
|
|
||||||
# Neck
|
# Neck
|
||||||
assert neck is not None
|
assert neck is not None
|
||||||
self.neck = build_neck(neck)
|
self.neck = MODELS.build(neck)
|
||||||
|
|
||||||
# Head
|
# Head
|
||||||
assert head is not None
|
assert head is not None
|
||||||
head.update(num_classes=self.label_convertor.num_classes())
|
head.update(num_classes=self.label_convertor.num_classes())
|
||||||
self.head = build_head(head)
|
self.head = MODELS.build(head)
|
||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
assert loss is not None
|
assert loss is not None
|
||||||
self.loss = build_loss(loss)
|
self.loss = MODELS.build(loss)
|
||||||
|
|
||||||
self.train_cfg = train_cfg
|
self.train_cfg = train_cfg
|
||||||
self.test_cfg = test_cfg
|
self.test_cfg = test_cfg
|
||||||
|
|
|
@ -24,9 +24,9 @@ from mmocr.apis.inference import model_inference
|
||||||
from mmocr.core.visualize import det_recog_show_result
|
from mmocr.core.visualize import det_recog_show_result
|
||||||
from mmocr.datasets.kie_dataset import KIEDataset
|
from mmocr.datasets.kie_dataset import KIEDataset
|
||||||
from mmocr.datasets.pipelines.crop import crop_img
|
from mmocr.datasets.pipelines.crop import crop_img
|
||||||
from mmocr.models import build_detector
|
|
||||||
from mmocr.models.textdet.detectors import TextDetectorMixin
|
from mmocr.models.textdet.detectors import TextDetectorMixin
|
||||||
from mmocr.models.textrecog.recognizer import BaseRecognizer
|
from mmocr.models.textrecog.recognizer import BaseRecognizer
|
||||||
|
from mmocr.registry import MODELS
|
||||||
from mmocr.utils import is_type_list
|
from mmocr.utils import is_type_list
|
||||||
from mmocr.utils.box_util import stitch_boxes_into_lines
|
from mmocr.utils.box_util import stitch_boxes_into_lines
|
||||||
from mmocr.utils.fileio import list_from_file
|
from mmocr.utils.fileio import list_from_file
|
||||||
|
@ -427,7 +427,7 @@ class MMOCR:
|
||||||
'kie/' + kie_models[self.kie]['ckpt']
|
'kie/' + kie_models[self.kie]['ckpt']
|
||||||
|
|
||||||
kie_cfg = Config.fromfile(kie_config)
|
kie_cfg = Config.fromfile(kie_config)
|
||||||
self.kie_model = build_detector(
|
self.kie_model = MODELS.build(
|
||||||
kie_cfg.model, test_cfg=kie_cfg.get('test_cfg'))
|
kie_cfg.model, test_cfg=kie_cfg.get('test_cfg'))
|
||||||
self.kie_model = revert_sync_batchnorm(self.kie_model)
|
self.kie_model = revert_sync_batchnorm(self.kie_model)
|
||||||
self.kie_model.cfg = kie_cfg
|
self.kie_model.cfg = kie_cfg
|
||||||
|
|
Loading…
Reference in New Issue