mirror of https://github.com/open-mmlab/mmocr.git
[Refactor] refactor DATASETS and TRANSFORMS
parent
b5fc589320
commit
a90b9600ce
|
@ -3,8 +3,8 @@ from argparse import ArgumentParser
|
|||
|
||||
from mmocr.apis import init_detector
|
||||
from mmocr.apis.inference import text_model_inference
|
||||
from mmocr.datasets import build_dataset # NOQA
|
||||
from mmocr.models import build_detector # NOQA
|
||||
from mmocr.registry import DATASETS # NOQA
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
@ -5,8 +5,8 @@ import cv2
|
|||
import torch
|
||||
|
||||
from mmocr.apis import init_detector, model_inference
|
||||
from mmocr.datasets import build_dataset # noqa: F401
|
||||
from mmocr.models import build_detector # noqa: F401
|
||||
from mmocr.registry import DATASETS # noqa: F401
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
|
|
@ -10,11 +10,12 @@ from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner,
|
|||
Fp16OptimizerHook, OptimizerHook, build_optimizer,
|
||||
build_runner, get_dist_info)
|
||||
from mmdet.core import DistEvalHook, EvalHook
|
||||
from mmdet.datasets import build_dataloader, build_dataset
|
||||
from mmdet.datasets import build_dataloader
|
||||
|
||||
from mmocr import digit_version
|
||||
from mmocr.apis.utils import (disable_text_recog_aug_test,
|
||||
replace_image_to_tensor)
|
||||
from mmocr.registry import DATASETS
|
||||
from mmocr.utils import get_root_logger
|
||||
|
||||
|
||||
|
@ -132,7 +133,7 @@ def train_detector(model,
|
|||
cfg = disable_text_recog_aug_test(cfg)
|
||||
cfg = replace_image_to_tensor(cfg)
|
||||
|
||||
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
|
||||
val_dataset = DATASETS.build(cfg.data.val, dict(test_mode=True))
|
||||
|
||||
val_loader_cfg = {
|
||||
**default_loader_cfg,
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdet.datasets.builder import DATASETS, build_dataloader, build_dataset
|
||||
|
||||
from . import utils
|
||||
from .base_dataset import BaseDataset
|
||||
from .builder import DATASETS, LOADERS, PARSERS, TRANSFORMS
|
||||
from .icdar_dataset import IcdarDataset
|
||||
from .kie_dataset import KIEDataset
|
||||
from .ner_dataset import NerDataset
|
||||
|
@ -15,10 +14,10 @@ from .uniform_concat_dataset import UniformConcatDataset
|
|||
from .utils import * # NOQA
|
||||
|
||||
__all__ = [
|
||||
'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset',
|
||||
'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle',
|
||||
'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets',
|
||||
'NerDataset', 'UniformConcatDataset', 'OpensetKIEDataset'
|
||||
'DATASETS', 'IcdarDataset', 'BaseDataset', 'OCRDataset', 'TextDetDataset',
|
||||
'CustomFormatBundle', 'DBNetTargets', 'OCRSegDataset', 'KIEDataset',
|
||||
'FCENetTargets', 'NerDataset', 'UniformConcatDataset', 'OpensetKIEDataset',
|
||||
'TRANSFORMS', 'PARSERS', 'LOADERS'
|
||||
]
|
||||
|
||||
__all__ += utils.__all__
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
from mmcv.utils import print_log
|
||||
from mmdet.datasets.builder import DATASETS
|
||||
from mmdet.datasets.pipelines import Compose
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmocr.datasets.builder import build_loader
|
||||
from mmocr.datasets.builder import LOADERS
|
||||
from mmocr.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -66,7 +66,7 @@ class BaseDataset(Dataset):
|
|||
self.ann_file = ann_file
|
||||
# load annotations
|
||||
loader.update(ann_file=ann_file)
|
||||
self.data_infos = build_loader(loader)
|
||||
self.data_infos = LOADERS.build(loader)
|
||||
# processing pipeline
|
||||
self.pipeline = Compose(pipeline)
|
||||
# set group flag and class, no meaning
|
||||
|
|
|
@ -1,15 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.utils import Registry, build_from_cfg
|
||||
|
||||
LOADERS = Registry('loader')
|
||||
PARSERS = Registry('parser')
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
|
||||
def build_loader(cfg):
|
||||
"""Build anno file loader."""
|
||||
return build_from_cfg(cfg, LOADERS)
|
||||
|
||||
|
||||
def build_parser(cfg):
|
||||
"""Build anno file parser."""
|
||||
return build_from_cfg(cfg, PARSERS)
|
||||
LOADERS = TRANSFORMS
|
||||
PARSERS = TRANSFORMS
|
||||
|
|
|
@ -2,12 +2,12 @@
|
|||
import mmcv
|
||||
import numpy as np
|
||||
from mmdet.datasets.api_wrappers import COCO
|
||||
from mmdet.datasets.builder import DATASETS
|
||||
from mmdet.datasets.coco import CocoDataset
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr import digit_version
|
||||
from mmocr.core.evaluation.hmean import eval_hmean
|
||||
from mmocr.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
|
|
@ -5,11 +5,11 @@ from os import path as osp
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmdet.datasets.builder import DATASETS
|
||||
|
||||
from mmocr.core import compute_f1_score
|
||||
from mmocr.datasets.base_dataset import BaseDataset
|
||||
from mmocr.datasets.pipelines import sort_vertex8
|
||||
from mmocr.registry import DATASETS
|
||||
from mmocr.utils import is_type_list, list_from_file
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdet.datasets.builder import DATASETS
|
||||
|
||||
from mmocr.core.evaluation.ner_metric import eval_ner_f1
|
||||
from mmocr.datasets.base_dataset import BaseDataset
|
||||
from mmocr.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdet.datasets.builder import DATASETS
|
||||
|
||||
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
|
||||
from mmocr.datasets.base_dataset import BaseDataset
|
||||
from mmocr.registry import DATASETS
|
||||
from mmocr.utils import is_type_list
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdet.datasets.builder import DATASETS
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.datasets.ocr_dataset import OCRDataset
|
||||
from mmocr.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
|
|
@ -3,9 +3,9 @@ import copy
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmdet.datasets.builder import DATASETS
|
||||
|
||||
from mmocr.datasets import KIEDataset
|
||||
from mmocr.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
from mmcv.parallel import DataContainer as DC
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from mmdet.datasets.pipelines.formatting import DefaultFormatBundle
|
||||
|
||||
from mmocr.core.visualize import overlay_mask_img, show_feature
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class CustomFormatBundle(DefaultFormatBundle):
|
||||
"""Custom formatting bundle.
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@ import imgaug.augmenters as iaa
|
|||
import mmcv
|
||||
import numpy as np
|
||||
from mmdet.core.mask import PolygonMasks
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
|
||||
class AugmenterBuilder:
|
||||
|
@ -45,7 +46,7 @@ class AugmenterBuilder:
|
|||
return obj
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class ImgAug:
|
||||
"""A wrapper to use imgaug https://github.com/aleju/imgaug.
|
||||
|
||||
|
@ -174,7 +175,7 @@ class ImgAug:
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class EastRandomCrop:
|
||||
|
||||
def __init__(self,
|
||||
|
|
|
@ -2,11 +2,12 @@
|
|||
import numpy as np
|
||||
from mmcv import rescale_size
|
||||
from mmcv.parallel import DataContainer as DC
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from mmdet.datasets.pipelines.formatting import DefaultFormatBundle, to_tensor
|
||||
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
@PIPELINES.register_module()
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ResizeNoImg:
|
||||
"""Image resizing without img.
|
||||
|
||||
|
@ -39,7 +40,7 @@ class ResizeNoImg:
|
|||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class KIEFormatBundle(DefaultFormatBundle):
|
||||
"""Key information extraction formatting bundle.
|
||||
|
||||
|
|
|
@ -5,11 +5,12 @@ import lmdb
|
|||
import mmcv
|
||||
import numpy as np
|
||||
from mmdet.core import BitmapMasks, PolygonMasks
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from mmdet.datasets.pipelines.loading import LoadAnnotations, LoadImageFromFile
|
||||
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
@PIPELINES.register_module()
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadTextAnnotations(LoadAnnotations):
|
||||
"""Load annotations for text detection.
|
||||
|
||||
|
@ -99,7 +100,7 @@ class LoadTextAnnotations(LoadAnnotations):
|
|||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromNdarray(LoadImageFromFile):
|
||||
"""Load an image from np.ndarray.
|
||||
|
||||
|
@ -136,7 +137,7 @@ class LoadImageFromNdarray(LoadImageFromFile):
|
|||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromLMDB(object):
|
||||
"""Load an image from lmdb file.
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
|
||||
from mmocr.models.builder import build_convertor
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class NerTransform:
|
||||
"""Convert text to ID and entity in ground truth to label ID. The masks and
|
||||
tokens are generated at the same time. The four parameters will be used as
|
||||
|
@ -42,7 +42,7 @@ class NerTransform:
|
|||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class ToTensorNER:
|
||||
"""Convert data with ``list`` type to tensor."""
|
||||
|
||||
|
|
|
@ -2,13 +2,13 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
|
||||
import mmocr.utils.check_argument as check_argument
|
||||
from mmocr.models.builder import build_convertor
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class OCRSegTargets:
|
||||
"""Generate gt shrunk kernels for segmentation based OCR framework.
|
||||
|
||||
|
|
|
@ -6,16 +6,16 @@ import numpy as np
|
|||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
from mmcv.runner.dist_utils import get_dist_info
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from PIL import Image
|
||||
from shapely.geometry import Polygon
|
||||
from shapely.geometry import box as shapely_box
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.datasets.pipelines.crop import warp_img
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class ResizeOCR:
|
||||
"""Image resizing and padding for OCR.
|
||||
|
||||
|
@ -129,7 +129,7 @@ class ResizeOCR:
|
|||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class ToTensorOCR:
|
||||
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor."""
|
||||
|
||||
|
@ -142,7 +142,7 @@ class ToTensorOCR:
|
|||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class NormalizeOCR:
|
||||
"""Normalize a tensor image with mean and standard deviation."""
|
||||
|
||||
|
@ -156,7 +156,7 @@ class NormalizeOCR:
|
|||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class OnlineCropOCR:
|
||||
"""Crop text areas from whole image with bounding box jitter. If no bbox is
|
||||
given, return directly.
|
||||
|
@ -216,7 +216,7 @@ class OnlineCropOCR:
|
|||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class FancyPCA:
|
||||
"""Implementation of PCA based image augmentation, proposed in the paper
|
||||
``Imagenet Classification With Deep Convolutional Neural Networks``.
|
||||
|
@ -257,7 +257,7 @@ class FancyPCA:
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomPaddingOCR:
|
||||
"""Pad the given image on all sides, as well as modify the coordinates of
|
||||
character bounding box in image.
|
||||
|
@ -319,7 +319,7 @@ class RandomPaddingOCR:
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomRotateImageBox:
|
||||
"""Rotate augmentation for segmentation based text recognition.
|
||||
|
||||
|
@ -416,7 +416,7 @@ class RandomRotateImageBox:
|
|||
return [new_x, new_y]
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class OpencvToPil:
|
||||
"""Convert ``numpy.ndarray`` (bgr) to ``PIL Image`` (rgb)."""
|
||||
|
||||
|
@ -435,7 +435,7 @@ class OpencvToPil:
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class PilToOpencv:
|
||||
"""Convert ``PIL Image`` (rgb) to ``numpy.ndarray`` (bgr)."""
|
||||
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from mmdet.datasets.pipelines.compose import Compose
|
||||
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
@PIPELINES.register_module()
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class MultiRotateAugOCR:
|
||||
"""Test-time augmentation with multiple rotations in the case that
|
||||
img_height > img_width.
|
||||
|
|
|
@ -3,13 +3,13 @@ import cv2
|
|||
import numpy as np
|
||||
import pyclipper
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
from mmocr.registry import TRANSFORMS
|
||||
from . import BaseTextDetTargets
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class DBNetTargets(BaseTextDetTargets):
|
||||
"""Generate gt shrunk text, gt threshold map, and their effective region
|
||||
masks to learn DBNet: Real-time Scene Text Detection with Differentiable
|
||||
|
|
|
@ -3,14 +3,14 @@ import cv2
|
|||
import numpy as np
|
||||
from lanms import merge_quadrangle_n9 as la_nms
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from numpy.linalg import norm
|
||||
|
||||
import mmocr.utils.check_argument as check_argument
|
||||
from mmocr.registry import TRANSFORMS
|
||||
from .textsnake_targets import TextSnakeTargets
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class DRRGTargets(TextSnakeTargets):
|
||||
"""Generate the ground truth targets of DRRG: Deep Relational Reasoning
|
||||
Graph Network for Arbitrary Shape Text Detection.
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import cv2
|
||||
import numpy as np
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from numpy.fft import fft
|
||||
from numpy.linalg import norm
|
||||
|
||||
import mmocr.utils.check_argument as check_argument
|
||||
from mmocr.registry import TRANSFORMS
|
||||
from .textsnake_targets import TextSnakeTargets
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class FCENetTargets(TextSnakeTargets):
|
||||
"""Generate the ground truth targets of FCENet: Fourier Contour Embedding
|
||||
for Arbitrary-Shaped Text Detection.
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
|
||||
from mmocr.registry import TRANSFORMS
|
||||
from . import BaseTextDetTargets
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class PANetTargets(BaseTextDetTargets):
|
||||
"""Generate the ground truths for PANet: Efficient and Accurate Arbitrary-
|
||||
Shaped Text Detection with Pixel Aggregation Network.
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
|
||||
from mmocr.registry import TRANSFORMS
|
||||
from . import PANetTargets
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class PSENetTargets(PANetTargets):
|
||||
"""Generate the ground truth targets of PSENet: Shape robust text detection
|
||||
with progressive scale expansion network.
|
||||
|
|
|
@ -2,14 +2,14 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from numpy.linalg import norm
|
||||
|
||||
import mmocr.utils.check_argument as check_argument
|
||||
from mmocr.registry import TRANSFORMS
|
||||
from . import BaseTextDetTargets
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class TextSnakeTargets(BaseTextDetTargets):
|
||||
"""Generate the ground truth targets of TextSnake: TextSnake: A Flexible
|
||||
Representation for Detecting Text of Arbitrary Shapes.
|
||||
|
|
|
@ -5,13 +5,13 @@ import random
|
|||
import mmcv
|
||||
import numpy as np
|
||||
import torchvision.transforms as torchvision_transforms
|
||||
from mmcv.utils import build_from_cfg
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from mmdet.datasets.pipelines import Compose
|
||||
from PIL import Image
|
||||
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
@PIPELINES.register_module()
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class OneOfWrapper:
|
||||
"""Randomly select and apply one of the transforms, each with the equal
|
||||
chance.
|
||||
|
@ -31,7 +31,7 @@ class OneOfWrapper:
|
|||
self.transforms = []
|
||||
for t in transforms:
|
||||
if isinstance(t, dict):
|
||||
self.transforms.append(build_from_cfg(t, PIPELINES))
|
||||
self.transforms.append(TRANSFORMS.build(t))
|
||||
elif callable(t):
|
||||
self.transforms.append(t)
|
||||
else:
|
||||
|
@ -46,7 +46,7 @@ class OneOfWrapper:
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomWrapper:
|
||||
"""Run a transform or a sequence of transforms with probability p.
|
||||
|
||||
|
@ -71,7 +71,7 @@ class RandomWrapper:
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class TorchVisionWrapper:
|
||||
"""A wrapper of torchvision trasnforms. It applies specific transform to
|
||||
``img`` and updates ``img_shape`` accordingly.
|
||||
|
|
|
@ -6,16 +6,16 @@ import mmcv
|
|||
import numpy as np
|
||||
import torchvision.transforms as transforms
|
||||
from mmdet.core import BitmapMasks, PolygonMasks
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from mmdet.datasets.pipelines.transforms import Resize
|
||||
from PIL import Image
|
||||
from shapely.geometry import Polygon as plg
|
||||
|
||||
import mmocr.core.evaluation.utils as eval_utils
|
||||
from mmocr.registry import TRANSFORMS
|
||||
from mmocr.utils import check_argument
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomCropInstances:
|
||||
"""Randomly crop images and make sure to contain text instances.
|
||||
|
||||
|
@ -176,7 +176,7 @@ class RandomCropInstances:
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomRotateTextDet:
|
||||
"""Randomly rotate images."""
|
||||
|
||||
|
@ -223,7 +223,7 @@ class RandomRotateTextDet:
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class ColorJitter:
|
||||
"""An interface for torch color jitter so that it can be invoked in
|
||||
mmdetection pipeline."""
|
||||
|
@ -246,7 +246,7 @@ class ColorJitter:
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class ScaleAspectJitter(Resize):
|
||||
"""Resize image and segmentation mask encoded by coordinates.
|
||||
|
||||
|
@ -335,7 +335,7 @@ class ScaleAspectJitter(Resize):
|
|||
results['scale_idx'] = None
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class AffineJitter:
|
||||
"""An interface for torchvision random affine so that it can be invoked in
|
||||
mmdet pipeline."""
|
||||
|
@ -370,7 +370,7 @@ class AffineJitter:
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomCropPolyInstances:
|
||||
"""Randomly crop images and make sure to contain at least one intact
|
||||
instance."""
|
||||
|
@ -513,7 +513,7 @@ class RandomCropPolyInstances:
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomRotatePolyInstances:
|
||||
|
||||
def __init__(self,
|
||||
|
@ -639,7 +639,7 @@ class RandomRotatePolyInstances:
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class SquareResizePad:
|
||||
|
||||
def __init__(self,
|
||||
|
@ -737,7 +737,7 @@ class SquareResizePad:
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomScaling:
|
||||
|
||||
def __init__(self, size=800, scale=(3. / 4, 5. / 2)):
|
||||
|
@ -774,7 +774,7 @@ class RandomScaling:
|
|||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomCropFlip:
|
||||
|
||||
def __init__(self,
|
||||
|
@ -969,7 +969,7 @@ class RandomCropFlip:
|
|||
return h_axis, w_axis
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class PyramidRescale:
|
||||
"""Resize the image to the base shape, downsample it with gaussian pyramid,
|
||||
and rescale it back to original size.
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
from mmdet.datasets.builder import DATASETS
|
||||
|
||||
from mmocr.core.evaluation.hmean import eval_hmean
|
||||
from mmocr.datasets.base_dataset import BaseDataset
|
||||
from mmocr.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
|
|
@ -4,8 +4,9 @@ from collections import defaultdict
|
|||
|
||||
import numpy as np
|
||||
from mmcv.utils import print_log
|
||||
from mmdet.datasets import DATASETS, ConcatDataset, build_dataset
|
||||
from mmdet.datasets import ConcatDataset
|
||||
|
||||
from mmocr.registry import DATASETS
|
||||
from mmocr.utils import is_2dlist, is_type_list
|
||||
|
||||
|
||||
|
@ -64,7 +65,7 @@ class UniformConcatDataset(ConcatDataset):
|
|||
new_datasets.extend(sub_datasets)
|
||||
else:
|
||||
new_datasets = datasets
|
||||
datasets = [build_dataset(c, kwargs) for c in new_datasets]
|
||||
datasets = [DATASETS.build(c, kwargs) for c in new_datasets]
|
||||
super().__init__(datasets, separate_eval)
|
||||
|
||||
if not separate_eval:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmocr.datasets.builder import LOADERS, build_parser
|
||||
from mmocr.datasets.builder import LOADERS, PARSERS
|
||||
from .backend import (HardDiskAnnFileBackend, HTTPAnnFileBackend,
|
||||
PetrelAnnFileBackend)
|
||||
|
||||
|
@ -46,7 +46,7 @@ class AnnFileLoader:
|
|||
raise ValueError('We only support using LineJsonParser '
|
||||
'to parse lmdb file. Please use LineJsonParser '
|
||||
'in the dataset config')
|
||||
self.parser = build_parser(parser)
|
||||
self.parser = PARSERS.build(parser)
|
||||
self.repeat = repeat
|
||||
self.ann_file_backend = self._backends[file_storage_backend](
|
||||
file_format, **kwargs)
|
||||
|
|
|
@ -6,8 +6,8 @@ import pytest
|
|||
from mmcv.image import imread
|
||||
|
||||
from mmocr.apis.inference import init_detector, model_inference
|
||||
from mmocr.datasets import build_dataset # noqa: F401
|
||||
from mmocr.models import build_detector # noqa: F401
|
||||
from mmocr.registry import DATASETS # noqa: F401
|
||||
from mmocr.utils import revert_sync_batchnorm
|
||||
|
||||
|
||||
|
|
|
@ -13,8 +13,9 @@ from mmcv import Config
|
|||
from mmcv.parallel import MMDataParallel
|
||||
|
||||
from mmocr.apis.test import single_gpu_test
|
||||
from mmocr.datasets import build_dataloader, build_dataset
|
||||
from mmocr.datasets import build_dataloader
|
||||
from mmocr.models import build_detector
|
||||
from mmocr.registry import DATASETS
|
||||
from mmocr.utils import check_argument, list_to_file, revert_sync_batchnorm
|
||||
|
||||
|
||||
|
@ -45,7 +46,7 @@ def generate_sample_dataloader(cfg, curr_dir, img_prefix='', ann_file=''):
|
|||
test.ann_file = ann_file
|
||||
cfg.data.workers_per_gpu = 0
|
||||
cfg.data.test.datasets = [test]
|
||||
dataset = build_dataset(cfg.data.test)
|
||||
dataset = DATASETS.build(cfg.data.test)
|
||||
|
||||
loader_cfg = {
|
||||
**dict((k, cfg.data[k]) for k in [
|
||||
|
@ -140,7 +141,7 @@ def gene_sdmgr_model_dataloader(cfg, dirname, curr_dir, empty_img=False):
|
|||
cfg.model.class_list = osp.join(curr_dir,
|
||||
'data/kie_toy_dataset/class_list.txt')
|
||||
|
||||
dataset = build_dataset(cfg.data.test)
|
||||
dataset = DATASETS.build(cfg.data.test)
|
||||
|
||||
loader_cfg = {
|
||||
**dict((k, cfg.data[k]) for k in [
|
||||
|
|
|
@ -19,9 +19,7 @@ import mmcv
|
|||
from mmcv import Config
|
||||
from mmdet.datasets import build_dataloader
|
||||
|
||||
from mmocr.datasets import build_dataset
|
||||
|
||||
assert build_dataset is not None
|
||||
from mmocr.registry import DATASETS
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -30,7 +28,7 @@ def main():
|
|||
args = parser.parse_args()
|
||||
cfg = Config.fromfile(args.config)
|
||||
|
||||
dataset = build_dataset(cfg.data.train)
|
||||
dataset = DATASETS.build(cfg.data.train)
|
||||
|
||||
# prepare data loaders
|
||||
if 'imgs_per_gpu' in cfg.data:
|
||||
|
|
|
@ -10,7 +10,8 @@ from mmdet.apis import single_gpu_test
|
|||
from mmocr.apis.inference import disable_text_recog_aug_test
|
||||
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
|
||||
TensorRTDetector, TensorRTRecognizer)
|
||||
from mmocr.datasets import build_dataloader, build_dataset
|
||||
from mmocr.datasets import build_dataloader
|
||||
from mmocr.registry import DATASETS
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -79,7 +80,7 @@ def main():
|
|||
# build the dataloader
|
||||
samples_per_gpu = 1
|
||||
cfg = disable_text_recog_aug_test(cfg)
|
||||
dataset = build_dataset(cfg.data.test)
|
||||
dataset = DATASETS.build(cfg.data.test)
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=samples_per_gpu,
|
||||
|
|
|
@ -13,8 +13,9 @@ from mmcv.image import tensor2imgs
|
|||
from mmcv.parallel import MMDataParallel
|
||||
from mmcv.runner import load_checkpoint
|
||||
|
||||
from mmocr.datasets import build_dataloader, build_dataset
|
||||
from mmocr.datasets import build_dataloader
|
||||
from mmocr.models import build_detector
|
||||
from mmocr.registry import DATASETS
|
||||
|
||||
|
||||
def save_results(model, img_meta, gt_bboxes, result, out_dir):
|
||||
|
@ -140,7 +141,7 @@ def main():
|
|||
distributed = False
|
||||
|
||||
# build the dataloader
|
||||
dataset = build_dataset(cfg.data.test)
|
||||
dataset = DATASETS.build(cfg.data.test)
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=1,
|
||||
|
|
|
@ -11,8 +11,8 @@ from mmcv.utils import ProgressBar
|
|||
|
||||
from mmocr.apis import init_detector, model_inference
|
||||
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
|
||||
from mmocr.datasets import build_dataset # noqa: F401
|
||||
from mmocr.models import build_detector # noqa: F401
|
||||
from mmocr.registry import DATASETS # noqa: F401
|
||||
from mmocr.utils import get_root_logger, list_from_file, list_to_file
|
||||
|
||||
|
||||
|
|
|
@ -16,8 +16,9 @@ from mmdet.apis import multi_gpu_test
|
|||
from mmocr.apis.test import single_gpu_test
|
||||
from mmocr.apis.utils import (disable_text_recog_aug_test,
|
||||
replace_image_to_tensor)
|
||||
from mmocr.datasets import build_dataloader, build_dataset
|
||||
from mmocr.datasets import build_dataloader
|
||||
from mmocr.models import build_detector
|
||||
from mmocr.registry import DATASETS
|
||||
from mmocr.utils import revert_sync_batchnorm, setup_multi_processes
|
||||
|
||||
|
||||
|
@ -162,7 +163,7 @@ def main():
|
|||
init_dist(args.launcher, **cfg.dist_params)
|
||||
|
||||
# build the dataloader
|
||||
dataset = build_dataset(cfg.data.test, dict(test_mode=True))
|
||||
dataset = DATASETS.build(cfg.data.test, dict(test_mode=True))
|
||||
# step 1: give default values and override (if exist) from cfg.data
|
||||
default_loader_cfg = {
|
||||
**dict(seed=cfg.get('seed'), drop_last=False, dist=distributed),
|
||||
|
|
|
@ -16,8 +16,8 @@ from mmcv.utils import get_git_hash
|
|||
|
||||
from mmocr import __version__
|
||||
from mmocr.apis import init_random_seed, train_detector
|
||||
from mmocr.datasets import build_dataset
|
||||
from mmocr.models import build_detector
|
||||
from mmocr.registry import DATASETS
|
||||
from mmocr.utils import (collect_env, get_root_logger, is_2dlist,
|
||||
setup_multi_processes)
|
||||
|
||||
|
@ -189,7 +189,7 @@ def main():
|
|||
test_cfg=cfg.get('test_cfg'))
|
||||
model.init_weights()
|
||||
|
||||
datasets = [build_dataset(cfg.data.train)]
|
||||
datasets = [DATASETS.build(cfg.data.train)]
|
||||
if len(cfg.workflow) == 2:
|
||||
val_dataset = copy.deepcopy(cfg.data.val)
|
||||
if cfg.data.train.get('pipeline', None) is None:
|
||||
|
|
Loading…
Reference in New Issue