139 lines
5.6 KiB
Python
139 lines
5.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
"""MMRazor provides 17 registry nodes to support using modules across projects.
|
|
Each node is a child of the root registry in MMEngine.
|
|
|
|
More details can be found at
|
|
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
|
|
"""
|
|
from typing import Any, Dict, Optional, Union
|
|
|
|
from mmengine.config import Config, ConfigDict
|
|
from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS
|
|
from mmengine.registry import DATASETS as MMENGINE_DATASETS
|
|
from mmengine.registry import HOOKS as MMENGINE_HOOKS
|
|
from mmengine.registry import LOOPS as MMENGINE_LOOPS
|
|
from mmengine.registry import METRICS as MMENGINE_METRICS
|
|
from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS
|
|
from mmengine.registry import MODELS as MMENGINE_MODELS
|
|
from mmengine.registry import \
|
|
OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS
|
|
from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS
|
|
from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS
|
|
from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS
|
|
from mmengine.registry import \
|
|
RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS
|
|
from mmengine.registry import RUNNERS as MMENGINE_RUNNERS
|
|
from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS
|
|
from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS
|
|
from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS
|
|
from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS
|
|
from mmengine.registry import \
|
|
WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS
|
|
from mmengine.registry import Registry, build_from_cfg
|
|
|
|
|
|
def build_razor_model_from_cfg(
|
|
cfg: Union[dict, ConfigDict, Config],
|
|
registry: 'Registry',
|
|
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
|
|
|
|
# TODO relay on mmengine:HAOCHENYE/config_new_feature
|
|
if cfg.get('cfg_path', None) and not cfg.get('type', None):
|
|
from mmengine.hub import get_model
|
|
model = get_model(**cfg) # type: ignore
|
|
return model
|
|
|
|
return_architecture = False
|
|
if cfg.get('_return_architecture_', None):
|
|
return_architecture = cfg.pop('_return_architecture_')
|
|
razor_model = build_from_cfg(cfg, registry, default_args)
|
|
if return_architecture:
|
|
return razor_model.architecture
|
|
else:
|
|
return razor_model
|
|
|
|
|
|
# Registries For Runner and the related
|
|
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
|
|
RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS)
|
|
# manage runner constructors that define how to initialize runners
|
|
RUNNER_CONSTRUCTORS = Registry(
|
|
'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS)
|
|
# manage all kinds of loops like `EpochBasedTrainLoop`
|
|
LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
|
|
# manage all kinds of hooks like `CheckpointHook`
|
|
HOOKS = Registry('hook', parent=MMENGINE_HOOKS)
|
|
|
|
# Registries For Data and the related
|
|
# manage data-related modules
|
|
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
|
|
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
|
|
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
|
|
|
|
# manage all kinds of modules inheriting `nn.Module`
|
|
MODELS = Registry(
|
|
'model', parent=MMENGINE_MODELS, build_func=build_razor_model_from_cfg)
|
|
# manage all kinds of model wrappers like 'MMDistributedDataParallel'
|
|
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
|
|
# manage all kinds of weight initialization modules like `Uniform`
|
|
WEIGHT_INITIALIZERS = Registry(
|
|
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
|
|
|
|
# Registries For Optimizer and the related
|
|
# manage all kinds of optimizers like `SGD` and `Adam`
|
|
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
|
|
# manage optimizer wrapper
|
|
OPTIM_WRAPPERS = Registry('optimizer_wrapper', parent=MMENGINE_OPTIM_WRAPPERS)
|
|
# manage constructors that customize the optimization hyperparameters.
|
|
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
|
|
'optimizer wrapper constructor',
|
|
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
|
|
# manage all kinds of parameter schedulers like `MultiStepLR`
|
|
PARAM_SCHEDULERS = Registry(
|
|
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
|
|
|
|
# manage all kinds of metrics
|
|
METRICS = Registry('metric', parent=MMENGINE_METRICS)
|
|
|
|
# manage task-specific modules like anchor generators and box coders
|
|
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
|
|
|
|
# Registries For Visualizer and the related
|
|
# manage visualizer
|
|
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
|
|
# manage visualizer backend
|
|
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)
|
|
|
|
|
|
# manage sub models for downstream repos
|
|
@MODELS.register_module()
|
|
def sub_model(cfg,
|
|
fix_subnet,
|
|
mode: str = 'mutable',
|
|
prefix: str = '',
|
|
extra_prefix: str = '',
|
|
init_weight_from_supernet: bool = False,
|
|
init_cfg: Optional[Dict] = None):
|
|
model = MODELS.build(cfg)
|
|
# Save path type cfg process, set init_cfg directly.
|
|
if init_cfg:
|
|
# update init_cfg when init_cfg is valid.
|
|
model.init_cfg = init_cfg
|
|
if init_weight_from_supernet:
|
|
# Supernet is modified after load_fix_subnet(), init weight here.
|
|
model.init_weights()
|
|
from mmrazor.structures import load_fix_subnet
|
|
|
|
load_fix_subnet(
|
|
model,
|
|
fix_subnet,
|
|
load_subnet_mode=mode,
|
|
prefix=prefix,
|
|
extra_prefix=extra_prefix)
|
|
|
|
if init_weight_from_supernet:
|
|
# Supernet is modified after load_fix_subnet().
|
|
model.init_cfg = None
|
|
|
|
return model
|