From a90b9600ce8b0bda92eb5e17c3f77b5e252f881f Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Wed, 11 May 2022 06:10:36 +0000 Subject: [PATCH] [Refactor] refactor DATASETS and TRANSFORMS --- demo/ner_demo.py | 2 +- demo/webcam_demo.py | 2 +- mmocr/apis/train.py | 5 ++-- mmocr/datasets/__init__.py | 11 ++++----- mmocr/datasets/base_dataset.py | 6 ++--- mmocr/datasets/builder.py | 15 +++--------- mmocr/datasets/icdar_dataset.py | 2 +- mmocr/datasets/kie_dataset.py | 2 +- mmocr/datasets/ner_dataset.py | 3 +-- mmocr/datasets/ocr_dataset.py | 3 +-- mmocr/datasets/ocr_seg_dataset.py | 3 +-- mmocr/datasets/openset_kie_dataset.py | 2 +- .../pipelines/custom_format_bundle.py | 4 ++-- mmocr/datasets/pipelines/dbnet_transforms.py | 7 +++--- mmocr/datasets/pipelines/kie_transforms.py | 7 +++--- mmocr/datasets/pipelines/loading.py | 9 +++---- mmocr/datasets/pipelines/ner_transforms.py | 6 ++--- mmocr/datasets/pipelines/ocr_seg_targets.py | 4 ++-- mmocr/datasets/pipelines/ocr_transforms.py | 20 ++++++++-------- mmocr/datasets/pipelines/test_time_aug.py | 5 ++-- .../textdet_targets/dbnet_targets.py | 4 ++-- .../pipelines/textdet_targets/drrg_targets.py | 4 ++-- .../textdet_targets/fcenet_targets.py | 4 ++-- .../textdet_targets/panet_targets.py | 4 ++-- .../textdet_targets/psenet_targets.py | 5 ++-- .../textdet_targets/textsnake_targets.py | 4 ++-- .../datasets/pipelines/transform_wrappers.py | 12 +++++----- mmocr/datasets/pipelines/transforms.py | 24 +++++++++---------- mmocr/datasets/text_det_dataset.py | 2 +- mmocr/datasets/uniform_concat_dataset.py | 5 ++-- mmocr/datasets/utils/loader.py | 4 ++-- old_tests/test_apis/test_model_inference.py | 2 +- old_tests/test_apis/test_single_gpu_test.py | 7 +++--- tools/benchmark_processing.py | 6 ++--- tools/deployment/deploy_test.py | 5 ++-- tools/kie_test_imgs.py | 5 ++-- tools/recog_test_imgs.py | 2 +- tools/test.py | 5 ++-- tools/train.py | 4 ++-- 39 files changed, 110 insertions(+), 116 deletions(-) diff --git a/demo/ner_demo.py b/demo/ner_demo.py index 113d4e31..f61ee390 100755 --- a/demo/ner_demo.py +++ b/demo/ner_demo.py @@ -3,8 +3,8 @@ from argparse import ArgumentParser from mmocr.apis import init_detector from mmocr.apis.inference import text_model_inference -from mmocr.datasets import build_dataset # NOQA from mmocr.models import build_detector # NOQA +from mmocr.registry import DATASETS # NOQA def main(): diff --git a/demo/webcam_demo.py b/demo/webcam_demo.py index 475c29c2..9d2d6859 100644 --- a/demo/webcam_demo.py +++ b/demo/webcam_demo.py @@ -5,8 +5,8 @@ import cv2 import torch from mmocr.apis import init_detector, model_inference -from mmocr.datasets import build_dataset # noqa: F401 from mmocr.models import build_detector # noqa: F401 +from mmocr.registry import DATASETS # noqa: F401 def parse_args(): diff --git a/mmocr/apis/train.py b/mmocr/apis/train.py index e9178e00..56bcc4b7 100644 --- a/mmocr/apis/train.py +++ b/mmocr/apis/train.py @@ -10,11 +10,12 @@ from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, Fp16OptimizerHook, OptimizerHook, build_optimizer, build_runner, get_dist_info) from mmdet.core import DistEvalHook, EvalHook -from mmdet.datasets import build_dataloader, build_dataset +from mmdet.datasets import build_dataloader from mmocr import digit_version from mmocr.apis.utils import (disable_text_recog_aug_test, replace_image_to_tensor) +from mmocr.registry import DATASETS from mmocr.utils import get_root_logger @@ -132,7 +133,7 @@ def train_detector(model, cfg = disable_text_recog_aug_test(cfg) cfg = replace_image_to_tensor(cfg) - val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) + val_dataset = DATASETS.build(cfg.data.val, dict(test_mode=True)) val_loader_cfg = { **default_loader_cfg, diff --git a/mmocr/datasets/__init__.py b/mmocr/datasets/__init__.py index c16565b1..ac121774 100644 --- a/mmocr/datasets/__init__.py +++ b/mmocr/datasets/__init__.py @@ -1,8 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmdet.datasets.builder import DATASETS, build_dataloader, build_dataset - from . import utils from .base_dataset import BaseDataset +from .builder import DATASETS, LOADERS, PARSERS, TRANSFORMS from .icdar_dataset import IcdarDataset from .kie_dataset import KIEDataset from .ner_dataset import NerDataset @@ -15,10 +14,10 @@ from .uniform_concat_dataset import UniformConcatDataset from .utils import * # NOQA __all__ = [ - 'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset', - 'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle', - 'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets', - 'NerDataset', 'UniformConcatDataset', 'OpensetKIEDataset' + 'DATASETS', 'IcdarDataset', 'BaseDataset', 'OCRDataset', 'TextDetDataset', + 'CustomFormatBundle', 'DBNetTargets', 'OCRSegDataset', 'KIEDataset', + 'FCENetTargets', 'NerDataset', 'UniformConcatDataset', 'OpensetKIEDataset', + 'TRANSFORMS', 'PARSERS', 'LOADERS' ] __all__ += utils.__all__ diff --git a/mmocr/datasets/base_dataset.py b/mmocr/datasets/base_dataset.py index 5a39bf46..27632a02 100644 --- a/mmocr/datasets/base_dataset.py +++ b/mmocr/datasets/base_dataset.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np from mmcv.utils import print_log -from mmdet.datasets.builder import DATASETS from mmdet.datasets.pipelines import Compose from torch.utils.data import Dataset -from mmocr.datasets.builder import build_loader +from mmocr.datasets.builder import LOADERS +from mmocr.registry import DATASETS @DATASETS.register_module() @@ -66,7 +66,7 @@ class BaseDataset(Dataset): self.ann_file = ann_file # load annotations loader.update(ann_file=ann_file) - self.data_infos = build_loader(loader) + self.data_infos = LOADERS.build(loader) # processing pipeline self.pipeline = Compose(pipeline) # set group flag and class, no meaning diff --git a/mmocr/datasets/builder.py b/mmocr/datasets/builder.py index 1e4cc66e..b935a236 100644 --- a/mmocr/datasets/builder.py +++ b/mmocr/datasets/builder.py @@ -1,15 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmcv.utils import Registry, build_from_cfg -LOADERS = Registry('loader') -PARSERS = Registry('parser') +from mmocr.registry import TRANSFORMS - -def build_loader(cfg): - """Build anno file loader.""" - return build_from_cfg(cfg, LOADERS) - - -def build_parser(cfg): - """Build anno file parser.""" - return build_from_cfg(cfg, PARSERS) +LOADERS = TRANSFORMS +PARSERS = TRANSFORMS diff --git a/mmocr/datasets/icdar_dataset.py b/mmocr/datasets/icdar_dataset.py index 31d14df5..77a1879f 100644 --- a/mmocr/datasets/icdar_dataset.py +++ b/mmocr/datasets/icdar_dataset.py @@ -2,12 +2,12 @@ import mmcv import numpy as np from mmdet.datasets.api_wrappers import COCO -from mmdet.datasets.builder import DATASETS from mmdet.datasets.coco import CocoDataset import mmocr.utils as utils from mmocr import digit_version from mmocr.core.evaluation.hmean import eval_hmean +from mmocr.registry import DATASETS @DATASETS.register_module() diff --git a/mmocr/datasets/kie_dataset.py b/mmocr/datasets/kie_dataset.py index bcbf324f..91fb62f9 100644 --- a/mmocr/datasets/kie_dataset.py +++ b/mmocr/datasets/kie_dataset.py @@ -5,11 +5,11 @@ from os import path as osp import numpy as np import torch -from mmdet.datasets.builder import DATASETS from mmocr.core import compute_f1_score from mmocr.datasets.base_dataset import BaseDataset from mmocr.datasets.pipelines import sort_vertex8 +from mmocr.registry import DATASETS from mmocr.utils import is_type_list, list_from_file diff --git a/mmocr/datasets/ner_dataset.py b/mmocr/datasets/ner_dataset.py index 923942c3..de084fba 100644 --- a/mmocr/datasets/ner_dataset.py +++ b/mmocr/datasets/ner_dataset.py @@ -1,8 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmdet.datasets.builder import DATASETS - from mmocr.core.evaluation.ner_metric import eval_ner_f1 from mmocr.datasets.base_dataset import BaseDataset +from mmocr.registry import DATASETS @DATASETS.register_module() diff --git a/mmocr/datasets/ocr_dataset.py b/mmocr/datasets/ocr_dataset.py index a5e39523..ce3f0fd7 100644 --- a/mmocr/datasets/ocr_dataset.py +++ b/mmocr/datasets/ocr_dataset.py @@ -1,8 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmdet.datasets.builder import DATASETS - from mmocr.core.evaluation.ocr_metric import eval_ocr_metric from mmocr.datasets.base_dataset import BaseDataset +from mmocr.registry import DATASETS from mmocr.utils import is_type_list diff --git a/mmocr/datasets/ocr_seg_dataset.py b/mmocr/datasets/ocr_seg_dataset.py index cd4b727d..064dfcfe 100644 --- a/mmocr/datasets/ocr_seg_dataset.py +++ b/mmocr/datasets/ocr_seg_dataset.py @@ -1,8 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmdet.datasets.builder import DATASETS - import mmocr.utils as utils from mmocr.datasets.ocr_dataset import OCRDataset +from mmocr.registry import DATASETS @DATASETS.register_module() diff --git a/mmocr/datasets/openset_kie_dataset.py b/mmocr/datasets/openset_kie_dataset.py index 6973e891..a2c94368 100644 --- a/mmocr/datasets/openset_kie_dataset.py +++ b/mmocr/datasets/openset_kie_dataset.py @@ -3,9 +3,9 @@ import copy import numpy as np import torch -from mmdet.datasets.builder import DATASETS from mmocr.datasets import KIEDataset +from mmocr.registry import DATASETS @DATASETS.register_module() diff --git a/mmocr/datasets/pipelines/custom_format_bundle.py b/mmocr/datasets/pipelines/custom_format_bundle.py index 7f069ad5..59bc1f16 100644 --- a/mmocr/datasets/pipelines/custom_format_bundle.py +++ b/mmocr/datasets/pipelines/custom_format_bundle.py @@ -1,13 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np from mmcv.parallel import DataContainer as DC -from mmdet.datasets.builder import PIPELINES from mmdet.datasets.pipelines.formatting import DefaultFormatBundle from mmocr.core.visualize import overlay_mask_img, show_feature +from mmocr.registry import TRANSFORMS -@PIPELINES.register_module() +@TRANSFORMS.register_module() class CustomFormatBundle(DefaultFormatBundle): """Custom formatting bundle. diff --git a/mmocr/datasets/pipelines/dbnet_transforms.py b/mmocr/datasets/pipelines/dbnet_transforms.py index b736d337..1c154ef8 100644 --- a/mmocr/datasets/pipelines/dbnet_transforms.py +++ b/mmocr/datasets/pipelines/dbnet_transforms.py @@ -4,7 +4,8 @@ import imgaug.augmenters as iaa import mmcv import numpy as np from mmdet.core.mask import PolygonMasks -from mmdet.datasets.builder import PIPELINES + +from mmocr.registry import TRANSFORMS class AugmenterBuilder: @@ -45,7 +46,7 @@ class AugmenterBuilder: return obj -@PIPELINES.register_module() +@TRANSFORMS.register_module() class ImgAug: """A wrapper to use imgaug https://github.com/aleju/imgaug. @@ -174,7 +175,7 @@ class ImgAug: return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class EastRandomCrop: def __init__(self, diff --git a/mmocr/datasets/pipelines/kie_transforms.py b/mmocr/datasets/pipelines/kie_transforms.py index 21b4de70..e8b7ba68 100644 --- a/mmocr/datasets/pipelines/kie_transforms.py +++ b/mmocr/datasets/pipelines/kie_transforms.py @@ -2,11 +2,12 @@ import numpy as np from mmcv import rescale_size from mmcv.parallel import DataContainer as DC -from mmdet.datasets.builder import PIPELINES from mmdet.datasets.pipelines.formatting import DefaultFormatBundle, to_tensor +from mmocr.registry import TRANSFORMS -@PIPELINES.register_module() + +@TRANSFORMS.register_module() class ResizeNoImg: """Image resizing without img. @@ -39,7 +40,7 @@ class ResizeNoImg: return results -@PIPELINES.register_module() +@TRANSFORMS.register_module() class KIEFormatBundle(DefaultFormatBundle): """Key information extraction formatting bundle. diff --git a/mmocr/datasets/pipelines/loading.py b/mmocr/datasets/pipelines/loading.py index f6c540c1..71e1920b 100644 --- a/mmocr/datasets/pipelines/loading.py +++ b/mmocr/datasets/pipelines/loading.py @@ -5,11 +5,12 @@ import lmdb import mmcv import numpy as np from mmdet.core import BitmapMasks, PolygonMasks -from mmdet.datasets.builder import PIPELINES from mmdet.datasets.pipelines.loading import LoadAnnotations, LoadImageFromFile +from mmocr.registry import TRANSFORMS -@PIPELINES.register_module() + +@TRANSFORMS.register_module() class LoadTextAnnotations(LoadAnnotations): """Load annotations for text detection. @@ -99,7 +100,7 @@ class LoadTextAnnotations(LoadAnnotations): return results -@PIPELINES.register_module() +@TRANSFORMS.register_module() class LoadImageFromNdarray(LoadImageFromFile): """Load an image from np.ndarray. @@ -136,7 +137,7 @@ class LoadImageFromNdarray(LoadImageFromFile): return results -@PIPELINES.register_module() +@TRANSFORMS.register_module() class LoadImageFromLMDB(object): """Load an image from lmdb file. diff --git a/mmocr/datasets/pipelines/ner_transforms.py b/mmocr/datasets/pipelines/ner_transforms.py index b26fe74b..d230fe48 100644 --- a/mmocr/datasets/pipelines/ner_transforms.py +++ b/mmocr/datasets/pipelines/ner_transforms.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from mmdet.datasets.builder import PIPELINES from mmocr.models.builder import build_convertor +from mmocr.registry import TRANSFORMS -@PIPELINES.register_module() +@TRANSFORMS.register_module() class NerTransform: """Convert text to ID and entity in ground truth to label ID. The masks and tokens are generated at the same time. The four parameters will be used as @@ -42,7 +42,7 @@ class NerTransform: return results -@PIPELINES.register_module() +@TRANSFORMS.register_module() class ToTensorNER: """Convert data with ``list`` type to tensor.""" diff --git a/mmocr/datasets/pipelines/ocr_seg_targets.py b/mmocr/datasets/pipelines/ocr_seg_targets.py index 8c9c8aba..9a4258fa 100644 --- a/mmocr/datasets/pipelines/ocr_seg_targets.py +++ b/mmocr/datasets/pipelines/ocr_seg_targets.py @@ -2,13 +2,13 @@ import cv2 import numpy as np from mmdet.core import BitmapMasks -from mmdet.datasets.builder import PIPELINES import mmocr.utils.check_argument as check_argument from mmocr.models.builder import build_convertor +from mmocr.registry import TRANSFORMS -@PIPELINES.register_module() +@TRANSFORMS.register_module() class OCRSegTargets: """Generate gt shrunk kernels for segmentation based OCR framework. diff --git a/mmocr/datasets/pipelines/ocr_transforms.py b/mmocr/datasets/pipelines/ocr_transforms.py index 9081d4b8..f04db7e0 100644 --- a/mmocr/datasets/pipelines/ocr_transforms.py +++ b/mmocr/datasets/pipelines/ocr_transforms.py @@ -6,16 +6,16 @@ import numpy as np import torch import torchvision.transforms.functional as TF from mmcv.runner.dist_utils import get_dist_info -from mmdet.datasets.builder import PIPELINES from PIL import Image from shapely.geometry import Polygon from shapely.geometry import box as shapely_box import mmocr.utils as utils from mmocr.datasets.pipelines.crop import warp_img +from mmocr.registry import TRANSFORMS -@PIPELINES.register_module() +@TRANSFORMS.register_module() class ResizeOCR: """Image resizing and padding for OCR. @@ -129,7 +129,7 @@ class ResizeOCR: return results -@PIPELINES.register_module() +@TRANSFORMS.register_module() class ToTensorOCR: """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.""" @@ -142,7 +142,7 @@ class ToTensorOCR: return results -@PIPELINES.register_module() +@TRANSFORMS.register_module() class NormalizeOCR: """Normalize a tensor image with mean and standard deviation.""" @@ -156,7 +156,7 @@ class NormalizeOCR: return results -@PIPELINES.register_module() +@TRANSFORMS.register_module() class OnlineCropOCR: """Crop text areas from whole image with bounding box jitter. If no bbox is given, return directly. @@ -216,7 +216,7 @@ class OnlineCropOCR: return results -@PIPELINES.register_module() +@TRANSFORMS.register_module() class FancyPCA: """Implementation of PCA based image augmentation, proposed in the paper ``Imagenet Classification With Deep Convolutional Neural Networks``. @@ -257,7 +257,7 @@ class FancyPCA: return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class RandomPaddingOCR: """Pad the given image on all sides, as well as modify the coordinates of character bounding box in image. @@ -319,7 +319,7 @@ class RandomPaddingOCR: return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class RandomRotateImageBox: """Rotate augmentation for segmentation based text recognition. @@ -416,7 +416,7 @@ class RandomRotateImageBox: return [new_x, new_y] -@PIPELINES.register_module() +@TRANSFORMS.register_module() class OpencvToPil: """Convert ``numpy.ndarray`` (bgr) to ``PIL Image`` (rgb).""" @@ -435,7 +435,7 @@ class OpencvToPil: return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class PilToOpencv: """Convert ``PIL Image`` (rgb) to ``numpy.ndarray`` (bgr).""" diff --git a/mmocr/datasets/pipelines/test_time_aug.py b/mmocr/datasets/pipelines/test_time_aug.py index 773ea14b..e2b52d02 100644 --- a/mmocr/datasets/pipelines/test_time_aug.py +++ b/mmocr/datasets/pipelines/test_time_aug.py @@ -1,11 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import mmcv import numpy as np -from mmdet.datasets.builder import PIPELINES from mmdet.datasets.pipelines.compose import Compose +from mmocr.registry import TRANSFORMS -@PIPELINES.register_module() + +@TRANSFORMS.register_module() class MultiRotateAugOCR: """Test-time augmentation with multiple rotations in the case that img_height > img_width. diff --git a/mmocr/datasets/pipelines/textdet_targets/dbnet_targets.py b/mmocr/datasets/pipelines/textdet_targets/dbnet_targets.py index 71088c58..ae77430d 100644 --- a/mmocr/datasets/pipelines/textdet_targets/dbnet_targets.py +++ b/mmocr/datasets/pipelines/textdet_targets/dbnet_targets.py @@ -3,13 +3,13 @@ import cv2 import numpy as np import pyclipper from mmdet.core import BitmapMasks -from mmdet.datasets.builder import PIPELINES from shapely.geometry import Polygon +from mmocr.registry import TRANSFORMS from . import BaseTextDetTargets -@PIPELINES.register_module() +@TRANSFORMS.register_module() class DBNetTargets(BaseTextDetTargets): """Generate gt shrunk text, gt threshold map, and their effective region masks to learn DBNet: Real-time Scene Text Detection with Differentiable diff --git a/mmocr/datasets/pipelines/textdet_targets/drrg_targets.py b/mmocr/datasets/pipelines/textdet_targets/drrg_targets.py index fdf3a494..7ca110df 100644 --- a/mmocr/datasets/pipelines/textdet_targets/drrg_targets.py +++ b/mmocr/datasets/pipelines/textdet_targets/drrg_targets.py @@ -3,14 +3,14 @@ import cv2 import numpy as np from lanms import merge_quadrangle_n9 as la_nms from mmdet.core import BitmapMasks -from mmdet.datasets.builder import PIPELINES from numpy.linalg import norm import mmocr.utils.check_argument as check_argument +from mmocr.registry import TRANSFORMS from .textsnake_targets import TextSnakeTargets -@PIPELINES.register_module() +@TRANSFORMS.register_module() class DRRGTargets(TextSnakeTargets): """Generate the ground truth targets of DRRG: Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection. diff --git a/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py b/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py index 2d667b58..c40d577f 100644 --- a/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py +++ b/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py @@ -1,15 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import cv2 import numpy as np -from mmdet.datasets.builder import PIPELINES from numpy.fft import fft from numpy.linalg import norm import mmocr.utils.check_argument as check_argument +from mmocr.registry import TRANSFORMS from .textsnake_targets import TextSnakeTargets -@PIPELINES.register_module() +@TRANSFORMS.register_module() class FCENetTargets(TextSnakeTargets): """Generate the ground truth targets of FCENet: Fourier Contour Embedding for Arbitrary-Shaped Text Detection. diff --git a/mmocr/datasets/pipelines/textdet_targets/panet_targets.py b/mmocr/datasets/pipelines/textdet_targets/panet_targets.py index 92449cdb..ab961d1c 100644 --- a/mmocr/datasets/pipelines/textdet_targets/panet_targets.py +++ b/mmocr/datasets/pipelines/textdet_targets/panet_targets.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmdet.core import BitmapMasks -from mmdet.datasets.builder import PIPELINES +from mmocr.registry import TRANSFORMS from . import BaseTextDetTargets -@PIPELINES.register_module() +@TRANSFORMS.register_module() class PANetTargets(BaseTextDetTargets): """Generate the ground truths for PANet: Efficient and Accurate Arbitrary- Shaped Text Detection with Pixel Aggregation Network. diff --git a/mmocr/datasets/pipelines/textdet_targets/psenet_targets.py b/mmocr/datasets/pipelines/textdet_targets/psenet_targets.py index 0bdc77fa..49577235 100644 --- a/mmocr/datasets/pipelines/textdet_targets/psenet_targets.py +++ b/mmocr/datasets/pipelines/textdet_targets/psenet_targets.py @@ -1,10 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmdet.datasets.builder import PIPELINES - +from mmocr.registry import TRANSFORMS from . import PANetTargets -@PIPELINES.register_module() +@TRANSFORMS.register_module() class PSENetTargets(PANetTargets): """Generate the ground truth targets of PSENet: Shape robust text detection with progressive scale expansion network. diff --git a/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py b/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py index 3a8e4d21..207c7500 100644 --- a/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py +++ b/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py @@ -2,14 +2,14 @@ import cv2 import numpy as np from mmdet.core import BitmapMasks -from mmdet.datasets.builder import PIPELINES from numpy.linalg import norm import mmocr.utils.check_argument as check_argument +from mmocr.registry import TRANSFORMS from . import BaseTextDetTargets -@PIPELINES.register_module() +@TRANSFORMS.register_module() class TextSnakeTargets(BaseTextDetTargets): """Generate the ground truth targets of TextSnake: TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes. diff --git a/mmocr/datasets/pipelines/transform_wrappers.py b/mmocr/datasets/pipelines/transform_wrappers.py index c85f3d11..eb94a9fd 100644 --- a/mmocr/datasets/pipelines/transform_wrappers.py +++ b/mmocr/datasets/pipelines/transform_wrappers.py @@ -5,13 +5,13 @@ import random import mmcv import numpy as np import torchvision.transforms as torchvision_transforms -from mmcv.utils import build_from_cfg -from mmdet.datasets.builder import PIPELINES from mmdet.datasets.pipelines import Compose from PIL import Image +from mmocr.registry import TRANSFORMS -@PIPELINES.register_module() + +@TRANSFORMS.register_module() class OneOfWrapper: """Randomly select and apply one of the transforms, each with the equal chance. @@ -31,7 +31,7 @@ class OneOfWrapper: self.transforms = [] for t in transforms: if isinstance(t, dict): - self.transforms.append(build_from_cfg(t, PIPELINES)) + self.transforms.append(TRANSFORMS.build(t)) elif callable(t): self.transforms.append(t) else: @@ -46,7 +46,7 @@ class OneOfWrapper: return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class RandomWrapper: """Run a transform or a sequence of transforms with probability p. @@ -71,7 +71,7 @@ class RandomWrapper: return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class TorchVisionWrapper: """A wrapper of torchvision trasnforms. It applies specific transform to ``img`` and updates ``img_shape`` accordingly. diff --git a/mmocr/datasets/pipelines/transforms.py b/mmocr/datasets/pipelines/transforms.py index 1ad1d2bc..0d7805bc 100644 --- a/mmocr/datasets/pipelines/transforms.py +++ b/mmocr/datasets/pipelines/transforms.py @@ -6,16 +6,16 @@ import mmcv import numpy as np import torchvision.transforms as transforms from mmdet.core import BitmapMasks, PolygonMasks -from mmdet.datasets.builder import PIPELINES from mmdet.datasets.pipelines.transforms import Resize from PIL import Image from shapely.geometry import Polygon as plg import mmocr.core.evaluation.utils as eval_utils +from mmocr.registry import TRANSFORMS from mmocr.utils import check_argument -@PIPELINES.register_module() +@TRANSFORMS.register_module() class RandomCropInstances: """Randomly crop images and make sure to contain text instances. @@ -176,7 +176,7 @@ class RandomCropInstances: return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class RandomRotateTextDet: """Randomly rotate images.""" @@ -223,7 +223,7 @@ class RandomRotateTextDet: return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class ColorJitter: """An interface for torch color jitter so that it can be invoked in mmdetection pipeline.""" @@ -246,7 +246,7 @@ class ColorJitter: return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class ScaleAspectJitter(Resize): """Resize image and segmentation mask encoded by coordinates. @@ -335,7 +335,7 @@ class ScaleAspectJitter(Resize): results['scale_idx'] = None -@PIPELINES.register_module() +@TRANSFORMS.register_module() class AffineJitter: """An interface for torchvision random affine so that it can be invoked in mmdet pipeline.""" @@ -370,7 +370,7 @@ class AffineJitter: return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class RandomCropPolyInstances: """Randomly crop images and make sure to contain at least one intact instance.""" @@ -513,7 +513,7 @@ class RandomCropPolyInstances: return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class RandomRotatePolyInstances: def __init__(self, @@ -639,7 +639,7 @@ class RandomRotatePolyInstances: return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class SquareResizePad: def __init__(self, @@ -737,7 +737,7 @@ class SquareResizePad: return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class RandomScaling: def __init__(self, size=800, scale=(3. / 4, 5. / 2)): @@ -774,7 +774,7 @@ class RandomScaling: return results -@PIPELINES.register_module() +@TRANSFORMS.register_module() class RandomCropFlip: def __init__(self, @@ -969,7 +969,7 @@ class RandomCropFlip: return h_axis, w_axis -@PIPELINES.register_module() +@TRANSFORMS.register_module() class PyramidRescale: """Resize the image to the base shape, downsample it with gaussian pyramid, and rescale it back to original size. diff --git a/mmocr/datasets/text_det_dataset.py b/mmocr/datasets/text_det_dataset.py index ea1610a0..19ddd0dc 100644 --- a/mmocr/datasets/text_det_dataset.py +++ b/mmocr/datasets/text_det_dataset.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np -from mmdet.datasets.builder import DATASETS from mmocr.core.evaluation.hmean import eval_hmean from mmocr.datasets.base_dataset import BaseDataset +from mmocr.registry import DATASETS @DATASETS.register_module() diff --git a/mmocr/datasets/uniform_concat_dataset.py b/mmocr/datasets/uniform_concat_dataset.py index 9fdc6063..95480741 100644 --- a/mmocr/datasets/uniform_concat_dataset.py +++ b/mmocr/datasets/uniform_concat_dataset.py @@ -4,8 +4,9 @@ from collections import defaultdict import numpy as np from mmcv.utils import print_log -from mmdet.datasets import DATASETS, ConcatDataset, build_dataset +from mmdet.datasets import ConcatDataset +from mmocr.registry import DATASETS from mmocr.utils import is_2dlist, is_type_list @@ -64,7 +65,7 @@ class UniformConcatDataset(ConcatDataset): new_datasets.extend(sub_datasets) else: new_datasets = datasets - datasets = [build_dataset(c, kwargs) for c in new_datasets] + datasets = [DATASETS.build(c, kwargs) for c in new_datasets] super().__init__(datasets, separate_eval) if not separate_eval: diff --git a/mmocr/datasets/utils/loader.py b/mmocr/datasets/utils/loader.py index ccdd7601..d962d174 100644 --- a/mmocr/datasets/utils/loader.py +++ b/mmocr/datasets/utils/loader.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings -from mmocr.datasets.builder import LOADERS, build_parser +from mmocr.datasets.builder import LOADERS, PARSERS from .backend import (HardDiskAnnFileBackend, HTTPAnnFileBackend, PetrelAnnFileBackend) @@ -46,7 +46,7 @@ class AnnFileLoader: raise ValueError('We only support using LineJsonParser ' 'to parse lmdb file. Please use LineJsonParser ' 'in the dataset config') - self.parser = build_parser(parser) + self.parser = PARSERS.build(parser) self.repeat = repeat self.ann_file_backend = self._backends[file_storage_backend]( file_format, **kwargs) diff --git a/old_tests/test_apis/test_model_inference.py b/old_tests/test_apis/test_model_inference.py index 9c09fa80..2ad73e43 100644 --- a/old_tests/test_apis/test_model_inference.py +++ b/old_tests/test_apis/test_model_inference.py @@ -6,8 +6,8 @@ import pytest from mmcv.image import imread from mmocr.apis.inference import init_detector, model_inference -from mmocr.datasets import build_dataset # noqa: F401 from mmocr.models import build_detector # noqa: F401 +from mmocr.registry import DATASETS # noqa: F401 from mmocr.utils import revert_sync_batchnorm diff --git a/old_tests/test_apis/test_single_gpu_test.py b/old_tests/test_apis/test_single_gpu_test.py index 64fd99fe..8a0735c0 100644 --- a/old_tests/test_apis/test_single_gpu_test.py +++ b/old_tests/test_apis/test_single_gpu_test.py @@ -13,8 +13,9 @@ from mmcv import Config from mmcv.parallel import MMDataParallel from mmocr.apis.test import single_gpu_test -from mmocr.datasets import build_dataloader, build_dataset +from mmocr.datasets import build_dataloader from mmocr.models import build_detector +from mmocr.registry import DATASETS from mmocr.utils import check_argument, list_to_file, revert_sync_batchnorm @@ -45,7 +46,7 @@ def generate_sample_dataloader(cfg, curr_dir, img_prefix='', ann_file=''): test.ann_file = ann_file cfg.data.workers_per_gpu = 0 cfg.data.test.datasets = [test] - dataset = build_dataset(cfg.data.test) + dataset = DATASETS.build(cfg.data.test) loader_cfg = { **dict((k, cfg.data[k]) for k in [ @@ -140,7 +141,7 @@ def gene_sdmgr_model_dataloader(cfg, dirname, curr_dir, empty_img=False): cfg.model.class_list = osp.join(curr_dir, 'data/kie_toy_dataset/class_list.txt') - dataset = build_dataset(cfg.data.test) + dataset = DATASETS.build(cfg.data.test) loader_cfg = { **dict((k, cfg.data[k]) for k in [ diff --git a/tools/benchmark_processing.py b/tools/benchmark_processing.py index 13b215ef..6d4e245b 100755 --- a/tools/benchmark_processing.py +++ b/tools/benchmark_processing.py @@ -19,9 +19,7 @@ import mmcv from mmcv import Config from mmdet.datasets import build_dataloader -from mmocr.datasets import build_dataset - -assert build_dataset is not None +from mmocr.registry import DATASETS def main(): @@ -30,7 +28,7 @@ def main(): args = parser.parse_args() cfg = Config.fromfile(args.config) - dataset = build_dataset(cfg.data.train) + dataset = DATASETS.build(cfg.data.train) # prepare data loaders if 'imgs_per_gpu' in cfg.data: diff --git a/tools/deployment/deploy_test.py b/tools/deployment/deploy_test.py index 11e0fa2d..22de9bb4 100644 --- a/tools/deployment/deploy_test.py +++ b/tools/deployment/deploy_test.py @@ -10,7 +10,8 @@ from mmdet.apis import single_gpu_test from mmocr.apis.inference import disable_text_recog_aug_test from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer, TensorRTDetector, TensorRTRecognizer) -from mmocr.datasets import build_dataloader, build_dataset +from mmocr.datasets import build_dataloader +from mmocr.registry import DATASETS def parse_args(): @@ -79,7 +80,7 @@ def main(): # build the dataloader samples_per_gpu = 1 cfg = disable_text_recog_aug_test(cfg) - dataset = build_dataset(cfg.data.test) + dataset = DATASETS.build(cfg.data.test) data_loader = build_dataloader( dataset, samples_per_gpu=samples_per_gpu, diff --git a/tools/kie_test_imgs.py b/tools/kie_test_imgs.py index caabc5d5..acb0c7c8 100755 --- a/tools/kie_test_imgs.py +++ b/tools/kie_test_imgs.py @@ -13,8 +13,9 @@ from mmcv.image import tensor2imgs from mmcv.parallel import MMDataParallel from mmcv.runner import load_checkpoint -from mmocr.datasets import build_dataloader, build_dataset +from mmocr.datasets import build_dataloader from mmocr.models import build_detector +from mmocr.registry import DATASETS def save_results(model, img_meta, gt_bboxes, result, out_dir): @@ -140,7 +141,7 @@ def main(): distributed = False # build the dataloader - dataset = build_dataset(cfg.data.test) + dataset = DATASETS.build(cfg.data.test) data_loader = build_dataloader( dataset, samples_per_gpu=1, diff --git a/tools/recog_test_imgs.py b/tools/recog_test_imgs.py index a6db8cb5..c44c2193 100755 --- a/tools/recog_test_imgs.py +++ b/tools/recog_test_imgs.py @@ -11,8 +11,8 @@ from mmcv.utils import ProgressBar from mmocr.apis import init_detector, model_inference from mmocr.core.evaluation.ocr_metric import eval_ocr_metric -from mmocr.datasets import build_dataset # noqa: F401 from mmocr.models import build_detector # noqa: F401 +from mmocr.registry import DATASETS # noqa: F401 from mmocr.utils import get_root_logger, list_from_file, list_to_file diff --git a/tools/test.py b/tools/test.py index 2774bf4c..5663c8e9 100755 --- a/tools/test.py +++ b/tools/test.py @@ -16,8 +16,9 @@ from mmdet.apis import multi_gpu_test from mmocr.apis.test import single_gpu_test from mmocr.apis.utils import (disable_text_recog_aug_test, replace_image_to_tensor) -from mmocr.datasets import build_dataloader, build_dataset +from mmocr.datasets import build_dataloader from mmocr.models import build_detector +from mmocr.registry import DATASETS from mmocr.utils import revert_sync_batchnorm, setup_multi_processes @@ -162,7 +163,7 @@ def main(): init_dist(args.launcher, **cfg.dist_params) # build the dataloader - dataset = build_dataset(cfg.data.test, dict(test_mode=True)) + dataset = DATASETS.build(cfg.data.test, dict(test_mode=True)) # step 1: give default values and override (if exist) from cfg.data default_loader_cfg = { **dict(seed=cfg.get('seed'), drop_last=False, dist=distributed), diff --git a/tools/train.py b/tools/train.py index d4c0e7a0..901ada85 100755 --- a/tools/train.py +++ b/tools/train.py @@ -16,8 +16,8 @@ from mmcv.utils import get_git_hash from mmocr import __version__ from mmocr.apis import init_random_seed, train_detector -from mmocr.datasets import build_dataset from mmocr.models import build_detector +from mmocr.registry import DATASETS from mmocr.utils import (collect_env, get_root_logger, is_2dlist, setup_multi_processes) @@ -189,7 +189,7 @@ def main(): test_cfg=cfg.get('test_cfg')) model.init_weights() - datasets = [build_dataset(cfg.data.train)] + datasets = [DATASETS.build(cfg.data.train)] if len(cfg.workflow) == 2: val_dataset = copy.deepcopy(cfg.data.val) if cfg.data.train.get('pipeline', None) is None: