[Experimental] Packaging a minimal and flatten algorithm from downstream repos (#229)

pull/234/head
Guoping Pan 2023-10-31 11:08:16 +08:00 committed by GitHub
parent 5d54183ef9
commit c9a5ea6a88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 2666 additions and 16 deletions

View File

@ -0,0 +1,99 @@
# Export (experimental)
`mim export` is a new feature in `mim`, which can export a minimum trainable and testable model package through the config file.
The minimal model compactly combines `model, datasets, engine` and other components. There is no need to search for each module separately based on the config file. Users who do not install from source can also get the model file at a glance.
In addition, lengthy inheritance relationships, such as `CondInstBboxHead -> FCOSHead -> AnchorFreeHead -> BaseDenseHead -> BaseModule` in `mmdetection`, will be directly flattened into `CondInstBboxHead -> BaseModule`. There is no need to open and compare multiple model files when all functions inherited from the parent class are clearly visible.
### Instructions Usage
```bash
mim export config_path
# config_path has the following two optional types:
# 1. config of downstream repo: Complete the call through repo::model_name/xxx.py. For example:
mim export mmdet::configs/mask_rcnn/mask-rcnn_r101_fpn_1x_coco.py
# 2. config in a certain folder, for example:
mim export config_dir/mask-rcnn_r101_fpn_1x_coco.py
```
### Minimum model package directory structure
```
minimun_package(Named as pack_from_{repo}_20231212_121212)
|- pack
| |- configs # Configuration folder
| | |- model_name
| | |- xxx.py # Configuration file
| |
| |- models # model folder
| | |- model_file.py
| | |- ...
| |
| |- data # data folder
| |
| |- demo # demo folder
| |
| |-datasets #Dataset class definition
| | |- transforms
| |
| |- registry.py # Registrar
|
|- tools
| |- train.py # training
| |- test.py # test
|
```
### limit
Currently, `mim export` only supports some config files of `mmpose`, `mmdetection`, `mmagic`, and `mmsegmentation` Besides, there are some constraints on the downstream repo.
#### For downstream repos
1. It is best to name the config without special symbols, otherwise it cannot be parsed through `mmengine.hub.get_config()`, such as:
- gn+ws/faster-rcnn_r101_fpn_gn-ws-all_1x_coco.py
- legacy_1.x/cascade-mask-rcnn_r50_fpn_1x_coco_v1.py
2. For `mmsegmentation`, before using `mim.export` for config in `mmseg`, you should firstly modify it like `mmseg/registry/registry.py -> mmseg/registry.py`, without a directory to wrap `registry.py`
3. It is recommended that the downstream Registry name inherited from mmengine should not be changed. For example, mmagic renamed `EVALUATOR` to `EVALUATORS`
```python
from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR
# Evaluators to define the evaluation process.
EVALUATORS = Registry(
'evaluator',
parent=MMENGINE_EVALUATOR,
locations=['mmagic.evaluation'],
)
```
4. In addition, if you add a register that is not in mmengine, such as `DIFFUSION_SCHEDULERS` in mmagic, you need to add a key-value pair in `REGISTRY_TYPE` in `mim/_internal/export/common.py` for registering `torch `Module to `DIFFUSION_SCHEDULERS`
```python
# "mmagic/mmagic/registry.py"
# modules for diffusion models that support adding noise and denoising
DIFFUSION_SCHEDULERS = Registry(
'diffusion scheduler',
locations=['mmagic.models.diffusion_schedulers'],
)
# "mim/utils/mmpack/common.py"
REGISTRY_TYPE = {
...
'diffusion scheduler': 'DIFFUSION_SCHEDULERS',
...
}
```
#### Need to be improve
1. Currently, the inheritance relationship expansion of double parent classes is not supported. Improvements will be made in the future depending on the needs.
2. Here comes ERROR when the origin config contains other files' `path` that can not be found in the current path. You can avoid export errors by manually modifying the `path` in the origin config.
3. When isinstance() is used, if the parent class is just a class in the inheritance chain, the judgment may be False after expansion, because the original inheritance relationship will not be retained.

View File

@ -0,0 +1,100 @@
# Export (experimental)
`mim export``mim` 里面一个新的功能,可以实现通过 config 文件,就能够导出一个最小可训练测试的模型包。
最小模型将 `model、datasets、engine`等组件紧凑地组合在一起,不用再根据 config 文件单独寻找每一个模块,对于非源码安装的用户也能够获取到一目了然的模型文件。
此外,对于模型中冗长的继承关系,如 `mmdetection` 中的 `CondInstBboxHead -> FCOSHead -> AnchorFreeHead -> BaseDenseHead -> BaseModule`,将会被直接展平为 `CondInstBboxHead -> BaseModule`,即无需再在多个模型文件之间跳转比较了,所有继承父类的函数一览无遗。
### 使用方法
```bash
mim export config_path
# config_path 有以下两种可选类型:
# 下游 repo 的 config通过 repo::configs/xxx.py 来完成调用。例如:
mim export mmdet::configs/mask_rcnn/mask-rcnn_r101_fpn_1x_coco.py
# 在某文件夹下的 config, 例如:
mim export config_dir/mask-rcnn_r101_fpn_1x_coco.py
```
### 最小模型包目录结构
```
minimun_package(Named as pack_from_{repo}_20231212_121212)
|- pack
| |- configs # 配置文件夹
| | |- model_name
| | |- xxx.py # 配置文件
| |
| |- models # 模型文件夹
| | |- model_file.py
| | |- ...
| |
| |- data # 数据文件夹
| |
| |- demo # demo文件夹
| |
| |- datasets # 数据集类定义
| | |- transforms
| |
| |- registry.py # 注册器
|
|- tools
| |- train.py # 训练
| |- test.py # 测试
|
```
### 限制
`mim export` 目前只支持 `mmpose`、`mmdetection`、`mmagic` 和 `mmsegmentation` 的部分 config 配置文件,并且对下游算法库有一些约束。
#### 针对下游库
1. config 命名最好**不要有特殊符号**,否则无法通过 `mmengine.hub.get_config()` 进行解析,如:
- gn+ws/faster-rcnn_r101_fpn_gn-ws-all_1x_coco.py
- legacy_1.x/cascade-mask-rcnn_r50_fpn_1x_coco_v1.py
2. 针对 `mmsegmentation`, 在使用 `mim.export` 导出 `mmseg` 的 config 之前, 首先需要去掉对于 `registry.py` 的外层文件夹封装, 即修改 `mmseg/registry/registry.py -> mmseg/registry.py`
3. 建议下游继承于 mmengine 的 Registry 名字不要改动,如 mmagic 中就将 `EVALUATOR` 重新命名为了 `EVALUATORS`
```python
from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR
# Evaluators to define the evaluation process.
EVALUATORS = Registry(
'evaluator',
parent=MMENGINE_EVALUATOR,
locations=['mmagic.evaluation'],
)
```
4. 另外,如果添加了 mmengine 中没有的注册器,如 mmagic 中的 `DIFFUSION_SCHEDULERS`,需要在 `mim/_internal/export/common.py``REGISTRY_TYPE` 中添加键值对,用于注册 `torch` 模块到 `DIFFUSION_SCHEDULERS`
```python
# "mmagic/mmagic/registry.py"
# modules for diffusion models that support adding noise and denoising
DIFFUSION_SCHEDULERS = Registry(
'diffusion scheduler',
locations=['mmagic.models.diffusion_schedulers'],
)
# "mim/utils/mmpack/common.py"
REGISTRY_TYPE = {
...
'diffusion scheduler': 'DIFFUSION_SCHEDULERS',
...
}
```
#### 对于 `mim.export` 功能需要改进的地方
1. 目前还不支持双父类的继承关系展开,后续看需求进行改进
2. 对于用到 `isinstance()` 时,如果父类只是继承链中某个类,可能展开后判断就会为 False因为并不会保留原有的继承关系
3. 当 config 文件中含有当前文件夹没法被访问到的`数据集路径`,导出可能会失败。目前的临时解决方法是:将原来的 config 文件保存到当前文件夹下,然后需要用户手动修改`数据集路径`为当前路径下的可访问路径。如:`data/ADEChallengeData2016/ -> your_data_dir/ADEChallengeData2016/`

View File

@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -0,0 +1,75 @@
# Copyright (c) OpenMMLab. All rights reserved.
# initialization template for __init__.py
_init_str = """
import os
import os.path as osp
all_files = os.listdir(osp.dirname(__file__))
for file in all_files:
if (file.endswith('.py') and file != '__init__.py') or '.' not in file:
exec(f'from .{osp.splitext(file)[0]} import *')
"""
# import pack path for tools
_import_pack_str = """
import os.path as osp
import sys
sys.path.append(osp.dirname(osp.dirname(__file__)))
import pack
"""
OBJECTS_TO_BE_PATCHED = {
'MODELS': [
'BACKBONES',
'NECKS',
'HEADS',
'LOSSES',
'SEGMENTORS',
'build_backbone',
'build_neck',
'build_head',
'build_loss',
'build_segmentor',
'CLASSIFIERS',
'RETRIEVER',
'build_classifier',
'build_retriever',
'POSE_ESTIMATORS',
'build_pose_estimator',
'build_posenet',
],
'TASK_UTILS': [
'PIXEL_SAMPLERS',
'build_pixel_sampler',
]
}
REGISTRY_TYPES = {
'runner': 'RUNNERS',
'runner constructor': 'RUNNER_CONSTRUCTORS',
'hook': 'HOOKS',
'strategy': 'STRATEGIES',
'dataset': 'DATASETS',
'data sampler': 'DATA_SAMPLERS',
'transform': 'TRANSFORMS',
'model': 'MODELS',
'model wrapper': 'MODEL_WRAPPERS',
'weight initializer': 'WEIGHT_INITIALIZERS',
'optimizer': 'OPTIMIZERS',
'optimizer wrapper': 'OPTIM_WRAPPERS',
'optimizer wrapper constructor': 'OPTIM_WRAPPER_CONSTRUCTORS',
'parameter scheduler': 'PARAM_SCHEDULERS',
'param scheduler': 'PARAM_SCHEDULERS',
'metric': 'METRICS',
'evaluator': 'EVALUATOR', # TODO EVALUATORS in mmagic
'task utils': 'TASK_UTILS',
'loop': 'LOOPS',
'visualizer': 'VISUALIZERS',
'vis_backend': 'VISBACKENDS',
'log processor': 'LOG_PROCESSORS',
'inferencer': 'INFERENCERS',
'function': 'FUNCTIONS',
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,312 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import os.path as osp
import shutil
import signal
import sys
import tempfile
import traceback
from datetime import datetime
from typing import Optional, Union
from mmengine import MMLogger
from mmengine.config import Config, ConfigDict
from mmengine.hub import get_config
from mmengine.logging import print_log
from mmengine.registry import init_default_scope
from mmengine.runner import Runner
from mmengine.utils import get_installed_path, mkdir_or_exist
from mim.utils import echo_error
from .common import _import_pack_str, _init_str
from .utils import (
_get_all_files,
_postprocess_importfrom_module_to_pack,
_postprocess_registry_locations,
_replace_config_scope_to_pack,
_wrapper_all_registries_build_func,
)
def export_from_cfg(cfg: Union[str, ConfigDict],
export_root_dir: str,
model_only: Optional[bool] = False,
save_log: Optional[bool] = False):
"""A function to pack the minimum available package according to config
file.
Args:
cfg (:obj:`ConfigDict` or str): Config file for packing the
minimum package.
export_root_dir (str, optional): The pack directory to save the
packed package.
fast_test (bool, optional): Turn to fast testing mode.
Defaults to False.
"""
# generate temp dir for export
export_root_tempdir = tempfile.TemporaryDirectory()
export_log_tempdir = tempfile.TemporaryDirectory()
# delete the incomplete export package when keyboard interrupt
signal.signal(signal.SIGINT, lambda sig, frame: keyboardinterupt_handler(
sig, frame, export_root_tempdir, export_log_tempdir)
) # type: ignore[arg-type]
# get config
if isinstance(cfg, str):
if '::' in cfg:
cfg = get_config(cfg)
else:
cfg = Config.fromfile(cfg)
default_scope = cfg.get('default_scope', 'mmengine')
# automatically generate ``export_root_dir``
if export_root_dir is None:
export_root_dir = f'pack_from_{default_scope}_' + \
f"{datetime.now().strftime(r'%Y%m%d_%H%M%S')}"
# generate ``export_log_dir``
if osp.sep in export_root_dir:
export_path = osp.dirname(export_root_dir)
else:
export_path = os.getcwd()
export_log_dir = osp.join(export_path, 'export_log')
export_logger = MMLogger.get_instance( # noqa: F841
'export',
log_file=osp.join(export_log_tempdir.name, 'export.log'))
export_module_tempdir_name = osp.join(export_root_tempdir.name, 'pack')
# export config
if '.mim' in cfg.filename:
cfg_path = osp.join(export_module_tempdir_name,
cfg.filename[cfg.filename.find('configs'):])
else:
cfg_path = osp.join(
osp.join(export_module_tempdir_name, 'configs'),
osp.basename(cfg.filename))
mkdir_or_exist(osp.dirname(cfg_path))
# transform to default_scope
init_default_scope(default_scope)
# wrap ``Registry.build()`` for exporting modules
_wrapper_all_registries_build_func(
export_module_dir=export_module_tempdir_name, scope=default_scope)
print_log(
f'[ Export Package Name ]: {export_root_dir}\n'
f' package from config: {cfg.filename}\n'
f" from downstream package: '{default_scope}'\n",
logger='export',
level=logging.INFO)
# creat temp work_dirs for export
cfg['work_dir'] = export_log_tempdir.name
# use runner to export all needed modules
runner = Runner.from_cfg(cfg)
# HARD CODE: In order to deal with some module will build in
# ``before_run`` or ``after_run``, we can call them without need
# to call "runner.train())".
# Example:
# >>> @HOOKS.register_module()
# >>> class EMAHook(Hook):
# >>> ...
# >>> def before_run(self, runner) -> None:
# >>> """Create an ema copy of the model.
# >>> Args:
# >>> runner (Runner): The runner of the training process.
# >>> """
# >>> model = runner.model
# >>> if is_model_wrapper(model):
# >>> model = model.module
# >>> self.src_model = model
# >>> self.ema_model = MODELS.build(
# >>> self.ema_cfg, default_args=dict(model=self.src_model))
# It need to build ``self.ema_model`` in ``before_run``.
for hook in runner.hooks:
hook.before_run(runner)
hook.after_run(runner)
def dump():
cfg['work_dir'] = 'work_dirs' # recover to default.
_replace_config_scope_to_pack(cfg)
cfg.dump(cfg_path)
# copy temp log to export log
if save_log:
shutil.copytree(
export_log_tempdir.name, export_log_dir, dirs_exist_ok=True)
export_log_tempdir.cleanup()
# copy temp_package_dir to export_package_dir
shutil.copytree(
export_root_tempdir.name, export_root_dir, dirs_exist_ok=True)
export_root_tempdir.cleanup()
print_log(
f'[ Export Package Name ]: '
f'{osp.join(os.getcwd(), export_root_dir)}\n',
logger='export',
level=logging.INFO)
if model_only:
dump()
return 0
try:
runner.build_train_loop(cfg.train_cfg)
except FileNotFoundError:
error_postprocess(export_log_dir, default_scope,
export_root_tempdir, export_log_tempdir,
osp.basename(cfg_path), 'train_dataloader')
try:
if 'val_cfg' in cfg and cfg.val_cfg is not None:
runner.build_val_loop(cfg.val_cfg)
except FileNotFoundError:
error_postprocess(export_log_dir, default_scope,
export_root_tempdir, export_log_tempdir,
osp.basename(cfg_path), 'val_dataloader')
try:
if 'test_cfg' in cfg and cfg.test_cfg is not None:
runner.build_test_loop(cfg.test_cfg)
except FileNotFoundError:
error_postprocess(export_log_dir, default_scope,
export_root_tempdir, export_log_tempdir,
osp.basename(cfg_path), 'test_dataloader')
if 'optim_wrapper' in cfg and cfg.optim_wrapper is not None:
runner.optim_wrapper = runner.build_optim_wrapper(cfg.optim_wrapper)
if 'param_scheduler' in cfg and cfg.param_scheduler is not None:
runner.build_param_scheduler(cfg.param_scheduler)
# add ``__init__.py`` to all dirs, for transferring directories
# to be modules
for directory, _, _ in os.walk(export_module_tempdir_name):
if not osp.exists(osp.join(directory, '__init__.py')) \
and 'configs' not in directory:
with open(osp.join(directory, '__init__.py'), 'w') as f:
f.write(_init_str)
# postprocess for ``pack/registry.py``
_postprocess_registry_locations(export_root_tempdir.name)
# postprocess for ImportFrom Node, turn to import from export path
all_export_files = _get_all_files(export_module_tempdir_name)
for file in all_export_files:
_postprocess_importfrom_module_to_pack(file)
# get tools from web
tools_dir = osp.join(export_root_tempdir.name, 'tools')
mkdir_or_exist(tools_dir)
for tool_name in [
'train.py', 'test.py', 'dist_train.sh', 'dist_test.sh',
'slurm_train.sh', 'slurm_test.sh'
]:
pack_tools(
tool_name=tool_name,
scope=default_scope,
tool_dir=tools_dir,
auto_import=True)
# TODO: get demo.py
dump()
return 0
def keyboardinterupt_handler(
sig: int,
frame,
export_root_tempdir: tempfile.TemporaryDirectory,
export_log_tempdir: tempfile.TemporaryDirectory,
):
"""Clear uncompleted exported package by interrupting with keyboard."""
export_log_tempdir.cleanup()
export_root_tempdir.cleanup()
sys.exit(-1)
def error_postprocess(export_log_dir: str, scope: str,
export_root_dir_tempfile: tempfile.TemporaryDirectory,
export_log_dir_tempfile: tempfile.TemporaryDirectory,
cfg_name: str, error_key: str):
"""Print Debug message when package can't successfully export for missing
datasets.
Args:
export_root_dir (str): _description_
absolute_cfg_path (str): _description_
origin_cfg (ConfigDict): _description_
error_key (str): _description_
logger (_type_): _description_
"""
shutil.copytree(
export_log_dir_tempfile.name, export_log_dir, dirs_exist_ok=True)
export_root_dir_tempfile.cleanup()
export_log_dir_tempfile.cleanup()
traceback.print_exc()
error_msg = f"{'=' * 20} Debug Message {'=' * 20}"\
f"\nThe data root of '{error_key}' is not found. You can"\
' use the below two method to deal with.\n\n'\
" >>> Method 1: Please modify the 'data_root' in"\
f" duplicate config '{export_log_dir}/{cfg_name}'.\n"\
" >>> Method 2: Use '--model_only' to export model only.\n\n"\
"After finishing one of the above steps, you can use 'mim export"\
f" {scope} {export_log_dir}/{cfg_name} [--model-only]' to export"\
' again.'
echo_error(error_msg)
sys.exit(-1)
def pack_tools(tool_name: str,
scope: str,
tool_dir: str,
auto_import: Optional[bool] = False):
"""pack tools from installed repo.
Args:
tool_name (str): Tool name in repos' tool dir.
scope (str): The scope of repo.
path (str): Path to save tool.
auto_import (bool, optional): Automatically add "import pack" to the
tool file. Defaults to "False"
"""
pkg_root = get_installed_path(scope)
path = osp.join(tool_dir, tool_name)
if os.path.exists(path):
os.remove(path)
# tools will be put in package/.mim in PR #68
tool_script = osp.join(pkg_root, '.mim', 'tools', tool_name)
if not osp.exists(tool_script):
tool_script = osp.join(pkg_root, 'tools', tool_name)
shutil.copy(tool_script, path)
# automatically import the pack modules
if auto_import:
with open(path, 'r+') as f:
lines = f.readlines()
code = ''.join(lines[:1] + [_import_pack_str] + lines[1:])
f.seek(0)
f.write(code)

View File

@ -0,0 +1,75 @@
# Patch Utils
## Problem
This patch is mainly to solve the problem that the module cannot be properly registered due to the renaming of the registry in the downstream repo, such as an example of the `mmsegmentation`:
```python
# "mmsegmentation/mmseg/structures/sampler/builder.py"
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from mmseg.registry import TASK_UTILS
PIXEL_SAMPLERS = TASK_UTILS
def build_pixel_sampler(cfg, **default_args):
"""Build pixel sampler for segmentation map."""
warnings.warn(
'``build_pixel_sampler`` would be deprecated soon, please use '
'``mmseg.registry.TASK_UTILS.build()`` ')
return TASK_UTILS.build(cfg, default_args=default_args)
```
Some modules may use the renamed registry, which makes it difficult for `mim export` to find the original name of the renamed modules.
```python
# "mmsegmentation/mmseg/structures/sampler/ohem_pixel_sampler.py"
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_pixel_sampler import BasePixelSampler
from .builder import PIXEL_SAMPLERS
@PIXEL_SAMPLERS.register_module()
class OHEMPixelSampler(BasePixelSampler):
...
```
## Solution
Therefore, we have currently migrated the necessary modules in `mmpose/mmdetection/mmseg/mmpretrain` listed below, directly to `patch_utils.patch_model` and `patch_utils.patch_task`. In order to build a patch containing renamed registry and special module constructor functions.
```python
"mmdetection/mmdet/models/task_modules/builder.py"
"mmdetection/build/lib/mmdet/models/task_modules/builder.py"
"mmsegmentation/mmseg/models/builder.py"
"mmsegmentation/mmseg/structures/sampler/builder.py"
"mmsegmentation/build/lib/mmseg/models/builder.py"
"mmsegmentation/build/lib/mmseg/structures/sampler/builder.py"
"mmpretrain/mmpretrain/datasets/builder.py"
"mmpretrain/mmpretrain/models/builder.py"
"mmpretrain/build/lib/mmpretrain/datasets/builder.py"
"mmpretrain/build/lib/mmpretrain/models/builder.py"
"mmpose/mmpose/datasets/builder.py"
"mmpose/mmpose/models/builder.py"
"mmpose/build/lib/mmpose/datasets/builder.py"
"mmpose/build/lib/mmpose/models/builder.py"
"mmengine/mmengine/evaluator/builder.py"
"mmengine/mmengine/model/builder.py"
"mmengine/mmengine/optim/optimizer/builder.py"
"mmengine/mmengine/visualization/builder.py"
"mmengine/build/lib/mmengine/evaluator/builder.py"
"mmengine/build/lib/mmengine/model/builder.py"
"mmengine/build/lib/mmengine/optim/optimizer/builder.py"
```

View File

@ -0,0 +1,75 @@
# Patch Utils
## 问题
该补丁主要是为了解决下游 repo 中存在对注册器进行重命名导致模块无法被正确注册的问题,如 `mmsegmentation` 中的一个例子:
```python
# "mmsegmentation/mmseg/structures/sampler/builder.py"
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from mmseg.registry import TASK_UTILS
PIXEL_SAMPLERS = TASK_UTILS
def build_pixel_sampler(cfg, **default_args):
"""Build pixel sampler for segmentation map."""
warnings.warn(
'``build_pixel_sampler`` would be deprecated soon, please use '
'``mmseg.registry.TASK_UTILS.build()`` ')
return TASK_UTILS.build(cfg, default_args=default_args)
```
然后在某些模块中可能会使用改名后的注册器,这对于导出后处理很难找到重命名模块原来的名字
```python
# "mmsegmentation/mmseg/structures/sampler/ohem_pixel_sampler.py"
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_pixel_sampler import BasePixelSampler
from .builder import PIXEL_SAMPLERS
@PIXEL_SAMPLERS.register_module()
class OHEMPixelSampler(BasePixelSampler):
...
```
## 解决方案
因此我们目前已经将如下 `mmpose / mmdetection / mmseg / mmpretrain` 中必要的模块直接迁移到 `patch_utils.patch_model``patch_utils.patch_task` 中构建一个包含注册器重命名和特殊模块构造函数的补丁。
```python
"mmdetection/mmdet/models/task_modules/builder.py"
"mmdetection/build/lib/mmdet/models/task_modules/builder.py"
"mmsegmentation/mmseg/models/builder.py"
"mmsegmentation/mmseg/structures/sampler/builder.py"
"mmsegmentation/build/lib/mmseg/models/builder.py"
"mmsegmentation/build/lib/mmseg/structures/sampler/builder.py"
"mmpretrain/mmpretrain/datasets/builder.py"
"mmpretrain/mmpretrain/models/builder.py"
"mmpretrain/build/lib/mmpretrain/datasets/builder.py"
"mmpretrain/build/lib/mmpretrain/models/builder.py"
"mmpose/mmpose/datasets/builder.py"
"mmpose/mmpose/models/builder.py"
"mmpose/build/lib/mmpose/datasets/builder.py"
"mmpose/build/lib/mmpose/models/builder.py"
"mmengine/mmengine/evaluator/builder.py"
"mmengine/mmengine/model/builder.py"
"mmengine/mmengine/optim/optimizer/builder.py"
"mmengine/mmengine/visualization/builder.py"
"mmengine/build/lib/mmengine/evaluator/builder.py"
"mmengine/build/lib/mmengine/model/builder.py"
"mmengine/build/lib/mmengine/optim/optimizer/builder.py"
```

View File

@ -0,0 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .patch_model import * # noqa: F401, F403
from .patch_task import * # noqa: F401, F403

View File

@ -0,0 +1,82 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from mmengine.registry import MODELS
BACKBONES = MODELS
NECKS = MODELS
HEADS = MODELS
LOSSES = MODELS
SEGMENTORS = MODELS
def build_backbone(cfg):
"""Build backbone."""
warnings.warn('``build_backbone`` would be deprecated soon, please use '
'``mmseg.registry.MODELS.build()`` ')
return BACKBONES.build(cfg)
def build_neck(cfg):
"""Build neck."""
warnings.warn('``build_neck`` would be deprecated soon, please use '
'``mmseg.registry.MODELS.build()`` ')
return NECKS.build(cfg)
def build_head(cfg):
"""Build head."""
warnings.warn('``build_head`` would be deprecated soon, please use '
'``mmseg.registry.MODELS.build()`` ')
return HEADS.build(cfg)
def build_loss(cfg):
"""Build loss."""
warnings.warn('``build_loss`` would be deprecated soon, please use '
'``mmseg.registry.MODELS.build()`` ')
return LOSSES.build(cfg)
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
"""Build segmentor."""
if train_cfg is not None or test_cfg is not None:
warnings.warn(
'train_cfg and test_cfg is deprecated, '
'please specify them in model', UserWarning)
assert cfg.get('train_cfg') is None or train_cfg is None, \
'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field '
return SEGMENTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
CLASSIFIERS = MODELS
RETRIEVER = MODELS
def build_classifier(cfg):
"""Build classifier."""
return CLASSIFIERS.build(cfg)
def build_retriever(cfg):
"""Build retriever."""
return RETRIEVER.build(cfg)
POSE_ESTIMATORS = MODELS
def build_pose_estimator(cfg):
"""Build pose estimator."""
return POSE_ESTIMATORS.build(cfg)
def build_posenet(cfg):
"""Build posenet."""
warnings.warn(
'``build_posenet`` will be deprecated soon, '
'please use ``build_pose_estimator`` instead.', DeprecationWarning)
return build_pose_estimator(cfg)

View File

@ -0,0 +1,73 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from mmengine.registry import TASK_UTILS
PRIOR_GENERATORS = TASK_UTILS
ANCHOR_GENERATORS = TASK_UTILS
BBOX_ASSIGNERS = TASK_UTILS
BBOX_SAMPLERS = TASK_UTILS
BBOX_CODERS = TASK_UTILS
MATCH_COSTS = TASK_UTILS
IOU_CALCULATORS = TASK_UTILS
def build_bbox_coder(cfg, **default_args):
"""Builder of box coder."""
warnings.warn('``build_sampler`` would be deprecated soon, please use '
'``mmdet.registry.TASK_UTILS.build()`` ')
return TASK_UTILS.build(cfg, default_args=default_args)
def build_iou_calculator(cfg, default_args=None):
"""Builder of IoU calculator."""
warnings.warn(
'``build_iou_calculator`` would be deprecated soon, please use '
'``mmdet.registry.TASK_UTILS.build()`` ')
return TASK_UTILS.build(cfg, default_args=default_args)
def build_match_cost(cfg, default_args=None):
"""Builder of IoU calculator."""
warnings.warn('``build_match_cost`` would be deprecated soon, please use '
'``mmdet.registry.TASK_UTILS.build()`` ')
return TASK_UTILS.build(cfg, default_args=default_args)
def build_assigner(cfg, **default_args):
"""Builder of box assigner."""
warnings.warn('``build_assigner`` would be deprecated soon, please use '
'``mmdet.registry.TASK_UTILS.build()`` ')
return TASK_UTILS.build(cfg, default_args=default_args)
def build_sampler(cfg, **default_args):
"""Builder of box sampler."""
warnings.warn('``build_sampler`` would be deprecated soon, please use '
'``mmdet.registry.TASK_UTILS.build()`` ')
return TASK_UTILS.build(cfg, default_args=default_args)
def build_prior_generator(cfg, default_args=None):
warnings.warn(
'``build_prior_generator`` would be deprecated soon, please use '
'``mmdet.registry.TASK_UTILS.build()`` ')
return TASK_UTILS.build(cfg, default_args=default_args)
def build_anchor_generator(cfg, default_args=None):
warnings.warn(
'``build_anchor_generator`` would be deprecated soon, please use '
'``mmdet.registry.TASK_UTILS.build()`` ')
return TASK_UTILS.build(cfg, default_args=default_args)
PIXEL_SAMPLERS = TASK_UTILS
def build_pixel_sampler(cfg, **default_args):
"""Build pixel sampler for segmentation map."""
warnings.warn(
'``build_pixel_sampler`` would be deprecated soon, please use '
'``mmseg.registry.TASK_UTILS.build()`` ')
return TASK_UTILS.build(cfg, default_args=default_args)

View File

@ -0,0 +1,555 @@
# Copyright (c) OpenMMLab. All rights reserved.
import ast
import copy
import importlib
import inspect
import logging
import os
import os.path as osp
import shutil
from functools import wraps
from typing import Callable
import torch.nn
from mmengine import mkdir_or_exist
from mmengine.config import ConfigDict
from mmengine.logging import print_log
from mmengine.model import (
BaseDataPreprocessor,
BaseModel,
BaseModule,
ImgDataPreprocessor,
)
from mmengine.registry import Registry
from yapf.yapflib.yapf_api import FormatCode
from mim.utils import OFFICIAL_MODULES
from .common import REGISTRY_TYPES
from .flatten_func import * # noqa: F403, F401
from .flatten_func import (
ImportResolverTransformer,
RegisterModuleTransformer,
flatten_inheritance_chain,
ignore_ast_docstring,
postprocess_super_call,
)
def format_code(code_text: str):
"""Format the code text with yapf."""
yapf_style = dict(
based_on_style='pep8',
blank_line_before_nested_class_or_def=True,
split_before_expression_after_opening_paren=True)
try:
code_text, _ = FormatCode(code_text, style_config=yapf_style)
except: # noqa: E722
raise SyntaxError('Failed to format the config file, please '
f'check the syntax of: \n{code_text}')
return code_text
def _postprocess_registry_locations(export_root_dir: str):
"""Remove the Registry.locations if it doesn't exist.
Check the location path for Registry to load modules if the path hasn't
been exported, then need to be removed. Finally will use the root Registry
to find module until it actually doesn't exist.
"""
export_module_dir = osp.join(export_root_dir, 'pack')
with open(
osp.join(export_module_dir, 'registry.py'), encoding='utf-8') as f:
ast_tree = ast.parse(f.read())
for node in ast.walk(ast_tree):
"""node structure.
Assign( targets=[ Name(id='EVALUATORS', ctx=Store())],
value=Call( func=Name(id='Registry', ctx=Load()), args=[
Constant(value='evaluator')], keywords=[ keyword( arg='parent',
value=Name(id='MMENGINE_EVALUATOR', ctx=Load())), keyword(
arg='locations', value=List( elts=[
Constant(value='pack.evaluation')], ctx=Load()))])),
"""
if isinstance(node, ast.Call):
need_to_be_remove = None
for keyword in node.keywords:
if keyword.arg == 'locations':
for sub_node in ast.walk(keyword):
# the locations of Registry already transfer to `pack`
# scope before. if the location path is exist, then
# turn to pack scope
if isinstance(
sub_node,
ast.Constant) and 'pack' in sub_node.value:
path = sub_node.value
if not osp.exists(
osp.join(export_root_dir, path).replace(
'.', osp.sep)):
print_log(
'[ Pass ] Remove Registry.locations '
f"'{osp.join(export_root_dir, path).replace('.',osp.sep)}', " # noqa: E501
'which is no need to export.',
logger='export',
level=logging.DEBUG)
need_to_be_remove = keyword
break
if need_to_be_remove is not None:
break
if need_to_be_remove is not None:
node.keywords.remove(need_to_be_remove)
with open(
osp.join(export_module_dir, 'registry.py'), 'w',
encoding='utf-8') as f:
f.write(format_code(ast.unparse(ast_tree)))
def _get_all_files(directory: str):
"""Get all files of the directory.
Args:
directory (str): The directory path.
Returns:
List: Return the a list containing all the files in the directory.
"""
file_paths = []
for root, dirs, files in os.walk(directory):
for file in files:
if '__init__' not in file and 'registry.py' not in file:
file_paths.append(os.path.join(root, file))
return file_paths
def _postprocess_importfrom_module_to_pack(file_path: str):
"""Transfer the importfrom path from "downstream repo" to export module.
Args:
file_path (str): The path of file needed to be transfer.
Examples:
>>> from mmdet.models.detectors.two_stage import TwoStageDetector
>>> # transfer to below, if "TwoStageDetector" had been exported
>>> from pack.models.detectors.two_stage import TwoStageDetector
"""
from mmengine import Registry
# _module_path_dict is a class attribute,
# already record all the exported module and their path before
_module_path_dict = Registry._module_path_dict
with open(file_path, encoding='utf-8') as f:
ast_tree = ast.parse(f.read())
# if the import module have the same name with the object in these file,
# they import path won't be change
can_not_change_module = []
for node in ast_tree.body:
if isinstance(node, ast.FunctionDef) or isinstance(node, ast.ClassDef):
can_not_change_module.append(node.name)
def check_change_importfrom_node(node: ast.ImportFrom):
"""Check if the ImportFrom node should be changed.
If the modules in node had already been exported, they will be
separated and compose a new ast.ImportFrom node with the export
path as the module path.
Args:
node (ast.ImportFrom): ImportFrom node.
Returns:
ast.ImportFrom | None: Return a new ast.ImportFrom node
if one of the module in node had been export else ``None``.
"""
export_module_path = None
needed_change_alias = []
for alias in node.names:
# if the import module's name is equal to the class or function
# name, it can not be transfer for avoiding circular import.
if alias.name in _module_path_dict.keys(
) and alias.name not in can_not_change_module:
if export_module_path is None:
export_module_path = _module_path_dict[alias.name]
else:
assert _module_path_dict[alias.name] == \
export_module_path,\
'There are two module from the same downstream repo,'\
" but can't change to the same export path."
needed_change_alias.append(alias)
if len(needed_change_alias) != 0:
for alias in needed_change_alias:
node.names.remove(alias)
return ast.ImportFrom(
module=export_module_path, names=needed_change_alias, level=0)
return None
# Naming rules for searching ast syntax tree
# - node: node of ast.Module
# - func_sub_node: sub_node of ast.FunctionDef
# - class_sub_node: sub_node of ast.ClassDef
# - func_sub_class_sub_node: sub_node ast.FunctionDef in ast.ClassDef
# record the insert_idx and node needed to be insert for later insert.
insert_idx_and_node = {}
insert_idx = 0
for idx, node in enumerate(ast_tree.body):
# search ast.ImportFrom in ast.Module scope
# ast.Module -> ast.ImportFrom
if isinstance(node, ast.ImportFrom):
insert_idx += 1
temp_node = check_change_importfrom_node(node)
if temp_node is not None:
if len(node.names) == 0:
ast_tree.body[idx] = temp_node
else:
insert_idx_and_node[insert_idx] = temp_node
insert_idx += 1
elif isinstance(node, ast.Import):
insert_idx += 1
else:
# search ast.ImportFrom in ast.FunctionDef scope
# ast.Module -> ast.FunctionDef -> ast.ImportFrom
if isinstance(node, ast.FunctionDef):
temp_func_insert_idx = ignore_ast_docstring(node)
func_need_to_be_removed_nodes = []
for func_sub_node in node.body:
if isinstance(func_sub_node, ast.ImportFrom):
temp_node = check_change_importfrom_node(
func_sub_node) # noqa: E501
if temp_node is not None:
node.body.insert(temp_func_insert_idx, temp_node)
# if importfrom module is empty, the node should be remove # noqa: E501
if len(func_sub_node.names) == 0:
func_need_to_be_removed_nodes.append(
func_sub_node) # noqa: E501
for need_to_be_removed_node in func_need_to_be_removed_nodes:
node.body.remove(need_to_be_removed_node)
# search ast.ImportFrom in ast.ClassDef scope
# ast.Module -> ast.ClassDef -> ast.ImportFrom
# -> ast.FunctionDef -> ast.ImportFrom
elif isinstance(node, ast.ClassDef):
temp_class_insert_idx = ignore_ast_docstring(node)
class_need_to_be_removed_nodes = []
for class_sub_node in node.body:
# ast.Module -> ast.ClassDef -> ast.ImportFrom
if isinstance(class_sub_node, ast.ImportFrom):
temp_node = check_change_importfrom_node(
class_sub_node)
if temp_node is not None:
node.body.insert(temp_class_insert_idx, temp_node)
if len(class_sub_node.names) == 0:
class_need_to_be_removed_nodes.append(
class_sub_node)
# ast.Module -> ast.ClassDef -> ast.FunctionDef -> ast.ImportFrom # noqa: E501
elif isinstance(class_sub_node, ast.FunctionDef):
temp_class_sub_insert_idx = ignore_ast_docstring(node)
func_need_to_be_removed_nodes = []
for func_sub_class_sub_node in class_sub_node.body:
if isinstance(func_sub_class_sub_node,
ast.ImportFrom):
temp_node = check_change_importfrom_node(
func_sub_class_sub_node)
if temp_node is not None:
node.body.insert(temp_class_sub_insert_idx,
temp_node)
if len(func_sub_class_sub_node.names) == 0:
func_need_to_be_removed_nodes.append(
func_sub_class_sub_node)
for need_to_be_removed_node in func_need_to_be_removed_nodes: # noqa: E501
class_sub_node.body.remove(need_to_be_removed_node)
for class_need_to_be_removed_node in class_need_to_be_removed_nodes: # noqa: E501
node.body.remove(class_need_to_be_removed_node)
# lazy add new ast.ImportFrom node to ast.Module
for insert_idx, temp_node in insert_idx_and_node.items():
ast_tree.body.insert(insert_idx, temp_node)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(format_code(ast.unparse(ast_tree)))
def _replace_config_scope_to_pack(cfg: ConfigDict):
"""Replace the config scope from "mmxxx" to "pack".
Args:
cfg (ConfigDict): The config dict to be replaced.
"""
for key, value in cfg.items():
if key == '_scope_' or key == 'default_scope':
cfg[key] = 'pack'
elif isinstance(value, dict):
_replace_config_scope_to_pack(value)
def _wrapper_all_registries_build_func(export_module_dir: str, scope: str):
"""A function to wrap all registries' build_func.
Args:
pack_module_dir (str): The root dir for packing modules.
scope (str): The default scope of the config.
"""
# copy the downstream repo.registry to pack.registry
# and change all the registry.locations
repo_registries = importlib.import_module('.registry', scope)
origin_file = inspect.getfile(repo_registries)
registry_path = osp.join(export_module_dir, 'registry.py')
shutil.copy(origin_file, registry_path)
# replace 'repo' name in Registry.locations to 'pack'
with open(
osp.join(export_module_dir, 'registry.py'), encoding='utf-8') as f:
ast_tree = ast.parse(f.read())
for node in ast.walk(ast_tree):
if isinstance(node, ast.Constant):
if scope in node.value:
node.value = node.value.replace(scope, 'pack')
with open(osp.join(export_module_dir, 'registry.py'), 'w') as f:
f.write(format_code(ast.unparse(ast_tree)))
# prevent circular registration
Registry._extra_module_set = set()
# record the exported module for postprocessing the importfrom path
Registry._module_path_dict = {}
# prevent circular wrapper
if Registry.build.__name__ == 'wrapper':
Registry.build = _wrap_build(Registry.init_build_func,
export_module_dir)
Registry.get = _wrap_get(Registry.init_get_func, export_module_dir)
else:
Registry.init_build_func = copy.deepcopy(Registry.build)
Registry.init_get_func = copy.deepcopy(Registry.get)
Registry.build = _wrap_build(Registry.build, export_module_dir)
Registry.get = _wrap_get(Registry.get, export_module_dir)
def ignore_self_cache(func):
"""Ignore the ``@lru_cache`` for function.
Args:
func (Callable): The function to be ignored.
Returns:
Callable: The function without ``@lru_cache``.
"""
cache = {}
@wraps(func)
def wrapper(self, *args, **kwargs):
key = args
if key not in cache:
cache[key] = 1
func(self, *args, **kwargs)
else:
return
return wrapper
@ignore_self_cache
def _export_module(self, obj_cls: type, pack_module_dir, obj_type: str):
"""Export module.
This function will get the object's file and export to
``pack_module_dir``.
If the object is built by ``MODELS`` registry, all the objects
as the top classes in this file, will be iteratively flattened.
Else will be directly exported.
The flatten logic is:
1. get the origin file of object, which built
by ``MODELS.build()``
2. get all the classes in the origin file
3. flatten all the classes but not only the object
4. call ``flatten_module()`` to finish flatten
according to ``class.mro()``
Args:
obj (object): The object to be flatten.
"""
# find the file by obj class
file_path = inspect.getfile(obj_cls)
if osp.exists(file_path):
print_log(
f'building class: '
f'{obj_cls.__name__} from file: {file_path}.',
logger='export',
level=logging.DEBUG)
else:
raise FileExistsError(f"file [{file_path}] doesn't exist.")
# local origin module
module = obj_cls.__module__
parent = module.split('.')[0]
new_module = module.replace(parent, 'pack')
# Not necessary to export module implemented in `mmcv` and `mmengine`
if parent in set(OFFICIAL_MODULES) - {'mmcv', 'mmengine'}:
with open(file_path, encoding='utf-8') as f:
top_ast_tree = ast.parse(f.read())
# deal with relative import
ImportResolverTransformer(module).visit(top_ast_tree)
# NOTE: ``MODELS.build()`` means to flatten model module
if self.name == 'model':
# record all the class needed to be flattened
need_to_be_flattened_class_names = []
for node in top_ast_tree.body:
if isinstance(node, ast.ClassDef):
need_to_be_flattened_class_names.append(node.name)
imported_module = importlib.import_module(obj_cls.__module__)
for cls_name in need_to_be_flattened_class_names:
# record the exported module for postprocessing the importfrom path # noqa: E501
self._module_path_dict[cls_name] = new_module
cls = getattr(imported_module, cls_name)
for super_cls in cls.__bases__:
# the class only will be flattened when:
# 1. super class doesn't exist in this file
# 2. and super class is not base class
# 3. and super class is not torch module
if super_cls.__name__\
not in need_to_be_flattened_class_names \
and (super_cls not in [BaseModule,
BaseModel,
BaseDataPreprocessor,
ImgDataPreprocessor]) \
and 'torch' not in super_cls.__module__: # noqa: E501
print_log(
f'need_flatten: {cls_name}\n',
logger='export',
level=logging.INFO)
flatten_inheritance_chain(top_ast_tree, cls)
break
postprocess_super_call(top_ast_tree)
else:
self._module_path_dict[obj_cls.__name__] = new_module
# add ``register_module(force=True)`` to cover the registered modules # noqa: E501
RegisterModuleTransformer().visit(top_ast_tree)
# unparse ast tree and save reformat code
new_file_path = new_module.split('.', 1)[1].replace('.',
osp.sep) + '.py'
new_file_path = osp.join(pack_module_dir, new_file_path)
new_dir = osp.dirname(new_file_path)
mkdir_or_exist(new_dir)
with open(new_file_path, mode='w') as f:
f.write(format_code(ast.unparse(top_ast_tree)))
# Downstream repo could register torch module into Registry, such as
# registering `torch.nn.Linear` into `MODELS`. We need to reserve these
# codes in the exported module.
elif 'torch' in module.split('.')[0]:
# get the root registry, because it can get all the modules
# had been registered.
root_registry = self if self.parent is None else self.parent
if (obj_type not in self._extra_module_set) and (
root_registry.init_get_func(obj_type) is None):
self._extra_module_set.add(obj_type)
with open(osp.join(pack_module_dir, 'registry.py'), 'a') as f:
# TODO: When the downstream repo registries' name are
# different with mmengine, the module may not be registried
# to the right register.
# For example: `EVALUATOR` in mmengine, `EVALUATORS` in mmdet.
f.write('\n')
f.write(f'from {module} import {obj_cls.__name__}\n')
f.write(
f"{REGISTRY_TYPES[self.name]}.register_module('{obj_type}', module={obj_cls.__name__}, force=True)" # noqa: E501
)
def _wrap_build(build_func: Callable, pack_module_dir: str):
"""wrap Registry.build()
Args:
build_func (Callable): ``Registry.build()``, which will be wrapped.
pack_module_dir (str): Modules export path.
"""
def wrapper(self, cfg: dict, *args, **kwargs):
# obj is class instanace
obj = build_func(self, cfg, *args, **kwargs)
args = cfg.copy() # type: ignore
obj_type = args.pop('type') # type: ignore
obj_type = obj_type if isinstance(obj_type, str) else obj_type.__name__
# modules in ``torch.nn.Sequential`` should be respectively exported
if isinstance(obj, torch.nn.Sequential):
for children in obj.children():
_export_module(self, children.__class__, pack_module_dir,
obj_type)
else:
_export_module(self, obj.__class__, pack_module_dir, obj_type)
return obj
return wrapper
def _wrap_get(get_func: Callable, pack_module_dir: str):
"""wrap Registry.get()
Args:
get_func (Callable): ``Registry.get()``, which will be wrapped.
pack_module_dir (str): Modules export path.
"""
def wrapper(self, key: str):
obj_cls = get_func(self, key)
_export_module(self, obj_cls, pack_module_dir, key)
return obj_cls
return wrapper

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mim.commands.list import list_package
from mim.utils.default import OFFICIAL_MODULES
def get_installed_package(ctx, args, incomplete):
@ -19,19 +20,4 @@ def get_downstream_package(ctx, args, incomplete):
def get_official_package(ctx=None, args=None, incomplete=None):
return [
'mmcls',
'mmdet',
'mmdet3d',
'mmseg',
'mmaction2',
'mmtrack',
'mmpose',
'mmedit',
'mmocr',
'mmgen',
'mmselfsup'
'mmrotate',
'mmflow',
'mmyolo',
]
return OFFICIAL_MODULES

View File

@ -0,0 +1,124 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import sys
import click
from mmengine.config import Config
from mmengine.hub import get_config
from mim._internal.export.pack_cfg import export_from_cfg
from mim.click import CustomCommand
PYTHON = sys.executable
@click.command(
name='export',
context_settings=dict(ignore_unknown_options=True),
cls=CustomCommand)
@click.argument('package', type=str)
@click.argument('config', type=str)
@click.argument('export_dir', type=str, default=None, required=False)
@click.option(
'-f',
'--fast-test',
is_flag=True,
help='The fast_test mode. In order to quickly test if'
' there is any error in the export package,'
' it only use the first two data of your datasets'
' which only be used to train 2 iters/epoches.')
@click.option(
'--save-log',
is_flag=True,
help='The flag to keep the export log of the process. The log of export'
" process will be save to directory 'export_log'. Default will"
' automatically delete the log after export.')
def cli(config: str,
package: str,
export_dir: str,
fast_test: bool = False,
save_log: bool = False) -> None:
"""Export package from config file.
Example:
\b
>>> # Export package from downstream config file.
>>> mim export mmdet dab_detr/dab-detr_r50_8xb2-50e_coco.py \\
... dab_detr
>>>
>>> # Export package from specified config file.
>>> mim export mmdet mmdetection/configs/dab_detr/dab-detr_r50_8xb2-50e_coco.py dab_detr # noqa: E501
>>>
>>> # It can auto generate a export dir when not specified.
>>> # like'pack_from_mmdet_20231026_052704.
>>> mim export mmdet dab_detr/dab-detr_r50_8xb2-50e_coco.py
>>>
>>> # Only export the model of config file.
>>> mim export mmdet dab_detr/dab-detr_r50_8xb2-50e_coco.py \\
... mask_rcnn_package --model-only
>>>
>>> # Keep the export log of the process.
>>> mim export mmdet dab_detr/dab-detr_r50_8xb2-50e_coco.py \\
... mask_rcnn_package --save-log
>>>
>>> # Print the help information of export command.
>>> mim export -h
"""
# get config
if osp.exists(config):
config = Config.fromfile(config) # from local
else:
try:
config = get_config(package + '::' + config) # from downstream
except Exception:
raise FileNotFoundError(
f"Config file '{config}' or '{package + '::' + config}'.")
fast_test_mode(config, fast_test)
export_from_cfg(config, export_root_dir=export_dir, save_log=save_log)
def fast_test_mode(cfg, fast_test: bool = False):
"""Use less data for faster testing.
Args:
cfg (Config): Config of export package.
fast_test (bool, optional): Fast testing mode. Defaults to False.
"""
if fast_test:
# for batch_norm using at least 2 data
if 'dataset' in cfg.train_dataloader.dataset:
cfg.train_dataloader.dataset.dataset.indices = [0, 1]
else:
cfg.train_dataloader.dataset.indices = [0, 1]
cfg.train_dataloader.batch_size = 2
if cfg.get('test_dataloader') is not None:
cfg.test_dataloader.dataset.indices = [0, 1]
cfg.test_dataloader.batch_size = 2
if cfg.get('val_dataloader') is not None:
cfg.val_dataloader.dataset.indices = [0, 1]
cfg.val_dataloader.batch_size = 2
if (cfg.train_cfg.get('type') == 'IterBasedTrainLoop') \
or (cfg.train_cfg.get('by_epoch') is None
and cfg.train_cfg.get('type') != 'EpochBasedTrainLoop'):
cfg.train_cfg.max_iters = 2
else:
cfg.train_cfg.max_epochs = 2
cfg.train_cfg.val_interval = 1
cfg.default_hooks.logger.interval = 1
if 'param_scheduler' in cfg and cfg.param_scheduler is not None:
if isinstance(cfg.param_scheduler, list):
for lr_sc in cfg.param_scheduler:
lr_sc.begin = 0
lr_sc.end = 2
else:
cfg.param_scheduler.begin = 0
cfg.param_scheduler.end = 2

View File

@ -4,6 +4,7 @@ from .default import (
DEFAULT_MMCV_BASE_URL,
DEFAULT_URL,
MODULE2PKG,
OFFICIAL_MODULES,
PKG2MODULE,
PKG2PROJECT,
RAW_GITHUB_URL,
@ -90,4 +91,5 @@ __all__ = [
'parse_home_page',
'ensure_installation',
'rich_progress_bar',
'OFFICIAL_MODULES',
]

View File

@ -13,6 +13,13 @@ WHEEL_URL = {
'{torch_version}/index.html',
}
RAW_GITHUB_URL = 'https://raw.githubusercontent.com/{owner}/{repo}/{branch}'
OFFICIAL_MODULES = [
'mmcls', 'mmdet', 'mmdet3d', 'mmseg', 'mmaction2', 'mmtrack', 'mmpose',
'mmedit', 'mmocr', 'mmgen', 'mmselfsup', 'mmrotate', 'mmflow', 'mmyolo',
'mmpretrain', 'mmagic'
]
PKG2PROJECT = {
'mmcv-full': 'mmcv',
'mmcls': 'mmclassification',