[Feature] Support auto import modules from registry. (#1731)

* [Feature] Support auto import modules from registry.

* limit mmdet version

* location parrent dir if it not exist
pull/1722/head^2
liukuikun 2023-02-17 10:28:34 +08:00 committed by GitHub
parent df0be646ea
commit 1127240108
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 98 additions and 61 deletions

View File

@ -193,6 +193,6 @@ MMOCR has different version requirements on MMEngine, MMCV and MMDetection at ea
| MMOCR | MMEngine | MMCV | MMDetection |
| -------------- | --------------------------- | -------------------------- | --------------------------- |
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc5 \<= mmdet \< 3.1.0 |
| 1.0.0rc\[4-5\] | 0.1.0 \<= mmengine \< 1.0.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
| 1.0.0rc\[0-3\] | 0.0.0 \<= mmengine \< 0.2.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |

View File

@ -194,6 +194,6 @@ docker run --gpus all --shm-size=8g -it -v {实际数据目录}:/mmocr/data mmoc
| MMOCR | MMEngine | MMCV | MMDetection |
| -------------- | --------------------------- | -------------------------- | --------------------------- |
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc5 \<= mmdet \< 3.1.0 |
| 1.0.0rc\[4-5\] | 0.1.0 \<= mmengine \< 1.0.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
| 1.0.0rc\[0-3\] | 0.0.0 \<= mmengine \< 0.2.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |

View File

@ -41,7 +41,7 @@ assert (mmengine_version >= digit_version(mmengine_minimum_version)
f'Please install mmengine>={mmengine_minimum_version}, ' \
f'<{mmengine_maximum_version}.'
mmdet_minimum_version = '3.0.0rc0'
mmdet_minimum_version = '3.0.0rc5'
mmdet_maximum_version = '3.1.0'
mmdet_version = digit_version(mmdet.__version__)

View File

@ -7,10 +7,11 @@ import mmengine
import numpy as np
from mmengine.dataset import Compose
from mmengine.infer.infer import BaseInferencer, ModelType
from mmengine.registry import init_default_scope
from mmengine.structures import InstanceData
from torch import Tensor
from mmocr.utils import ConfigType, register_all_modules
from mmocr.utils import ConfigType
InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
@ -58,7 +59,7 @@ class BaseMMOCRInferencer(BaseInferencer):
# A global counter tracking the number of images processed, for
# naming of the output images
self.num_visualized_imgs = 0
register_all_modules()
init_default_scope(scope)
super().__init__(
model=model, weights=weights, device=device, scope=scope)

View File

@ -32,51 +32,103 @@ from mmengine.registry import \
from mmengine.registry import Registry
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS)
RUNNERS = Registry(
'runner',
parent=MMENGINE_RUNNERS,
# TODO: update the location when mmocr has its own runner
locations=['mmocr.engine'])
# manage runner constructors that define how to initialize runners
RUNNER_CONSTRUCTORS = Registry(
'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS)
'runner constructor',
parent=MMENGINE_RUNNER_CONSTRUCTORS,
# TODO: update the location when mmocr has its own runner constructor
locations=['mmocr.engine'])
# manage all kinds of loops like `EpochBasedTrainLoop`
LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
LOOPS = Registry(
'loop',
parent=MMENGINE_LOOPS,
# TODO: update the location when mmocr has its own loop
locations=['mmocr.engine'])
# manage all kinds of hooks like `CheckpointHook`
HOOKS = Registry('hook', parent=MMENGINE_HOOKS)
HOOKS = Registry(
'hook', parent=MMENGINE_HOOKS, locations=['mmocr.engine.hooks'])
# manage data-related modules
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
DATASETS = Registry(
'dataset', parent=MMENGINE_DATASETS, locations=['mmocr.datasets'])
DATA_SAMPLERS = Registry(
'data sampler',
parent=MMENGINE_DATA_SAMPLERS,
locations=['mmocr.datasets.samplers'])
TRANSFORMS = Registry(
'transform',
parent=MMENGINE_TRANSFORMS,
locations=['mmocr.datasets.transforms'])
# manage all kinds of modules inheriting `nn.Module`
MODELS = Registry('model', parent=MMENGINE_MODELS)
MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmocr.models'])
# manage all kinds of model wrappers like 'MMDistributedDataParallel'
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
MODEL_WRAPPERS = Registry(
'model wrapper',
parent=MMENGINE_MODEL_WRAPPERS,
locations=['mmocr.models'])
# manage all kinds of weight initialization modules like `Uniform`
WEIGHT_INITIALIZERS = Registry(
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
'weight initializer',
parent=MMENGINE_WEIGHT_INITIALIZERS,
locations=['mmocr.models'])
# manage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
OPTIMIZERS = Registry(
'optimizer',
parent=MMENGINE_OPTIMIZERS,
# TODO: update the location when mmocr has its own optimizer
locations=['mmocr.engine'])
# manage optimizer wrapper
OPTIM_WRAPPERS = Registry('optim wrapper', parent=MMENGINE_OPTIM_WRAPPERS)
OPTIM_WRAPPERS = Registry(
'optimizer wrapper',
parent=MMENGINE_OPTIM_WRAPPERS,
# TODO: update the location when mmocr has its own optimizer wrapper
locations=['mmocr.engine'])
# manage constructors that customize the optimization hyperparameters.
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
'optimizer constructor', parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
'optimizer constructor',
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS,
# TODO: update the location when mmocr has its own optimizer constructor
locations=['mmocr.engine'])
# manage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry(
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
'parameter scheduler',
parent=MMENGINE_PARAM_SCHEDULERS,
# TODO: update the location when mmocr has its own parameter scheduler
locations=['mmocr.engine'])
# manage all kinds of metrics
METRICS = Registry('metric', parent=MMENGINE_METRICS)
METRICS = Registry(
'metric', parent=MMENGINE_METRICS, locations=['mmocr.evaluation.metrics'])
# manage evaluator
EVALUATOR = Registry('evaluator', parent=MMENGINE_EVALUATOR)
EVALUATOR = Registry(
'evaluator',
parent=MMENGINE_EVALUATOR,
locations=['mmocr.evaluation.evaluator'])
# manage task-specific modules like anchor generators and box coders
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
TASK_UTILS = Registry(
'task util', parent=MMENGINE_TASK_UTILS, locations=['mmocr.models'])
# manage visualizer
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
VISUALIZERS = Registry(
'visualizer',
parent=MMENGINE_VISUALIZERS,
locations=['mmocr.visualization'])
# manage visualizer backend
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)
VISBACKENDS = Registry(
'visualizer backend',
parent=MMENGINE_VISBACKENDS,
locations=['mmocr.visualization'])
# manage logprocessor
LOG_PROCESSORS = Registry('log_processor', parent=MMENGINE_LOG_PROCESSORS)
LOG_PROCESSORS = Registry(
'logger processor',
parent=MMENGINE_LOG_PROCESSORS,
# TODO: update the location when mmocr has its own log processor
locations=['mmocr.engine'])

View File

@ -1,3 +1,3 @@
mmcv>==2.0.0rc1,<2.1.0
mmdet>=3.0.0rc0,<3.1.0
mmengine>= 0.1.0, <1.0.0
mmcv>==2.0.0rc4,<2.1.0
mmdet>=3.0.0rc5,<3.1.0
mmengine>= 0.5.0, <1.0.0

View File

@ -4,9 +4,10 @@ from copy import deepcopy
from unittest import TestCase
from unittest.mock import MagicMock
from mmengine.registry import init_default_scope
from mmocr.datasets import ConcatDataset, OCRDataset
from mmocr.registry import TRANSFORMS
from mmocr.utils import register_all_modules
class TestConcatDataset(TestCase):
@ -22,7 +23,7 @@ class TestConcatDataset(TestCase):
def setUp(self):
register_all_modules()
init_default_scope('mmocr')
dataset = OCRDataset
# create dataset_a

View File

@ -7,10 +7,10 @@ from unittest import mock
import numpy as np
import torch
from mmengine.config import Config, ConfigDict
from mmengine.registry import init_default_scope
from mmocr.registry import MODELS
from mmocr.testing.data import create_dummy_textdet_inputs
from mmocr.utils import register_all_modules
class TestDRRG(unittest.TestCase):
@ -18,7 +18,8 @@ class TestDRRG(unittest.TestCase):
def setUp(self):
cfg_path = 'textdet/drrg/drrg_resnet50_fpn-unet_1200e_ctw1500.py'
self.model_cfg = self._get_detector_cfg(cfg_path)
register_all_modules()
cfg = self._get_config_module(cfg_path)
init_default_scope(cfg.get('default_scope', 'mmocr'))
self.model = MODELS.build(self.model_cfg)
self.inputs = create_dummy_textdet_inputs(input_shape=(1, 3, 224, 224))

View File

@ -5,17 +5,17 @@ import torch
from mmdet.structures import DetDataSample
from mmdet.testing import demo_mm_inputs
from mmengine.config import Config
from mmengine.registry import init_default_scope
from mmengine.structures import InstanceData
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from mmocr.utils import register_all_modules
class TestMMDetWrapper(unittest.TestCase):
def setUp(self):
register_all_modules()
init_default_scope('mmocr')
model_cfg_fcos = dict(
type='MMDetWrapper',
cfg=dict(

View File

@ -2,16 +2,16 @@
from unittest import TestCase
import torch
from mmengine.registry import init_default_scope
from mmocr.models.textrecog.backbones import ResNet
from mmocr.utils import register_all_modules
class TestResNet(TestCase):
def setUp(self) -> None:
self.img = torch.rand(1, 3, 32, 100)
register_all_modules()
init_default_scope('mmocr')
def test_resnet45_aster(self):
resnet45_aster = ResNet(

View File

@ -9,11 +9,11 @@ import mmcv
import numpy as np
from mmengine.config import Config, DictAction
from mmengine.dataset import Compose
from mmengine.registry import init_default_scope
from mmengine.utils import ProgressBar
from mmengine.visualization import Visualizer
from mmocr.registry import DATASETS, VISUALIZERS
from mmocr.utils import register_all_modules
# TODO: Support for printing the change in key of results
@ -331,8 +331,7 @@ def main():
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# register all modules in mmyolo into the registries
register_all_modules()
init_default_scope(cfg.get('default_scope', 'mmocr'))
dataset_cfg, visualizer_cfg = obtain_dataset_cfg(cfg, args.phase,
args.mode, args.task)

View File

@ -4,11 +4,9 @@ import argparse
import torch
from fvcore.nn import FlopCountAnalysis, flop_count_table
from mmengine import Config
from mmengine.registry import init_default_scope
from mmocr.registry import MODELS
from mmocr.utils import register_all_modules
register_all_modules()
def parse_args():
@ -38,6 +36,7 @@ def main():
input_shape = (1, 3, h, w)
cfg = Config.fromfile(args.config)
init_default_scope(cfg.get('default_scope', 'mmocr'))
model = MODELS.build(cfg.model)
flops = FlopCountAnalysis(model, torch.ones(input_shape))

View File

@ -5,8 +5,7 @@ import json
import mmengine
from mmengine.config import Config, DictAction
from mmengine.evaluator import Evaluator
from mmocr.utils import register_all_modules
from mmengine.registry import init_default_scope
def parse_args():
@ -33,10 +32,9 @@ def parse_args():
def main():
args = parse_args()
register_all_modules()
# load config
cfg = Config.fromfile(args.config)
init_default_scope(cfg.get('default_scope', 'mmocr'))
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

View File

@ -4,7 +4,6 @@ import os.path as osp
import warnings
from mmocr.datasets.preparers import DatasetPreparer
from mmocr.utils import register_all_modules
def parse_args():
@ -39,7 +38,6 @@ def parse_args():
def main():
args = parse_args()
register_all_modules()
for dataset in args.datasets:
if not osp.isdir(osp.join(args.dataset_zoo_path, dataset)):
warnings.warn(f'{dataset} is not supported yet. Please check '

View File

@ -7,8 +7,6 @@ from mmengine.config import Config, DictAction
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
from mmocr.utils import register_all_modules
def parse_args():
parser = argparse.ArgumentParser(description='Test (and eval) a model')
@ -80,10 +78,6 @@ def trigger_visualization_hook(cfg, args):
def main():
args = parse_args()
# register all modules in mmocr into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher

View File

@ -9,8 +9,6 @@ from mmengine.logging import print_log
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
from mmocr.utils import register_all_modules
def parse_args():
parser = argparse.ArgumentParser(description='Train a model')
@ -54,10 +52,6 @@ def parse_args():
def main():
args = parse_args()
# register all modules in mmdet into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher