mirror of
https://github.com/open-mmlab/mim.git
synced 2025-06-03 14:59:11 +08:00
Compare commits
23 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
b39fbd9730 | ||
|
b40fe72e2f | ||
|
c9a5ea6a88 | ||
|
5d54183ef9 | ||
|
8021d1b0eb | ||
|
bf69e00b6d | ||
|
30dc4fa9c0 | ||
|
53536fc6a7 | ||
|
bc5aec2abe | ||
|
706cdc58b2 | ||
|
3d742a9dac | ||
|
2b495c9abd | ||
|
406e447bcd | ||
|
2a884557e1 | ||
|
81b54d6836 | ||
|
2a6adf55cb | ||
|
dd7f11c679 | ||
|
959f9c9e8c | ||
|
213f1b2235 | ||
|
231ffc73cb | ||
|
75c45590c8 | ||
|
444862746d | ||
|
8fb7776ebe |
@ -5,7 +5,7 @@ repos:
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.10.1
|
||||
rev: 5.11.5
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
|
56
README.md
56
README.md
@ -102,7 +102,7 @@ Please refer to [installation.md](docs/en/installation.md) for installation.
|
||||
uninstall('mmcv-full')
|
||||
|
||||
# uninstall mmcls
|
||||
uninstall('mmcls)
|
||||
uninstall('mmcls')
|
||||
```
|
||||
|
||||
</details>
|
||||
@ -182,7 +182,7 @@ Please refer to [installation.md](docs/en/installation.md) for installation.
|
||||
from mim import download
|
||||
|
||||
download('mmcls', ['resnet18_8xb16_cifar10'])
|
||||
download('mmcls', ['resnet18_8xb16_cifar10'], dest_dir='.')
|
||||
download('mmcls', ['resnet18_8xb16_cifar10'], dest_root='.')
|
||||
```
|
||||
|
||||
</details>
|
||||
@ -217,14 +217,14 @@ Please refer to [installation.md](docs/en/installation.md) for installation.
|
||||
from mim import train
|
||||
|
||||
train(repo='mmcls', config='resnet18_8xb16_cifar10.py', gpus=0,
|
||||
other_args='--work-dir tmp')
|
||||
other_args=('--work-dir', 'tmp'))
|
||||
train(repo='mmcls', config='resnet18_8xb16_cifar10.py', gpus=1,
|
||||
other_args='--work-dir tmp')
|
||||
other_args=('--work-dir', 'tmp'))
|
||||
train(repo='mmcls', config='resnet18_8xb16_cifar10.py', gpus=4,
|
||||
launcher='pytorch', other_args='--work-dir tmp')
|
||||
launcher='pytorch', other_args=('--work-dir', 'tmp'))
|
||||
train(repo='mmcls', config='resnet18_8xb16_cifar10.py', gpus=8,
|
||||
launcher='slurm', gpus_per_node=8, partition='partition_name',
|
||||
other_args='--work-dir tmp')
|
||||
other_args=('--work-dir', 'tmp'))
|
||||
```
|
||||
|
||||
</details>
|
||||
@ -260,15 +260,15 @@ Please refer to [installation.md](docs/en/installation.md) for installation.
|
||||
```python
|
||||
from mim import test
|
||||
test(repo='mmcls', config='resnet101_b16x8_cifar10.py',
|
||||
checkpoint='tmp/epoch_3.pth', gpus=1, other_args='--metrics accuracy')
|
||||
checkpoint='tmp/epoch_3.pth', gpus=1, other_args=('--metrics', 'accuracy'))
|
||||
test(repo='mmcls', config='resnet101_b16x8_cifar10.py',
|
||||
checkpoint='tmp/epoch_3.pth', gpus=1, other_args='--out tmp.pkl')
|
||||
checkpoint='tmp/epoch_3.pth', gpus=1, other_args=('--out', 'tmp.pkl'))
|
||||
test(repo='mmcls', config='resnet101_b16x8_cifar10.py',
|
||||
checkpoint='tmp/epoch_3.pth', gpus=4, launcher='pytorch',
|
||||
other_args='--metrics accuracy')
|
||||
other_args=('--metrics', 'accuracy'))
|
||||
test(repo='mmcls', config='resnet101_b16x8_cifar10.py',
|
||||
checkpoint='tmp/epoch_3.pth', gpus=8, partition='partition_name',
|
||||
launcher='slurm', gpus_per_node=8, other_args='--metrics accuracy')
|
||||
launcher='slurm', gpus_per_node=8, other_args=('--metrics', 'accuracy'))
|
||||
```
|
||||
|
||||
</details>
|
||||
@ -305,13 +305,13 @@ Please refer to [installation.md](docs/en/installation.md) for installation.
|
||||
from mim import run
|
||||
|
||||
run(repo='mmcls', command='get_flops',
|
||||
other_args='resnet101_b16x8_cifar10.py')
|
||||
other_args=('resnet101_b16x8_cifar10.py',))
|
||||
run(repo='mmcls', command='publish_model',
|
||||
other_args='input.pth output.pth')
|
||||
other_args=('input.pth', 'output.pth'))
|
||||
run(repo='mmcls', command='train',
|
||||
other_args='resnet101_b16x8_cifar10.py --work-dir tmp')
|
||||
other_args=('resnet101_b16x8_cifar10.py', '--work-dir', 'tmp'))
|
||||
run(repo='mmcls', command='test',
|
||||
other_args='resnet101_b16x8_cifar10.py tmp/epoch_3.pth --metrics accuracy')
|
||||
other_args=('resnet101_b16x8_cifar10.py', 'tmp/epoch_3.pth', '--metrics accuracy'))
|
||||
```
|
||||
|
||||
</details>
|
||||
@ -366,28 +366,28 @@ Please refer to [installation.md](docs/en/installation.md) for installation.
|
||||
|
||||
gridsearch(repo='mmcls', config='resnet101_b16x8_cifar10.py', gpus=0,
|
||||
search_args='--optimizer.lr 1e-2 1e-3',
|
||||
other_args='--work-dir tmp')
|
||||
other_args=('--work-dir', 'tmp'))
|
||||
gridsearch(repo='mmcls', config='resnet101_b16x8_cifar10.py', gpus=1,
|
||||
search_args='--optimizer.lr 1e-2 1e-3',
|
||||
other_args='--work-dir tmp')
|
||||
other_args=('--work-dir', 'tmp'))
|
||||
gridsearch(repo='mmcls', config='resnet101_b16x8_cifar10.py', gpus=1,
|
||||
search_args='--optimizer.weight_decay 1e-3 1e-4',
|
||||
other_args='--work-dir tmp')
|
||||
other_args=('--work-dir', 'tmp'))
|
||||
gridsearch(repo='mmcls', config='resnet101_b16x8_cifar10.py', gpus=1,
|
||||
search_args='--optimizer.lr 1e-2 1e-3 --optimizer.weight_decay'
|
||||
'1e-3 1e-4',
|
||||
other_args='--work-dir tmp')
|
||||
other_args=('--work-dir', 'tmp'))
|
||||
gridsearch(repo='mmcls', config='resnet101_b16x8_cifar10.py', gpus=8,
|
||||
partition='partition_name', gpus_per_node=8, launcher='slurm',
|
||||
search_args='--optimizer.lr 1e-2 1e-3 --optimizer.weight_decay'
|
||||
' 1e-3 1e-4',
|
||||
other_args='--work-dir tmp')
|
||||
other_args=('--work-dir', 'tmp'))
|
||||
gridsearch(repo='mmcls', config='resnet101_b16x8_cifar10.py', gpus=8,
|
||||
partition='partition_name', gpus_per_node=8, launcher='slurm',
|
||||
max_workers=2,
|
||||
search_args='--optimizer.lr 1e-2 1e-3 --optimizer.weight_decay'
|
||||
' 1e-3 1e-4',
|
||||
other_args='--work-dir tmp')
|
||||
other_args=('--work-dir', 'tmp'))
|
||||
```
|
||||
|
||||
</details>
|
||||
@ -402,22 +402,24 @@ This project is released under the [Apache 2.0 license](LICENSE).
|
||||
|
||||
## Projects in OpenMMLab
|
||||
|
||||
- [MMEngine](https://github.com/open-mmlab/mmengine): OpenMMLab foundational library for training deep learning models.
|
||||
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab foundational library for computer vision.
|
||||
- [MIM](https://github.com/open-mmlab/mim): MIM installs OpenMMLab packages.
|
||||
- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab image classification toolbox and benchmark.
|
||||
- [MMEval](https://github.com/open-mmlab/mmeval): A unified evaluation library for multiple machine learning libraries.
|
||||
- [MMPreTrain](https://github.com/open-mmlab/mmpretrain): OpenMMLab pre-training toolbox and benchmark.
|
||||
- [MMagic](https://github.com/open-mmlab/mmagic): Open**MM**Lab **A**dvanced, **G**enerative and **I**ntelligent **C**reation toolbox.
|
||||
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab detection toolbox and benchmark.
|
||||
- [MMYOLO](https://github.com/open-mmlab/mmyolo): OpenMMLab YOLO series toolbox and benchmark.
|
||||
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab's next-generation platform for general 3D object detection.
|
||||
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark.
|
||||
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark.
|
||||
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark.
|
||||
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark.
|
||||
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox.
|
||||
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark.
|
||||
- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 3D human parametric model toolbox and benchmark.
|
||||
- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab self-supervised learning toolbox and benchmark.
|
||||
- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab model compression toolbox and benchmark.
|
||||
- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab fewshot learning toolbox and benchmark.
|
||||
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab's next-generation action understanding toolbox and benchmark.
|
||||
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark.
|
||||
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark.
|
||||
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab image and video editing toolbox.
|
||||
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab image and video generative models toolbox.
|
||||
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab model deployment framework.
|
||||
- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab model compression toolbox and benchmark.
|
||||
- [Playground](https://github.com/open-mmlab/playground): A central hub for gathering and showcasing amazing projects built upon OpenMMLab.
|
||||
|
@ -1,4 +1,18 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
# NOTE: We could got AssertionError when importing pip before
|
||||
# setuptools. A workaround is to import setuptools first and filter
|
||||
# warnings that are caused by setuptools replacing distutils.
|
||||
# Related issues:
|
||||
# - https://github.com/pypa/setuptools/issues/3621
|
||||
# - https://github.com/open-mmlab/mmclassification/issues/1343
|
||||
try:
|
||||
import setuptools # noqa: F401
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore', 'Setuptools is replacing distutils')
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from .commands import (
|
||||
download,
|
||||
get_model_info,
|
||||
|
99
mim/_internal/export/README.md
Normal file
99
mim/_internal/export/README.md
Normal file
@ -0,0 +1,99 @@
|
||||
# Export (experimental)
|
||||
|
||||
`mim export` is a new feature in `mim`, which can export a minimum trainable and testable model package through the config file.
|
||||
|
||||
The minimal model compactly combines `model, datasets, engine` and other components. There is no need to search for each module separately based on the config file. Users who do not install from source can also get the model file at a glance.
|
||||
|
||||
In addition, lengthy inheritance relationships, such as `CondInstBboxHead -> FCOSHead -> AnchorFreeHead -> BaseDenseHead -> BaseModule` in `mmdetection`, will be directly flattened into `CondInstBboxHead -> BaseModule`. There is no need to open and compare multiple model files when all functions inherited from the parent class are clearly visible.
|
||||
|
||||
### Instructions Usage
|
||||
|
||||
```bash
|
||||
mim export config_path
|
||||
|
||||
# config_path has the following two optional types:
|
||||
|
||||
# 1. config of downstream repo: Complete the call through repo::model_name/xxx.py. For example:
|
||||
mim export mmdet::configs/mask_rcnn/mask-rcnn_r101_fpn_1x_coco.py
|
||||
|
||||
# 2. config in a certain folder, for example:
|
||||
mim export config_dir/mask-rcnn_r101_fpn_1x_coco.py
|
||||
```
|
||||
|
||||
### Minimum model package directory structure
|
||||
|
||||
```
|
||||
minimun_package(Named as pack_from_{repo}_20231212_121212)
|
||||
|- pack
|
||||
| |- configs # Configuration folder
|
||||
| | |- model_name
|
||||
| | |- xxx.py # Configuration file
|
||||
| |
|
||||
| |- models # model folder
|
||||
| | |- model_file.py
|
||||
| | |- ...
|
||||
| |
|
||||
| |- data # data folder
|
||||
| |
|
||||
| |- demo # demo folder
|
||||
| |
|
||||
| |-datasets #Dataset class definition
|
||||
| | |- transforms
|
||||
| |
|
||||
| |- registry.py # Registrar
|
||||
|
|
||||
|- tools
|
||||
| |- train.py # training
|
||||
| |- test.py # test
|
||||
|
|
||||
```
|
||||
|
||||
### limit
|
||||
|
||||
Currently, `mim export` only supports some config files of `mmpose`, `mmdetection`, `mmagic`, and `mmsegmentation` Besides, there are some constraints on the downstream repo.
|
||||
|
||||
#### For downstream repos
|
||||
|
||||
1. It is best to name the config without special symbols, otherwise it cannot be parsed through `mmengine.hub.get_config()`, such as:
|
||||
|
||||
- gn+ws/faster-rcnn_r101_fpn_gn-ws-all_1x_coco.py
|
||||
- legacy_1.x/cascade-mask-rcnn_r50_fpn_1x_coco_v1.py
|
||||
|
||||
2. For `mmsegmentation`, before using `mim.export` for config in `mmseg`, you should firstly modify it like `mmseg/registry/registry.py -> mmseg/registry.py`, without a directory to wrap `registry.py`
|
||||
|
||||
3. It is recommended that the downstream Registry name inherited from mmengine should not be changed. For example, mmagic renamed `EVALUATOR` to `EVALUATORS`
|
||||
|
||||
```python
|
||||
from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR
|
||||
|
||||
# Evaluators to define the evaluation process.
|
||||
EVALUATORS = Registry(
|
||||
'evaluator',
|
||||
parent=MMENGINE_EVALUATOR,
|
||||
locations=['mmagic.evaluation'],
|
||||
)
|
||||
```
|
||||
|
||||
4. In addition, if you add a register that is not in mmengine, such as `DIFFUSION_SCHEDULERS` in mmagic, you need to add a key-value pair in `REGISTRY_TYPE` in `mim/_internal/export/common.py` for registering `torch `Module to `DIFFUSION_SCHEDULERS`
|
||||
|
||||
```python
|
||||
# "mmagic/mmagic/registry.py"
|
||||
# modules for diffusion models that support adding noise and denoising
|
||||
DIFFUSION_SCHEDULERS = Registry(
|
||||
'diffusion scheduler',
|
||||
locations=['mmagic.models.diffusion_schedulers'],
|
||||
)
|
||||
|
||||
# "mim/utils/mmpack/common.py"
|
||||
REGISTRY_TYPE = {
|
||||
...
|
||||
'diffusion scheduler': 'DIFFUSION_SCHEDULERS',
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
#### Need to be improve
|
||||
|
||||
1. Currently, the inheritance relationship expansion of double parent classes is not supported. Improvements will be made in the future depending on the needs.
|
||||
2. Here comes ERROR when the origin config contains other files' `path` that can not be found in the current path. You can avoid export errors by manually modifying the `path` in the origin config.
|
||||
3. When isinstance() is used, if the parent class is just a class in the inheritance chain, the judgment may be False after expansion, because the original inheritance relationship will not be retained.
|
100
mim/_internal/export/README_zh-CN.md
Normal file
100
mim/_internal/export/README_zh-CN.md
Normal file
@ -0,0 +1,100 @@
|
||||
# Export (experimental)
|
||||
|
||||
`mim export` 是 `mim` 里面一个新的功能,可以实现通过 config 文件,就能够导出一个最小可训练测试的模型包。
|
||||
|
||||
最小模型将 `model、datasets、engine`等组件紧凑地组合在一起,不用再根据 config 文件单独寻找每一个模块,对于非源码安装的用户也能够获取到一目了然的模型文件。
|
||||
|
||||
此外,对于模型中冗长的继承关系,如 `mmdetection` 中的 `CondInstBboxHead -> FCOSHead -> AnchorFreeHead -> BaseDenseHead -> BaseModule`,将会被直接展平为 `CondInstBboxHead -> BaseModule`,即无需再在多个模型文件之间跳转比较了,所有继承父类的函数一览无遗。
|
||||
|
||||
### 使用方法
|
||||
|
||||
```bash
|
||||
mim export config_path
|
||||
|
||||
# config_path 有以下两种可选类型:
|
||||
# 下游 repo 的 config:通过 repo::configs/xxx.py 来完成调用。例如:
|
||||
mim export mmdet::configs/mask_rcnn/mask-rcnn_r101_fpn_1x_coco.py
|
||||
|
||||
# 在某文件夹下的 config, 例如:
|
||||
mim export config_dir/mask-rcnn_r101_fpn_1x_coco.py
|
||||
```
|
||||
|
||||
### 最小模型包目录结构
|
||||
|
||||
```
|
||||
minimun_package(Named as pack_from_{repo}_20231212_121212)
|
||||
|- pack
|
||||
| |- configs # 配置文件夹
|
||||
| | |- model_name
|
||||
| | |- xxx.py # 配置文件
|
||||
| |
|
||||
| |- models # 模型文件夹
|
||||
| | |- model_file.py
|
||||
| | |- ...
|
||||
| |
|
||||
| |- data # 数据文件夹
|
||||
| |
|
||||
| |- demo # demo文件夹
|
||||
| |
|
||||
| |- datasets # 数据集类定义
|
||||
| | |- transforms
|
||||
| |
|
||||
| |- registry.py # 注册器
|
||||
|
|
||||
|- tools
|
||||
| |- train.py # 训练
|
||||
| |- test.py # 测试
|
||||
|
|
||||
```
|
||||
|
||||
### 限制
|
||||
|
||||
`mim export` 目前只支持 `mmpose`、`mmdetection`、`mmagic` 和 `mmsegmentation` 的部分 config 配置文件,并且对下游算法库有一些约束。
|
||||
|
||||
#### 针对下游库
|
||||
|
||||
1. config 命名最好**不要有特殊符号**,否则无法通过 `mmengine.hub.get_config()` 进行解析,如:
|
||||
|
||||
- gn+ws/faster-rcnn_r101_fpn_gn-ws-all_1x_coco.py
|
||||
- legacy_1.x/cascade-mask-rcnn_r50_fpn_1x_coco_v1.py
|
||||
|
||||
2. 针对 `mmsegmentation`, 在使用 `mim.export` 导出 `mmseg` 的 config 之前, 首先需要去掉对于 `registry.py` 的外层文件夹封装, 即修改 `mmseg/registry/registry.py -> mmseg/registry.py`。
|
||||
|
||||
3. 建议下游继承于 mmengine 的 Registry 名字不要改动,如 mmagic 中就将 `EVALUATOR` 重新命名为了 `EVALUATORS`
|
||||
|
||||
```python
|
||||
from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR
|
||||
|
||||
# Evaluators to define the evaluation process.
|
||||
EVALUATORS = Registry(
|
||||
'evaluator',
|
||||
parent=MMENGINE_EVALUATOR,
|
||||
locations=['mmagic.evaluation'],
|
||||
)
|
||||
```
|
||||
|
||||
4. 另外,如果添加了 mmengine 中没有的注册器,如 mmagic 中的 `DIFFUSION_SCHEDULERS`,需要在 `mim/_internal/export/common.py` 的 `REGISTRY_TYPE` 中添加键值对,用于注册 `torch` 模块到 `DIFFUSION_SCHEDULERS`
|
||||
|
||||
```python
|
||||
# "mmagic/mmagic/registry.py"
|
||||
# modules for diffusion models that support adding noise and denoising
|
||||
DIFFUSION_SCHEDULERS = Registry(
|
||||
'diffusion scheduler',
|
||||
locations=['mmagic.models.diffusion_schedulers'],
|
||||
)
|
||||
|
||||
# "mim/utils/mmpack/common.py"
|
||||
REGISTRY_TYPE = {
|
||||
...
|
||||
'diffusion scheduler': 'DIFFUSION_SCHEDULERS',
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
#### 对于 `mim.export` 功能需要改进的地方
|
||||
|
||||
1. 目前还不支持双父类的继承关系展开,后续看需求进行改进
|
||||
|
||||
2. 对于用到 `isinstance()` 时,如果父类只是继承链中某个类,可能展开后判断就会为 False,因为并不会保留原有的继承关系
|
||||
|
||||
3. 当 config 文件中含有当前文件夹没法被访问到的`数据集路径`,导出可能会失败。目前的临时解决方法是:将原来的 config 文件保存到当前文件夹下,然后需要用户手动修改`数据集路径`为当前路径下的可访问路径。如:`data/ADEChallengeData2016/ -> your_data_dir/ADEChallengeData2016/`
|
1
mim/_internal/export/__init__.py
Normal file
1
mim/_internal/export/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
75
mim/_internal/export/common.py
Normal file
75
mim/_internal/export/common.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
# initialization template for __init__.py
|
||||
_init_str = """
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
all_files = os.listdir(osp.dirname(__file__))
|
||||
|
||||
for file in all_files:
|
||||
if (file.endswith('.py') and file != '__init__.py') or '.' not in file:
|
||||
exec(f'from .{osp.splitext(file)[0]} import *')
|
||||
"""
|
||||
|
||||
# import pack path for tools
|
||||
_import_pack_str = """
|
||||
import os.path as osp
|
||||
import sys
|
||||
sys.path.append(osp.dirname(osp.dirname(__file__)))
|
||||
import pack
|
||||
|
||||
"""
|
||||
|
||||
OBJECTS_TO_BE_PATCHED = {
|
||||
'MODELS': [
|
||||
'BACKBONES',
|
||||
'NECKS',
|
||||
'HEADS',
|
||||
'LOSSES',
|
||||
'SEGMENTORS',
|
||||
'build_backbone',
|
||||
'build_neck',
|
||||
'build_head',
|
||||
'build_loss',
|
||||
'build_segmentor',
|
||||
'CLASSIFIERS',
|
||||
'RETRIEVER',
|
||||
'build_classifier',
|
||||
'build_retriever',
|
||||
'POSE_ESTIMATORS',
|
||||
'build_pose_estimator',
|
||||
'build_posenet',
|
||||
],
|
||||
'TASK_UTILS': [
|
||||
'PIXEL_SAMPLERS',
|
||||
'build_pixel_sampler',
|
||||
]
|
||||
}
|
||||
|
||||
REGISTRY_TYPES = {
|
||||
'runner': 'RUNNERS',
|
||||
'runner constructor': 'RUNNER_CONSTRUCTORS',
|
||||
'hook': 'HOOKS',
|
||||
'strategy': 'STRATEGIES',
|
||||
'dataset': 'DATASETS',
|
||||
'data sampler': 'DATA_SAMPLERS',
|
||||
'transform': 'TRANSFORMS',
|
||||
'model': 'MODELS',
|
||||
'model wrapper': 'MODEL_WRAPPERS',
|
||||
'weight initializer': 'WEIGHT_INITIALIZERS',
|
||||
'optimizer': 'OPTIMIZERS',
|
||||
'optimizer wrapper': 'OPTIM_WRAPPERS',
|
||||
'optimizer wrapper constructor': 'OPTIM_WRAPPER_CONSTRUCTORS',
|
||||
'parameter scheduler': 'PARAM_SCHEDULERS',
|
||||
'param scheduler': 'PARAM_SCHEDULERS',
|
||||
'metric': 'METRICS',
|
||||
'evaluator': 'EVALUATOR', # TODO EVALUATORS in mmagic
|
||||
'task utils': 'TASK_UTILS',
|
||||
'loop': 'LOOPS',
|
||||
'visualizer': 'VISUALIZERS',
|
||||
'vis_backend': 'VISBACKENDS',
|
||||
'log processor': 'LOG_PROCESSORS',
|
||||
'inferencer': 'INFERENCERS',
|
||||
'function': 'FUNCTIONS',
|
||||
}
|
1081
mim/_internal/export/flatten_func.py
Normal file
1081
mim/_internal/export/flatten_func.py
Normal file
File diff suppressed because it is too large
Load Diff
312
mim/_internal/export/pack_cfg.py
Normal file
312
mim/_internal/export/pack_cfg.py
Normal file
@ -0,0 +1,312 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from mmengine import MMLogger
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.hub import get_config
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import Runner
|
||||
from mmengine.utils import get_installed_path, mkdir_or_exist
|
||||
|
||||
from mim.utils import echo_error
|
||||
from .common import _import_pack_str, _init_str
|
||||
from .utils import (
|
||||
_get_all_files,
|
||||
_postprocess_importfrom_module_to_pack,
|
||||
_postprocess_registry_locations,
|
||||
_replace_config_scope_to_pack,
|
||||
_wrapper_all_registries_build_func,
|
||||
)
|
||||
|
||||
|
||||
def export_from_cfg(cfg: Union[str, ConfigDict],
|
||||
export_root_dir: str,
|
||||
model_only: Optional[bool] = False,
|
||||
save_log: Optional[bool] = False):
|
||||
"""A function to pack the minimum available package according to config
|
||||
file.
|
||||
|
||||
Args:
|
||||
cfg (:obj:`ConfigDict` or str): Config file for packing the
|
||||
minimum package.
|
||||
export_root_dir (str, optional): The pack directory to save the
|
||||
packed package.
|
||||
fast_test (bool, optional): Turn to fast testing mode.
|
||||
Defaults to False.
|
||||
"""
|
||||
# generate temp dir for export
|
||||
export_root_tempdir = tempfile.TemporaryDirectory()
|
||||
export_log_tempdir = tempfile.TemporaryDirectory()
|
||||
|
||||
# delete the incomplete export package when keyboard interrupt
|
||||
signal.signal(signal.SIGINT, lambda sig, frame: keyboardinterupt_handler(
|
||||
sig, frame, export_root_tempdir, export_log_tempdir)
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
# get config
|
||||
if isinstance(cfg, str):
|
||||
if '::' in cfg:
|
||||
cfg = get_config(cfg)
|
||||
else:
|
||||
cfg = Config.fromfile(cfg)
|
||||
|
||||
default_scope = cfg.get('default_scope', 'mmengine')
|
||||
|
||||
# automatically generate ``export_root_dir``
|
||||
if export_root_dir is None:
|
||||
export_root_dir = f'pack_from_{default_scope}_' + \
|
||||
f"{datetime.now().strftime(r'%Y%m%d_%H%M%S')}"
|
||||
|
||||
# generate ``export_log_dir``
|
||||
if osp.sep in export_root_dir:
|
||||
export_path = osp.dirname(export_root_dir)
|
||||
else:
|
||||
export_path = os.getcwd()
|
||||
export_log_dir = osp.join(export_path, 'export_log')
|
||||
|
||||
export_logger = MMLogger.get_instance( # noqa: F841
|
||||
'export',
|
||||
log_file=osp.join(export_log_tempdir.name, 'export.log'))
|
||||
|
||||
export_module_tempdir_name = osp.join(export_root_tempdir.name, 'pack')
|
||||
|
||||
# export config
|
||||
if '.mim' in cfg.filename:
|
||||
cfg_path = osp.join(export_module_tempdir_name,
|
||||
cfg.filename[cfg.filename.find('configs'):])
|
||||
else:
|
||||
cfg_path = osp.join(
|
||||
osp.join(export_module_tempdir_name, 'configs'),
|
||||
osp.basename(cfg.filename))
|
||||
mkdir_or_exist(osp.dirname(cfg_path))
|
||||
|
||||
# transform to default_scope
|
||||
init_default_scope(default_scope)
|
||||
|
||||
# wrap ``Registry.build()`` for exporting modules
|
||||
_wrapper_all_registries_build_func(
|
||||
export_module_dir=export_module_tempdir_name, scope=default_scope)
|
||||
|
||||
print_log(
|
||||
f'[ Export Package Name ]: {export_root_dir}\n'
|
||||
f' package from config: {cfg.filename}\n'
|
||||
f" from downstream package: '{default_scope}'\n",
|
||||
logger='export',
|
||||
level=logging.INFO)
|
||||
|
||||
# creat temp work_dirs for export
|
||||
cfg['work_dir'] = export_log_tempdir.name
|
||||
|
||||
# use runner to export all needed modules
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
# HARD CODE: In order to deal with some module will build in
|
||||
# ``before_run`` or ``after_run``, we can call them without need
|
||||
# to call "runner.train())".
|
||||
|
||||
# Example:
|
||||
# >>> @HOOKS.register_module()
|
||||
# >>> class EMAHook(Hook):
|
||||
# >>> ...
|
||||
# >>> def before_run(self, runner) -> None:
|
||||
# >>> """Create an ema copy of the model.
|
||||
# >>> Args:
|
||||
# >>> runner (Runner): The runner of the training process.
|
||||
# >>> """
|
||||
# >>> model = runner.model
|
||||
# >>> if is_model_wrapper(model):
|
||||
# >>> model = model.module
|
||||
# >>> self.src_model = model
|
||||
# >>> self.ema_model = MODELS.build(
|
||||
# >>> self.ema_cfg, default_args=dict(model=self.src_model))
|
||||
|
||||
# It need to build ``self.ema_model`` in ``before_run``.
|
||||
|
||||
for hook in runner.hooks:
|
||||
hook.before_run(runner)
|
||||
hook.after_run(runner)
|
||||
|
||||
def dump():
|
||||
cfg['work_dir'] = 'work_dirs' # recover to default.
|
||||
_replace_config_scope_to_pack(cfg)
|
||||
cfg.dump(cfg_path)
|
||||
|
||||
# copy temp log to export log
|
||||
if save_log:
|
||||
shutil.copytree(
|
||||
export_log_tempdir.name, export_log_dir, dirs_exist_ok=True)
|
||||
|
||||
export_log_tempdir.cleanup()
|
||||
|
||||
# copy temp_package_dir to export_package_dir
|
||||
shutil.copytree(
|
||||
export_root_tempdir.name, export_root_dir, dirs_exist_ok=True)
|
||||
export_root_tempdir.cleanup()
|
||||
|
||||
print_log(
|
||||
f'[ Export Package Name ]: '
|
||||
f'{osp.join(os.getcwd(), export_root_dir)}\n',
|
||||
logger='export',
|
||||
level=logging.INFO)
|
||||
|
||||
if model_only:
|
||||
dump()
|
||||
return 0
|
||||
|
||||
try:
|
||||
runner.build_train_loop(cfg.train_cfg)
|
||||
except FileNotFoundError:
|
||||
error_postprocess(export_log_dir, default_scope,
|
||||
export_root_tempdir, export_log_tempdir,
|
||||
osp.basename(cfg_path), 'train_dataloader')
|
||||
|
||||
try:
|
||||
if 'val_cfg' in cfg and cfg.val_cfg is not None:
|
||||
runner.build_val_loop(cfg.val_cfg)
|
||||
except FileNotFoundError:
|
||||
error_postprocess(export_log_dir, default_scope,
|
||||
export_root_tempdir, export_log_tempdir,
|
||||
osp.basename(cfg_path), 'val_dataloader')
|
||||
|
||||
try:
|
||||
if 'test_cfg' in cfg and cfg.test_cfg is not None:
|
||||
runner.build_test_loop(cfg.test_cfg)
|
||||
except FileNotFoundError:
|
||||
error_postprocess(export_log_dir, default_scope,
|
||||
export_root_tempdir, export_log_tempdir,
|
||||
osp.basename(cfg_path), 'test_dataloader')
|
||||
|
||||
if 'optim_wrapper' in cfg and cfg.optim_wrapper is not None:
|
||||
runner.optim_wrapper = runner.build_optim_wrapper(cfg.optim_wrapper)
|
||||
if 'param_scheduler' in cfg and cfg.param_scheduler is not None:
|
||||
runner.build_param_scheduler(cfg.param_scheduler)
|
||||
|
||||
# add ``__init__.py`` to all dirs, for transferring directories
|
||||
# to be modules
|
||||
for directory, _, _ in os.walk(export_module_tempdir_name):
|
||||
if not osp.exists(osp.join(directory, '__init__.py')) \
|
||||
and 'configs' not in directory:
|
||||
with open(osp.join(directory, '__init__.py'), 'w') as f:
|
||||
f.write(_init_str)
|
||||
|
||||
# postprocess for ``pack/registry.py``
|
||||
_postprocess_registry_locations(export_root_tempdir.name)
|
||||
|
||||
# postprocess for ImportFrom Node, turn to import from export path
|
||||
all_export_files = _get_all_files(export_module_tempdir_name)
|
||||
for file in all_export_files:
|
||||
_postprocess_importfrom_module_to_pack(file)
|
||||
|
||||
# get tools from web
|
||||
tools_dir = osp.join(export_root_tempdir.name, 'tools')
|
||||
mkdir_or_exist(tools_dir)
|
||||
|
||||
for tool_name in [
|
||||
'train.py', 'test.py', 'dist_train.sh', 'dist_test.sh',
|
||||
'slurm_train.sh', 'slurm_test.sh'
|
||||
]:
|
||||
pack_tools(
|
||||
tool_name=tool_name,
|
||||
scope=default_scope,
|
||||
tool_dir=tools_dir,
|
||||
auto_import=True)
|
||||
|
||||
# TODO: get demo.py
|
||||
|
||||
dump()
|
||||
return 0
|
||||
|
||||
|
||||
def keyboardinterupt_handler(
|
||||
sig: int,
|
||||
frame,
|
||||
export_root_tempdir: tempfile.TemporaryDirectory,
|
||||
export_log_tempdir: tempfile.TemporaryDirectory,
|
||||
):
|
||||
"""Clear uncompleted exported package by interrupting with keyboard."""
|
||||
|
||||
export_log_tempdir.cleanup()
|
||||
export_root_tempdir.cleanup()
|
||||
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def error_postprocess(export_log_dir: str, scope: str,
|
||||
export_root_dir_tempfile: tempfile.TemporaryDirectory,
|
||||
export_log_dir_tempfile: tempfile.TemporaryDirectory,
|
||||
cfg_name: str, error_key: str):
|
||||
"""Print Debug message when package can't successfully export for missing
|
||||
datasets.
|
||||
|
||||
Args:
|
||||
export_root_dir (str): _description_
|
||||
absolute_cfg_path (str): _description_
|
||||
origin_cfg (ConfigDict): _description_
|
||||
error_key (str): _description_
|
||||
logger (_type_): _description_
|
||||
"""
|
||||
shutil.copytree(
|
||||
export_log_dir_tempfile.name, export_log_dir, dirs_exist_ok=True)
|
||||
export_root_dir_tempfile.cleanup()
|
||||
export_log_dir_tempfile.cleanup()
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
error_msg = f"{'=' * 20} Debug Message {'=' * 20}"\
|
||||
f"\nThe data root of '{error_key}' is not found. You can"\
|
||||
' use the below two method to deal with.\n\n'\
|
||||
" >>> Method 1: Please modify the 'data_root' in"\
|
||||
f" duplicate config '{export_log_dir}/{cfg_name}'.\n"\
|
||||
" >>> Method 2: Use '--model_only' to export model only.\n\n"\
|
||||
"After finishing one of the above steps, you can use 'mim export"\
|
||||
f" {scope} {export_log_dir}/{cfg_name} [--model-only]' to export"\
|
||||
' again.'
|
||||
|
||||
echo_error(error_msg)
|
||||
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def pack_tools(tool_name: str,
|
||||
scope: str,
|
||||
tool_dir: str,
|
||||
auto_import: Optional[bool] = False):
|
||||
"""pack tools from installed repo.
|
||||
|
||||
Args:
|
||||
tool_name (str): Tool name in repos' tool dir.
|
||||
scope (str): The scope of repo.
|
||||
path (str): Path to save tool.
|
||||
auto_import (bool, optional): Automatically add "import pack" to the
|
||||
tool file. Defaults to "False"
|
||||
"""
|
||||
pkg_root = get_installed_path(scope)
|
||||
path = osp.join(tool_dir, tool_name)
|
||||
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
|
||||
# tools will be put in package/.mim in PR #68
|
||||
tool_script = osp.join(pkg_root, '.mim', 'tools', tool_name)
|
||||
if not osp.exists(tool_script):
|
||||
tool_script = osp.join(pkg_root, 'tools', tool_name)
|
||||
|
||||
shutil.copy(tool_script, path)
|
||||
|
||||
# automatically import the pack modules
|
||||
if auto_import:
|
||||
with open(path, 'r+') as f:
|
||||
lines = f.readlines()
|
||||
code = ''.join(lines[:1] + [_import_pack_str] + lines[1:])
|
||||
f.seek(0)
|
||||
f.write(code)
|
75
mim/_internal/export/patch_utils/README.md
Normal file
75
mim/_internal/export/patch_utils/README.md
Normal file
@ -0,0 +1,75 @@
|
||||
# Patch Utils
|
||||
|
||||
## Problem
|
||||
|
||||
This patch is mainly to solve the problem that the module cannot be properly registered due to the renaming of the registry in the downstream repo, such as an example of the `mmsegmentation`:
|
||||
|
||||
```python
|
||||
# "mmsegmentation/mmseg/structures/sampler/builder.py"
|
||||
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmseg.registry import TASK_UTILS
|
||||
|
||||
PIXEL_SAMPLERS = TASK_UTILS
|
||||
|
||||
|
||||
def build_pixel_sampler(cfg, **default_args):
|
||||
"""Build pixel sampler for segmentation map."""
|
||||
warnings.warn(
|
||||
'``build_pixel_sampler`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
```
|
||||
|
||||
Some modules may use the renamed registry, which makes it difficult for `mim export` to find the original name of the renamed modules.
|
||||
|
||||
```python
|
||||
# "mmsegmentation/mmseg/structures/sampler/ohem_pixel_sampler.py"
|
||||
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base_pixel_sampler import BasePixelSampler
|
||||
from .builder import PIXEL_SAMPLERS
|
||||
|
||||
|
||||
@PIXEL_SAMPLERS.register_module()
|
||||
class OHEMPixelSampler(BasePixelSampler):
|
||||
...
|
||||
```
|
||||
|
||||
## Solution
|
||||
|
||||
Therefore, we have currently migrated the necessary modules in `mmpose/mmdetection/mmseg/mmpretrain` listed below, directly to `patch_utils.patch_model` and `patch_utils.patch_task`. In order to build a patch containing renamed registry and special module constructor functions.
|
||||
|
||||
```python
|
||||
"mmdetection/mmdet/models/task_modules/builder.py"
|
||||
"mmdetection/build/lib/mmdet/models/task_modules/builder.py"
|
||||
|
||||
"mmsegmentation/mmseg/models/builder.py"
|
||||
"mmsegmentation/mmseg/structures/sampler/builder.py"
|
||||
"mmsegmentation/build/lib/mmseg/models/builder.py"
|
||||
"mmsegmentation/build/lib/mmseg/structures/sampler/builder.py"
|
||||
|
||||
"mmpretrain/mmpretrain/datasets/builder.py"
|
||||
"mmpretrain/mmpretrain/models/builder.py"
|
||||
"mmpretrain/build/lib/mmpretrain/datasets/builder.py"
|
||||
"mmpretrain/build/lib/mmpretrain/models/builder.py"
|
||||
|
||||
"mmpose/mmpose/datasets/builder.py"
|
||||
"mmpose/mmpose/models/builder.py"
|
||||
"mmpose/build/lib/mmpose/datasets/builder.py"
|
||||
"mmpose/build/lib/mmpose/models/builder.py"
|
||||
|
||||
"mmengine/mmengine/evaluator/builder.py"
|
||||
"mmengine/mmengine/model/builder.py"
|
||||
"mmengine/mmengine/optim/optimizer/builder.py"
|
||||
"mmengine/mmengine/visualization/builder.py"
|
||||
"mmengine/build/lib/mmengine/evaluator/builder.py"
|
||||
"mmengine/build/lib/mmengine/model/builder.py"
|
||||
"mmengine/build/lib/mmengine/optim/optimizer/builder.py"
|
||||
```
|
75
mim/_internal/export/patch_utils/README_zh-CN.md
Normal file
75
mim/_internal/export/patch_utils/README_zh-CN.md
Normal file
@ -0,0 +1,75 @@
|
||||
# Patch Utils
|
||||
|
||||
## 问题
|
||||
|
||||
该补丁主要是为了解决下游 repo 中存在对注册器进行重命名导致模块无法被正确注册的问题,如 `mmsegmentation` 中的一个例子:
|
||||
|
||||
```python
|
||||
# "mmsegmentation/mmseg/structures/sampler/builder.py"
|
||||
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmseg.registry import TASK_UTILS
|
||||
|
||||
PIXEL_SAMPLERS = TASK_UTILS
|
||||
|
||||
|
||||
def build_pixel_sampler(cfg, **default_args):
|
||||
"""Build pixel sampler for segmentation map."""
|
||||
warnings.warn(
|
||||
'``build_pixel_sampler`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
```
|
||||
|
||||
然后在某些模块中可能会使用改名后的注册器,这对于导出后处理很难找到重命名模块原来的名字
|
||||
|
||||
```python
|
||||
# "mmsegmentation/mmseg/structures/sampler/ohem_pixel_sampler.py"
|
||||
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base_pixel_sampler import BasePixelSampler
|
||||
from .builder import PIXEL_SAMPLERS
|
||||
|
||||
|
||||
@PIXEL_SAMPLERS.register_module()
|
||||
class OHEMPixelSampler(BasePixelSampler):
|
||||
...
|
||||
```
|
||||
|
||||
## 解决方案
|
||||
|
||||
因此我们目前已经将如下 `mmpose / mmdetection / mmseg / mmpretrain` 中必要的模块直接迁移到 `patch_utils.patch_model` 和 `patch_utils.patch_task` 中构建一个包含注册器重命名和特殊模块构造函数的补丁。
|
||||
|
||||
```python
|
||||
"mmdetection/mmdet/models/task_modules/builder.py"
|
||||
"mmdetection/build/lib/mmdet/models/task_modules/builder.py"
|
||||
|
||||
"mmsegmentation/mmseg/models/builder.py"
|
||||
"mmsegmentation/mmseg/structures/sampler/builder.py"
|
||||
"mmsegmentation/build/lib/mmseg/models/builder.py"
|
||||
"mmsegmentation/build/lib/mmseg/structures/sampler/builder.py"
|
||||
|
||||
"mmpretrain/mmpretrain/datasets/builder.py"
|
||||
"mmpretrain/mmpretrain/models/builder.py"
|
||||
"mmpretrain/build/lib/mmpretrain/datasets/builder.py"
|
||||
"mmpretrain/build/lib/mmpretrain/models/builder.py"
|
||||
|
||||
"mmpose/mmpose/datasets/builder.py"
|
||||
"mmpose/mmpose/models/builder.py"
|
||||
"mmpose/build/lib/mmpose/datasets/builder.py"
|
||||
"mmpose/build/lib/mmpose/models/builder.py"
|
||||
|
||||
"mmengine/mmengine/evaluator/builder.py"
|
||||
"mmengine/mmengine/model/builder.py"
|
||||
"mmengine/mmengine/optim/optimizer/builder.py"
|
||||
"mmengine/mmengine/visualization/builder.py"
|
||||
"mmengine/build/lib/mmengine/evaluator/builder.py"
|
||||
"mmengine/build/lib/mmengine/model/builder.py"
|
||||
"mmengine/build/lib/mmengine/optim/optimizer/builder.py"
|
||||
```
|
3
mim/_internal/export/patch_utils/__init__.py
Normal file
3
mim/_internal/export/patch_utils/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .patch_model import * # noqa: F401, F403
|
||||
from .patch_task import * # noqa: F401, F403
|
82
mim/_internal/export/patch_utils/patch_model.py
Normal file
82
mim/_internal/export/patch_utils/patch_model.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmengine.registry import MODELS
|
||||
|
||||
BACKBONES = MODELS
|
||||
NECKS = MODELS
|
||||
HEADS = MODELS
|
||||
LOSSES = MODELS
|
||||
SEGMENTORS = MODELS
|
||||
|
||||
|
||||
def build_backbone(cfg):
|
||||
"""Build backbone."""
|
||||
warnings.warn('``build_backbone`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return BACKBONES.build(cfg)
|
||||
|
||||
|
||||
def build_neck(cfg):
|
||||
"""Build neck."""
|
||||
warnings.warn('``build_neck`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return NECKS.build(cfg)
|
||||
|
||||
|
||||
def build_head(cfg):
|
||||
"""Build head."""
|
||||
warnings.warn('``build_head`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return HEADS.build(cfg)
|
||||
|
||||
|
||||
def build_loss(cfg):
|
||||
"""Build loss."""
|
||||
warnings.warn('``build_loss`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return LOSSES.build(cfg)
|
||||
|
||||
|
||||
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
|
||||
"""Build segmentor."""
|
||||
if train_cfg is not None or test_cfg is not None:
|
||||
warnings.warn(
|
||||
'train_cfg and test_cfg is deprecated, '
|
||||
'please specify them in model', UserWarning)
|
||||
assert cfg.get('train_cfg') is None or train_cfg is None, \
|
||||
'train_cfg specified in both outer field and model field '
|
||||
assert cfg.get('test_cfg') is None or test_cfg is None, \
|
||||
'test_cfg specified in both outer field and model field '
|
||||
return SEGMENTORS.build(
|
||||
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
|
||||
|
||||
|
||||
CLASSIFIERS = MODELS
|
||||
RETRIEVER = MODELS
|
||||
|
||||
|
||||
def build_classifier(cfg):
|
||||
"""Build classifier."""
|
||||
return CLASSIFIERS.build(cfg)
|
||||
|
||||
|
||||
def build_retriever(cfg):
|
||||
"""Build retriever."""
|
||||
return RETRIEVER.build(cfg)
|
||||
|
||||
|
||||
POSE_ESTIMATORS = MODELS
|
||||
|
||||
|
||||
def build_pose_estimator(cfg):
|
||||
"""Build pose estimator."""
|
||||
return POSE_ESTIMATORS.build(cfg)
|
||||
|
||||
|
||||
def build_posenet(cfg):
|
||||
"""Build posenet."""
|
||||
warnings.warn(
|
||||
'``build_posenet`` will be deprecated soon, '
|
||||
'please use ``build_pose_estimator`` instead.', DeprecationWarning)
|
||||
return build_pose_estimator(cfg)
|
73
mim/_internal/export/patch_utils/patch_task.py
Normal file
73
mim/_internal/export/patch_utils/patch_task.py
Normal file
@ -0,0 +1,73 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmengine.registry import TASK_UTILS
|
||||
|
||||
PRIOR_GENERATORS = TASK_UTILS
|
||||
ANCHOR_GENERATORS = TASK_UTILS
|
||||
BBOX_ASSIGNERS = TASK_UTILS
|
||||
BBOX_SAMPLERS = TASK_UTILS
|
||||
BBOX_CODERS = TASK_UTILS
|
||||
MATCH_COSTS = TASK_UTILS
|
||||
IOU_CALCULATORS = TASK_UTILS
|
||||
|
||||
|
||||
def build_bbox_coder(cfg, **default_args):
|
||||
"""Builder of box coder."""
|
||||
warnings.warn('``build_sampler`` would be deprecated soon, please use '
|
||||
'``mmdet.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
|
||||
|
||||
def build_iou_calculator(cfg, default_args=None):
|
||||
"""Builder of IoU calculator."""
|
||||
warnings.warn(
|
||||
'``build_iou_calculator`` would be deprecated soon, please use '
|
||||
'``mmdet.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
|
||||
|
||||
def build_match_cost(cfg, default_args=None):
|
||||
"""Builder of IoU calculator."""
|
||||
warnings.warn('``build_match_cost`` would be deprecated soon, please use '
|
||||
'``mmdet.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
|
||||
|
||||
def build_assigner(cfg, **default_args):
|
||||
"""Builder of box assigner."""
|
||||
warnings.warn('``build_assigner`` would be deprecated soon, please use '
|
||||
'``mmdet.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
|
||||
|
||||
def build_sampler(cfg, **default_args):
|
||||
"""Builder of box sampler."""
|
||||
warnings.warn('``build_sampler`` would be deprecated soon, please use '
|
||||
'``mmdet.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
|
||||
|
||||
def build_prior_generator(cfg, default_args=None):
|
||||
warnings.warn(
|
||||
'``build_prior_generator`` would be deprecated soon, please use '
|
||||
'``mmdet.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
|
||||
|
||||
def build_anchor_generator(cfg, default_args=None):
|
||||
warnings.warn(
|
||||
'``build_anchor_generator`` would be deprecated soon, please use '
|
||||
'``mmdet.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
|
||||
|
||||
PIXEL_SAMPLERS = TASK_UTILS
|
||||
|
||||
|
||||
def build_pixel_sampler(cfg, **default_args):
|
||||
"""Build pixel sampler for segmentation map."""
|
||||
warnings.warn(
|
||||
'``build_pixel_sampler`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
555
mim/_internal/export/utils.py
Normal file
555
mim/_internal/export/utils.py
Normal file
@ -0,0 +1,555 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import ast
|
||||
import copy
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
import torch.nn
|
||||
from mmengine import mkdir_or_exist
|
||||
from mmengine.config import ConfigDict
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.model import (
|
||||
BaseDataPreprocessor,
|
||||
BaseModel,
|
||||
BaseModule,
|
||||
ImgDataPreprocessor,
|
||||
)
|
||||
from mmengine.registry import Registry
|
||||
from yapf.yapflib.yapf_api import FormatCode
|
||||
|
||||
from mim.utils import OFFICIAL_MODULES
|
||||
from .common import REGISTRY_TYPES
|
||||
from .flatten_func import * # noqa: F403, F401
|
||||
from .flatten_func import (
|
||||
ImportResolverTransformer,
|
||||
RegisterModuleTransformer,
|
||||
flatten_inheritance_chain,
|
||||
ignore_ast_docstring,
|
||||
postprocess_super_call,
|
||||
)
|
||||
|
||||
|
||||
def format_code(code_text: str):
|
||||
"""Format the code text with yapf."""
|
||||
yapf_style = dict(
|
||||
based_on_style='pep8',
|
||||
blank_line_before_nested_class_or_def=True,
|
||||
split_before_expression_after_opening_paren=True)
|
||||
try:
|
||||
code_text, _ = FormatCode(code_text, style_config=yapf_style)
|
||||
except: # noqa: E722
|
||||
raise SyntaxError('Failed to format the config file, please '
|
||||
f'check the syntax of: \n{code_text}')
|
||||
|
||||
return code_text
|
||||
|
||||
|
||||
def _postprocess_registry_locations(export_root_dir: str):
|
||||
"""Remove the Registry.locations if it doesn't exist.
|
||||
|
||||
Check the location path for Registry to load modules if the path hasn't
|
||||
been exported, then need to be removed. Finally will use the root Registry
|
||||
to find module until it actually doesn't exist.
|
||||
"""
|
||||
export_module_dir = osp.join(export_root_dir, 'pack')
|
||||
|
||||
with open(
|
||||
osp.join(export_module_dir, 'registry.py'), encoding='utf-8') as f:
|
||||
ast_tree = ast.parse(f.read())
|
||||
|
||||
for node in ast.walk(ast_tree):
|
||||
"""node structure.
|
||||
|
||||
Assign( targets=[ Name(id='EVALUATORS', ctx=Store())],
|
||||
value=Call( func=Name(id='Registry', ctx=Load()), args=[
|
||||
Constant(value='evaluator')], keywords=[ keyword( arg='parent',
|
||||
value=Name(id='MMENGINE_EVALUATOR', ctx=Load())), keyword(
|
||||
arg='locations', value=List( elts=[
|
||||
Constant(value='pack.evaluation')], ctx=Load()))])),
|
||||
"""
|
||||
if isinstance(node, ast.Call):
|
||||
need_to_be_remove = None
|
||||
|
||||
for keyword in node.keywords:
|
||||
if keyword.arg == 'locations':
|
||||
for sub_node in ast.walk(keyword):
|
||||
|
||||
# the locations of Registry already transfer to `pack`
|
||||
# scope before. if the location path is exist, then
|
||||
# turn to pack scope
|
||||
if isinstance(
|
||||
sub_node,
|
||||
ast.Constant) and 'pack' in sub_node.value:
|
||||
|
||||
path = sub_node.value
|
||||
if not osp.exists(
|
||||
osp.join(export_root_dir, path).replace(
|
||||
'.', osp.sep)):
|
||||
print_log(
|
||||
'[ Pass ] Remove Registry.locations '
|
||||
f"'{osp.join(export_root_dir, path).replace('.',osp.sep)}', " # noqa: E501
|
||||
'which is no need to export.',
|
||||
logger='export',
|
||||
level=logging.DEBUG)
|
||||
need_to_be_remove = keyword
|
||||
break
|
||||
|
||||
if need_to_be_remove is not None:
|
||||
break
|
||||
|
||||
if need_to_be_remove is not None:
|
||||
node.keywords.remove(need_to_be_remove)
|
||||
|
||||
with open(
|
||||
osp.join(export_module_dir, 'registry.py'), 'w',
|
||||
encoding='utf-8') as f:
|
||||
f.write(format_code(ast.unparse(ast_tree)))
|
||||
|
||||
|
||||
def _get_all_files(directory: str):
|
||||
"""Get all files of the directory.
|
||||
|
||||
Args:
|
||||
directory (str): The directory path.
|
||||
|
||||
Returns:
|
||||
List: Return the a list containing all the files in the directory.
|
||||
"""
|
||||
file_paths = []
|
||||
for root, dirs, files in os.walk(directory):
|
||||
for file in files:
|
||||
if '__init__' not in file and 'registry.py' not in file:
|
||||
file_paths.append(os.path.join(root, file))
|
||||
|
||||
return file_paths
|
||||
|
||||
|
||||
def _postprocess_importfrom_module_to_pack(file_path: str):
|
||||
"""Transfer the importfrom path from "downstream repo" to export module.
|
||||
|
||||
Args:
|
||||
file_path (str): The path of file needed to be transfer.
|
||||
|
||||
Examples:
|
||||
>>> from mmdet.models.detectors.two_stage import TwoStageDetector
|
||||
>>> # transfer to below, if "TwoStageDetector" had been exported
|
||||
>>> from pack.models.detectors.two_stage import TwoStageDetector
|
||||
"""
|
||||
from mmengine import Registry
|
||||
|
||||
# _module_path_dict is a class attribute,
|
||||
# already record all the exported module and their path before
|
||||
_module_path_dict = Registry._module_path_dict
|
||||
|
||||
with open(file_path, encoding='utf-8') as f:
|
||||
ast_tree = ast.parse(f.read())
|
||||
|
||||
# if the import module have the same name with the object in these file,
|
||||
# they import path won't be change
|
||||
can_not_change_module = []
|
||||
for node in ast_tree.body:
|
||||
if isinstance(node, ast.FunctionDef) or isinstance(node, ast.ClassDef):
|
||||
can_not_change_module.append(node.name)
|
||||
|
||||
def check_change_importfrom_node(node: ast.ImportFrom):
|
||||
"""Check if the ImportFrom node should be changed.
|
||||
|
||||
If the modules in node had already been exported, they will be
|
||||
separated and compose a new ast.ImportFrom node with the export
|
||||
path as the module path.
|
||||
|
||||
Args:
|
||||
node (ast.ImportFrom): ImportFrom node.
|
||||
|
||||
Returns:
|
||||
ast.ImportFrom | None: Return a new ast.ImportFrom node
|
||||
if one of the module in node had been export else ``None``.
|
||||
"""
|
||||
export_module_path = None
|
||||
needed_change_alias = []
|
||||
|
||||
for alias in node.names:
|
||||
# if the import module's name is equal to the class or function
|
||||
# name, it can not be transfer for avoiding circular import.
|
||||
if alias.name in _module_path_dict.keys(
|
||||
) and alias.name not in can_not_change_module:
|
||||
|
||||
if export_module_path is None:
|
||||
export_module_path = _module_path_dict[alias.name]
|
||||
else:
|
||||
assert _module_path_dict[alias.name] == \
|
||||
export_module_path,\
|
||||
'There are two module from the same downstream repo,'\
|
||||
" but can't change to the same export path."
|
||||
|
||||
needed_change_alias.append(alias)
|
||||
|
||||
if len(needed_change_alias) != 0:
|
||||
for alias in needed_change_alias:
|
||||
node.names.remove(alias)
|
||||
|
||||
return ast.ImportFrom(
|
||||
module=export_module_path, names=needed_change_alias, level=0)
|
||||
|
||||
return None
|
||||
|
||||
# Naming rules for searching ast syntax tree
|
||||
# - node: node of ast.Module
|
||||
# - func_sub_node: sub_node of ast.FunctionDef
|
||||
# - class_sub_node: sub_node of ast.ClassDef
|
||||
# - func_sub_class_sub_node: sub_node ast.FunctionDef in ast.ClassDef
|
||||
|
||||
# record the insert_idx and node needed to be insert for later insert.
|
||||
insert_idx_and_node = {}
|
||||
|
||||
insert_idx = 0
|
||||
for idx, node in enumerate(ast_tree.body):
|
||||
|
||||
# search ast.ImportFrom in ast.Module scope
|
||||
# ast.Module -> ast.ImportFrom
|
||||
if isinstance(node, ast.ImportFrom):
|
||||
insert_idx += 1
|
||||
temp_node = check_change_importfrom_node(node)
|
||||
if temp_node is not None:
|
||||
if len(node.names) == 0:
|
||||
ast_tree.body[idx] = temp_node
|
||||
else:
|
||||
insert_idx_and_node[insert_idx] = temp_node
|
||||
insert_idx += 1
|
||||
|
||||
elif isinstance(node, ast.Import):
|
||||
insert_idx += 1
|
||||
|
||||
else:
|
||||
# search ast.ImportFrom in ast.FunctionDef scope
|
||||
# ast.Module -> ast.FunctionDef -> ast.ImportFrom
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
temp_func_insert_idx = ignore_ast_docstring(node)
|
||||
func_need_to_be_removed_nodes = []
|
||||
|
||||
for func_sub_node in node.body:
|
||||
if isinstance(func_sub_node, ast.ImportFrom):
|
||||
temp_node = check_change_importfrom_node(
|
||||
func_sub_node) # noqa: E501
|
||||
if temp_node is not None:
|
||||
node.body.insert(temp_func_insert_idx, temp_node)
|
||||
|
||||
# if importfrom module is empty, the node should be remove # noqa: E501
|
||||
if len(func_sub_node.names) == 0:
|
||||
func_need_to_be_removed_nodes.append(
|
||||
func_sub_node) # noqa: E501
|
||||
|
||||
for need_to_be_removed_node in func_need_to_be_removed_nodes:
|
||||
node.body.remove(need_to_be_removed_node)
|
||||
|
||||
# search ast.ImportFrom in ast.ClassDef scope
|
||||
# ast.Module -> ast.ClassDef -> ast.ImportFrom
|
||||
# -> ast.FunctionDef -> ast.ImportFrom
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
temp_class_insert_idx = ignore_ast_docstring(node)
|
||||
class_need_to_be_removed_nodes = []
|
||||
|
||||
for class_sub_node in node.body:
|
||||
|
||||
# ast.Module -> ast.ClassDef -> ast.ImportFrom
|
||||
if isinstance(class_sub_node, ast.ImportFrom):
|
||||
temp_node = check_change_importfrom_node(
|
||||
class_sub_node)
|
||||
if temp_node is not None:
|
||||
node.body.insert(temp_class_insert_idx, temp_node)
|
||||
if len(class_sub_node.names) == 0:
|
||||
class_need_to_be_removed_nodes.append(
|
||||
class_sub_node)
|
||||
|
||||
# ast.Module -> ast.ClassDef -> ast.FunctionDef -> ast.ImportFrom # noqa: E501
|
||||
elif isinstance(class_sub_node, ast.FunctionDef):
|
||||
temp_class_sub_insert_idx = ignore_ast_docstring(node)
|
||||
func_need_to_be_removed_nodes = []
|
||||
|
||||
for func_sub_class_sub_node in class_sub_node.body:
|
||||
if isinstance(func_sub_class_sub_node,
|
||||
ast.ImportFrom):
|
||||
temp_node = check_change_importfrom_node(
|
||||
func_sub_class_sub_node)
|
||||
if temp_node is not None:
|
||||
node.body.insert(temp_class_sub_insert_idx,
|
||||
temp_node)
|
||||
if len(func_sub_class_sub_node.names) == 0:
|
||||
func_need_to_be_removed_nodes.append(
|
||||
func_sub_class_sub_node)
|
||||
|
||||
for need_to_be_removed_node in func_need_to_be_removed_nodes: # noqa: E501
|
||||
class_sub_node.body.remove(need_to_be_removed_node)
|
||||
|
||||
for class_need_to_be_removed_node in class_need_to_be_removed_nodes: # noqa: E501
|
||||
node.body.remove(class_need_to_be_removed_node)
|
||||
|
||||
# lazy add new ast.ImportFrom node to ast.Module
|
||||
for insert_idx, temp_node in insert_idx_and_node.items():
|
||||
ast_tree.body.insert(insert_idx, temp_node)
|
||||
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(format_code(ast.unparse(ast_tree)))
|
||||
|
||||
|
||||
def _replace_config_scope_to_pack(cfg: ConfigDict):
|
||||
"""Replace the config scope from "mmxxx" to "pack".
|
||||
|
||||
Args:
|
||||
cfg (ConfigDict): The config dict to be replaced.
|
||||
"""
|
||||
|
||||
for key, value in cfg.items():
|
||||
if key == '_scope_' or key == 'default_scope':
|
||||
cfg[key] = 'pack'
|
||||
elif isinstance(value, dict):
|
||||
_replace_config_scope_to_pack(value)
|
||||
|
||||
|
||||
def _wrapper_all_registries_build_func(export_module_dir: str, scope: str):
|
||||
"""A function to wrap all registries' build_func.
|
||||
|
||||
Args:
|
||||
pack_module_dir (str): The root dir for packing modules.
|
||||
scope (str): The default scope of the config.
|
||||
"""
|
||||
# copy the downstream repo.registry to pack.registry
|
||||
# and change all the registry.locations
|
||||
repo_registries = importlib.import_module('.registry', scope)
|
||||
origin_file = inspect.getfile(repo_registries)
|
||||
registry_path = osp.join(export_module_dir, 'registry.py')
|
||||
shutil.copy(origin_file, registry_path)
|
||||
|
||||
# replace 'repo' name in Registry.locations to 'pack'
|
||||
with open(
|
||||
osp.join(export_module_dir, 'registry.py'), encoding='utf-8') as f:
|
||||
ast_tree = ast.parse(f.read())
|
||||
for node in ast.walk(ast_tree):
|
||||
if isinstance(node, ast.Constant):
|
||||
if scope in node.value:
|
||||
node.value = node.value.replace(scope, 'pack')
|
||||
|
||||
with open(osp.join(export_module_dir, 'registry.py'), 'w') as f:
|
||||
f.write(format_code(ast.unparse(ast_tree)))
|
||||
|
||||
# prevent circular registration
|
||||
Registry._extra_module_set = set()
|
||||
|
||||
# record the exported module for postprocessing the importfrom path
|
||||
Registry._module_path_dict = {}
|
||||
|
||||
# prevent circular wrapper
|
||||
if Registry.build.__name__ == 'wrapper':
|
||||
Registry.build = _wrap_build(Registry.init_build_func,
|
||||
export_module_dir)
|
||||
Registry.get = _wrap_get(Registry.init_get_func, export_module_dir)
|
||||
else:
|
||||
Registry.init_build_func = copy.deepcopy(Registry.build)
|
||||
Registry.init_get_func = copy.deepcopy(Registry.get)
|
||||
Registry.build = _wrap_build(Registry.build, export_module_dir)
|
||||
Registry.get = _wrap_get(Registry.get, export_module_dir)
|
||||
|
||||
|
||||
def ignore_self_cache(func):
|
||||
"""Ignore the ``@lru_cache`` for function.
|
||||
|
||||
Args:
|
||||
func (Callable): The function to be ignored.
|
||||
|
||||
Returns:
|
||||
Callable: The function without ``@lru_cache``.
|
||||
"""
|
||||
cache = {}
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
key = args
|
||||
if key not in cache:
|
||||
cache[key] = 1
|
||||
func(self, *args, **kwargs)
|
||||
else:
|
||||
return
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@ignore_self_cache
|
||||
def _export_module(self, obj_cls: type, pack_module_dir, obj_type: str):
|
||||
"""Export module.
|
||||
|
||||
This function will get the object's file and export to
|
||||
``pack_module_dir``.
|
||||
|
||||
If the object is built by ``MODELS`` registry, all the objects
|
||||
as the top classes in this file, will be iteratively flattened.
|
||||
Else will be directly exported.
|
||||
|
||||
The flatten logic is:
|
||||
1. get the origin file of object, which built
|
||||
by ``MODELS.build()``
|
||||
2. get all the classes in the origin file
|
||||
3. flatten all the classes but not only the object
|
||||
4. call ``flatten_module()`` to finish flatten
|
||||
according to ``class.mro()``
|
||||
|
||||
Args:
|
||||
obj (object): The object to be flatten.
|
||||
"""
|
||||
# find the file by obj class
|
||||
file_path = inspect.getfile(obj_cls)
|
||||
|
||||
if osp.exists(file_path):
|
||||
print_log(
|
||||
f'building class: '
|
||||
f'{obj_cls.__name__} from file: {file_path}.',
|
||||
logger='export',
|
||||
level=logging.DEBUG)
|
||||
else:
|
||||
raise FileExistsError(f"file [{file_path}] doesn't exist.")
|
||||
|
||||
# local origin module
|
||||
module = obj_cls.__module__
|
||||
parent = module.split('.')[0]
|
||||
new_module = module.replace(parent, 'pack')
|
||||
|
||||
# Not necessary to export module implemented in `mmcv` and `mmengine`
|
||||
if parent in set(OFFICIAL_MODULES) - {'mmcv', 'mmengine'}:
|
||||
|
||||
with open(file_path, encoding='utf-8') as f:
|
||||
top_ast_tree = ast.parse(f.read())
|
||||
|
||||
# deal with relative import
|
||||
ImportResolverTransformer(module).visit(top_ast_tree)
|
||||
|
||||
# NOTE: ``MODELS.build()`` means to flatten model module
|
||||
if self.name == 'model':
|
||||
|
||||
# record all the class needed to be flattened
|
||||
need_to_be_flattened_class_names = []
|
||||
for node in top_ast_tree.body:
|
||||
if isinstance(node, ast.ClassDef):
|
||||
need_to_be_flattened_class_names.append(node.name)
|
||||
|
||||
imported_module = importlib.import_module(obj_cls.__module__)
|
||||
for cls_name in need_to_be_flattened_class_names:
|
||||
|
||||
# record the exported module for postprocessing the importfrom path # noqa: E501
|
||||
self._module_path_dict[cls_name] = new_module
|
||||
|
||||
cls = getattr(imported_module, cls_name)
|
||||
|
||||
for super_cls in cls.__bases__:
|
||||
|
||||
# the class only will be flattened when:
|
||||
# 1. super class doesn't exist in this file
|
||||
# 2. and super class is not base class
|
||||
# 3. and super class is not torch module
|
||||
if super_cls.__name__\
|
||||
not in need_to_be_flattened_class_names \
|
||||
and (super_cls not in [BaseModule,
|
||||
BaseModel,
|
||||
BaseDataPreprocessor,
|
||||
ImgDataPreprocessor]) \
|
||||
and 'torch' not in super_cls.__module__: # noqa: E501
|
||||
|
||||
print_log(
|
||||
f'need_flatten: {cls_name}\n',
|
||||
logger='export',
|
||||
level=logging.INFO)
|
||||
|
||||
flatten_inheritance_chain(top_ast_tree, cls)
|
||||
break
|
||||
postprocess_super_call(top_ast_tree)
|
||||
|
||||
else:
|
||||
self._module_path_dict[obj_cls.__name__] = new_module
|
||||
|
||||
# add ``register_module(force=True)`` to cover the registered modules # noqa: E501
|
||||
RegisterModuleTransformer().visit(top_ast_tree)
|
||||
|
||||
# unparse ast tree and save reformat code
|
||||
new_file_path = new_module.split('.', 1)[1].replace('.',
|
||||
osp.sep) + '.py'
|
||||
new_file_path = osp.join(pack_module_dir, new_file_path)
|
||||
new_dir = osp.dirname(new_file_path)
|
||||
mkdir_or_exist(new_dir)
|
||||
|
||||
with open(new_file_path, mode='w') as f:
|
||||
f.write(format_code(ast.unparse(top_ast_tree)))
|
||||
|
||||
# Downstream repo could register torch module into Registry, such as
|
||||
# registering `torch.nn.Linear` into `MODELS`. We need to reserve these
|
||||
# codes in the exported module.
|
||||
elif 'torch' in module.split('.')[0]:
|
||||
|
||||
# get the root registry, because it can get all the modules
|
||||
# had been registered.
|
||||
root_registry = self if self.parent is None else self.parent
|
||||
if (obj_type not in self._extra_module_set) and (
|
||||
root_registry.init_get_func(obj_type) is None):
|
||||
self._extra_module_set.add(obj_type)
|
||||
with open(osp.join(pack_module_dir, 'registry.py'), 'a') as f:
|
||||
|
||||
# TODO: When the downstream repo registries' name are
|
||||
# different with mmengine, the module may not be registried
|
||||
# to the right register.
|
||||
# For example: `EVALUATOR` in mmengine, `EVALUATORS` in mmdet.
|
||||
f.write('\n')
|
||||
f.write(f'from {module} import {obj_cls.__name__}\n')
|
||||
f.write(
|
||||
f"{REGISTRY_TYPES[self.name]}.register_module('{obj_type}', module={obj_cls.__name__}, force=True)" # noqa: E501
|
||||
)
|
||||
|
||||
|
||||
def _wrap_build(build_func: Callable, pack_module_dir: str):
|
||||
"""wrap Registry.build()
|
||||
|
||||
Args:
|
||||
build_func (Callable): ``Registry.build()``, which will be wrapped.
|
||||
pack_module_dir (str): Modules export path.
|
||||
"""
|
||||
|
||||
def wrapper(self, cfg: dict, *args, **kwargs):
|
||||
|
||||
# obj is class instanace
|
||||
obj = build_func(self, cfg, *args, **kwargs)
|
||||
args = cfg.copy() # type: ignore
|
||||
obj_type = args.pop('type') # type: ignore
|
||||
obj_type = obj_type if isinstance(obj_type, str) else obj_type.__name__
|
||||
|
||||
# modules in ``torch.nn.Sequential`` should be respectively exported
|
||||
if isinstance(obj, torch.nn.Sequential):
|
||||
for children in obj.children():
|
||||
_export_module(self, children.__class__, pack_module_dir,
|
||||
obj_type)
|
||||
else:
|
||||
_export_module(self, obj.__class__, pack_module_dir, obj_type)
|
||||
|
||||
return obj
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _wrap_get(get_func: Callable, pack_module_dir: str):
|
||||
"""wrap Registry.get()
|
||||
|
||||
Args:
|
||||
get_func (Callable): ``Registry.get()``, which will be wrapped.
|
||||
pack_module_dir (str): Modules export path.
|
||||
"""
|
||||
|
||||
def wrapper(self, key: str):
|
||||
|
||||
obj_cls = get_func(self, key)
|
||||
|
||||
_export_module(self, obj_cls, pack_module_dir, key)
|
||||
|
||||
return obj_cls
|
||||
|
||||
return wrapper
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mim.commands.list import list_package
|
||||
from mim.utils.default import OFFICIAL_MODULES
|
||||
|
||||
|
||||
def get_installed_package(ctx, args, incomplete):
|
||||
@ -19,15 +20,4 @@ def get_downstream_package(ctx, args, incomplete):
|
||||
|
||||
|
||||
def get_official_package(ctx=None, args=None, incomplete=None):
|
||||
return [
|
||||
'mmcls',
|
||||
'mmdet',
|
||||
'mmdet3d',
|
||||
'mmseg',
|
||||
'mmaction2',
|
||||
'mmtrack',
|
||||
'mmpose',
|
||||
'mmedit',
|
||||
'mmocr',
|
||||
'mmgen',
|
||||
]
|
||||
return OFFICIAL_MODULES
|
||||
|
@ -1,9 +1,13 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import List, Optional
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import click
|
||||
import yaml
|
||||
|
||||
from mim.click import (
|
||||
OptionEatAll,
|
||||
@ -14,11 +18,14 @@ from mim.click import (
|
||||
from mim.commands.search import get_model_info
|
||||
from mim.utils import (
|
||||
DEFAULT_CACHE_DIR,
|
||||
call_command,
|
||||
color_echo,
|
||||
download_from_file,
|
||||
echo_success,
|
||||
get_installed_path,
|
||||
highlighted_error,
|
||||
is_installed,
|
||||
module_full_name,
|
||||
split_package_version,
|
||||
)
|
||||
|
||||
@ -33,8 +40,13 @@ from mim.utils import (
|
||||
'--config',
|
||||
'configs',
|
||||
cls=OptionEatAll,
|
||||
required=True,
|
||||
help='Config ids to download, such as resnet18_8xb16_cifar10')
|
||||
help='Config ids to download, such as resnet18_8xb16_cifar10',
|
||||
default=None)
|
||||
@click.option(
|
||||
'--dataset',
|
||||
'dataset',
|
||||
help='dataset name to download, such as coco2017',
|
||||
default=None)
|
||||
@click.option(
|
||||
'--ignore-ssl',
|
||||
'check_certificate',
|
||||
@ -44,7 +56,8 @@ from mim.utils import (
|
||||
@click.option(
|
||||
'--dest', 'dest_root', type=str, help='Destination of saving checkpoints.')
|
||||
def cli(package: str,
|
||||
configs: List[str],
|
||||
configs: Optional[List[str]],
|
||||
dataset: Optional[str],
|
||||
dest_root: Optional[str] = None,
|
||||
check_certificate: bool = True) -> None:
|
||||
"""Download checkpoints from url and parse configs from package.
|
||||
@ -54,31 +67,55 @@ def cli(package: str,
|
||||
> mim download mmcls --config resnet18_8xb16_cifar10
|
||||
> mim download mmcls --config resnet18_8xb16_cifar10 --dest .
|
||||
"""
|
||||
download(package, configs, dest_root, check_certificate)
|
||||
download(package, configs, dest_root, check_certificate, dataset)
|
||||
|
||||
|
||||
def download(package: str,
|
||||
configs: List[str],
|
||||
configs: Optional[List[str]] = None,
|
||||
dest_root: Optional[str] = None,
|
||||
check_certificate: bool = True) -> List[str]:
|
||||
check_certificate: bool = True,
|
||||
dataset: Optional[str] = None) -> Union[List[str], None]:
|
||||
"""Download checkpoints from url and parse configs from package.
|
||||
|
||||
Args:
|
||||
package (str): Name of package.
|
||||
configs (List[str]): List of config ids.
|
||||
dest_root (Optional[str]): Destination directory to save checkpoint and
|
||||
configs (List[str], optional): List of config ids.
|
||||
dest_root (str, optional): Destination directory to save checkpoint and
|
||||
config. Default: None.
|
||||
check_certificate (bool): Whether to check the ssl certificate.
|
||||
Default: True.
|
||||
dataset (str, optional): The name of dataset.
|
||||
"""
|
||||
full_name = module_full_name(package)
|
||||
if full_name == '':
|
||||
msg = f"Can't determine a unique package given abbreviation {package}"
|
||||
raise ValueError(highlighted_error(msg))
|
||||
package = full_name
|
||||
|
||||
if dest_root is None:
|
||||
dest_root = DEFAULT_CACHE_DIR
|
||||
|
||||
dest_root = osp.abspath(dest_root)
|
||||
|
||||
if configs is not None and dataset is not None:
|
||||
raise ValueError(
|
||||
'Cannot download config and dataset at the same time!')
|
||||
if configs is None and dataset is None:
|
||||
raise ValueError('Please specify config or dataset to download!')
|
||||
|
||||
if configs is not None:
|
||||
return _download_configs(package, configs, dest_root,
|
||||
check_certificate)
|
||||
else:
|
||||
return _download_dataset(package, dataset, dest_root) # type: ignore
|
||||
|
||||
|
||||
def _download_configs(package: str,
|
||||
configs: List[str],
|
||||
dest_root: str,
|
||||
check_certificate: bool = True) -> List[str]:
|
||||
# Create the destination directory if it does not exist.
|
||||
if not osp.exists(dest_root):
|
||||
os.makedirs(dest_root)
|
||||
os.makedirs(dest_root, exist_ok=True)
|
||||
|
||||
package, version = split_package_version(package)
|
||||
if version:
|
||||
@ -152,3 +189,87 @@ def download(package: str,
|
||||
highlighted_error(f'{config_path} is not found.'))
|
||||
|
||||
return checkpoints
|
||||
|
||||
|
||||
def _download_dataset(package: str, dataset: str, dest_root: str) -> None:
|
||||
if platform.system() != 'Linux':
|
||||
raise RuntimeError('downloading dataset is only supported in Linux!')
|
||||
|
||||
if not is_installed(package):
|
||||
raise RuntimeError(
|
||||
f'Please install {package} by `pip install {package}`')
|
||||
|
||||
installed_path = get_installed_path(package)
|
||||
mim_path = osp.join(installed_path, '.mim')
|
||||
dataset_index_path = osp.join(mim_path, 'dataset-index.yml')
|
||||
|
||||
if not osp.exists(dataset_index_path):
|
||||
raise FileNotFoundError(
|
||||
f'Cannot find {dataset_index_path}, '
|
||||
f'please update {package} to the latest version! If you have '
|
||||
f'already updated it and still get this error, please report an '
|
||||
f'issue to {package}')
|
||||
with open(dataset_index_path) as f:
|
||||
dataset_metas = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
|
||||
if dataset not in dataset_metas:
|
||||
raise KeyError(f'Cannot find {dataset} in {dataset_index_path}. '
|
||||
'here are the available datasets: '
|
||||
'{}'.format('\n'.join(dataset_metas.keys())))
|
||||
dataset_meta = dataset_metas[dataset]
|
||||
# OpenMMLab repo will define the `dataset-index.yml` like this:
|
||||
# openxlab: true
|
||||
# voc2007:
|
||||
# dataset: PASCAL_VOC2007
|
||||
# download_root: data
|
||||
# data_root: data
|
||||
# script: tools/dataset_converters/scripts/preprocess_voc2007.sh
|
||||
|
||||
# In this case:
|
||||
# `openxlab` means download the dataset with `openxlab` cli, otherwise
|
||||
# use the `odl` cli. Although `odl` cli will not be maintained in the
|
||||
# future, we still keep it here for compatibility.
|
||||
|
||||
# The top level key "voc2007" means the "Dataset Name" passed
|
||||
# to `mim download --dataset {Dataset Name}`
|
||||
|
||||
# The nested field "dataset" means the argument passed to `odl get`
|
||||
# If the value of "dataset" is the same as the "Dataset Name", downstream
|
||||
# repos can skip defining "dataset" and "Dataset Name" will be passed
|
||||
# to `odl get`
|
||||
|
||||
use_openxlab = dataset_metas.get('openxlab', False)
|
||||
src_name = dataset_meta.get('dataset', dataset)
|
||||
# `odl get` will download raw dataset to `download_root`, and the script
|
||||
# will process the raws data and put the prepared data to the `data_root`
|
||||
download_root = dataset_meta['download_root']
|
||||
os.makedirs(download_root, exist_ok=True)
|
||||
|
||||
color_echo(f'Start downloading {dataset} to {download_root}...', 'blue')
|
||||
if use_openxlab:
|
||||
subprocess.check_call(
|
||||
['openxlab', 'dataset', 'get', src_name, '-d', download_root],
|
||||
stdin=sys.stdin,
|
||||
stdout=sys.stdout)
|
||||
else:
|
||||
subprocess.check_call(['odl', 'get', src_name, '-d', download_root],
|
||||
stdin=sys.stdin,
|
||||
stdout=sys.stdout)
|
||||
|
||||
if not osp.exists(download_root):
|
||||
return
|
||||
|
||||
script_path = dataset_meta.get('script')
|
||||
if script_path is None:
|
||||
return
|
||||
|
||||
script_path = osp.join(mim_path, script_path)
|
||||
color_echo('Preprocess data ...', 'blue')
|
||||
if dest_root == osp.abspath(DEFAULT_CACHE_DIR):
|
||||
data_root = dataset_meta['data_root']
|
||||
else:
|
||||
data_root = dest_root
|
||||
os.makedirs(data_root, exist_ok=True)
|
||||
call_command(['chmod', '+x', script_path])
|
||||
call_command([script_path, download_root, data_root])
|
||||
echo_success('Finished!')
|
||||
|
124
mim/commands/export.py
Normal file
124
mim/commands/export.py
Normal file
@ -0,0 +1,124 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import sys
|
||||
|
||||
import click
|
||||
from mmengine.config import Config
|
||||
from mmengine.hub import get_config
|
||||
|
||||
from mim._internal.export.pack_cfg import export_from_cfg
|
||||
from mim.click import CustomCommand
|
||||
|
||||
PYTHON = sys.executable
|
||||
|
||||
|
||||
@click.command(
|
||||
name='export',
|
||||
context_settings=dict(ignore_unknown_options=True),
|
||||
cls=CustomCommand)
|
||||
@click.argument('package', type=str)
|
||||
@click.argument('config', type=str)
|
||||
@click.argument('export_dir', type=str, default=None, required=False)
|
||||
@click.option(
|
||||
'-f',
|
||||
'--fast-test',
|
||||
is_flag=True,
|
||||
help='The fast_test mode. In order to quickly test if'
|
||||
' there is any error in the export package,'
|
||||
' it only use the first two data of your datasets'
|
||||
' which only be used to train 2 iters/epoches.')
|
||||
@click.option(
|
||||
'--save-log',
|
||||
is_flag=True,
|
||||
help='The flag to keep the export log of the process. The log of export'
|
||||
" process will be save to directory 'export_log'. Default will"
|
||||
' automatically delete the log after export.')
|
||||
def cli(config: str,
|
||||
package: str,
|
||||
export_dir: str,
|
||||
fast_test: bool = False,
|
||||
save_log: bool = False) -> None:
|
||||
"""Export package from config file.
|
||||
|
||||
Example:
|
||||
|
||||
\b
|
||||
>>> # Export package from downstream config file.
|
||||
>>> mim export mmdet dab_detr/dab-detr_r50_8xb2-50e_coco.py \\
|
||||
... dab_detr
|
||||
>>>
|
||||
>>> # Export package from specified config file.
|
||||
>>> mim export mmdet mmdetection/configs/dab_detr/dab-detr_r50_8xb2-50e_coco.py dab_detr # noqa: E501
|
||||
>>>
|
||||
>>> # It can auto generate a export dir when not specified.
|
||||
>>> # like:'pack_from_mmdet_20231026_052704.
|
||||
>>> mim export mmdet dab_detr/dab-detr_r50_8xb2-50e_coco.py
|
||||
>>>
|
||||
>>> # Only export the model of config file.
|
||||
>>> mim export mmdet dab_detr/dab-detr_r50_8xb2-50e_coco.py \\
|
||||
... mask_rcnn_package --model-only
|
||||
>>>
|
||||
>>> # Keep the export log of the process.
|
||||
>>> mim export mmdet dab_detr/dab-detr_r50_8xb2-50e_coco.py \\
|
||||
... mask_rcnn_package --save-log
|
||||
>>>
|
||||
>>> # Print the help information of export command.
|
||||
>>> mim export -h
|
||||
"""
|
||||
|
||||
# get config
|
||||
if osp.exists(config):
|
||||
config = Config.fromfile(config) # from local
|
||||
else:
|
||||
try:
|
||||
config = get_config(package + '::' + config) # from downstream
|
||||
except Exception:
|
||||
raise FileNotFoundError(
|
||||
f"Config file '{config}' or '{package + '::' + config}'.")
|
||||
|
||||
fast_test_mode(config, fast_test)
|
||||
|
||||
export_from_cfg(config, export_root_dir=export_dir, save_log=save_log)
|
||||
|
||||
|
||||
def fast_test_mode(cfg, fast_test: bool = False):
|
||||
"""Use less data for faster testing.
|
||||
|
||||
Args:
|
||||
cfg (Config): Config of export package.
|
||||
fast_test (bool, optional): Fast testing mode. Defaults to False.
|
||||
"""
|
||||
if fast_test:
|
||||
# for batch_norm using at least 2 data
|
||||
if 'dataset' in cfg.train_dataloader.dataset:
|
||||
cfg.train_dataloader.dataset.dataset.indices = [0, 1]
|
||||
else:
|
||||
cfg.train_dataloader.dataset.indices = [0, 1]
|
||||
cfg.train_dataloader.batch_size = 2
|
||||
|
||||
if cfg.get('test_dataloader') is not None:
|
||||
cfg.test_dataloader.dataset.indices = [0, 1]
|
||||
cfg.test_dataloader.batch_size = 2
|
||||
|
||||
if cfg.get('val_dataloader') is not None:
|
||||
cfg.val_dataloader.dataset.indices = [0, 1]
|
||||
cfg.val_dataloader.batch_size = 2
|
||||
|
||||
if (cfg.train_cfg.get('type') == 'IterBasedTrainLoop') \
|
||||
or (cfg.train_cfg.get('by_epoch') is None
|
||||
and cfg.train_cfg.get('type') != 'EpochBasedTrainLoop'):
|
||||
cfg.train_cfg.max_iters = 2
|
||||
else:
|
||||
cfg.train_cfg.max_epochs = 2
|
||||
|
||||
cfg.train_cfg.val_interval = 1
|
||||
cfg.default_hooks.logger.interval = 1
|
||||
|
||||
if 'param_scheduler' in cfg and cfg.param_scheduler is not None:
|
||||
if isinstance(cfg.param_scheduler, list):
|
||||
for lr_sc in cfg.param_scheduler:
|
||||
lr_sc.begin = 0
|
||||
lr_sc.end = 2
|
||||
else:
|
||||
cfg.param_scheduler.begin = 0
|
||||
cfg.param_scheduler.end = 2
|
@ -17,7 +17,8 @@ from mim.utils import (
|
||||
DEFAULT_MMCV_BASE_URL,
|
||||
PKG2PROJECT,
|
||||
echo_warning,
|
||||
get_torch_cuda_version,
|
||||
exit_with_error,
|
||||
get_torch_device_version,
|
||||
)
|
||||
|
||||
|
||||
@ -160,16 +161,20 @@ def get_mmcv_full_find_link(mmcv_base_url: str) -> str:
|
||||
|
||||
Returns:
|
||||
str: The mmcv find links corresponding to the current torch version and
|
||||
cuda version.
|
||||
CUDA/NPU version.
|
||||
"""
|
||||
torch_v, cuda_v = get_torch_cuda_version()
|
||||
torch_v, device, device_v = get_torch_device_version()
|
||||
major, minor, *_ = torch_v.split('.')
|
||||
torch_v = '.'.join([major, minor, '0'])
|
||||
|
||||
if cuda_v.isdigit():
|
||||
cuda_v = f'cu{cuda_v}'
|
||||
if device == 'cuda' and device_v.isdigit():
|
||||
device_link = f'cu{device_v}'
|
||||
elif device == 'ascend':
|
||||
device_link = f'ascend{device_v}'
|
||||
else:
|
||||
device_link = 'cpu'
|
||||
|
||||
find_link = f'{mmcv_base_url}/mmcv/dist/{cuda_v}/torch{torch_v}/index.html' # noqa: E501
|
||||
find_link = f'{mmcv_base_url}/mmcv/dist/{device_link}/torch{torch_v}/index.html' # noqa: E501
|
||||
return find_link
|
||||
|
||||
|
||||
@ -244,7 +249,7 @@ def patch_importlib_distribution(index_url: Optional[str] = None) -> Generator:
|
||||
if self.canonical_name not in PKG2PROJECT or self.canonical_name == 'mmcv-full': # noqa: E501
|
||||
return deps
|
||||
|
||||
if 'mim' in self.iter_provided_extras:
|
||||
if 'mim' in self.iter_provided_extras():
|
||||
mim_extra_requires = list(
|
||||
origin_iter_dependencies(self, ('mim', )))
|
||||
filter_invalid_marker(mim_extra_requires)
|
||||
|
@ -1,8 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
import typing
|
||||
from typing import Any, List, Optional
|
||||
@ -11,6 +10,7 @@ import click
|
||||
from modelindex.load_model_index import load
|
||||
from modelindex.models.ModelIndex import ModelIndex
|
||||
from pandas import DataFrame, Series
|
||||
from pip._internal.commands import create_command
|
||||
|
||||
from mim.click import (
|
||||
OptionEatAll,
|
||||
@ -19,15 +19,14 @@ from mim.click import (
|
||||
param2lowercase,
|
||||
)
|
||||
from mim.utils import (
|
||||
DEFAULT_CACHE_DIR,
|
||||
PKG2PROJECT,
|
||||
cast2lowercase,
|
||||
echo_success,
|
||||
get_github_url,
|
||||
echo_warning,
|
||||
extract_tar,
|
||||
get_installed_path,
|
||||
highlighted_error,
|
||||
is_installed,
|
||||
split_package_version,
|
||||
recursively_find,
|
||||
)
|
||||
|
||||
|
||||
@ -71,6 +70,15 @@ from mim.utils import (
|
||||
@click.option('--to-dict', 'to_dict', is_flag=True, help='Return metadata.')
|
||||
@click.option(
|
||||
'--local/--remote', default=True, help='Show local or remote packages.')
|
||||
@click.option(
|
||||
'-i',
|
||||
'--index-url',
|
||||
'--pypi-url',
|
||||
'index_url',
|
||||
help='Base URL of the Python Package Index (default %default). '
|
||||
'This should point to a repository compliant with PEP 503 '
|
||||
'(the simple repository API) or a local directory laid out '
|
||||
'in the same format.')
|
||||
@click.option(
|
||||
'--display-width', type=int, default=80, help='The display width.')
|
||||
def cli(packages: List[str],
|
||||
@ -87,7 +95,8 @@ def cli(packages: List[str],
|
||||
json_path: Optional[str] = None,
|
||||
to_dict: bool = False,
|
||||
local: bool = True,
|
||||
display_width: int = 80) -> Any:
|
||||
display_width: int = 80,
|
||||
index_url: Optional[str] = None) -> Any:
|
||||
"""Show the information of packages.
|
||||
|
||||
\b
|
||||
@ -118,7 +127,8 @@ def cli(packages: List[str],
|
||||
ascending=ascending,
|
||||
shown_fields=shown_fields,
|
||||
unshown_fields=unshown_fields,
|
||||
local=local)
|
||||
local=local,
|
||||
index_url=index_url)
|
||||
|
||||
if to_dict or json_path:
|
||||
packages_info.update(dataframe.to_dict('index')) # type: ignore
|
||||
@ -150,7 +160,8 @@ def get_model_info(package: str,
|
||||
shown_fields: Optional[List[str]] = None,
|
||||
unshown_fields: Optional[List[str]] = None,
|
||||
local: bool = True,
|
||||
to_dict: bool = False) -> Any:
|
||||
to_dict: bool = False,
|
||||
index_url: Optional[str] = None) -> Any:
|
||||
"""Get model information like metric or dataset.
|
||||
|
||||
Args:
|
||||
@ -170,8 +181,11 @@ def get_model_info(package: str,
|
||||
local (bool): Query from local environment or remote github.
|
||||
Default: True.
|
||||
to_dict (bool): Convert dataframe into dict. Default: False.
|
||||
index_url (str, optional): The pypi index url, if given, will be used
|
||||
in ``pip download`` command for downloading packages when local
|
||||
is False. Default: None.
|
||||
"""
|
||||
metadata = load_metadata(package, local)
|
||||
metadata = load_metadata(package, local, index_url)
|
||||
dataframe = convert2df(metadata)
|
||||
dataframe = filter_by_configs(dataframe, configs)
|
||||
dataframe = filter_by_conditions(dataframe, filter_conditions)
|
||||
@ -186,13 +200,18 @@ def get_model_info(package: str,
|
||||
return dataframe
|
||||
|
||||
|
||||
def load_metadata(package: str, local: bool = True) -> Optional[ModelIndex]:
|
||||
def load_metadata(package: str,
|
||||
local: bool = True,
|
||||
index_url: Optional[str] = None) -> Optional[ModelIndex]:
|
||||
"""Load metadata from local package or remote package.
|
||||
|
||||
Args:
|
||||
package (str): Name of package to load metadata.
|
||||
local (bool): Query from local environment or remote github.
|
||||
Default: True.
|
||||
index_url (str, optional): The pypi index url, if given, will be used
|
||||
in ``pip download`` command for downloading packages when local
|
||||
is False. Default: None.
|
||||
"""
|
||||
if '=' in package and local:
|
||||
raise ValueError(
|
||||
@ -203,7 +222,7 @@ def load_metadata(package: str, local: bool = True) -> Optional[ModelIndex]:
|
||||
if local:
|
||||
return load_metadata_from_local(package)
|
||||
else:
|
||||
return load_metadata_from_remote(package)
|
||||
return load_metadata_from_remote(package, index_url)
|
||||
|
||||
|
||||
def load_metadata_from_local(package: str):
|
||||
@ -241,55 +260,62 @@ def load_metadata_from_local(package: str):
|
||||
f'install {package}" or use mim search {package} --remote'))
|
||||
|
||||
|
||||
def load_metadata_from_remote(package: str) -> Optional[ModelIndex]:
|
||||
"""Load metadata from github.
|
||||
def load_metadata_from_remote(package: str,
|
||||
index_url: Optional[str] = None
|
||||
) -> Optional[ModelIndex]:
|
||||
"""Load metadata from PyPI.
|
||||
|
||||
Download the model_zoo directory from github and parse it into metadata.
|
||||
Download the model_zoo directory from PyPI and parse it into metadata.
|
||||
|
||||
Args:
|
||||
package (str): Name of package to load metadata.
|
||||
index_url (str, optional): The pypi index url, if given, will be used
|
||||
in ``pip download`` command for downloading packages.
|
||||
Default: None.
|
||||
|
||||
Example:
|
||||
>>> # load metadata from master branch
|
||||
>>> # load metadata from latest version
|
||||
>>> metadata = load_metadata_from_remote('mmcls')
|
||||
>>> # load metadata from 0.11.0
|
||||
>>> metadata = load_metadata_from_remote('mmcls==0.11.0')
|
||||
"""
|
||||
package, version = split_package_version(package)
|
||||
|
||||
github_url = get_github_url(package)
|
||||
|
||||
pkl_path = osp.join(DEFAULT_CACHE_DIR, f'{package}-{version}.pkl')
|
||||
if osp.exists(pkl_path):
|
||||
with open(pkl_path, 'rb') as fr:
|
||||
return pickle.load(fr)
|
||||
if index_url is not None:
|
||||
click.echo(f'Loading metadata from PyPI ({index_url}) with '
|
||||
'"pip download" command.')
|
||||
else:
|
||||
clone_cmd = ['git', 'clone', github_url]
|
||||
if version:
|
||||
clone_cmd.extend(['-b', f'v{version}'])
|
||||
with tempfile.TemporaryDirectory() as temp:
|
||||
repo_root = osp.join(temp, PKG2PROJECT[package])
|
||||
clone_cmd.append(repo_root)
|
||||
subprocess.check_call(
|
||||
clone_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
|
||||
click.echo('Loading metadata from PyPI with "pip download" command.')
|
||||
|
||||
# rename the model_zoo.yml to model-index.yml but support both of
|
||||
# them for backward compatibility
|
||||
possible_metadata_paths = [
|
||||
osp.join(repo_root, 'model-index.yml'),
|
||||
osp.join(repo_root, 'model_zoo.yml'),
|
||||
]
|
||||
for metadata_path in possible_metadata_paths:
|
||||
if osp.exists(metadata_path):
|
||||
metadata = load(metadata_path)
|
||||
if version:
|
||||
with open(pkl_path, 'wb') as fw:
|
||||
pickle.dump(metadata, fw)
|
||||
return metadata
|
||||
raise FileNotFoundError(
|
||||
highlighted_error(
|
||||
'model-index.yml or model_zoo.yml is not found, please '
|
||||
f'upgrade your {package} to support search command'))
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
download_args = [
|
||||
package, '-d', temp_dir, '--no-deps', '--no-binary', ':all:', '-q'
|
||||
]
|
||||
if index_url is not None:
|
||||
download_args += ['-i', index_url]
|
||||
status_code = create_command('download').main(download_args)
|
||||
if status_code != 0:
|
||||
echo_warning(f'pip download failed with args: {download_args}')
|
||||
exit(status_code)
|
||||
|
||||
# untar the file and get the package directory
|
||||
tar_path = osp.join(temp_dir, os.listdir(temp_dir)[0])
|
||||
extract_tar(tar_path, temp_dir)
|
||||
filename_no_ext = osp.basename(tar_path).rstrip('.tar.gz')
|
||||
package_dir = osp.join(temp_dir, filename_no_ext)
|
||||
|
||||
# rename the model_zoo.yml to model-index.yml but support both of
|
||||
# them for backward compatibility
|
||||
possible_metadata_paths = recursively_find(package_dir,
|
||||
'model-index.yml')
|
||||
possible_metadata_paths.extend(
|
||||
recursively_find(package_dir, 'model_zoo.yml'))
|
||||
for metadata_path in possible_metadata_paths:
|
||||
if osp.exists(metadata_path):
|
||||
metadata = load(metadata_path)
|
||||
return metadata
|
||||
raise FileNotFoundError(
|
||||
highlighted_error(
|
||||
'model-index.yml or model_zoo.yml is not found, please '
|
||||
f'upgrade your {package} to support search command'))
|
||||
|
||||
|
||||
def convert2df(metadata: ModelIndex) -> DataFrame:
|
||||
|
@ -4,6 +4,7 @@ from .default import (
|
||||
DEFAULT_MMCV_BASE_URL,
|
||||
DEFAULT_URL,
|
||||
MODULE2PKG,
|
||||
OFFICIAL_MODULES,
|
||||
PKG2MODULE,
|
||||
PKG2PROJECT,
|
||||
RAW_GITHUB_URL,
|
||||
@ -32,7 +33,7 @@ from .utils import (
|
||||
get_package_info_from_pypi,
|
||||
get_package_version,
|
||||
get_release_version,
|
||||
get_torch_cuda_version,
|
||||
get_torch_device_version,
|
||||
highlighted_error,
|
||||
is_installed,
|
||||
is_version_equal,
|
||||
@ -59,7 +60,7 @@ __all__ = [
|
||||
'get_installed_version',
|
||||
'get_installed_path',
|
||||
'get_latest_version',
|
||||
'get_torch_cuda_version',
|
||||
'get_torch_device_version',
|
||||
'is_installed',
|
||||
'parse_url',
|
||||
'PKG2PROJECT',
|
||||
@ -90,4 +91,5 @@ __all__ = [
|
||||
'parse_home_page',
|
||||
'ensure_installation',
|
||||
'rich_progress_bar',
|
||||
'OFFICIAL_MODULES',
|
||||
]
|
||||
|
@ -13,9 +13,17 @@ WHEEL_URL = {
|
||||
'{torch_version}/index.html',
|
||||
}
|
||||
RAW_GITHUB_URL = 'https://raw.githubusercontent.com/{owner}/{repo}/{branch}'
|
||||
|
||||
OFFICIAL_MODULES = [
|
||||
'mmcls', 'mmdet', 'mmdet3d', 'mmseg', 'mmaction2', 'mmtrack', 'mmpose',
|
||||
'mmedit', 'mmocr', 'mmgen', 'mmselfsup', 'mmrotate', 'mmflow', 'mmyolo',
|
||||
'mmpretrain', 'mmagic'
|
||||
]
|
||||
|
||||
PKG2PROJECT = {
|
||||
'mmcv-full': 'mmcv',
|
||||
'mmcls': 'mmclassification',
|
||||
'mmpretrain': 'mmpretrain',
|
||||
'mmdet': 'mmdetection',
|
||||
'mmdet3d': 'mmdetection3d',
|
||||
'mmsegmentation': 'mmsegmentation',
|
||||
@ -29,6 +37,7 @@ PKG2PROJECT = {
|
||||
'mmrotate': 'mmrotate',
|
||||
'mmflow': 'mmflow',
|
||||
'mmyolo': 'mmyolo',
|
||||
'mmagic': 'mmagic',
|
||||
}
|
||||
# TODO: Should directly infer MODULE name from PKG info
|
||||
PKG2MODULE = {
|
||||
|
@ -12,7 +12,7 @@ import typing
|
||||
from collections import defaultdict
|
||||
from email.parser import FeedParser
|
||||
from pkg_resources import get_distribution, parse_version
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import click
|
||||
import requests
|
||||
@ -23,6 +23,15 @@ from requests.models import Response
|
||||
from .default import PKG2PROJECT
|
||||
from .progress_bars import rich_progress_bar
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
IS_NPU_AVAILABLE = hasattr(
|
||||
torch, 'npu') and torch.npu.is_available() # type: ignore
|
||||
except Exception:
|
||||
IS_NPU_AVAILABLE = False
|
||||
|
||||
|
||||
def parse_url(url: str) -> Tuple[str, str]:
|
||||
"""Parse username and repo from url.
|
||||
@ -327,12 +336,37 @@ def get_installed_path(package: str) -> str:
|
||||
return osp.join(pkg.location, package2module(package))
|
||||
|
||||
|
||||
def get_torch_cuda_version() -> Tuple[str, str]:
|
||||
"""Get PyTorch version and CUDA version if it is available.
|
||||
def is_npu_available() -> bool:
|
||||
"""Returns True if Ascend PyTorch and npu devices exist."""
|
||||
return IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
def get_npu_version() -> str:
|
||||
"""Returns the version of NPU when npu devices exist."""
|
||||
if not is_npu_available():
|
||||
return ''
|
||||
ascend_home_path = os.environ.get('ASCEND_HOME_PATH', None)
|
||||
if not ascend_home_path or not os.path.exists(ascend_home_path):
|
||||
raise RuntimeError(
|
||||
highlighted_error(
|
||||
f'ASCEND_HOME_PATH:{ascend_home_path} does not exists when '
|
||||
'installing OpenMMLab projects on Ascend NPU.'
|
||||
"Please run 'source set_env.sh' in the CANN installation path."
|
||||
))
|
||||
npu_version = torch.version.cann
|
||||
return npu_version
|
||||
|
||||
|
||||
def get_torch_device_version() -> Tuple[str, str, str]:
|
||||
"""Get PyTorch version and CUDA/NPU version if it is available.
|
||||
|
||||
Example:
|
||||
>>> get_torch_cuda_version()
|
||||
'1.8.0', '102'
|
||||
>>> get_torch_device_version()
|
||||
'1.8.0', 'cpu', ''
|
||||
>>> get_torch_device_version()
|
||||
'1.8.0', 'cuda', '102'
|
||||
>>> get_torch_device_version()
|
||||
'1.11.0', 'ascend', '602'
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
@ -344,11 +378,17 @@ def get_torch_cuda_version() -> Tuple[str, str]:
|
||||
torch_v = torch_v.split('+')[0]
|
||||
|
||||
if torch.version.cuda is not None:
|
||||
device = 'cuda'
|
||||
# torch.version.cuda like 10.2 -> 102
|
||||
cuda_v = ''.join(torch.version.cuda.split('.'))
|
||||
device_v = ''.join(torch.version.cuda.split('.'))
|
||||
elif is_npu_available():
|
||||
device = 'ascend'
|
||||
device_v = get_npu_version()
|
||||
device_v = ''.join(device_v.split('.'))
|
||||
else:
|
||||
cuda_v = 'cpu'
|
||||
return torch_v, cuda_v
|
||||
device = 'cpu'
|
||||
device_v = ''
|
||||
return torch_v, device, device_v
|
||||
|
||||
|
||||
def cast2lowercase(input: Union[list, tuple, str]) -> Any:
|
||||
@ -509,8 +549,11 @@ def get_config(cfg, name):
|
||||
name = name.split('.')
|
||||
suffix = ''
|
||||
for item in name:
|
||||
assert item in cfg, f'attribute {item} not cfg{suffix}'
|
||||
cfg = cfg[item]
|
||||
if isinstance(cfg, Sequence) and not isinstance(cfg, str):
|
||||
cfg = cfg[int(item)]
|
||||
else:
|
||||
assert item in cfg, f'attribute {item} not cfg{suffix}'
|
||||
cfg = cfg[item]
|
||||
suffix += f'.{item}'
|
||||
return cfg
|
||||
|
||||
@ -524,8 +567,11 @@ def set_config(cfg, name, value):
|
||||
name = name.split('.')
|
||||
suffix = ''
|
||||
for item in name[:-1]:
|
||||
assert item in cfg, f'attribute {item} not cfg{suffix}'
|
||||
cfg = cfg[item]
|
||||
if isinstance(cfg, Sequence) and not isinstance(cfg, str):
|
||||
cfg = cfg[int(item)]
|
||||
else:
|
||||
assert item in cfg, f'attribute {item} not cfg{suffix}'
|
||||
cfg = cfg[item]
|
||||
suffix += f'.{item}'
|
||||
|
||||
assert name[-1] in cfg, f'attribute {item} not cfg{suffix}'
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
__version__ = '0.3.4'
|
||||
__version__ = '0.3.10'
|
||||
|
||||
|
||||
def parse_version_info(version_str):
|
||||
|
@ -1,6 +1,7 @@
|
||||
Click
|
||||
colorama
|
||||
model-index
|
||||
opendatalab
|
||||
pandas
|
||||
pip>=19.3
|
||||
requests
|
||||
|
@ -20,4 +20,4 @@ default_section = THIRDPARTY
|
||||
include_trailing_comma = true
|
||||
|
||||
[codespell]
|
||||
ignore-words-list = te
|
||||
ignore-words-list = te, cann
|
||||
|
@ -12,7 +12,11 @@ model = dict(
|
||||
dataset_type = 'MNIST'
|
||||
data_preprocessor = dict(mean=[33.46], std=[78.87])
|
||||
|
||||
pipeline = [dict(type='Resize', scale=32), dict(type='PackClsInputs')]
|
||||
pipeline = [
|
||||
dict(type='Resize', scale=32),
|
||||
dict(type='Pad', size=(32, 32)),
|
||||
dict(type='PackClsInputs')
|
||||
]
|
||||
|
||||
common_data_cfg = dict(
|
||||
type=dataset_type, data_prefix='data/mnist', pipeline=pipeline)
|
||||
|
@ -57,6 +57,12 @@ def test_gridsearch(gpus, tmp_path):
|
||||
f'--work-dir={tmp_path}', '--search-args'
|
||||
]
|
||||
|
||||
args5 = [
|
||||
'mmcls', 'tests/data/lenet5_mnist_2.0.py', f'--gpus={gpus}',
|
||||
f'--work-dir={tmp_path}', '--search-args',
|
||||
'--train_dataloader.dataset.pipeline.0.scale 16 32'
|
||||
]
|
||||
|
||||
result = runner.invoke(gridsearch, args1)
|
||||
assert result.exit_code == 0
|
||||
|
||||
@ -69,6 +75,9 @@ def test_gridsearch(gpus, tmp_path):
|
||||
result = runner.invoke(gridsearch, args4)
|
||||
assert result.exit_code != 0
|
||||
|
||||
result = runner.invoke(gridsearch, args5)
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def teardown_module():
|
||||
runner = CliRunner()
|
||||
|
@ -1,12 +1,10 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
from click.testing import CliRunner
|
||||
|
||||
from mim.commands.install import cli as install
|
||||
from mim.commands.search import cli as search
|
||||
from mim.commands.uninstall import cli as uninstall
|
||||
from mim.utils import DEFAULT_CACHE_DIR
|
||||
|
||||
|
||||
def setup_module():
|
||||
@ -39,11 +37,9 @@ def test_search():
|
||||
result = runner.invoke(search, ['mmaction2', '--remote'])
|
||||
assert result.exit_code == 0
|
||||
|
||||
# mim search mmcls==0.11.0 --remote
|
||||
result = runner.invoke(search, ['mmcls==0.11.0', '--remote'])
|
||||
# mim search mmcls==0.24.0 --remote
|
||||
result = runner.invoke(search, ['mmcls==0.24.0', '--remote'])
|
||||
assert result.exit_code == 0
|
||||
# the metadata of mmcls==0.11.0 will be saved in cache
|
||||
assert osp.exists(osp.join(DEFAULT_CACHE_DIR, 'mmcls-0.11.0.pkl'))
|
||||
|
||||
# always test latest mmcls
|
||||
result = runner.invoke(uninstall, ['mmcls', '--yes'])
|
||||
|
@ -4,6 +4,7 @@ from click.testing import CliRunner
|
||||
from mim.commands.install import cli as install
|
||||
from mim.commands.uninstall import cli as uninstall
|
||||
from mim.utils import get_github_url, parse_home_page
|
||||
from mim.utils.utils import get_torch_device_version, is_npu_available
|
||||
|
||||
|
||||
def setup_module():
|
||||
@ -39,6 +40,13 @@ def test_get_github_url():
|
||||
'mmcls') == 'https://github.com/open-mmlab/mmclassification.git'
|
||||
|
||||
|
||||
def test_get_torch_device_version():
|
||||
torch_v, device, device_v = get_torch_device_version()
|
||||
assert torch_v.replace('.', '').isdigit()
|
||||
if is_npu_available():
|
||||
assert device == 'ascend'
|
||||
|
||||
|
||||
def teardown_module():
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
|
||||
|
Loading…
x
Reference in New Issue
Block a user