[Refactor] refactor DATASETS and TRANSFORMS

pull/1178/head
liukuikun 2022-05-11 06:10:36 +00:00 committed by gaotongxiao
parent b5fc589320
commit a90b9600ce
39 changed files with 110 additions and 116 deletions

View File

@ -3,8 +3,8 @@ from argparse import ArgumentParser
from mmocr.apis import init_detector
from mmocr.apis.inference import text_model_inference
from mmocr.datasets import build_dataset # NOQA
from mmocr.models import build_detector # NOQA
from mmocr.registry import DATASETS # NOQA
def main():

View File

@ -5,8 +5,8 @@ import cv2
import torch
from mmocr.apis import init_detector, model_inference
from mmocr.datasets import build_dataset # noqa: F401
from mmocr.models import build_detector # noqa: F401
from mmocr.registry import DATASETS # noqa: F401
def parse_args():

View File

@ -10,11 +10,12 @@ from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer,
build_runner, get_dist_info)
from mmdet.core import DistEvalHook, EvalHook
from mmdet.datasets import build_dataloader, build_dataset
from mmdet.datasets import build_dataloader
from mmocr import digit_version
from mmocr.apis.utils import (disable_text_recog_aug_test,
replace_image_to_tensor)
from mmocr.registry import DATASETS
from mmocr.utils import get_root_logger
@ -132,7 +133,7 @@ def train_detector(model,
cfg = disable_text_recog_aug_test(cfg)
cfg = replace_image_to_tensor(cfg)
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_dataset = DATASETS.build(cfg.data.val, dict(test_mode=True))
val_loader_cfg = {
**default_loader_cfg,

View File

@ -1,8 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.datasets.builder import DATASETS, build_dataloader, build_dataset
from . import utils
from .base_dataset import BaseDataset
from .builder import DATASETS, LOADERS, PARSERS, TRANSFORMS
from .icdar_dataset import IcdarDataset
from .kie_dataset import KIEDataset
from .ner_dataset import NerDataset
@ -15,10 +14,10 @@ from .uniform_concat_dataset import UniformConcatDataset
from .utils import * # NOQA
__all__ = [
'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset',
'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle',
'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets',
'NerDataset', 'UniformConcatDataset', 'OpensetKIEDataset'
'DATASETS', 'IcdarDataset', 'BaseDataset', 'OCRDataset', 'TextDetDataset',
'CustomFormatBundle', 'DBNetTargets', 'OCRSegDataset', 'KIEDataset',
'FCENetTargets', 'NerDataset', 'UniformConcatDataset', 'OpensetKIEDataset',
'TRANSFORMS', 'PARSERS', 'LOADERS'
]
__all__ += utils.__all__

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmcv.utils import print_log
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.pipelines import Compose
from torch.utils.data import Dataset
from mmocr.datasets.builder import build_loader
from mmocr.datasets.builder import LOADERS
from mmocr.registry import DATASETS
@DATASETS.register_module()
@ -66,7 +66,7 @@ class BaseDataset(Dataset):
self.ann_file = ann_file
# load annotations
loader.update(ann_file=ann_file)
self.data_infos = build_loader(loader)
self.data_infos = LOADERS.build(loader)
# processing pipeline
self.pipeline = Compose(pipeline)
# set group flag and class, no meaning

View File

@ -1,15 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry, build_from_cfg
LOADERS = Registry('loader')
PARSERS = Registry('parser')
from mmocr.registry import TRANSFORMS
def build_loader(cfg):
"""Build anno file loader."""
return build_from_cfg(cfg, LOADERS)
def build_parser(cfg):
"""Build anno file parser."""
return build_from_cfg(cfg, PARSERS)
LOADERS = TRANSFORMS
PARSERS = TRANSFORMS

View File

@ -2,12 +2,12 @@
import mmcv
import numpy as np
from mmdet.datasets.api_wrappers import COCO
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.coco import CocoDataset
import mmocr.utils as utils
from mmocr import digit_version
from mmocr.core.evaluation.hmean import eval_hmean
from mmocr.registry import DATASETS
@DATASETS.register_module()

View File

@ -5,11 +5,11 @@ from os import path as osp
import numpy as np
import torch
from mmdet.datasets.builder import DATASETS
from mmocr.core import compute_f1_score
from mmocr.datasets.base_dataset import BaseDataset
from mmocr.datasets.pipelines import sort_vertex8
from mmocr.registry import DATASETS
from mmocr.utils import is_type_list, list_from_file

View File

@ -1,8 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.datasets.builder import DATASETS
from mmocr.core.evaluation.ner_metric import eval_ner_f1
from mmocr.datasets.base_dataset import BaseDataset
from mmocr.registry import DATASETS
@DATASETS.register_module()

View File

@ -1,8 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.datasets.builder import DATASETS
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
from mmocr.datasets.base_dataset import BaseDataset
from mmocr.registry import DATASETS
from mmocr.utils import is_type_list

View File

@ -1,8 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.datasets.builder import DATASETS
import mmocr.utils as utils
from mmocr.datasets.ocr_dataset import OCRDataset
from mmocr.registry import DATASETS
@DATASETS.register_module()

View File

@ -3,9 +3,9 @@ import copy
import numpy as np
import torch
from mmdet.datasets.builder import DATASETS
from mmocr.datasets import KIEDataset
from mmocr.registry import DATASETS
@DATASETS.register_module()

View File

@ -1,13 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmcv.parallel import DataContainer as DC
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines.formatting import DefaultFormatBundle
from mmocr.core.visualize import overlay_mask_img, show_feature
from mmocr.registry import TRANSFORMS
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class CustomFormatBundle(DefaultFormatBundle):
"""Custom formatting bundle.

View File

@ -4,7 +4,8 @@ import imgaug.augmenters as iaa
import mmcv
import numpy as np
from mmdet.core.mask import PolygonMasks
from mmdet.datasets.builder import PIPELINES
from mmocr.registry import TRANSFORMS
class AugmenterBuilder:
@ -45,7 +46,7 @@ class AugmenterBuilder:
return obj
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class ImgAug:
"""A wrapper to use imgaug https://github.com/aleju/imgaug.
@ -174,7 +175,7 @@ class ImgAug:
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class EastRandomCrop:
def __init__(self,

View File

@ -2,11 +2,12 @@
import numpy as np
from mmcv import rescale_size
from mmcv.parallel import DataContainer as DC
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines.formatting import DefaultFormatBundle, to_tensor
from mmocr.registry import TRANSFORMS
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class ResizeNoImg:
"""Image resizing without img.
@ -39,7 +40,7 @@ class ResizeNoImg:
return results
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class KIEFormatBundle(DefaultFormatBundle):
"""Key information extraction formatting bundle.

View File

@ -5,11 +5,12 @@ import lmdb
import mmcv
import numpy as np
from mmdet.core import BitmapMasks, PolygonMasks
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines.loading import LoadAnnotations, LoadImageFromFile
from mmocr.registry import TRANSFORMS
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class LoadTextAnnotations(LoadAnnotations):
"""Load annotations for text detection.
@ -99,7 +100,7 @@ class LoadTextAnnotations(LoadAnnotations):
return results
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class LoadImageFromNdarray(LoadImageFromFile):
"""Load an image from np.ndarray.
@ -136,7 +137,7 @@ class LoadImageFromNdarray(LoadImageFromFile):
return results
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class LoadImageFromLMDB(object):
"""Load an image from lmdb file.

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.datasets.builder import PIPELINES
from mmocr.models.builder import build_convertor
from mmocr.registry import TRANSFORMS
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class NerTransform:
"""Convert text to ID and entity in ground truth to label ID. The masks and
tokens are generated at the same time. The four parameters will be used as
@ -42,7 +42,7 @@ class NerTransform:
return results
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class ToTensorNER:
"""Convert data with ``list`` type to tensor."""

View File

@ -2,13 +2,13 @@
import cv2
import numpy as np
from mmdet.core import BitmapMasks
from mmdet.datasets.builder import PIPELINES
import mmocr.utils.check_argument as check_argument
from mmocr.models.builder import build_convertor
from mmocr.registry import TRANSFORMS
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class OCRSegTargets:
"""Generate gt shrunk kernels for segmentation based OCR framework.

View File

@ -6,16 +6,16 @@ import numpy as np
import torch
import torchvision.transforms.functional as TF
from mmcv.runner.dist_utils import get_dist_info
from mmdet.datasets.builder import PIPELINES
from PIL import Image
from shapely.geometry import Polygon
from shapely.geometry import box as shapely_box
import mmocr.utils as utils
from mmocr.datasets.pipelines.crop import warp_img
from mmocr.registry import TRANSFORMS
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class ResizeOCR:
"""Image resizing and padding for OCR.
@ -129,7 +129,7 @@ class ResizeOCR:
return results
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class ToTensorOCR:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor."""
@ -142,7 +142,7 @@ class ToTensorOCR:
return results
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class NormalizeOCR:
"""Normalize a tensor image with mean and standard deviation."""
@ -156,7 +156,7 @@ class NormalizeOCR:
return results
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class OnlineCropOCR:
"""Crop text areas from whole image with bounding box jitter. If no bbox is
given, return directly.
@ -216,7 +216,7 @@ class OnlineCropOCR:
return results
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class FancyPCA:
"""Implementation of PCA based image augmentation, proposed in the paper
``Imagenet Classification With Deep Convolutional Neural Networks``.
@ -257,7 +257,7 @@ class FancyPCA:
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class RandomPaddingOCR:
"""Pad the given image on all sides, as well as modify the coordinates of
character bounding box in image.
@ -319,7 +319,7 @@ class RandomPaddingOCR:
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class RandomRotateImageBox:
"""Rotate augmentation for segmentation based text recognition.
@ -416,7 +416,7 @@ class RandomRotateImageBox:
return [new_x, new_y]
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class OpencvToPil:
"""Convert ``numpy.ndarray`` (bgr) to ``PIL Image`` (rgb)."""
@ -435,7 +435,7 @@ class OpencvToPil:
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class PilToOpencv:
"""Convert ``PIL Image`` (rgb) to ``numpy.ndarray`` (bgr)."""

View File

@ -1,11 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines.compose import Compose
from mmocr.registry import TRANSFORMS
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class MultiRotateAugOCR:
"""Test-time augmentation with multiple rotations in the case that
img_height > img_width.

View File

@ -3,13 +3,13 @@ import cv2
import numpy as np
import pyclipper
from mmdet.core import BitmapMasks
from mmdet.datasets.builder import PIPELINES
from shapely.geometry import Polygon
from mmocr.registry import TRANSFORMS
from . import BaseTextDetTargets
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class DBNetTargets(BaseTextDetTargets):
"""Generate gt shrunk text, gt threshold map, and their effective region
masks to learn DBNet: Real-time Scene Text Detection with Differentiable

View File

@ -3,14 +3,14 @@ import cv2
import numpy as np
from lanms import merge_quadrangle_n9 as la_nms
from mmdet.core import BitmapMasks
from mmdet.datasets.builder import PIPELINES
from numpy.linalg import norm
import mmocr.utils.check_argument as check_argument
from mmocr.registry import TRANSFORMS
from .textsnake_targets import TextSnakeTargets
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class DRRGTargets(TextSnakeTargets):
"""Generate the ground truth targets of DRRG: Deep Relational Reasoning
Graph Network for Arbitrary Shape Text Detection.

View File

@ -1,15 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import numpy as np
from mmdet.datasets.builder import PIPELINES
from numpy.fft import fft
from numpy.linalg import norm
import mmocr.utils.check_argument as check_argument
from mmocr.registry import TRANSFORMS
from .textsnake_targets import TextSnakeTargets
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class FCENetTargets(TextSnakeTargets):
"""Generate the ground truth targets of FCENet: Fourier Contour Embedding
for Arbitrary-Shaped Text Detection.

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.core import BitmapMasks
from mmdet.datasets.builder import PIPELINES
from mmocr.registry import TRANSFORMS
from . import BaseTextDetTargets
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class PANetTargets(BaseTextDetTargets):
"""Generate the ground truths for PANet: Efficient and Accurate Arbitrary-
Shaped Text Detection with Pixel Aggregation Network.

View File

@ -1,10 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.datasets.builder import PIPELINES
from mmocr.registry import TRANSFORMS
from . import PANetTargets
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class PSENetTargets(PANetTargets):
"""Generate the ground truth targets of PSENet: Shape robust text detection
with progressive scale expansion network.

View File

@ -2,14 +2,14 @@
import cv2
import numpy as np
from mmdet.core import BitmapMasks
from mmdet.datasets.builder import PIPELINES
from numpy.linalg import norm
import mmocr.utils.check_argument as check_argument
from mmocr.registry import TRANSFORMS
from . import BaseTextDetTargets
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class TextSnakeTargets(BaseTextDetTargets):
"""Generate the ground truth targets of TextSnake: TextSnake: A Flexible
Representation for Detecting Text of Arbitrary Shapes.

View File

@ -5,13 +5,13 @@ import random
import mmcv
import numpy as np
import torchvision.transforms as torchvision_transforms
from mmcv.utils import build_from_cfg
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import Compose
from PIL import Image
from mmocr.registry import TRANSFORMS
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class OneOfWrapper:
"""Randomly select and apply one of the transforms, each with the equal
chance.
@ -31,7 +31,7 @@ class OneOfWrapper:
self.transforms = []
for t in transforms:
if isinstance(t, dict):
self.transforms.append(build_from_cfg(t, PIPELINES))
self.transforms.append(TRANSFORMS.build(t))
elif callable(t):
self.transforms.append(t)
else:
@ -46,7 +46,7 @@ class OneOfWrapper:
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class RandomWrapper:
"""Run a transform or a sequence of transforms with probability p.
@ -71,7 +71,7 @@ class RandomWrapper:
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class TorchVisionWrapper:
"""A wrapper of torchvision trasnforms. It applies specific transform to
``img`` and updates ``img_shape`` accordingly.

View File

@ -6,16 +6,16 @@ import mmcv
import numpy as np
import torchvision.transforms as transforms
from mmdet.core import BitmapMasks, PolygonMasks
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines.transforms import Resize
from PIL import Image
from shapely.geometry import Polygon as plg
import mmocr.core.evaluation.utils as eval_utils
from mmocr.registry import TRANSFORMS
from mmocr.utils import check_argument
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class RandomCropInstances:
"""Randomly crop images and make sure to contain text instances.
@ -176,7 +176,7 @@ class RandomCropInstances:
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class RandomRotateTextDet:
"""Randomly rotate images."""
@ -223,7 +223,7 @@ class RandomRotateTextDet:
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class ColorJitter:
"""An interface for torch color jitter so that it can be invoked in
mmdetection pipeline."""
@ -246,7 +246,7 @@ class ColorJitter:
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class ScaleAspectJitter(Resize):
"""Resize image and segmentation mask encoded by coordinates.
@ -335,7 +335,7 @@ class ScaleAspectJitter(Resize):
results['scale_idx'] = None
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class AffineJitter:
"""An interface for torchvision random affine so that it can be invoked in
mmdet pipeline."""
@ -370,7 +370,7 @@ class AffineJitter:
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class RandomCropPolyInstances:
"""Randomly crop images and make sure to contain at least one intact
instance."""
@ -513,7 +513,7 @@ class RandomCropPolyInstances:
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class RandomRotatePolyInstances:
def __init__(self,
@ -639,7 +639,7 @@ class RandomRotatePolyInstances:
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class SquareResizePad:
def __init__(self,
@ -737,7 +737,7 @@ class SquareResizePad:
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class RandomScaling:
def __init__(self, size=800, scale=(3. / 4, 5. / 2)):
@ -774,7 +774,7 @@ class RandomScaling:
return results
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class RandomCropFlip:
def __init__(self,
@ -969,7 +969,7 @@ class RandomCropFlip:
return h_axis, w_axis
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class PyramidRescale:
"""Resize the image to the base shape, downsample it with gaussian pyramid,
and rescale it back to original size.

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmdet.datasets.builder import DATASETS
from mmocr.core.evaluation.hmean import eval_hmean
from mmocr.datasets.base_dataset import BaseDataset
from mmocr.registry import DATASETS
@DATASETS.register_module()

View File

@ -4,8 +4,9 @@ from collections import defaultdict
import numpy as np
from mmcv.utils import print_log
from mmdet.datasets import DATASETS, ConcatDataset, build_dataset
from mmdet.datasets import ConcatDataset
from mmocr.registry import DATASETS
from mmocr.utils import is_2dlist, is_type_list
@ -64,7 +65,7 @@ class UniformConcatDataset(ConcatDataset):
new_datasets.extend(sub_datasets)
else:
new_datasets = datasets
datasets = [build_dataset(c, kwargs) for c in new_datasets]
datasets = [DATASETS.build(c, kwargs) for c in new_datasets]
super().__init__(datasets, separate_eval)
if not separate_eval:

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from mmocr.datasets.builder import LOADERS, build_parser
from mmocr.datasets.builder import LOADERS, PARSERS
from .backend import (HardDiskAnnFileBackend, HTTPAnnFileBackend,
PetrelAnnFileBackend)
@ -46,7 +46,7 @@ class AnnFileLoader:
raise ValueError('We only support using LineJsonParser '
'to parse lmdb file. Please use LineJsonParser '
'in the dataset config')
self.parser = build_parser(parser)
self.parser = PARSERS.build(parser)
self.repeat = repeat
self.ann_file_backend = self._backends[file_storage_backend](
file_format, **kwargs)

View File

@ -6,8 +6,8 @@ import pytest
from mmcv.image import imread
from mmocr.apis.inference import init_detector, model_inference
from mmocr.datasets import build_dataset # noqa: F401
from mmocr.models import build_detector # noqa: F401
from mmocr.registry import DATASETS # noqa: F401
from mmocr.utils import revert_sync_batchnorm

View File

@ -13,8 +13,9 @@ from mmcv import Config
from mmcv.parallel import MMDataParallel
from mmocr.apis.test import single_gpu_test
from mmocr.datasets import build_dataloader, build_dataset
from mmocr.datasets import build_dataloader
from mmocr.models import build_detector
from mmocr.registry import DATASETS
from mmocr.utils import check_argument, list_to_file, revert_sync_batchnorm
@ -45,7 +46,7 @@ def generate_sample_dataloader(cfg, curr_dir, img_prefix='', ann_file=''):
test.ann_file = ann_file
cfg.data.workers_per_gpu = 0
cfg.data.test.datasets = [test]
dataset = build_dataset(cfg.data.test)
dataset = DATASETS.build(cfg.data.test)
loader_cfg = {
**dict((k, cfg.data[k]) for k in [
@ -140,7 +141,7 @@ def gene_sdmgr_model_dataloader(cfg, dirname, curr_dir, empty_img=False):
cfg.model.class_list = osp.join(curr_dir,
'data/kie_toy_dataset/class_list.txt')
dataset = build_dataset(cfg.data.test)
dataset = DATASETS.build(cfg.data.test)
loader_cfg = {
**dict((k, cfg.data[k]) for k in [

View File

@ -19,9 +19,7 @@ import mmcv
from mmcv import Config
from mmdet.datasets import build_dataloader
from mmocr.datasets import build_dataset
assert build_dataset is not None
from mmocr.registry import DATASETS
def main():
@ -30,7 +28,7 @@ def main():
args = parser.parse_args()
cfg = Config.fromfile(args.config)
dataset = build_dataset(cfg.data.train)
dataset = DATASETS.build(cfg.data.train)
# prepare data loaders
if 'imgs_per_gpu' in cfg.data:

View File

@ -10,7 +10,8 @@ from mmdet.apis import single_gpu_test
from mmocr.apis.inference import disable_text_recog_aug_test
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
TensorRTDetector, TensorRTRecognizer)
from mmocr.datasets import build_dataloader, build_dataset
from mmocr.datasets import build_dataloader
from mmocr.registry import DATASETS
def parse_args():
@ -79,7 +80,7 @@ def main():
# build the dataloader
samples_per_gpu = 1
cfg = disable_text_recog_aug_test(cfg)
dataset = build_dataset(cfg.data.test)
dataset = DATASETS.build(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,

View File

@ -13,8 +13,9 @@ from mmcv.image import tensor2imgs
from mmcv.parallel import MMDataParallel
from mmcv.runner import load_checkpoint
from mmocr.datasets import build_dataloader, build_dataset
from mmocr.datasets import build_dataloader
from mmocr.models import build_detector
from mmocr.registry import DATASETS
def save_results(model, img_meta, gt_bboxes, result, out_dir):
@ -140,7 +141,7 @@ def main():
distributed = False
# build the dataloader
dataset = build_dataset(cfg.data.test)
dataset = DATASETS.build(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,

View File

@ -11,8 +11,8 @@ from mmcv.utils import ProgressBar
from mmocr.apis import init_detector, model_inference
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
from mmocr.datasets import build_dataset # noqa: F401
from mmocr.models import build_detector # noqa: F401
from mmocr.registry import DATASETS # noqa: F401
from mmocr.utils import get_root_logger, list_from_file, list_to_file

View File

@ -16,8 +16,9 @@ from mmdet.apis import multi_gpu_test
from mmocr.apis.test import single_gpu_test
from mmocr.apis.utils import (disable_text_recog_aug_test,
replace_image_to_tensor)
from mmocr.datasets import build_dataloader, build_dataset
from mmocr.datasets import build_dataloader
from mmocr.models import build_detector
from mmocr.registry import DATASETS
from mmocr.utils import revert_sync_batchnorm, setup_multi_processes
@ -162,7 +163,7 @@ def main():
init_dist(args.launcher, **cfg.dist_params)
# build the dataloader
dataset = build_dataset(cfg.data.test, dict(test_mode=True))
dataset = DATASETS.build(cfg.data.test, dict(test_mode=True))
# step 1: give default values and override (if exist) from cfg.data
default_loader_cfg = {
**dict(seed=cfg.get('seed'), drop_last=False, dist=distributed),

View File

@ -16,8 +16,8 @@ from mmcv.utils import get_git_hash
from mmocr import __version__
from mmocr.apis import init_random_seed, train_detector
from mmocr.datasets import build_dataset
from mmocr.models import build_detector
from mmocr.registry import DATASETS
from mmocr.utils import (collect_env, get_root_logger, is_2dlist,
setup_multi_processes)
@ -189,7 +189,7 @@ def main():
test_cfg=cfg.get('test_cfg'))
model.init_weights()
datasets = [build_dataset(cfg.data.train)]
datasets = [DATASETS.build(cfg.data.train)]
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
if cfg.data.train.get('pipeline', None) is None: