mirror of https://github.com/open-mmlab/mmocr.git
[Feature] Support auto import modules from registry. (#1731)
* [Feature] Support auto import modules from registry. * limit mmdet version * location parrent dir if it not existpull/1722/head^2
parent
df0be646ea
commit
1127240108
|
@ -193,6 +193,6 @@ MMOCR has different version requirements on MMEngine, MMCV and MMDetection at ea
|
|||
|
||||
| MMOCR | MMEngine | MMCV | MMDetection |
|
||||
| -------------- | --------------------------- | -------------------------- | --------------------------- |
|
||||
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
|
||||
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc5 \<= mmdet \< 3.1.0 |
|
||||
| 1.0.0rc\[4-5\] | 0.1.0 \<= mmengine \< 1.0.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
|
||||
| 1.0.0rc\[0-3\] | 0.0.0 \<= mmengine \< 0.2.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
|
||||
|
|
|
@ -194,6 +194,6 @@ docker run --gpus all --shm-size=8g -it -v {实际数据目录}:/mmocr/data mmoc
|
|||
|
||||
| MMOCR | MMEngine | MMCV | MMDetection |
|
||||
| -------------- | --------------------------- | -------------------------- | --------------------------- |
|
||||
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
|
||||
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc5 \<= mmdet \< 3.1.0 |
|
||||
| 1.0.0rc\[4-5\] | 0.1.0 \<= mmengine \< 1.0.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
|
||||
| 1.0.0rc\[0-3\] | 0.0.0 \<= mmengine \< 0.2.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
|
||||
|
|
|
@ -41,7 +41,7 @@ assert (mmengine_version >= digit_version(mmengine_minimum_version)
|
|||
f'Please install mmengine>={mmengine_minimum_version}, ' \
|
||||
f'<{mmengine_maximum_version}.'
|
||||
|
||||
mmdet_minimum_version = '3.0.0rc0'
|
||||
mmdet_minimum_version = '3.0.0rc5'
|
||||
mmdet_maximum_version = '3.1.0'
|
||||
mmdet_version = digit_version(mmdet.__version__)
|
||||
|
||||
|
|
|
@ -7,10 +7,11 @@ import mmengine
|
|||
import numpy as np
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.infer.infer import BaseInferencer, ModelType
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmocr.utils import ConfigType, register_all_modules
|
||||
from mmocr.utils import ConfigType
|
||||
|
||||
InstanceList = List[InstanceData]
|
||||
InputType = Union[str, np.ndarray]
|
||||
|
@ -58,7 +59,7 @@ class BaseMMOCRInferencer(BaseInferencer):
|
|||
# A global counter tracking the number of images processed, for
|
||||
# naming of the output images
|
||||
self.num_visualized_imgs = 0
|
||||
register_all_modules()
|
||||
init_default_scope(scope)
|
||||
super().__init__(
|
||||
model=model, weights=weights, device=device, scope=scope)
|
||||
|
||||
|
|
|
@ -32,51 +32,103 @@ from mmengine.registry import \
|
|||
from mmengine.registry import Registry
|
||||
|
||||
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
|
||||
RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS)
|
||||
RUNNERS = Registry(
|
||||
'runner',
|
||||
parent=MMENGINE_RUNNERS,
|
||||
# TODO: update the location when mmocr has its own runner
|
||||
locations=['mmocr.engine'])
|
||||
# manage runner constructors that define how to initialize runners
|
||||
RUNNER_CONSTRUCTORS = Registry(
|
||||
'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS)
|
||||
'runner constructor',
|
||||
parent=MMENGINE_RUNNER_CONSTRUCTORS,
|
||||
# TODO: update the location when mmocr has its own runner constructor
|
||||
locations=['mmocr.engine'])
|
||||
# manage all kinds of loops like `EpochBasedTrainLoop`
|
||||
LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
|
||||
LOOPS = Registry(
|
||||
'loop',
|
||||
parent=MMENGINE_LOOPS,
|
||||
# TODO: update the location when mmocr has its own loop
|
||||
locations=['mmocr.engine'])
|
||||
# manage all kinds of hooks like `CheckpointHook`
|
||||
HOOKS = Registry('hook', parent=MMENGINE_HOOKS)
|
||||
HOOKS = Registry(
|
||||
'hook', parent=MMENGINE_HOOKS, locations=['mmocr.engine.hooks'])
|
||||
|
||||
# manage data-related modules
|
||||
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
|
||||
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
|
||||
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
|
||||
DATASETS = Registry(
|
||||
'dataset', parent=MMENGINE_DATASETS, locations=['mmocr.datasets'])
|
||||
DATA_SAMPLERS = Registry(
|
||||
'data sampler',
|
||||
parent=MMENGINE_DATA_SAMPLERS,
|
||||
locations=['mmocr.datasets.samplers'])
|
||||
TRANSFORMS = Registry(
|
||||
'transform',
|
||||
parent=MMENGINE_TRANSFORMS,
|
||||
locations=['mmocr.datasets.transforms'])
|
||||
|
||||
# manage all kinds of modules inheriting `nn.Module`
|
||||
MODELS = Registry('model', parent=MMENGINE_MODELS)
|
||||
MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmocr.models'])
|
||||
# manage all kinds of model wrappers like 'MMDistributedDataParallel'
|
||||
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
|
||||
MODEL_WRAPPERS = Registry(
|
||||
'model wrapper',
|
||||
parent=MMENGINE_MODEL_WRAPPERS,
|
||||
locations=['mmocr.models'])
|
||||
# manage all kinds of weight initialization modules like `Uniform`
|
||||
WEIGHT_INITIALIZERS = Registry(
|
||||
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
|
||||
'weight initializer',
|
||||
parent=MMENGINE_WEIGHT_INITIALIZERS,
|
||||
locations=['mmocr.models'])
|
||||
|
||||
# manage all kinds of optimizers like `SGD` and `Adam`
|
||||
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
|
||||
OPTIMIZERS = Registry(
|
||||
'optimizer',
|
||||
parent=MMENGINE_OPTIMIZERS,
|
||||
# TODO: update the location when mmocr has its own optimizer
|
||||
locations=['mmocr.engine'])
|
||||
# manage optimizer wrapper
|
||||
OPTIM_WRAPPERS = Registry('optim wrapper', parent=MMENGINE_OPTIM_WRAPPERS)
|
||||
OPTIM_WRAPPERS = Registry(
|
||||
'optimizer wrapper',
|
||||
parent=MMENGINE_OPTIM_WRAPPERS,
|
||||
# TODO: update the location when mmocr has its own optimizer wrapper
|
||||
locations=['mmocr.engine'])
|
||||
# manage constructors that customize the optimization hyperparameters.
|
||||
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
|
||||
'optimizer constructor', parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
|
||||
'optimizer constructor',
|
||||
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS,
|
||||
# TODO: update the location when mmocr has its own optimizer constructor
|
||||
locations=['mmocr.engine'])
|
||||
# manage all kinds of parameter schedulers like `MultiStepLR`
|
||||
PARAM_SCHEDULERS = Registry(
|
||||
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
|
||||
|
||||
'parameter scheduler',
|
||||
parent=MMENGINE_PARAM_SCHEDULERS,
|
||||
# TODO: update the location when mmocr has its own parameter scheduler
|
||||
locations=['mmocr.engine'])
|
||||
# manage all kinds of metrics
|
||||
METRICS = Registry('metric', parent=MMENGINE_METRICS)
|
||||
METRICS = Registry(
|
||||
'metric', parent=MMENGINE_METRICS, locations=['mmocr.evaluation.metrics'])
|
||||
# manage evaluator
|
||||
EVALUATOR = Registry('evaluator', parent=MMENGINE_EVALUATOR)
|
||||
EVALUATOR = Registry(
|
||||
'evaluator',
|
||||
parent=MMENGINE_EVALUATOR,
|
||||
locations=['mmocr.evaluation.evaluator'])
|
||||
|
||||
# manage task-specific modules like anchor generators and box coders
|
||||
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
|
||||
TASK_UTILS = Registry(
|
||||
'task util', parent=MMENGINE_TASK_UTILS, locations=['mmocr.models'])
|
||||
|
||||
# manage visualizer
|
||||
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
|
||||
VISUALIZERS = Registry(
|
||||
'visualizer',
|
||||
parent=MMENGINE_VISUALIZERS,
|
||||
locations=['mmocr.visualization'])
|
||||
# manage visualizer backend
|
||||
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)
|
||||
VISBACKENDS = Registry(
|
||||
'visualizer backend',
|
||||
parent=MMENGINE_VISBACKENDS,
|
||||
locations=['mmocr.visualization'])
|
||||
|
||||
# manage logprocessor
|
||||
LOG_PROCESSORS = Registry('log_processor', parent=MMENGINE_LOG_PROCESSORS)
|
||||
LOG_PROCESSORS = Registry(
|
||||
'logger processor',
|
||||
parent=MMENGINE_LOG_PROCESSORS,
|
||||
# TODO: update the location when mmocr has its own log processor
|
||||
locations=['mmocr.engine'])
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
mmcv>==2.0.0rc1,<2.1.0
|
||||
mmdet>=3.0.0rc0,<3.1.0
|
||||
mmengine>= 0.1.0, <1.0.0
|
||||
mmcv>==2.0.0rc4,<2.1.0
|
||||
mmdet>=3.0.0rc5,<3.1.0
|
||||
mmengine>= 0.5.0, <1.0.0
|
||||
|
|
|
@ -4,9 +4,10 @@ from copy import deepcopy
|
|||
from unittest import TestCase
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
from mmocr.datasets import ConcatDataset, OCRDataset
|
||||
from mmocr.registry import TRANSFORMS
|
||||
from mmocr.utils import register_all_modules
|
||||
|
||||
|
||||
class TestConcatDataset(TestCase):
|
||||
|
@ -22,7 +23,7 @@ class TestConcatDataset(TestCase):
|
|||
|
||||
def setUp(self):
|
||||
|
||||
register_all_modules()
|
||||
init_default_scope('mmocr')
|
||||
dataset = OCRDataset
|
||||
|
||||
# create dataset_a
|
||||
|
|
|
@ -7,10 +7,10 @@ from unittest import mock
|
|||
import numpy as np
|
||||
import torch
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.testing.data import create_dummy_textdet_inputs
|
||||
from mmocr.utils import register_all_modules
|
||||
|
||||
|
||||
class TestDRRG(unittest.TestCase):
|
||||
|
@ -18,7 +18,8 @@ class TestDRRG(unittest.TestCase):
|
|||
def setUp(self):
|
||||
cfg_path = 'textdet/drrg/drrg_resnet50_fpn-unet_1200e_ctw1500.py'
|
||||
self.model_cfg = self._get_detector_cfg(cfg_path)
|
||||
register_all_modules()
|
||||
cfg = self._get_config_module(cfg_path)
|
||||
init_default_scope(cfg.get('default_scope', 'mmocr'))
|
||||
self.model = MODELS.build(self.model_cfg)
|
||||
self.inputs = create_dummy_textdet_inputs(input_shape=(1, 3, 224, 224))
|
||||
|
||||
|
|
|
@ -5,17 +5,17 @@ import torch
|
|||
from mmdet.structures import DetDataSample
|
||||
from mmdet.testing import demo_mm_inputs
|
||||
from mmengine.config import Config
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.structures import InstanceData
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextDetDataSample
|
||||
from mmocr.utils import register_all_modules
|
||||
|
||||
|
||||
class TestMMDetWrapper(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
register_all_modules()
|
||||
init_default_scope('mmocr')
|
||||
model_cfg_fcos = dict(
|
||||
type='MMDetWrapper',
|
||||
cfg=dict(
|
||||
|
|
|
@ -2,16 +2,16 @@
|
|||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
from mmocr.models.textrecog.backbones import ResNet
|
||||
from mmocr.utils import register_all_modules
|
||||
|
||||
|
||||
class TestResNet(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.img = torch.rand(1, 3, 32, 100)
|
||||
register_all_modules()
|
||||
init_default_scope('mmocr')
|
||||
|
||||
def test_resnet45_aster(self):
|
||||
resnet45_aster = ResNet(
|
||||
|
|
|
@ -9,11 +9,11 @@ import mmcv
|
|||
import numpy as np
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.utils import ProgressBar
|
||||
from mmengine.visualization import Visualizer
|
||||
|
||||
from mmocr.registry import DATASETS, VISUALIZERS
|
||||
from mmocr.utils import register_all_modules
|
||||
|
||||
|
||||
# TODO: Support for printing the change in key of results
|
||||
|
@ -331,8 +331,7 @@ def main():
|
|||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# register all modules in mmyolo into the registries
|
||||
register_all_modules()
|
||||
init_default_scope(cfg.get('default_scope', 'mmocr'))
|
||||
|
||||
dataset_cfg, visualizer_cfg = obtain_dataset_cfg(cfg, args.phase,
|
||||
args.mode, args.task)
|
||||
|
|
|
@ -4,11 +4,9 @@ import argparse
|
|||
import torch
|
||||
from fvcore.nn import FlopCountAnalysis, flop_count_table
|
||||
from mmengine import Config
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -38,6 +36,7 @@ def main():
|
|||
input_shape = (1, 3, h, w)
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
init_default_scope(cfg.get('default_scope', 'mmocr'))
|
||||
model = MODELS.build(cfg.model)
|
||||
|
||||
flops = FlopCountAnalysis(model, torch.ones(input_shape))
|
||||
|
|
|
@ -5,8 +5,7 @@ import json
|
|||
import mmengine
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.evaluator import Evaluator
|
||||
|
||||
from mmocr.utils import register_all_modules
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -33,10 +32,9 @@ def parse_args():
|
|||
def main():
|
||||
args = parse_args()
|
||||
|
||||
register_all_modules()
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
init_default_scope(cfg.get('default_scope', 'mmocr'))
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import os.path as osp
|
|||
import warnings
|
||||
|
||||
from mmocr.datasets.preparers import DatasetPreparer
|
||||
from mmocr.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -39,7 +38,6 @@ def parse_args():
|
|||
|
||||
def main():
|
||||
args = parse_args()
|
||||
register_all_modules()
|
||||
for dataset in args.datasets:
|
||||
if not osp.isdir(osp.join(args.dataset_zoo_path, dataset)):
|
||||
warnings.warn(f'{dataset} is not supported yet. Please check '
|
||||
|
|
|
@ -7,8 +7,6 @@ from mmengine.config import Config, DictAction
|
|||
from mmengine.registry import RUNNERS
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmocr.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Test (and eval) a model')
|
||||
|
@ -80,10 +78,6 @@ def trigger_visualization_hook(cfg, args):
|
|||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# register all modules in mmocr into the registries
|
||||
# do not init the default scope here because it will be init in the runner
|
||||
register_all_modules(init_default_scope=False)
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg.launcher = args.launcher
|
||||
|
|
|
@ -9,8 +9,6 @@ from mmengine.logging import print_log
|
|||
from mmengine.registry import RUNNERS
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmocr.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Train a model')
|
||||
|
@ -54,10 +52,6 @@ def parse_args():
|
|||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# register all modules in mmdet into the registries
|
||||
# do not init the default scope here because it will be init in the runner
|
||||
register_all_modules(init_default_scope=False)
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg.launcher = args.launcher
|
||||
|
|
Loading…
Reference in New Issue