mirror of https://github.com/open-mmlab/mim.git
[Experimental] Packaging a minimal and flatten algorithm from downstream repos (#229)
parent
5d54183ef9
commit
c9a5ea6a88
|
@ -0,0 +1,99 @@
|
|||
# Export (experimental)
|
||||
|
||||
`mim export` is a new feature in `mim`, which can export a minimum trainable and testable model package through the config file.
|
||||
|
||||
The minimal model compactly combines `model, datasets, engine` and other components. There is no need to search for each module separately based on the config file. Users who do not install from source can also get the model file at a glance.
|
||||
|
||||
In addition, lengthy inheritance relationships, such as `CondInstBboxHead -> FCOSHead -> AnchorFreeHead -> BaseDenseHead -> BaseModule` in `mmdetection`, will be directly flattened into `CondInstBboxHead -> BaseModule`. There is no need to open and compare multiple model files when all functions inherited from the parent class are clearly visible.
|
||||
|
||||
### Instructions Usage
|
||||
|
||||
```bash
|
||||
mim export config_path
|
||||
|
||||
# config_path has the following two optional types:
|
||||
|
||||
# 1. config of downstream repo: Complete the call through repo::model_name/xxx.py. For example:
|
||||
mim export mmdet::configs/mask_rcnn/mask-rcnn_r101_fpn_1x_coco.py
|
||||
|
||||
# 2. config in a certain folder, for example:
|
||||
mim export config_dir/mask-rcnn_r101_fpn_1x_coco.py
|
||||
```
|
||||
|
||||
### Minimum model package directory structure
|
||||
|
||||
```
|
||||
minimun_package(Named as pack_from_{repo}_20231212_121212)
|
||||
|- pack
|
||||
| |- configs # Configuration folder
|
||||
| | |- model_name
|
||||
| | |- xxx.py # Configuration file
|
||||
| |
|
||||
| |- models # model folder
|
||||
| | |- model_file.py
|
||||
| | |- ...
|
||||
| |
|
||||
| |- data # data folder
|
||||
| |
|
||||
| |- demo # demo folder
|
||||
| |
|
||||
| |-datasets #Dataset class definition
|
||||
| | |- transforms
|
||||
| |
|
||||
| |- registry.py # Registrar
|
||||
|
|
||||
|- tools
|
||||
| |- train.py # training
|
||||
| |- test.py # test
|
||||
|
|
||||
```
|
||||
|
||||
### limit
|
||||
|
||||
Currently, `mim export` only supports some config files of `mmpose`, `mmdetection`, `mmagic`, and `mmsegmentation` Besides, there are some constraints on the downstream repo.
|
||||
|
||||
#### For downstream repos
|
||||
|
||||
1. It is best to name the config without special symbols, otherwise it cannot be parsed through `mmengine.hub.get_config()`, such as:
|
||||
|
||||
- gn+ws/faster-rcnn_r101_fpn_gn-ws-all_1x_coco.py
|
||||
- legacy_1.x/cascade-mask-rcnn_r50_fpn_1x_coco_v1.py
|
||||
|
||||
2. For `mmsegmentation`, before using `mim.export` for config in `mmseg`, you should firstly modify it like `mmseg/registry/registry.py -> mmseg/registry.py`, without a directory to wrap `registry.py`
|
||||
|
||||
3. It is recommended that the downstream Registry name inherited from mmengine should not be changed. For example, mmagic renamed `EVALUATOR` to `EVALUATORS`
|
||||
|
||||
```python
|
||||
from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR
|
||||
|
||||
# Evaluators to define the evaluation process.
|
||||
EVALUATORS = Registry(
|
||||
'evaluator',
|
||||
parent=MMENGINE_EVALUATOR,
|
||||
locations=['mmagic.evaluation'],
|
||||
)
|
||||
```
|
||||
|
||||
4. In addition, if you add a register that is not in mmengine, such as `DIFFUSION_SCHEDULERS` in mmagic, you need to add a key-value pair in `REGISTRY_TYPE` in `mim/_internal/export/common.py` for registering `torch `Module to `DIFFUSION_SCHEDULERS`
|
||||
|
||||
```python
|
||||
# "mmagic/mmagic/registry.py"
|
||||
# modules for diffusion models that support adding noise and denoising
|
||||
DIFFUSION_SCHEDULERS = Registry(
|
||||
'diffusion scheduler',
|
||||
locations=['mmagic.models.diffusion_schedulers'],
|
||||
)
|
||||
|
||||
# "mim/utils/mmpack/common.py"
|
||||
REGISTRY_TYPE = {
|
||||
...
|
||||
'diffusion scheduler': 'DIFFUSION_SCHEDULERS',
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
#### Need to be improve
|
||||
|
||||
1. Currently, the inheritance relationship expansion of double parent classes is not supported. Improvements will be made in the future depending on the needs.
|
||||
2. Here comes ERROR when the origin config contains other files' `path` that can not be found in the current path. You can avoid export errors by manually modifying the `path` in the origin config.
|
||||
3. When isinstance() is used, if the parent class is just a class in the inheritance chain, the judgment may be False after expansion, because the original inheritance relationship will not be retained.
|
|
@ -0,0 +1,100 @@
|
|||
# Export (experimental)
|
||||
|
||||
`mim export` 是 `mim` 里面一个新的功能,可以实现通过 config 文件,就能够导出一个最小可训练测试的模型包。
|
||||
|
||||
最小模型将 `model、datasets、engine`等组件紧凑地组合在一起,不用再根据 config 文件单独寻找每一个模块,对于非源码安装的用户也能够获取到一目了然的模型文件。
|
||||
|
||||
此外,对于模型中冗长的继承关系,如 `mmdetection` 中的 `CondInstBboxHead -> FCOSHead -> AnchorFreeHead -> BaseDenseHead -> BaseModule`,将会被直接展平为 `CondInstBboxHead -> BaseModule`,即无需再在多个模型文件之间跳转比较了,所有继承父类的函数一览无遗。
|
||||
|
||||
### 使用方法
|
||||
|
||||
```bash
|
||||
mim export config_path
|
||||
|
||||
# config_path 有以下两种可选类型:
|
||||
# 下游 repo 的 config:通过 repo::configs/xxx.py 来完成调用。例如:
|
||||
mim export mmdet::configs/mask_rcnn/mask-rcnn_r101_fpn_1x_coco.py
|
||||
|
||||
# 在某文件夹下的 config, 例如:
|
||||
mim export config_dir/mask-rcnn_r101_fpn_1x_coco.py
|
||||
```
|
||||
|
||||
### 最小模型包目录结构
|
||||
|
||||
```
|
||||
minimun_package(Named as pack_from_{repo}_20231212_121212)
|
||||
|- pack
|
||||
| |- configs # 配置文件夹
|
||||
| | |- model_name
|
||||
| | |- xxx.py # 配置文件
|
||||
| |
|
||||
| |- models # 模型文件夹
|
||||
| | |- model_file.py
|
||||
| | |- ...
|
||||
| |
|
||||
| |- data # 数据文件夹
|
||||
| |
|
||||
| |- demo # demo文件夹
|
||||
| |
|
||||
| |- datasets # 数据集类定义
|
||||
| | |- transforms
|
||||
| |
|
||||
| |- registry.py # 注册器
|
||||
|
|
||||
|- tools
|
||||
| |- train.py # 训练
|
||||
| |- test.py # 测试
|
||||
|
|
||||
```
|
||||
|
||||
### 限制
|
||||
|
||||
`mim export` 目前只支持 `mmpose`、`mmdetection`、`mmagic` 和 `mmsegmentation` 的部分 config 配置文件,并且对下游算法库有一些约束。
|
||||
|
||||
#### 针对下游库
|
||||
|
||||
1. config 命名最好**不要有特殊符号**,否则无法通过 `mmengine.hub.get_config()` 进行解析,如:
|
||||
|
||||
- gn+ws/faster-rcnn_r101_fpn_gn-ws-all_1x_coco.py
|
||||
- legacy_1.x/cascade-mask-rcnn_r50_fpn_1x_coco_v1.py
|
||||
|
||||
2. 针对 `mmsegmentation`, 在使用 `mim.export` 导出 `mmseg` 的 config 之前, 首先需要去掉对于 `registry.py` 的外层文件夹封装, 即修改 `mmseg/registry/registry.py -> mmseg/registry.py`。
|
||||
|
||||
3. 建议下游继承于 mmengine 的 Registry 名字不要改动,如 mmagic 中就将 `EVALUATOR` 重新命名为了 `EVALUATORS`
|
||||
|
||||
```python
|
||||
from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR
|
||||
|
||||
# Evaluators to define the evaluation process.
|
||||
EVALUATORS = Registry(
|
||||
'evaluator',
|
||||
parent=MMENGINE_EVALUATOR,
|
||||
locations=['mmagic.evaluation'],
|
||||
)
|
||||
```
|
||||
|
||||
4. 另外,如果添加了 mmengine 中没有的注册器,如 mmagic 中的 `DIFFUSION_SCHEDULERS`,需要在 `mim/_internal/export/common.py` 的 `REGISTRY_TYPE` 中添加键值对,用于注册 `torch` 模块到 `DIFFUSION_SCHEDULERS`
|
||||
|
||||
```python
|
||||
# "mmagic/mmagic/registry.py"
|
||||
# modules for diffusion models that support adding noise and denoising
|
||||
DIFFUSION_SCHEDULERS = Registry(
|
||||
'diffusion scheduler',
|
||||
locations=['mmagic.models.diffusion_schedulers'],
|
||||
)
|
||||
|
||||
# "mim/utils/mmpack/common.py"
|
||||
REGISTRY_TYPE = {
|
||||
...
|
||||
'diffusion scheduler': 'DIFFUSION_SCHEDULERS',
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
#### 对于 `mim.export` 功能需要改进的地方
|
||||
|
||||
1. 目前还不支持双父类的继承关系展开,后续看需求进行改进
|
||||
|
||||
2. 对于用到 `isinstance()` 时,如果父类只是继承链中某个类,可能展开后判断就会为 False,因为并不会保留原有的继承关系
|
||||
|
||||
3. 当 config 文件中含有当前文件夹没法被访问到的`数据集路径`,导出可能会失败。目前的临时解决方法是:将原来的 config 文件保存到当前文件夹下,然后需要用户手动修改`数据集路径`为当前路径下的可访问路径。如:`data/ADEChallengeData2016/ -> your_data_dir/ADEChallengeData2016/`
|
|
@ -0,0 +1 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
# initialization template for __init__.py
|
||||
_init_str = """
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
all_files = os.listdir(osp.dirname(__file__))
|
||||
|
||||
for file in all_files:
|
||||
if (file.endswith('.py') and file != '__init__.py') or '.' not in file:
|
||||
exec(f'from .{osp.splitext(file)[0]} import *')
|
||||
"""
|
||||
|
||||
# import pack path for tools
|
||||
_import_pack_str = """
|
||||
import os.path as osp
|
||||
import sys
|
||||
sys.path.append(osp.dirname(osp.dirname(__file__)))
|
||||
import pack
|
||||
|
||||
"""
|
||||
|
||||
OBJECTS_TO_BE_PATCHED = {
|
||||
'MODELS': [
|
||||
'BACKBONES',
|
||||
'NECKS',
|
||||
'HEADS',
|
||||
'LOSSES',
|
||||
'SEGMENTORS',
|
||||
'build_backbone',
|
||||
'build_neck',
|
||||
'build_head',
|
||||
'build_loss',
|
||||
'build_segmentor',
|
||||
'CLASSIFIERS',
|
||||
'RETRIEVER',
|
||||
'build_classifier',
|
||||
'build_retriever',
|
||||
'POSE_ESTIMATORS',
|
||||
'build_pose_estimator',
|
||||
'build_posenet',
|
||||
],
|
||||
'TASK_UTILS': [
|
||||
'PIXEL_SAMPLERS',
|
||||
'build_pixel_sampler',
|
||||
]
|
||||
}
|
||||
|
||||
REGISTRY_TYPES = {
|
||||
'runner': 'RUNNERS',
|
||||
'runner constructor': 'RUNNER_CONSTRUCTORS',
|
||||
'hook': 'HOOKS',
|
||||
'strategy': 'STRATEGIES',
|
||||
'dataset': 'DATASETS',
|
||||
'data sampler': 'DATA_SAMPLERS',
|
||||
'transform': 'TRANSFORMS',
|
||||
'model': 'MODELS',
|
||||
'model wrapper': 'MODEL_WRAPPERS',
|
||||
'weight initializer': 'WEIGHT_INITIALIZERS',
|
||||
'optimizer': 'OPTIMIZERS',
|
||||
'optimizer wrapper': 'OPTIM_WRAPPERS',
|
||||
'optimizer wrapper constructor': 'OPTIM_WRAPPER_CONSTRUCTORS',
|
||||
'parameter scheduler': 'PARAM_SCHEDULERS',
|
||||
'param scheduler': 'PARAM_SCHEDULERS',
|
||||
'metric': 'METRICS',
|
||||
'evaluator': 'EVALUATOR', # TODO EVALUATORS in mmagic
|
||||
'task utils': 'TASK_UTILS',
|
||||
'loop': 'LOOPS',
|
||||
'visualizer': 'VISUALIZERS',
|
||||
'vis_backend': 'VISBACKENDS',
|
||||
'log processor': 'LOG_PROCESSORS',
|
||||
'inferencer': 'INFERENCERS',
|
||||
'function': 'FUNCTIONS',
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,312 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from mmengine import MMLogger
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.hub import get_config
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import Runner
|
||||
from mmengine.utils import get_installed_path, mkdir_or_exist
|
||||
|
||||
from mim.utils import echo_error
|
||||
from .common import _import_pack_str, _init_str
|
||||
from .utils import (
|
||||
_get_all_files,
|
||||
_postprocess_importfrom_module_to_pack,
|
||||
_postprocess_registry_locations,
|
||||
_replace_config_scope_to_pack,
|
||||
_wrapper_all_registries_build_func,
|
||||
)
|
||||
|
||||
|
||||
def export_from_cfg(cfg: Union[str, ConfigDict],
|
||||
export_root_dir: str,
|
||||
model_only: Optional[bool] = False,
|
||||
save_log: Optional[bool] = False):
|
||||
"""A function to pack the minimum available package according to config
|
||||
file.
|
||||
|
||||
Args:
|
||||
cfg (:obj:`ConfigDict` or str): Config file for packing the
|
||||
minimum package.
|
||||
export_root_dir (str, optional): The pack directory to save the
|
||||
packed package.
|
||||
fast_test (bool, optional): Turn to fast testing mode.
|
||||
Defaults to False.
|
||||
"""
|
||||
# generate temp dir for export
|
||||
export_root_tempdir = tempfile.TemporaryDirectory()
|
||||
export_log_tempdir = tempfile.TemporaryDirectory()
|
||||
|
||||
# delete the incomplete export package when keyboard interrupt
|
||||
signal.signal(signal.SIGINT, lambda sig, frame: keyboardinterupt_handler(
|
||||
sig, frame, export_root_tempdir, export_log_tempdir)
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
# get config
|
||||
if isinstance(cfg, str):
|
||||
if '::' in cfg:
|
||||
cfg = get_config(cfg)
|
||||
else:
|
||||
cfg = Config.fromfile(cfg)
|
||||
|
||||
default_scope = cfg.get('default_scope', 'mmengine')
|
||||
|
||||
# automatically generate ``export_root_dir``
|
||||
if export_root_dir is None:
|
||||
export_root_dir = f'pack_from_{default_scope}_' + \
|
||||
f"{datetime.now().strftime(r'%Y%m%d_%H%M%S')}"
|
||||
|
||||
# generate ``export_log_dir``
|
||||
if osp.sep in export_root_dir:
|
||||
export_path = osp.dirname(export_root_dir)
|
||||
else:
|
||||
export_path = os.getcwd()
|
||||
export_log_dir = osp.join(export_path, 'export_log')
|
||||
|
||||
export_logger = MMLogger.get_instance( # noqa: F841
|
||||
'export',
|
||||
log_file=osp.join(export_log_tempdir.name, 'export.log'))
|
||||
|
||||
export_module_tempdir_name = osp.join(export_root_tempdir.name, 'pack')
|
||||
|
||||
# export config
|
||||
if '.mim' in cfg.filename:
|
||||
cfg_path = osp.join(export_module_tempdir_name,
|
||||
cfg.filename[cfg.filename.find('configs'):])
|
||||
else:
|
||||
cfg_path = osp.join(
|
||||
osp.join(export_module_tempdir_name, 'configs'),
|
||||
osp.basename(cfg.filename))
|
||||
mkdir_or_exist(osp.dirname(cfg_path))
|
||||
|
||||
# transform to default_scope
|
||||
init_default_scope(default_scope)
|
||||
|
||||
# wrap ``Registry.build()`` for exporting modules
|
||||
_wrapper_all_registries_build_func(
|
||||
export_module_dir=export_module_tempdir_name, scope=default_scope)
|
||||
|
||||
print_log(
|
||||
f'[ Export Package Name ]: {export_root_dir}\n'
|
||||
f' package from config: {cfg.filename}\n'
|
||||
f" from downstream package: '{default_scope}'\n",
|
||||
logger='export',
|
||||
level=logging.INFO)
|
||||
|
||||
# creat temp work_dirs for export
|
||||
cfg['work_dir'] = export_log_tempdir.name
|
||||
|
||||
# use runner to export all needed modules
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
# HARD CODE: In order to deal with some module will build in
|
||||
# ``before_run`` or ``after_run``, we can call them without need
|
||||
# to call "runner.train())".
|
||||
|
||||
# Example:
|
||||
# >>> @HOOKS.register_module()
|
||||
# >>> class EMAHook(Hook):
|
||||
# >>> ...
|
||||
# >>> def before_run(self, runner) -> None:
|
||||
# >>> """Create an ema copy of the model.
|
||||
# >>> Args:
|
||||
# >>> runner (Runner): The runner of the training process.
|
||||
# >>> """
|
||||
# >>> model = runner.model
|
||||
# >>> if is_model_wrapper(model):
|
||||
# >>> model = model.module
|
||||
# >>> self.src_model = model
|
||||
# >>> self.ema_model = MODELS.build(
|
||||
# >>> self.ema_cfg, default_args=dict(model=self.src_model))
|
||||
|
||||
# It need to build ``self.ema_model`` in ``before_run``.
|
||||
|
||||
for hook in runner.hooks:
|
||||
hook.before_run(runner)
|
||||
hook.after_run(runner)
|
||||
|
||||
def dump():
|
||||
cfg['work_dir'] = 'work_dirs' # recover to default.
|
||||
_replace_config_scope_to_pack(cfg)
|
||||
cfg.dump(cfg_path)
|
||||
|
||||
# copy temp log to export log
|
||||
if save_log:
|
||||
shutil.copytree(
|
||||
export_log_tempdir.name, export_log_dir, dirs_exist_ok=True)
|
||||
|
||||
export_log_tempdir.cleanup()
|
||||
|
||||
# copy temp_package_dir to export_package_dir
|
||||
shutil.copytree(
|
||||
export_root_tempdir.name, export_root_dir, dirs_exist_ok=True)
|
||||
export_root_tempdir.cleanup()
|
||||
|
||||
print_log(
|
||||
f'[ Export Package Name ]: '
|
||||
f'{osp.join(os.getcwd(), export_root_dir)}\n',
|
||||
logger='export',
|
||||
level=logging.INFO)
|
||||
|
||||
if model_only:
|
||||
dump()
|
||||
return 0
|
||||
|
||||
try:
|
||||
runner.build_train_loop(cfg.train_cfg)
|
||||
except FileNotFoundError:
|
||||
error_postprocess(export_log_dir, default_scope,
|
||||
export_root_tempdir, export_log_tempdir,
|
||||
osp.basename(cfg_path), 'train_dataloader')
|
||||
|
||||
try:
|
||||
if 'val_cfg' in cfg and cfg.val_cfg is not None:
|
||||
runner.build_val_loop(cfg.val_cfg)
|
||||
except FileNotFoundError:
|
||||
error_postprocess(export_log_dir, default_scope,
|
||||
export_root_tempdir, export_log_tempdir,
|
||||
osp.basename(cfg_path), 'val_dataloader')
|
||||
|
||||
try:
|
||||
if 'test_cfg' in cfg and cfg.test_cfg is not None:
|
||||
runner.build_test_loop(cfg.test_cfg)
|
||||
except FileNotFoundError:
|
||||
error_postprocess(export_log_dir, default_scope,
|
||||
export_root_tempdir, export_log_tempdir,
|
||||
osp.basename(cfg_path), 'test_dataloader')
|
||||
|
||||
if 'optim_wrapper' in cfg and cfg.optim_wrapper is not None:
|
||||
runner.optim_wrapper = runner.build_optim_wrapper(cfg.optim_wrapper)
|
||||
if 'param_scheduler' in cfg and cfg.param_scheduler is not None:
|
||||
runner.build_param_scheduler(cfg.param_scheduler)
|
||||
|
||||
# add ``__init__.py`` to all dirs, for transferring directories
|
||||
# to be modules
|
||||
for directory, _, _ in os.walk(export_module_tempdir_name):
|
||||
if not osp.exists(osp.join(directory, '__init__.py')) \
|
||||
and 'configs' not in directory:
|
||||
with open(osp.join(directory, '__init__.py'), 'w') as f:
|
||||
f.write(_init_str)
|
||||
|
||||
# postprocess for ``pack/registry.py``
|
||||
_postprocess_registry_locations(export_root_tempdir.name)
|
||||
|
||||
# postprocess for ImportFrom Node, turn to import from export path
|
||||
all_export_files = _get_all_files(export_module_tempdir_name)
|
||||
for file in all_export_files:
|
||||
_postprocess_importfrom_module_to_pack(file)
|
||||
|
||||
# get tools from web
|
||||
tools_dir = osp.join(export_root_tempdir.name, 'tools')
|
||||
mkdir_or_exist(tools_dir)
|
||||
|
||||
for tool_name in [
|
||||
'train.py', 'test.py', 'dist_train.sh', 'dist_test.sh',
|
||||
'slurm_train.sh', 'slurm_test.sh'
|
||||
]:
|
||||
pack_tools(
|
||||
tool_name=tool_name,
|
||||
scope=default_scope,
|
||||
tool_dir=tools_dir,
|
||||
auto_import=True)
|
||||
|
||||
# TODO: get demo.py
|
||||
|
||||
dump()
|
||||
return 0
|
||||
|
||||
|
||||
def keyboardinterupt_handler(
|
||||
sig: int,
|
||||
frame,
|
||||
export_root_tempdir: tempfile.TemporaryDirectory,
|
||||
export_log_tempdir: tempfile.TemporaryDirectory,
|
||||
):
|
||||
"""Clear uncompleted exported package by interrupting with keyboard."""
|
||||
|
||||
export_log_tempdir.cleanup()
|
||||
export_root_tempdir.cleanup()
|
||||
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def error_postprocess(export_log_dir: str, scope: str,
|
||||
export_root_dir_tempfile: tempfile.TemporaryDirectory,
|
||||
export_log_dir_tempfile: tempfile.TemporaryDirectory,
|
||||
cfg_name: str, error_key: str):
|
||||
"""Print Debug message when package can't successfully export for missing
|
||||
datasets.
|
||||
|
||||
Args:
|
||||
export_root_dir (str): _description_
|
||||
absolute_cfg_path (str): _description_
|
||||
origin_cfg (ConfigDict): _description_
|
||||
error_key (str): _description_
|
||||
logger (_type_): _description_
|
||||
"""
|
||||
shutil.copytree(
|
||||
export_log_dir_tempfile.name, export_log_dir, dirs_exist_ok=True)
|
||||
export_root_dir_tempfile.cleanup()
|
||||
export_log_dir_tempfile.cleanup()
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
error_msg = f"{'=' * 20} Debug Message {'=' * 20}"\
|
||||
f"\nThe data root of '{error_key}' is not found. You can"\
|
||||
' use the below two method to deal with.\n\n'\
|
||||
" >>> Method 1: Please modify the 'data_root' in"\
|
||||
f" duplicate config '{export_log_dir}/{cfg_name}'.\n"\
|
||||
" >>> Method 2: Use '--model_only' to export model only.\n\n"\
|
||||
"After finishing one of the above steps, you can use 'mim export"\
|
||||
f" {scope} {export_log_dir}/{cfg_name} [--model-only]' to export"\
|
||||
' again.'
|
||||
|
||||
echo_error(error_msg)
|
||||
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def pack_tools(tool_name: str,
|
||||
scope: str,
|
||||
tool_dir: str,
|
||||
auto_import: Optional[bool] = False):
|
||||
"""pack tools from installed repo.
|
||||
|
||||
Args:
|
||||
tool_name (str): Tool name in repos' tool dir.
|
||||
scope (str): The scope of repo.
|
||||
path (str): Path to save tool.
|
||||
auto_import (bool, optional): Automatically add "import pack" to the
|
||||
tool file. Defaults to "False"
|
||||
"""
|
||||
pkg_root = get_installed_path(scope)
|
||||
path = osp.join(tool_dir, tool_name)
|
||||
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
|
||||
# tools will be put in package/.mim in PR #68
|
||||
tool_script = osp.join(pkg_root, '.mim', 'tools', tool_name)
|
||||
if not osp.exists(tool_script):
|
||||
tool_script = osp.join(pkg_root, 'tools', tool_name)
|
||||
|
||||
shutil.copy(tool_script, path)
|
||||
|
||||
# automatically import the pack modules
|
||||
if auto_import:
|
||||
with open(path, 'r+') as f:
|
||||
lines = f.readlines()
|
||||
code = ''.join(lines[:1] + [_import_pack_str] + lines[1:])
|
||||
f.seek(0)
|
||||
f.write(code)
|
|
@ -0,0 +1,75 @@
|
|||
# Patch Utils
|
||||
|
||||
## Problem
|
||||
|
||||
This patch is mainly to solve the problem that the module cannot be properly registered due to the renaming of the registry in the downstream repo, such as an example of the `mmsegmentation`:
|
||||
|
||||
```python
|
||||
# "mmsegmentation/mmseg/structures/sampler/builder.py"
|
||||
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmseg.registry import TASK_UTILS
|
||||
|
||||
PIXEL_SAMPLERS = TASK_UTILS
|
||||
|
||||
|
||||
def build_pixel_sampler(cfg, **default_args):
|
||||
"""Build pixel sampler for segmentation map."""
|
||||
warnings.warn(
|
||||
'``build_pixel_sampler`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
```
|
||||
|
||||
Some modules may use the renamed registry, which makes it difficult for `mim export` to find the original name of the renamed modules.
|
||||
|
||||
```python
|
||||
# "mmsegmentation/mmseg/structures/sampler/ohem_pixel_sampler.py"
|
||||
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base_pixel_sampler import BasePixelSampler
|
||||
from .builder import PIXEL_SAMPLERS
|
||||
|
||||
|
||||
@PIXEL_SAMPLERS.register_module()
|
||||
class OHEMPixelSampler(BasePixelSampler):
|
||||
...
|
||||
```
|
||||
|
||||
## Solution
|
||||
|
||||
Therefore, we have currently migrated the necessary modules in `mmpose/mmdetection/mmseg/mmpretrain` listed below, directly to `patch_utils.patch_model` and `patch_utils.patch_task`. In order to build a patch containing renamed registry and special module constructor functions.
|
||||
|
||||
```python
|
||||
"mmdetection/mmdet/models/task_modules/builder.py"
|
||||
"mmdetection/build/lib/mmdet/models/task_modules/builder.py"
|
||||
|
||||
"mmsegmentation/mmseg/models/builder.py"
|
||||
"mmsegmentation/mmseg/structures/sampler/builder.py"
|
||||
"mmsegmentation/build/lib/mmseg/models/builder.py"
|
||||
"mmsegmentation/build/lib/mmseg/structures/sampler/builder.py"
|
||||
|
||||
"mmpretrain/mmpretrain/datasets/builder.py"
|
||||
"mmpretrain/mmpretrain/models/builder.py"
|
||||
"mmpretrain/build/lib/mmpretrain/datasets/builder.py"
|
||||
"mmpretrain/build/lib/mmpretrain/models/builder.py"
|
||||
|
||||
"mmpose/mmpose/datasets/builder.py"
|
||||
"mmpose/mmpose/models/builder.py"
|
||||
"mmpose/build/lib/mmpose/datasets/builder.py"
|
||||
"mmpose/build/lib/mmpose/models/builder.py"
|
||||
|
||||
"mmengine/mmengine/evaluator/builder.py"
|
||||
"mmengine/mmengine/model/builder.py"
|
||||
"mmengine/mmengine/optim/optimizer/builder.py"
|
||||
"mmengine/mmengine/visualization/builder.py"
|
||||
"mmengine/build/lib/mmengine/evaluator/builder.py"
|
||||
"mmengine/build/lib/mmengine/model/builder.py"
|
||||
"mmengine/build/lib/mmengine/optim/optimizer/builder.py"
|
||||
```
|
|
@ -0,0 +1,75 @@
|
|||
# Patch Utils
|
||||
|
||||
## 问题
|
||||
|
||||
该补丁主要是为了解决下游 repo 中存在对注册器进行重命名导致模块无法被正确注册的问题,如 `mmsegmentation` 中的一个例子:
|
||||
|
||||
```python
|
||||
# "mmsegmentation/mmseg/structures/sampler/builder.py"
|
||||
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmseg.registry import TASK_UTILS
|
||||
|
||||
PIXEL_SAMPLERS = TASK_UTILS
|
||||
|
||||
|
||||
def build_pixel_sampler(cfg, **default_args):
|
||||
"""Build pixel sampler for segmentation map."""
|
||||
warnings.warn(
|
||||
'``build_pixel_sampler`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
```
|
||||
|
||||
然后在某些模块中可能会使用改名后的注册器,这对于导出后处理很难找到重命名模块原来的名字
|
||||
|
||||
```python
|
||||
# "mmsegmentation/mmseg/structures/sampler/ohem_pixel_sampler.py"
|
||||
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base_pixel_sampler import BasePixelSampler
|
||||
from .builder import PIXEL_SAMPLERS
|
||||
|
||||
|
||||
@PIXEL_SAMPLERS.register_module()
|
||||
class OHEMPixelSampler(BasePixelSampler):
|
||||
...
|
||||
```
|
||||
|
||||
## 解决方案
|
||||
|
||||
因此我们目前已经将如下 `mmpose / mmdetection / mmseg / mmpretrain` 中必要的模块直接迁移到 `patch_utils.patch_model` 和 `patch_utils.patch_task` 中构建一个包含注册器重命名和特殊模块构造函数的补丁。
|
||||
|
||||
```python
|
||||
"mmdetection/mmdet/models/task_modules/builder.py"
|
||||
"mmdetection/build/lib/mmdet/models/task_modules/builder.py"
|
||||
|
||||
"mmsegmentation/mmseg/models/builder.py"
|
||||
"mmsegmentation/mmseg/structures/sampler/builder.py"
|
||||
"mmsegmentation/build/lib/mmseg/models/builder.py"
|
||||
"mmsegmentation/build/lib/mmseg/structures/sampler/builder.py"
|
||||
|
||||
"mmpretrain/mmpretrain/datasets/builder.py"
|
||||
"mmpretrain/mmpretrain/models/builder.py"
|
||||
"mmpretrain/build/lib/mmpretrain/datasets/builder.py"
|
||||
"mmpretrain/build/lib/mmpretrain/models/builder.py"
|
||||
|
||||
"mmpose/mmpose/datasets/builder.py"
|
||||
"mmpose/mmpose/models/builder.py"
|
||||
"mmpose/build/lib/mmpose/datasets/builder.py"
|
||||
"mmpose/build/lib/mmpose/models/builder.py"
|
||||
|
||||
"mmengine/mmengine/evaluator/builder.py"
|
||||
"mmengine/mmengine/model/builder.py"
|
||||
"mmengine/mmengine/optim/optimizer/builder.py"
|
||||
"mmengine/mmengine/visualization/builder.py"
|
||||
"mmengine/build/lib/mmengine/evaluator/builder.py"
|
||||
"mmengine/build/lib/mmengine/model/builder.py"
|
||||
"mmengine/build/lib/mmengine/optim/optimizer/builder.py"
|
||||
```
|
|
@ -0,0 +1,3 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .patch_model import * # noqa: F401, F403
|
||||
from .patch_task import * # noqa: F401, F403
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmengine.registry import MODELS
|
||||
|
||||
BACKBONES = MODELS
|
||||
NECKS = MODELS
|
||||
HEADS = MODELS
|
||||
LOSSES = MODELS
|
||||
SEGMENTORS = MODELS
|
||||
|
||||
|
||||
def build_backbone(cfg):
|
||||
"""Build backbone."""
|
||||
warnings.warn('``build_backbone`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return BACKBONES.build(cfg)
|
||||
|
||||
|
||||
def build_neck(cfg):
|
||||
"""Build neck."""
|
||||
warnings.warn('``build_neck`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return NECKS.build(cfg)
|
||||
|
||||
|
||||
def build_head(cfg):
|
||||
"""Build head."""
|
||||
warnings.warn('``build_head`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return HEADS.build(cfg)
|
||||
|
||||
|
||||
def build_loss(cfg):
|
||||
"""Build loss."""
|
||||
warnings.warn('``build_loss`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return LOSSES.build(cfg)
|
||||
|
||||
|
||||
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
|
||||
"""Build segmentor."""
|
||||
if train_cfg is not None or test_cfg is not None:
|
||||
warnings.warn(
|
||||
'train_cfg and test_cfg is deprecated, '
|
||||
'please specify them in model', UserWarning)
|
||||
assert cfg.get('train_cfg') is None or train_cfg is None, \
|
||||
'train_cfg specified in both outer field and model field '
|
||||
assert cfg.get('test_cfg') is None or test_cfg is None, \
|
||||
'test_cfg specified in both outer field and model field '
|
||||
return SEGMENTORS.build(
|
||||
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
|
||||
|
||||
|
||||
CLASSIFIERS = MODELS
|
||||
RETRIEVER = MODELS
|
||||
|
||||
|
||||
def build_classifier(cfg):
|
||||
"""Build classifier."""
|
||||
return CLASSIFIERS.build(cfg)
|
||||
|
||||
|
||||
def build_retriever(cfg):
|
||||
"""Build retriever."""
|
||||
return RETRIEVER.build(cfg)
|
||||
|
||||
|
||||
POSE_ESTIMATORS = MODELS
|
||||
|
||||
|
||||
def build_pose_estimator(cfg):
|
||||
"""Build pose estimator."""
|
||||
return POSE_ESTIMATORS.build(cfg)
|
||||
|
||||
|
||||
def build_posenet(cfg):
|
||||
"""Build posenet."""
|
||||
warnings.warn(
|
||||
'``build_posenet`` will be deprecated soon, '
|
||||
'please use ``build_pose_estimator`` instead.', DeprecationWarning)
|
||||
return build_pose_estimator(cfg)
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmengine.registry import TASK_UTILS
|
||||
|
||||
PRIOR_GENERATORS = TASK_UTILS
|
||||
ANCHOR_GENERATORS = TASK_UTILS
|
||||
BBOX_ASSIGNERS = TASK_UTILS
|
||||
BBOX_SAMPLERS = TASK_UTILS
|
||||
BBOX_CODERS = TASK_UTILS
|
||||
MATCH_COSTS = TASK_UTILS
|
||||
IOU_CALCULATORS = TASK_UTILS
|
||||
|
||||
|
||||
def build_bbox_coder(cfg, **default_args):
|
||||
"""Builder of box coder."""
|
||||
warnings.warn('``build_sampler`` would be deprecated soon, please use '
|
||||
'``mmdet.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
|
||||
|
||||
def build_iou_calculator(cfg, default_args=None):
|
||||
"""Builder of IoU calculator."""
|
||||
warnings.warn(
|
||||
'``build_iou_calculator`` would be deprecated soon, please use '
|
||||
'``mmdet.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
|
||||
|
||||
def build_match_cost(cfg, default_args=None):
|
||||
"""Builder of IoU calculator."""
|
||||
warnings.warn('``build_match_cost`` would be deprecated soon, please use '
|
||||
'``mmdet.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
|
||||
|
||||
def build_assigner(cfg, **default_args):
|
||||
"""Builder of box assigner."""
|
||||
warnings.warn('``build_assigner`` would be deprecated soon, please use '
|
||||
'``mmdet.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
|
||||
|
||||
def build_sampler(cfg, **default_args):
|
||||
"""Builder of box sampler."""
|
||||
warnings.warn('``build_sampler`` would be deprecated soon, please use '
|
||||
'``mmdet.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
|
||||
|
||||
def build_prior_generator(cfg, default_args=None):
|
||||
warnings.warn(
|
||||
'``build_prior_generator`` would be deprecated soon, please use '
|
||||
'``mmdet.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
|
||||
|
||||
def build_anchor_generator(cfg, default_args=None):
|
||||
warnings.warn(
|
||||
'``build_anchor_generator`` would be deprecated soon, please use '
|
||||
'``mmdet.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
|
||||
|
||||
PIXEL_SAMPLERS = TASK_UTILS
|
||||
|
||||
|
||||
def build_pixel_sampler(cfg, **default_args):
|
||||
"""Build pixel sampler for segmentation map."""
|
||||
warnings.warn(
|
||||
'``build_pixel_sampler`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
|
@ -0,0 +1,555 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import ast
|
||||
import copy
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
import torch.nn
|
||||
from mmengine import mkdir_or_exist
|
||||
from mmengine.config import ConfigDict
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.model import (
|
||||
BaseDataPreprocessor,
|
||||
BaseModel,
|
||||
BaseModule,
|
||||
ImgDataPreprocessor,
|
||||
)
|
||||
from mmengine.registry import Registry
|
||||
from yapf.yapflib.yapf_api import FormatCode
|
||||
|
||||
from mim.utils import OFFICIAL_MODULES
|
||||
from .common import REGISTRY_TYPES
|
||||
from .flatten_func import * # noqa: F403, F401
|
||||
from .flatten_func import (
|
||||
ImportResolverTransformer,
|
||||
RegisterModuleTransformer,
|
||||
flatten_inheritance_chain,
|
||||
ignore_ast_docstring,
|
||||
postprocess_super_call,
|
||||
)
|
||||
|
||||
|
||||
def format_code(code_text: str):
|
||||
"""Format the code text with yapf."""
|
||||
yapf_style = dict(
|
||||
based_on_style='pep8',
|
||||
blank_line_before_nested_class_or_def=True,
|
||||
split_before_expression_after_opening_paren=True)
|
||||
try:
|
||||
code_text, _ = FormatCode(code_text, style_config=yapf_style)
|
||||
except: # noqa: E722
|
||||
raise SyntaxError('Failed to format the config file, please '
|
||||
f'check the syntax of: \n{code_text}')
|
||||
|
||||
return code_text
|
||||
|
||||
|
||||
def _postprocess_registry_locations(export_root_dir: str):
|
||||
"""Remove the Registry.locations if it doesn't exist.
|
||||
|
||||
Check the location path for Registry to load modules if the path hasn't
|
||||
been exported, then need to be removed. Finally will use the root Registry
|
||||
to find module until it actually doesn't exist.
|
||||
"""
|
||||
export_module_dir = osp.join(export_root_dir, 'pack')
|
||||
|
||||
with open(
|
||||
osp.join(export_module_dir, 'registry.py'), encoding='utf-8') as f:
|
||||
ast_tree = ast.parse(f.read())
|
||||
|
||||
for node in ast.walk(ast_tree):
|
||||
"""node structure.
|
||||
|
||||
Assign( targets=[ Name(id='EVALUATORS', ctx=Store())],
|
||||
value=Call( func=Name(id='Registry', ctx=Load()), args=[
|
||||
Constant(value='evaluator')], keywords=[ keyword( arg='parent',
|
||||
value=Name(id='MMENGINE_EVALUATOR', ctx=Load())), keyword(
|
||||
arg='locations', value=List( elts=[
|
||||
Constant(value='pack.evaluation')], ctx=Load()))])),
|
||||
"""
|
||||
if isinstance(node, ast.Call):
|
||||
need_to_be_remove = None
|
||||
|
||||
for keyword in node.keywords:
|
||||
if keyword.arg == 'locations':
|
||||
for sub_node in ast.walk(keyword):
|
||||
|
||||
# the locations of Registry already transfer to `pack`
|
||||
# scope before. if the location path is exist, then
|
||||
# turn to pack scope
|
||||
if isinstance(
|
||||
sub_node,
|
||||
ast.Constant) and 'pack' in sub_node.value:
|
||||
|
||||
path = sub_node.value
|
||||
if not osp.exists(
|
||||
osp.join(export_root_dir, path).replace(
|
||||
'.', osp.sep)):
|
||||
print_log(
|
||||
'[ Pass ] Remove Registry.locations '
|
||||
f"'{osp.join(export_root_dir, path).replace('.',osp.sep)}', " # noqa: E501
|
||||
'which is no need to export.',
|
||||
logger='export',
|
||||
level=logging.DEBUG)
|
||||
need_to_be_remove = keyword
|
||||
break
|
||||
|
||||
if need_to_be_remove is not None:
|
||||
break
|
||||
|
||||
if need_to_be_remove is not None:
|
||||
node.keywords.remove(need_to_be_remove)
|
||||
|
||||
with open(
|
||||
osp.join(export_module_dir, 'registry.py'), 'w',
|
||||
encoding='utf-8') as f:
|
||||
f.write(format_code(ast.unparse(ast_tree)))
|
||||
|
||||
|
||||
def _get_all_files(directory: str):
|
||||
"""Get all files of the directory.
|
||||
|
||||
Args:
|
||||
directory (str): The directory path.
|
||||
|
||||
Returns:
|
||||
List: Return the a list containing all the files in the directory.
|
||||
"""
|
||||
file_paths = []
|
||||
for root, dirs, files in os.walk(directory):
|
||||
for file in files:
|
||||
if '__init__' not in file and 'registry.py' not in file:
|
||||
file_paths.append(os.path.join(root, file))
|
||||
|
||||
return file_paths
|
||||
|
||||
|
||||
def _postprocess_importfrom_module_to_pack(file_path: str):
|
||||
"""Transfer the importfrom path from "downstream repo" to export module.
|
||||
|
||||
Args:
|
||||
file_path (str): The path of file needed to be transfer.
|
||||
|
||||
Examples:
|
||||
>>> from mmdet.models.detectors.two_stage import TwoStageDetector
|
||||
>>> # transfer to below, if "TwoStageDetector" had been exported
|
||||
>>> from pack.models.detectors.two_stage import TwoStageDetector
|
||||
"""
|
||||
from mmengine import Registry
|
||||
|
||||
# _module_path_dict is a class attribute,
|
||||
# already record all the exported module and their path before
|
||||
_module_path_dict = Registry._module_path_dict
|
||||
|
||||
with open(file_path, encoding='utf-8') as f:
|
||||
ast_tree = ast.parse(f.read())
|
||||
|
||||
# if the import module have the same name with the object in these file,
|
||||
# they import path won't be change
|
||||
can_not_change_module = []
|
||||
for node in ast_tree.body:
|
||||
if isinstance(node, ast.FunctionDef) or isinstance(node, ast.ClassDef):
|
||||
can_not_change_module.append(node.name)
|
||||
|
||||
def check_change_importfrom_node(node: ast.ImportFrom):
|
||||
"""Check if the ImportFrom node should be changed.
|
||||
|
||||
If the modules in node had already been exported, they will be
|
||||
separated and compose a new ast.ImportFrom node with the export
|
||||
path as the module path.
|
||||
|
||||
Args:
|
||||
node (ast.ImportFrom): ImportFrom node.
|
||||
|
||||
Returns:
|
||||
ast.ImportFrom | None: Return a new ast.ImportFrom node
|
||||
if one of the module in node had been export else ``None``.
|
||||
"""
|
||||
export_module_path = None
|
||||
needed_change_alias = []
|
||||
|
||||
for alias in node.names:
|
||||
# if the import module's name is equal to the class or function
|
||||
# name, it can not be transfer for avoiding circular import.
|
||||
if alias.name in _module_path_dict.keys(
|
||||
) and alias.name not in can_not_change_module:
|
||||
|
||||
if export_module_path is None:
|
||||
export_module_path = _module_path_dict[alias.name]
|
||||
else:
|
||||
assert _module_path_dict[alias.name] == \
|
||||
export_module_path,\
|
||||
'There are two module from the same downstream repo,'\
|
||||
" but can't change to the same export path."
|
||||
|
||||
needed_change_alias.append(alias)
|
||||
|
||||
if len(needed_change_alias) != 0:
|
||||
for alias in needed_change_alias:
|
||||
node.names.remove(alias)
|
||||
|
||||
return ast.ImportFrom(
|
||||
module=export_module_path, names=needed_change_alias, level=0)
|
||||
|
||||
return None
|
||||
|
||||
# Naming rules for searching ast syntax tree
|
||||
# - node: node of ast.Module
|
||||
# - func_sub_node: sub_node of ast.FunctionDef
|
||||
# - class_sub_node: sub_node of ast.ClassDef
|
||||
# - func_sub_class_sub_node: sub_node ast.FunctionDef in ast.ClassDef
|
||||
|
||||
# record the insert_idx and node needed to be insert for later insert.
|
||||
insert_idx_and_node = {}
|
||||
|
||||
insert_idx = 0
|
||||
for idx, node in enumerate(ast_tree.body):
|
||||
|
||||
# search ast.ImportFrom in ast.Module scope
|
||||
# ast.Module -> ast.ImportFrom
|
||||
if isinstance(node, ast.ImportFrom):
|
||||
insert_idx += 1
|
||||
temp_node = check_change_importfrom_node(node)
|
||||
if temp_node is not None:
|
||||
if len(node.names) == 0:
|
||||
ast_tree.body[idx] = temp_node
|
||||
else:
|
||||
insert_idx_and_node[insert_idx] = temp_node
|
||||
insert_idx += 1
|
||||
|
||||
elif isinstance(node, ast.Import):
|
||||
insert_idx += 1
|
||||
|
||||
else:
|
||||
# search ast.ImportFrom in ast.FunctionDef scope
|
||||
# ast.Module -> ast.FunctionDef -> ast.ImportFrom
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
temp_func_insert_idx = ignore_ast_docstring(node)
|
||||
func_need_to_be_removed_nodes = []
|
||||
|
||||
for func_sub_node in node.body:
|
||||
if isinstance(func_sub_node, ast.ImportFrom):
|
||||
temp_node = check_change_importfrom_node(
|
||||
func_sub_node) # noqa: E501
|
||||
if temp_node is not None:
|
||||
node.body.insert(temp_func_insert_idx, temp_node)
|
||||
|
||||
# if importfrom module is empty, the node should be remove # noqa: E501
|
||||
if len(func_sub_node.names) == 0:
|
||||
func_need_to_be_removed_nodes.append(
|
||||
func_sub_node) # noqa: E501
|
||||
|
||||
for need_to_be_removed_node in func_need_to_be_removed_nodes:
|
||||
node.body.remove(need_to_be_removed_node)
|
||||
|
||||
# search ast.ImportFrom in ast.ClassDef scope
|
||||
# ast.Module -> ast.ClassDef -> ast.ImportFrom
|
||||
# -> ast.FunctionDef -> ast.ImportFrom
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
temp_class_insert_idx = ignore_ast_docstring(node)
|
||||
class_need_to_be_removed_nodes = []
|
||||
|
||||
for class_sub_node in node.body:
|
||||
|
||||
# ast.Module -> ast.ClassDef -> ast.ImportFrom
|
||||
if isinstance(class_sub_node, ast.ImportFrom):
|
||||
temp_node = check_change_importfrom_node(
|
||||
class_sub_node)
|
||||
if temp_node is not None:
|
||||
node.body.insert(temp_class_insert_idx, temp_node)
|
||||
if len(class_sub_node.names) == 0:
|
||||
class_need_to_be_removed_nodes.append(
|
||||
class_sub_node)
|
||||
|
||||
# ast.Module -> ast.ClassDef -> ast.FunctionDef -> ast.ImportFrom # noqa: E501
|
||||
elif isinstance(class_sub_node, ast.FunctionDef):
|
||||
temp_class_sub_insert_idx = ignore_ast_docstring(node)
|
||||
func_need_to_be_removed_nodes = []
|
||||
|
||||
for func_sub_class_sub_node in class_sub_node.body:
|
||||
if isinstance(func_sub_class_sub_node,
|
||||
ast.ImportFrom):
|
||||
temp_node = check_change_importfrom_node(
|
||||
func_sub_class_sub_node)
|
||||
if temp_node is not None:
|
||||
node.body.insert(temp_class_sub_insert_idx,
|
||||
temp_node)
|
||||
if len(func_sub_class_sub_node.names) == 0:
|
||||
func_need_to_be_removed_nodes.append(
|
||||
func_sub_class_sub_node)
|
||||
|
||||
for need_to_be_removed_node in func_need_to_be_removed_nodes: # noqa: E501
|
||||
class_sub_node.body.remove(need_to_be_removed_node)
|
||||
|
||||
for class_need_to_be_removed_node in class_need_to_be_removed_nodes: # noqa: E501
|
||||
node.body.remove(class_need_to_be_removed_node)
|
||||
|
||||
# lazy add new ast.ImportFrom node to ast.Module
|
||||
for insert_idx, temp_node in insert_idx_and_node.items():
|
||||
ast_tree.body.insert(insert_idx, temp_node)
|
||||
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(format_code(ast.unparse(ast_tree)))
|
||||
|
||||
|
||||
def _replace_config_scope_to_pack(cfg: ConfigDict):
|
||||
"""Replace the config scope from "mmxxx" to "pack".
|
||||
|
||||
Args:
|
||||
cfg (ConfigDict): The config dict to be replaced.
|
||||
"""
|
||||
|
||||
for key, value in cfg.items():
|
||||
if key == '_scope_' or key == 'default_scope':
|
||||
cfg[key] = 'pack'
|
||||
elif isinstance(value, dict):
|
||||
_replace_config_scope_to_pack(value)
|
||||
|
||||
|
||||
def _wrapper_all_registries_build_func(export_module_dir: str, scope: str):
|
||||
"""A function to wrap all registries' build_func.
|
||||
|
||||
Args:
|
||||
pack_module_dir (str): The root dir for packing modules.
|
||||
scope (str): The default scope of the config.
|
||||
"""
|
||||
# copy the downstream repo.registry to pack.registry
|
||||
# and change all the registry.locations
|
||||
repo_registries = importlib.import_module('.registry', scope)
|
||||
origin_file = inspect.getfile(repo_registries)
|
||||
registry_path = osp.join(export_module_dir, 'registry.py')
|
||||
shutil.copy(origin_file, registry_path)
|
||||
|
||||
# replace 'repo' name in Registry.locations to 'pack'
|
||||
with open(
|
||||
osp.join(export_module_dir, 'registry.py'), encoding='utf-8') as f:
|
||||
ast_tree = ast.parse(f.read())
|
||||
for node in ast.walk(ast_tree):
|
||||
if isinstance(node, ast.Constant):
|
||||
if scope in node.value:
|
||||
node.value = node.value.replace(scope, 'pack')
|
||||
|
||||
with open(osp.join(export_module_dir, 'registry.py'), 'w') as f:
|
||||
f.write(format_code(ast.unparse(ast_tree)))
|
||||
|
||||
# prevent circular registration
|
||||
Registry._extra_module_set = set()
|
||||
|
||||
# record the exported module for postprocessing the importfrom path
|
||||
Registry._module_path_dict = {}
|
||||
|
||||
# prevent circular wrapper
|
||||
if Registry.build.__name__ == 'wrapper':
|
||||
Registry.build = _wrap_build(Registry.init_build_func,
|
||||
export_module_dir)
|
||||
Registry.get = _wrap_get(Registry.init_get_func, export_module_dir)
|
||||
else:
|
||||
Registry.init_build_func = copy.deepcopy(Registry.build)
|
||||
Registry.init_get_func = copy.deepcopy(Registry.get)
|
||||
Registry.build = _wrap_build(Registry.build, export_module_dir)
|
||||
Registry.get = _wrap_get(Registry.get, export_module_dir)
|
||||
|
||||
|
||||
def ignore_self_cache(func):
|
||||
"""Ignore the ``@lru_cache`` for function.
|
||||
|
||||
Args:
|
||||
func (Callable): The function to be ignored.
|
||||
|
||||
Returns:
|
||||
Callable: The function without ``@lru_cache``.
|
||||
"""
|
||||
cache = {}
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
key = args
|
||||
if key not in cache:
|
||||
cache[key] = 1
|
||||
func(self, *args, **kwargs)
|
||||
else:
|
||||
return
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@ignore_self_cache
|
||||
def _export_module(self, obj_cls: type, pack_module_dir, obj_type: str):
|
||||
"""Export module.
|
||||
|
||||
This function will get the object's file and export to
|
||||
``pack_module_dir``.
|
||||
|
||||
If the object is built by ``MODELS`` registry, all the objects
|
||||
as the top classes in this file, will be iteratively flattened.
|
||||
Else will be directly exported.
|
||||
|
||||
The flatten logic is:
|
||||
1. get the origin file of object, which built
|
||||
by ``MODELS.build()``
|
||||
2. get all the classes in the origin file
|
||||
3. flatten all the classes but not only the object
|
||||
4. call ``flatten_module()`` to finish flatten
|
||||
according to ``class.mro()``
|
||||
|
||||
Args:
|
||||
obj (object): The object to be flatten.
|
||||
"""
|
||||
# find the file by obj class
|
||||
file_path = inspect.getfile(obj_cls)
|
||||
|
||||
if osp.exists(file_path):
|
||||
print_log(
|
||||
f'building class: '
|
||||
f'{obj_cls.__name__} from file: {file_path}.',
|
||||
logger='export',
|
||||
level=logging.DEBUG)
|
||||
else:
|
||||
raise FileExistsError(f"file [{file_path}] doesn't exist.")
|
||||
|
||||
# local origin module
|
||||
module = obj_cls.__module__
|
||||
parent = module.split('.')[0]
|
||||
new_module = module.replace(parent, 'pack')
|
||||
|
||||
# Not necessary to export module implemented in `mmcv` and `mmengine`
|
||||
if parent in set(OFFICIAL_MODULES) - {'mmcv', 'mmengine'}:
|
||||
|
||||
with open(file_path, encoding='utf-8') as f:
|
||||
top_ast_tree = ast.parse(f.read())
|
||||
|
||||
# deal with relative import
|
||||
ImportResolverTransformer(module).visit(top_ast_tree)
|
||||
|
||||
# NOTE: ``MODELS.build()`` means to flatten model module
|
||||
if self.name == 'model':
|
||||
|
||||
# record all the class needed to be flattened
|
||||
need_to_be_flattened_class_names = []
|
||||
for node in top_ast_tree.body:
|
||||
if isinstance(node, ast.ClassDef):
|
||||
need_to_be_flattened_class_names.append(node.name)
|
||||
|
||||
imported_module = importlib.import_module(obj_cls.__module__)
|
||||
for cls_name in need_to_be_flattened_class_names:
|
||||
|
||||
# record the exported module for postprocessing the importfrom path # noqa: E501
|
||||
self._module_path_dict[cls_name] = new_module
|
||||
|
||||
cls = getattr(imported_module, cls_name)
|
||||
|
||||
for super_cls in cls.__bases__:
|
||||
|
||||
# the class only will be flattened when:
|
||||
# 1. super class doesn't exist in this file
|
||||
# 2. and super class is not base class
|
||||
# 3. and super class is not torch module
|
||||
if super_cls.__name__\
|
||||
not in need_to_be_flattened_class_names \
|
||||
and (super_cls not in [BaseModule,
|
||||
BaseModel,
|
||||
BaseDataPreprocessor,
|
||||
ImgDataPreprocessor]) \
|
||||
and 'torch' not in super_cls.__module__: # noqa: E501
|
||||
|
||||
print_log(
|
||||
f'need_flatten: {cls_name}\n',
|
||||
logger='export',
|
||||
level=logging.INFO)
|
||||
|
||||
flatten_inheritance_chain(top_ast_tree, cls)
|
||||
break
|
||||
postprocess_super_call(top_ast_tree)
|
||||
|
||||
else:
|
||||
self._module_path_dict[obj_cls.__name__] = new_module
|
||||
|
||||
# add ``register_module(force=True)`` to cover the registered modules # noqa: E501
|
||||
RegisterModuleTransformer().visit(top_ast_tree)
|
||||
|
||||
# unparse ast tree and save reformat code
|
||||
new_file_path = new_module.split('.', 1)[1].replace('.',
|
||||
osp.sep) + '.py'
|
||||
new_file_path = osp.join(pack_module_dir, new_file_path)
|
||||
new_dir = osp.dirname(new_file_path)
|
||||
mkdir_or_exist(new_dir)
|
||||
|
||||
with open(new_file_path, mode='w') as f:
|
||||
f.write(format_code(ast.unparse(top_ast_tree)))
|
||||
|
||||
# Downstream repo could register torch module into Registry, such as
|
||||
# registering `torch.nn.Linear` into `MODELS`. We need to reserve these
|
||||
# codes in the exported module.
|
||||
elif 'torch' in module.split('.')[0]:
|
||||
|
||||
# get the root registry, because it can get all the modules
|
||||
# had been registered.
|
||||
root_registry = self if self.parent is None else self.parent
|
||||
if (obj_type not in self._extra_module_set) and (
|
||||
root_registry.init_get_func(obj_type) is None):
|
||||
self._extra_module_set.add(obj_type)
|
||||
with open(osp.join(pack_module_dir, 'registry.py'), 'a') as f:
|
||||
|
||||
# TODO: When the downstream repo registries' name are
|
||||
# different with mmengine, the module may not be registried
|
||||
# to the right register.
|
||||
# For example: `EVALUATOR` in mmengine, `EVALUATORS` in mmdet.
|
||||
f.write('\n')
|
||||
f.write(f'from {module} import {obj_cls.__name__}\n')
|
||||
f.write(
|
||||
f"{REGISTRY_TYPES[self.name]}.register_module('{obj_type}', module={obj_cls.__name__}, force=True)" # noqa: E501
|
||||
)
|
||||
|
||||
|
||||
def _wrap_build(build_func: Callable, pack_module_dir: str):
|
||||
"""wrap Registry.build()
|
||||
|
||||
Args:
|
||||
build_func (Callable): ``Registry.build()``, which will be wrapped.
|
||||
pack_module_dir (str): Modules export path.
|
||||
"""
|
||||
|
||||
def wrapper(self, cfg: dict, *args, **kwargs):
|
||||
|
||||
# obj is class instanace
|
||||
obj = build_func(self, cfg, *args, **kwargs)
|
||||
args = cfg.copy() # type: ignore
|
||||
obj_type = args.pop('type') # type: ignore
|
||||
obj_type = obj_type if isinstance(obj_type, str) else obj_type.__name__
|
||||
|
||||
# modules in ``torch.nn.Sequential`` should be respectively exported
|
||||
if isinstance(obj, torch.nn.Sequential):
|
||||
for children in obj.children():
|
||||
_export_module(self, children.__class__, pack_module_dir,
|
||||
obj_type)
|
||||
else:
|
||||
_export_module(self, obj.__class__, pack_module_dir, obj_type)
|
||||
|
||||
return obj
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _wrap_get(get_func: Callable, pack_module_dir: str):
|
||||
"""wrap Registry.get()
|
||||
|
||||
Args:
|
||||
get_func (Callable): ``Registry.get()``, which will be wrapped.
|
||||
pack_module_dir (str): Modules export path.
|
||||
"""
|
||||
|
||||
def wrapper(self, key: str):
|
||||
|
||||
obj_cls = get_func(self, key)
|
||||
|
||||
_export_module(self, obj_cls, pack_module_dir, key)
|
||||
|
||||
return obj_cls
|
||||
|
||||
return wrapper
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mim.commands.list import list_package
|
||||
from mim.utils.default import OFFICIAL_MODULES
|
||||
|
||||
|
||||
def get_installed_package(ctx, args, incomplete):
|
||||
|
@ -19,19 +20,4 @@ def get_downstream_package(ctx, args, incomplete):
|
|||
|
||||
|
||||
def get_official_package(ctx=None, args=None, incomplete=None):
|
||||
return [
|
||||
'mmcls',
|
||||
'mmdet',
|
||||
'mmdet3d',
|
||||
'mmseg',
|
||||
'mmaction2',
|
||||
'mmtrack',
|
||||
'mmpose',
|
||||
'mmedit',
|
||||
'mmocr',
|
||||
'mmgen',
|
||||
'mmselfsup'
|
||||
'mmrotate',
|
||||
'mmflow',
|
||||
'mmyolo',
|
||||
]
|
||||
return OFFICIAL_MODULES
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import sys
|
||||
|
||||
import click
|
||||
from mmengine.config import Config
|
||||
from mmengine.hub import get_config
|
||||
|
||||
from mim._internal.export.pack_cfg import export_from_cfg
|
||||
from mim.click import CustomCommand
|
||||
|
||||
PYTHON = sys.executable
|
||||
|
||||
|
||||
@click.command(
|
||||
name='export',
|
||||
context_settings=dict(ignore_unknown_options=True),
|
||||
cls=CustomCommand)
|
||||
@click.argument('package', type=str)
|
||||
@click.argument('config', type=str)
|
||||
@click.argument('export_dir', type=str, default=None, required=False)
|
||||
@click.option(
|
||||
'-f',
|
||||
'--fast-test',
|
||||
is_flag=True,
|
||||
help='The fast_test mode. In order to quickly test if'
|
||||
' there is any error in the export package,'
|
||||
' it only use the first two data of your datasets'
|
||||
' which only be used to train 2 iters/epoches.')
|
||||
@click.option(
|
||||
'--save-log',
|
||||
is_flag=True,
|
||||
help='The flag to keep the export log of the process. The log of export'
|
||||
" process will be save to directory 'export_log'. Default will"
|
||||
' automatically delete the log after export.')
|
||||
def cli(config: str,
|
||||
package: str,
|
||||
export_dir: str,
|
||||
fast_test: bool = False,
|
||||
save_log: bool = False) -> None:
|
||||
"""Export package from config file.
|
||||
|
||||
Example:
|
||||
|
||||
\b
|
||||
>>> # Export package from downstream config file.
|
||||
>>> mim export mmdet dab_detr/dab-detr_r50_8xb2-50e_coco.py \\
|
||||
... dab_detr
|
||||
>>>
|
||||
>>> # Export package from specified config file.
|
||||
>>> mim export mmdet mmdetection/configs/dab_detr/dab-detr_r50_8xb2-50e_coco.py dab_detr # noqa: E501
|
||||
>>>
|
||||
>>> # It can auto generate a export dir when not specified.
|
||||
>>> # like:'pack_from_mmdet_20231026_052704.
|
||||
>>> mim export mmdet dab_detr/dab-detr_r50_8xb2-50e_coco.py
|
||||
>>>
|
||||
>>> # Only export the model of config file.
|
||||
>>> mim export mmdet dab_detr/dab-detr_r50_8xb2-50e_coco.py \\
|
||||
... mask_rcnn_package --model-only
|
||||
>>>
|
||||
>>> # Keep the export log of the process.
|
||||
>>> mim export mmdet dab_detr/dab-detr_r50_8xb2-50e_coco.py \\
|
||||
... mask_rcnn_package --save-log
|
||||
>>>
|
||||
>>> # Print the help information of export command.
|
||||
>>> mim export -h
|
||||
"""
|
||||
|
||||
# get config
|
||||
if osp.exists(config):
|
||||
config = Config.fromfile(config) # from local
|
||||
else:
|
||||
try:
|
||||
config = get_config(package + '::' + config) # from downstream
|
||||
except Exception:
|
||||
raise FileNotFoundError(
|
||||
f"Config file '{config}' or '{package + '::' + config}'.")
|
||||
|
||||
fast_test_mode(config, fast_test)
|
||||
|
||||
export_from_cfg(config, export_root_dir=export_dir, save_log=save_log)
|
||||
|
||||
|
||||
def fast_test_mode(cfg, fast_test: bool = False):
|
||||
"""Use less data for faster testing.
|
||||
|
||||
Args:
|
||||
cfg (Config): Config of export package.
|
||||
fast_test (bool, optional): Fast testing mode. Defaults to False.
|
||||
"""
|
||||
if fast_test:
|
||||
# for batch_norm using at least 2 data
|
||||
if 'dataset' in cfg.train_dataloader.dataset:
|
||||
cfg.train_dataloader.dataset.dataset.indices = [0, 1]
|
||||
else:
|
||||
cfg.train_dataloader.dataset.indices = [0, 1]
|
||||
cfg.train_dataloader.batch_size = 2
|
||||
|
||||
if cfg.get('test_dataloader') is not None:
|
||||
cfg.test_dataloader.dataset.indices = [0, 1]
|
||||
cfg.test_dataloader.batch_size = 2
|
||||
|
||||
if cfg.get('val_dataloader') is not None:
|
||||
cfg.val_dataloader.dataset.indices = [0, 1]
|
||||
cfg.val_dataloader.batch_size = 2
|
||||
|
||||
if (cfg.train_cfg.get('type') == 'IterBasedTrainLoop') \
|
||||
or (cfg.train_cfg.get('by_epoch') is None
|
||||
and cfg.train_cfg.get('type') != 'EpochBasedTrainLoop'):
|
||||
cfg.train_cfg.max_iters = 2
|
||||
else:
|
||||
cfg.train_cfg.max_epochs = 2
|
||||
|
||||
cfg.train_cfg.val_interval = 1
|
||||
cfg.default_hooks.logger.interval = 1
|
||||
|
||||
if 'param_scheduler' in cfg and cfg.param_scheduler is not None:
|
||||
if isinstance(cfg.param_scheduler, list):
|
||||
for lr_sc in cfg.param_scheduler:
|
||||
lr_sc.begin = 0
|
||||
lr_sc.end = 2
|
||||
else:
|
||||
cfg.param_scheduler.begin = 0
|
||||
cfg.param_scheduler.end = 2
|
|
@ -4,6 +4,7 @@ from .default import (
|
|||
DEFAULT_MMCV_BASE_URL,
|
||||
DEFAULT_URL,
|
||||
MODULE2PKG,
|
||||
OFFICIAL_MODULES,
|
||||
PKG2MODULE,
|
||||
PKG2PROJECT,
|
||||
RAW_GITHUB_URL,
|
||||
|
@ -90,4 +91,5 @@ __all__ = [
|
|||
'parse_home_page',
|
||||
'ensure_installation',
|
||||
'rich_progress_bar',
|
||||
'OFFICIAL_MODULES',
|
||||
]
|
||||
|
|
|
@ -13,6 +13,13 @@ WHEEL_URL = {
|
|||
'{torch_version}/index.html',
|
||||
}
|
||||
RAW_GITHUB_URL = 'https://raw.githubusercontent.com/{owner}/{repo}/{branch}'
|
||||
|
||||
OFFICIAL_MODULES = [
|
||||
'mmcls', 'mmdet', 'mmdet3d', 'mmseg', 'mmaction2', 'mmtrack', 'mmpose',
|
||||
'mmedit', 'mmocr', 'mmgen', 'mmselfsup', 'mmrotate', 'mmflow', 'mmyolo',
|
||||
'mmpretrain', 'mmagic'
|
||||
]
|
||||
|
||||
PKG2PROJECT = {
|
||||
'mmcv-full': 'mmcv',
|
||||
'mmcls': 'mmclassification',
|
||||
|
|
Loading…
Reference in New Issue