From c9a5ea6a883cf48a931866b2297ec78ebf2b60a3 Mon Sep 17 00:00:00 2001 From: Guoping Pan <731061720@qq.com> Date: Tue, 31 Oct 2023 11:08:16 +0800 Subject: [PATCH] [Experimental] Packaging a minimal and flatten algorithm from downstream repos (#229) --- mim/_internal/export/README.md | 99 ++ mim/_internal/export/README_zh-CN.md | 100 ++ mim/_internal/export/__init__.py | 1 + mim/_internal/export/common.py | 75 ++ mim/_internal/export/flatten_func.py | 1081 +++++++++++++++++ mim/_internal/export/pack_cfg.py | 312 +++++ mim/_internal/export/patch_utils/README.md | 75 ++ .../export/patch_utils/README_zh-CN.md | 75 ++ mim/_internal/export/patch_utils/__init__.py | 3 + .../export/patch_utils/patch_model.py | 82 ++ .../export/patch_utils/patch_task.py | 73 ++ mim/_internal/export/utils.py | 555 +++++++++ mim/click/autocompletion.py | 18 +- mim/commands/export.py | 124 ++ mim/utils/__init__.py | 2 + mim/utils/default.py | 7 + 16 files changed, 2666 insertions(+), 16 deletions(-) create mode 100644 mim/_internal/export/README.md create mode 100644 mim/_internal/export/README_zh-CN.md create mode 100644 mim/_internal/export/__init__.py create mode 100644 mim/_internal/export/common.py create mode 100644 mim/_internal/export/flatten_func.py create mode 100644 mim/_internal/export/pack_cfg.py create mode 100644 mim/_internal/export/patch_utils/README.md create mode 100644 mim/_internal/export/patch_utils/README_zh-CN.md create mode 100644 mim/_internal/export/patch_utils/__init__.py create mode 100644 mim/_internal/export/patch_utils/patch_model.py create mode 100644 mim/_internal/export/patch_utils/patch_task.py create mode 100644 mim/_internal/export/utils.py create mode 100644 mim/commands/export.py diff --git a/mim/_internal/export/README.md b/mim/_internal/export/README.md new file mode 100644 index 0000000..40f4615 --- /dev/null +++ b/mim/_internal/export/README.md @@ -0,0 +1,99 @@ +# Export (experimental) + +`mim export` is a new feature in `mim`, which can export a minimum trainable and testable model package through the config file. + +The minimal model compactly combines `model, datasets, engine` and other components. There is no need to search for each module separately based on the config file. Users who do not install from source can also get the model file at a glance. + +In addition, lengthy inheritance relationships, such as `CondInstBboxHead -> FCOSHead -> AnchorFreeHead -> BaseDenseHead -> BaseModule` in `mmdetection`, will be directly flattened into `CondInstBboxHead -> BaseModule`. There is no need to open and compare multiple model files when all functions inherited from the parent class are clearly visible. + +### Instructions Usage + +```bash +mim export config_path + +# config_path has the following two optional types: + +# 1. config of downstream repo: Complete the call through repo::model_name/xxx.py. For example: +mim export mmdet::configs/mask_rcnn/mask-rcnn_r101_fpn_1x_coco.py + +# 2. config in a certain folder, for example: +mim export config_dir/mask-rcnn_r101_fpn_1x_coco.py +``` + +### Minimum model package directory structure + +``` +minimun_package(Named as pack_from_{repo}_20231212_121212) +|- pack +| |- configs # Configuration folder +| | |- model_name +| | |- xxx.py # Configuration file +| | +| |- models # model folder +| | |- model_file.py +| | |- ... +| | +| |- data # data folder +| | +| |- demo # demo folder +| | +| |-datasets #Dataset class definition +| | |- transforms +| | +| |- registry.py # Registrar +| +|- tools +| |- train.py # training +| |- test.py # test +| +``` + +### limit + +Currently, `mim export` only supports some config files of `mmpose`, `mmdetection`, `mmagic`, and `mmsegmentation` Besides, there are some constraints on the downstream repo. + +#### For downstream repos + +1. It is best to name the config without special symbols, otherwise it cannot be parsed through `mmengine.hub.get_config()`, such as: + + - gn+ws/faster-rcnn_r101_fpn_gn-ws-all_1x_coco.py + - legacy_1.x/cascade-mask-rcnn_r50_fpn_1x_coco_v1.py + +2. For `mmsegmentation`, before using `mim.export` for config in `mmseg`, you should firstly modify it like `mmseg/registry/registry.py -> mmseg/registry.py`, without a directory to wrap `registry.py` + +3. It is recommended that the downstream Registry name inherited from mmengine should not be changed. For example, mmagic renamed `EVALUATOR` to `EVALUATORS` + + ```python + from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR + + # Evaluators to define the evaluation process. + EVALUATORS = Registry( + 'evaluator', + parent=MMENGINE_EVALUATOR, + locations=['mmagic.evaluation'], + ) + ``` + +4. In addition, if you add a register that is not in mmengine, such as `DIFFUSION_SCHEDULERS` in mmagic, you need to add a key-value pair in `REGISTRY_TYPE` in `mim/_internal/export/common.py` for registering `torch `Module to `DIFFUSION_SCHEDULERS` + + ```python + # "mmagic/mmagic/registry.py" + # modules for diffusion models that support adding noise and denoising + DIFFUSION_SCHEDULERS = Registry( + 'diffusion scheduler', + locations=['mmagic.models.diffusion_schedulers'], + ) + + # "mim/utils/mmpack/common.py" + REGISTRY_TYPE = { + ... + 'diffusion scheduler': 'DIFFUSION_SCHEDULERS', + ... + } + ``` + +#### Need to be improve + +1. Currently, the inheritance relationship expansion of double parent classes is not supported. Improvements will be made in the future depending on the needs. +2. Here comes ERROR when the origin config contains other files' `path` that can not be found in the current path. You can avoid export errors by manually modifying the `path` in the origin config. +3. When isinstance() is used, if the parent class is just a class in the inheritance chain, the judgment may be False after expansion, because the original inheritance relationship will not be retained. diff --git a/mim/_internal/export/README_zh-CN.md b/mim/_internal/export/README_zh-CN.md new file mode 100644 index 0000000..e3a2643 --- /dev/null +++ b/mim/_internal/export/README_zh-CN.md @@ -0,0 +1,100 @@ +# Export (experimental) + +`mim export` 是 `mim` 里面一个新的功能,可以实现通过 config 文件,就能够导出一个最小可训练测试的模型包。 + +最小模型将 `model、datasets、engine`等组件紧凑地组合在一起,不用再根据 config 文件单独寻找每一个模块,对于非源码安装的用户也能够获取到一目了然的模型文件。 + +此外,对于模型中冗长的继承关系,如 `mmdetection` 中的 `CondInstBboxHead -> FCOSHead -> AnchorFreeHead -> BaseDenseHead -> BaseModule`,将会被直接展平为 `CondInstBboxHead -> BaseModule`,即无需再在多个模型文件之间跳转比较了,所有继承父类的函数一览无遗。 + +### 使用方法 + +```bash +mim export config_path + +# config_path 有以下两种可选类型: +# 下游 repo 的 config:通过 repo::configs/xxx.py 来完成调用。例如: +mim export mmdet::configs/mask_rcnn/mask-rcnn_r101_fpn_1x_coco.py + +# 在某文件夹下的 config, 例如: +mim export config_dir/mask-rcnn_r101_fpn_1x_coco.py +``` + +### 最小模型包目录结构 + +``` +minimun_package(Named as pack_from_{repo}_20231212_121212) +|- pack +| |- configs # 配置文件夹 +| | |- model_name +| | |- xxx.py # 配置文件 +| | +| |- models # 模型文件夹 +| | |- model_file.py +| | |- ... +| | +| |- data # 数据文件夹 +| | +| |- demo # demo文件夹 +| | +| |- datasets # 数据集类定义 +| | |- transforms +| | +| |- registry.py # 注册器 +| +|- tools +| |- train.py # 训练 +| |- test.py # 测试 +| +``` + +### 限制 + +`mim export` 目前只支持 `mmpose`、`mmdetection`、`mmagic` 和 `mmsegmentation` 的部分 config 配置文件,并且对下游算法库有一些约束。 + +#### 针对下游库 + +1. config 命名最好**不要有特殊符号**,否则无法通过 `mmengine.hub.get_config()` 进行解析,如: + + - gn+ws/faster-rcnn_r101_fpn_gn-ws-all_1x_coco.py + - legacy_1.x/cascade-mask-rcnn_r50_fpn_1x_coco_v1.py + +2. 针对 `mmsegmentation`, 在使用 `mim.export` 导出 `mmseg` 的 config 之前, 首先需要去掉对于 `registry.py` 的外层文件夹封装, 即修改 `mmseg/registry/registry.py -> mmseg/registry.py`。 + +3. 建议下游继承于 mmengine 的 Registry 名字不要改动,如 mmagic 中就将 `EVALUATOR` 重新命名为了 `EVALUATORS` + + ```python + from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR + + # Evaluators to define the evaluation process. + EVALUATORS = Registry( + 'evaluator', + parent=MMENGINE_EVALUATOR, + locations=['mmagic.evaluation'], + ) + ``` + +4. 另外,如果添加了 mmengine 中没有的注册器,如 mmagic 中的 `DIFFUSION_SCHEDULERS`,需要在 `mim/_internal/export/common.py` 的 `REGISTRY_TYPE` 中添加键值对,用于注册 `torch` 模块到 `DIFFUSION_SCHEDULERS` + + ```python + # "mmagic/mmagic/registry.py" + # modules for diffusion models that support adding noise and denoising + DIFFUSION_SCHEDULERS = Registry( + 'diffusion scheduler', + locations=['mmagic.models.diffusion_schedulers'], + ) + + # "mim/utils/mmpack/common.py" + REGISTRY_TYPE = { + ... + 'diffusion scheduler': 'DIFFUSION_SCHEDULERS', + ... + } + ``` + +#### 对于 `mim.export` 功能需要改进的地方 + +1. 目前还不支持双父类的继承关系展开,后续看需求进行改进 + +2. 对于用到 `isinstance()` 时,如果父类只是继承链中某个类,可能展开后判断就会为 False,因为并不会保留原有的继承关系 + +3. 当 config 文件中含有当前文件夹没法被访问到的`数据集路径`,导出可能会失败。目前的临时解决方法是:将原来的 config 文件保存到当前文件夹下,然后需要用户手动修改`数据集路径`为当前路径下的可访问路径。如:`data/ADEChallengeData2016/ -> your_data_dir/ADEChallengeData2016/` diff --git a/mim/_internal/export/__init__.py b/mim/_internal/export/__init__.py new file mode 100644 index 0000000..ef101fe --- /dev/null +++ b/mim/_internal/export/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mim/_internal/export/common.py b/mim/_internal/export/common.py new file mode 100644 index 0000000..32fd9d6 --- /dev/null +++ b/mim/_internal/export/common.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# initialization template for __init__.py +_init_str = """ +import os +import os.path as osp + +all_files = os.listdir(osp.dirname(__file__)) + +for file in all_files: + if (file.endswith('.py') and file != '__init__.py') or '.' not in file: + exec(f'from .{osp.splitext(file)[0]} import *') +""" + +# import pack path for tools +_import_pack_str = """ +import os.path as osp +import sys +sys.path.append(osp.dirname(osp.dirname(__file__))) +import pack + +""" + +OBJECTS_TO_BE_PATCHED = { + 'MODELS': [ + 'BACKBONES', + 'NECKS', + 'HEADS', + 'LOSSES', + 'SEGMENTORS', + 'build_backbone', + 'build_neck', + 'build_head', + 'build_loss', + 'build_segmentor', + 'CLASSIFIERS', + 'RETRIEVER', + 'build_classifier', + 'build_retriever', + 'POSE_ESTIMATORS', + 'build_pose_estimator', + 'build_posenet', + ], + 'TASK_UTILS': [ + 'PIXEL_SAMPLERS', + 'build_pixel_sampler', + ] +} + +REGISTRY_TYPES = { + 'runner': 'RUNNERS', + 'runner constructor': 'RUNNER_CONSTRUCTORS', + 'hook': 'HOOKS', + 'strategy': 'STRATEGIES', + 'dataset': 'DATASETS', + 'data sampler': 'DATA_SAMPLERS', + 'transform': 'TRANSFORMS', + 'model': 'MODELS', + 'model wrapper': 'MODEL_WRAPPERS', + 'weight initializer': 'WEIGHT_INITIALIZERS', + 'optimizer': 'OPTIMIZERS', + 'optimizer wrapper': 'OPTIM_WRAPPERS', + 'optimizer wrapper constructor': 'OPTIM_WRAPPER_CONSTRUCTORS', + 'parameter scheduler': 'PARAM_SCHEDULERS', + 'param scheduler': 'PARAM_SCHEDULERS', + 'metric': 'METRICS', + 'evaluator': 'EVALUATOR', # TODO EVALUATORS in mmagic + 'task utils': 'TASK_UTILS', + 'loop': 'LOOPS', + 'visualizer': 'VISUALIZERS', + 'vis_backend': 'VISBACKENDS', + 'log processor': 'LOG_PROCESSORS', + 'inferencer': 'INFERENCERS', + 'function': 'FUNCTIONS', +} diff --git a/mim/_internal/export/flatten_func.py b/mim/_internal/export/flatten_func.py new file mode 100644 index 0000000..491518a --- /dev/null +++ b/mim/_internal/export/flatten_func.py @@ -0,0 +1,1081 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import ast +import inspect +import logging +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Set, Tuple, Union + +from mmengine.logging import print_log +from mmengine.model import ( + BaseDataPreprocessor, + BaseModel, + BaseModule, + ImgDataPreprocessor, +) + +from .common import OBJECTS_TO_BE_PATCHED + + +@dataclass +class TopClassNodeInfo: + """Contatins ``top_cls_node`` information which waiting to be flattened. + + Attributes: + "cls_node": `ast.ClassDef`, needed to be flattened. + "level_cls_name": The class name of flattened layer. + "super_cls_name": The class name of super class. + "end_idx": `len(node.body)`, the last index of all nodes + in `ast.ClassDef`. + "sub_func": Collects all `ast.FunctionDef` nodes in this class. + "sub_assign": Collects all `ast.Assign` nodes in this class + """ + cls_node: ast.ClassDef + level_cls_name: Optional[str] = None + super_cls_name: Optional[str] = None + end_idx: Optional[int] = None + sub_func: Dict[str, ast.FunctionDef] = field(default_factory=dict) + sub_assign: Dict[str, ast.Assign] = field(default_factory=dict) + + +@dataclass +class TopAstInfo: + """top_ast_info (TopClassNodeInfo): Contatins the initial information of + the ``top_ast_tree``. + + Attributes: + class_dict (TopClassNodeInfo): Contatins ``top_cls_node`` + information which waiting to be flattened. + importfrom_dict (Dict[str, List[ast.ImportFrom, List[str]]]): + Contatins the simple alias of `ast.ImportFrom` information. + import_list (List[str]): Contatins imported module name. + assign_list (List[ast.Assign]): Contatins global assign. + if_list (List[ast.If]): Contatins global `ast.If`. + try_list (List[ast.Try]): Contatins global `ast.Try`. + importfrom_asname_dict (Dict[str, ast.ImportFrom]): Contatins the + asname alias of `ast.ImportFrom` information. + """ + class_dict: TopClassNodeInfo + importfrom_dict: Dict[str, Tuple[ast.ImportFrom, + List[str]]] = field(default_factory=dict) + importfrom_asname_dict: Dict[str, + ast.ImportFrom] = field(default_factory=dict) + import_list: List[str] = field(default_factory=list) + try_list: List[ast.Try] = field(default_factory=list) + if_list: List[ast.If] = field(default_factory=list) + assign_list: List[ast.Assign] = field(default_factory=list) + + +@dataclass +class ExtraSuperAstInfo: + """extra_super_ast_info (ExtraSuperAstInfo): Contatins the extra + information of the ``super_ast_tree`` which needed to be consider. + + Attributes: + used_module (Dict[ast.AST, Set[str]): The dict records + the node and a set + extra_import (Dict[ast.ImportFrom, List[ast.alias]]): + Records extra `ast.ImportFrom` nodes and their modules + in the super class file. + extra_importfrom (List[ast.alias]): Records extra + `ast.Import` nodes' alias in the super class file. + """ + used_module: Dict[ast.AST, Set[str]] = field( + default_factory=lambda: defaultdict(set)) + extra_import: List[ast.alias] = field(default_factory=list) + extra_importfrom: Dict[str, Tuple[ast.ImportFrom, + List[str]]] = field(default_factory=dict) + + +@dataclass +class NeededNodeInfo: + """need_node_info (NeededNodeInfo): Contatins the needed node by comparing + ``super_ast_tree`` and ``top_ast_tree``. + + Attributes: + need_importfrom_nodes (set, optional): Collect the needed + `ast.ImportFrom` node from ``ExtraSuperAstInfo.extra_importfrom``. + need_import_alias_asname (set, optional): Collect the needed + `ast.Import` asname nodes from ``ExtraSuperAstInfo.extra_import``. + need_import_alias (set, optional): Collect the needed `ast.Import` + nodes from ``ExtraSuperAstInfo.extra_import``. + """ + need_importfrom_nodes: Set[ast.ImportFrom] = field(default_factory=set) + need_import_alias_asname: Set[ast.alias] = field(default_factory=set) + need_import_alias: Set[ast.alias] = field(default_factory=set) + + +def get_len(used_modules_dict_set: Dict[ast.AST, Set[str]]): + """Get the sum of used modules. + + Args: + used_module_dict (Dict[ast.AST, Set[str]]): Records + the node and a set including module names it uses. + + Returns: + int: The sum of used modules. + """ + len_sum = 0 + for name_list in used_modules_dict_set.values(): + len_sum += len(name_list) + + return len_sum + + +def record_used_node(node: ast.AST, used_module_dict_set: Dict[ast.AST, + Set[str]]): + """Recode the node had been use and no need to remove. + + Args: + node (ast.AST): AST Node.. + used_module_dict_set (Dict[ast.AST, Set[str]): The dict records + the node and a set including module names it uses. + + Examples: + >>> # a = nn.MaxPool1d() + >>> Assign( + >>> targets=[ + >>> Name(id='a', ctx=Store())], + >>> value=Call( + >>> func=Attribute( + >>> value=Name(id='nn', ctx=Load()), + >>> attr='MaxPool1d', + >>> ctx=Load()), + >>> args=[], + >>> keywords=[])) + >>> + >>> used_module_dict_set[ast.Assign] = set('nn', 'a') + """ + # only traverse the body of FunctionDef to ignore the input args + if isinstance(node, ast.FunctionDef): + for func_sub_node in node.body: + for func_sub_sub_node in ast.walk(func_sub_node): + if isinstance(func_sub_sub_node, ast.Name): + used_module_dict_set[node].add(func_sub_sub_node.id) + + # iteratively process the body of ClassDef + elif isinstance(node, ast.ClassDef): + for class_sub_node in node.body: + record_used_node(class_sub_node, used_module_dict_set) + + else: + for sub_node in ast.walk(node): + if isinstance(sub_node, ast.Name): + used_module_dict_set[node].add(sub_node.id) + + +def if_need_remove(node: ast.AST, used_module_dict_set: Dict[ast.AST, + Set[str]]): + """Justify if the node should be remove. + + If the node not be use actually, it will be removed. + + Args: + node (ast.AST): AST Node. + used_module_dict_set (Dict[ast.AST, Set[str]]): The dict records the + node and a set including module names it uses. + + Returns: + bool: if not be used then return "True" meaning to be removed, + else "False". + """ + if isinstance(node, ast.Assign): + if isinstance(node.targets[0], ast.Name): + name = node.targets[0].id + else: + raise TypeError(f'expect the targets in ast.Assign is ast.Name\ + but got {type(node.targets[0])}') + elif isinstance(node, ast.FunctionDef) or isinstance(node, ast.ClassDef): + name = node.name + else: + # HARD CODE: if not the above type will directly remove. + return True + + for name_list in used_module_dict_set.values(): + if name in name_list: + return False + + return True + + +def is_in_top_ast_tree(node: ast.AST, + top_ast_info: TopAstInfo, + top_cls_and_func_node_name_list: List[str] = []): + """Justify if the module name already exists in ``top_ast_tree``. + + Args: + node (ast.AST): AST Node. + top_ast_info (TopClassNodeInfo): Contatins the initial information + of the ``top_ast_tree``. + top_cls_and_func_node_name_list (List[str], optional): Containing + `Class` or `Function` name in ``top_ast_tree``. Defaults to "[]" + + Returns: + bool: if the module name already exists in ``top_ast_tree`` return + "True", else "False". + """ + if isinstance(node, ast.Assign): + for _assign in top_ast_info.assign_list: + if ast.dump(_assign) == ast.dump(node): + return True + + elif isinstance(node, ast.Try): + for _try in top_ast_info.try_list: + if ast.dump(_try) == ast.dump(node): + return True + + elif isinstance(node, ast.If): + for _if in top_ast_info.if_list: + if ast.dump(_if) == ast.dump(node): + return True + + elif isinstance(node, ast.FunctionDef) or isinstance(node, ast.ClassDef): + if node.name in top_cls_and_func_node_name_list: + return True + + return False + + +def ignore_ast_docstring(node: Union[ast.ClassDef, ast.FunctionDef]): + """Get the insert key ignoring the docstring. + + Args: + node (ast.ClassDef | ast.FunctionDef): AST Node. + + Returns: + int: The beginning insert position of the node. + """ + insert_index = 0 + + for sub_node in node.body: + if isinstance(sub_node, ast.Expr): + insert_index += 1 + + # HARD CODE: prevent from some ast.Expr like warning.warns which + # need module "warning" + for sub_sub_node in ast.walk(sub_node): + if isinstance(sub_sub_node, ast.Name): + return 0 + else: + break + + return insert_index + + +def find_local_import(node: Union[ast.FunctionDef, ast.Assign], + extra_super_ast_info: ExtraSuperAstInfo, + need_node_info: Optional[NeededNodeInfo] = None): + """Find the needed Import and ImportFrom of the node. + + Args: + node (ast.FunctionDef | ast.Assign) + extra_super_ast_info (ExtraSuperAstInfo): Contatins the extra + information of the ``super_ast_tree`` which needed to be consider. + need_node_info (NeededNodeInfo, optional): Contatins the needed node + by comparing ``super_ast_tree`` and ``top_ast_tree``. + + Returns: + need_node_info + """ + if need_node_info is None: + need_node_info = NeededNodeInfo() + + # get all the used modules' name in specific node + used_module = extra_super_ast_info.used_module[node] + + if len(used_module) != 0: + + # record all used ast.ImportFrom nodes + for import_node, alias_list in \ + extra_super_ast_info.extra_importfrom.values(): + for module in used_module: + if module in alias_list: # type: ignore[operator] + need_node_info.need_importfrom_nodes.add( + import_node) # type: ignore[arg-type] # noqa: E501 + continue + + # record all used ast.Import nodes + for alias in extra_super_ast_info.extra_import: + if alias.asname is not None: + if alias.asname in used_module: + need_node_info.need_import_alias_asname.add(alias) + else: + if alias.name in used_module: + need_node_info.need_import_alias.add(alias) + + return need_node_info + + +def add_local_import_to_func(node, need_node_info: NeededNodeInfo): + """Add the needed ast.ImportFrom and ast.Import to ast.Function. + + Args: + node (ast.FunctionDef | ast.Assign) + need_node_info (NeededNodeInfo): Contatins the needed node by + comparing ``super_ast_tree`` and ``top_ast_tree``. + """ + insert_index = ignore_ast_docstring(node) + + for importfrom_node in need_node_info.need_importfrom_nodes: + node.body.insert(insert_index, importfrom_node) + + if len(need_node_info.need_import_alias) != 0: + node.body.insert( + insert_index, + ast.Import( + names=[alias for alias in need_node_info.need_import_alias])) + + for alias in need_node_info.need_import_alias_asname: + node.body.insert(insert_index, ast.Import(names=[alias])) + + +def add_local_import_to_class(cls_node: ast.ClassDef, + extra_super_ast_info: ExtraSuperAstInfo, + new_node_begin_index=-9999): + """Add the needed `ast.ImportFrom` and `ast.Import` to `ast.Class`'s + sub_nodes, including `ast.Assign` and `ast.Function`. + + Traverse `ast.ClassDef` node, recode all the used modules of class + attributes like the `ast.Assign`, and this needed `ast.ImportFrom` and + `ast.Import` will be add to the top of the cls_node.body. More, for each + sub functions in this class, we will process them as glabal functions + by using :func:`find_local_import` and :func:`add_local_import_to_func`. + + Args: + cls_node (ast.ClassDef) + used_module_dict_super (Dict[ast.AST, Set[str]]): Records + the node and a set including module names it uses. + new_node_begin_index (int, optional): The index of the last node + of cls_node.body. + """ + # for later add all the needed ast.ImportFrom and ast.Import nodes for + # class attributes + later_need_node_info = NeededNodeInfo() + + for i, cls_sub_node in enumerate(cls_node.body): + if isinstance(cls_sub_node, ast.Assign): + + find_local_import( + node=cls_sub_node, + extra_super_ast_info=extra_super_ast_info, + need_node_info=later_need_node_info, + ) + + # ``i >= new_node_begin_index`` means only processing those + # newly added nodes. + elif isinstance(cls_sub_node, + ast.FunctionDef) and i >= new_node_begin_index: + need_node_info = find_local_import( + node=cls_sub_node, + extra_super_ast_info=extra_super_ast_info, + ) + + add_local_import_to_func( + node=cls_sub_node, need_node_info=need_node_info) + + # add all the needed ast.ImportFrom and ast.Import nodes for + # class attributes + add_local_import_to_func( + node=cls_node, need_node_info=later_need_node_info) + + +def init_prepare(top_ast_tree: ast.Module, flattened_cls_name: str): + """Collect the initial information of the ``top_ast_tree``. + + Args: + top_ast_tree (ast.Module): Ast tree which will be continuelly updated + contains the class needed to be flattened. + flattened_cls_name (str): The name of the class needed to + be flattened. + + Returns: + top_ast_info (TopClassNodeInfo): Contatins the initial information + of the ``top_ast_tree``. + """ + class_dict = TopClassNodeInfo(None) # type: ignore + top_ast_info = TopAstInfo(class_dict) + + # top_ast_tree scope + for node in top_ast_tree.body: + + # ast.Module -> ast.ImporFrom + if isinstance(node, ast.ImportFrom): + if node.module is not None and node.names[0].asname is not None: + top_ast_info.importfrom_asname_dict[node.module] = node + elif node.module is not None: + # yapf: disable + top_ast_info.importfrom_dict[node.module] = ( + node, + [alias.name for alias in node.names] # type: ignore + ) + # yapf: enable + + # ast.Module -> ast.Import + elif isinstance(node, ast.Import): + top_ast_info.import_list.extend([ + alias.name if alias.asname is None else alias.asname + for alias in node.names + ]) + + # ast.Module -> ast.Assign + elif isinstance(node, ast.Assign): + top_ast_info.assign_list.append(node) + + # ast.Module -> ast.Try + elif isinstance(node, ast.Try): + top_ast_info.try_list.append(node) + + # ast.Module -> ast.If + elif isinstance(node, ast.If): + top_ast_info.if_list.append(node) + + # ast.Module -> specific ast.ClassDef + elif isinstance(node, + ast.ClassDef) and node.name == flattened_cls_name: + + # ``level_cls_name`` is the actual name in mro in this + # flatten level + # + # Examples: + # >>> # level_cls_name = 'A' + # >>> class A(B) class B(C) + # >>> + # >>> # after flattened + # >>> # level_cls_name = 'B' + # >>> class A(C) + top_ast_info.class_dict.cls_node = node + top_ast_info.class_dict.level_cls_name = flattened_cls_name + top_ast_info.class_dict.super_cls_name = node.bases[ + 0].id # type: ignore[attr-defined] + top_ast_info.class_dict.end_idx = len(node.body) + + for sub_node in node.body: + if isinstance(sub_node, ast.FunctionDef): + top_ast_info.class_dict.sub_func[sub_node.name] = sub_node + + elif isinstance(sub_node, ast.Assign): + top_ast_info.class_dict.sub_assign[sub_node.targets[ + 0].id] = sub_node # type: ignore[attr-defined] + + assert top_ast_info.class_dict is not None, \ + f"The class [{flattened_cls_name}] doesn't exist in the ast tree." + + return top_ast_info + + +def collect_needed_node_from_super(super_ast_tree: ast.Module, + top_ast_info: TopAstInfo): + """Flatten specific model class. + + This function traverses `super_ast_tree` and collection information + comparing with `top_cls_node` in ``top_ast_tree``. + + Need to process `ImportFrom, Import, ClassDef, If, Try, Assign`. + - ImportFrom soulution: If the node.module already exist in + ``top_ast_tree``, we will merge it's alias, but separately + deal with asname and simple ImportFrom node. Else will be + consider extra ImportFrom. + 1. asname alias use :func:`ast.dump` to compare. + 2. simple alias use :func:`set` to get the union set. + + - ClassDef solution: The main part. First we get the `top_cls_node` + and replace :func:`super` call in it with information in + `super_cls_node`. Second, traverse `super_ast_tree` to get those + super class `ast.FunctionDef` and super class `ast.Assign` needed + to add to `top_cls_node`. We should rename the function called by + :func:`super`. Last, insert all the needed super node into + `top_cls_node` and update `top_cls_node.bases`. + Finish class flatten. + + Args: + super_ast_tree (ast.Module): The super ast tree including the super + class in the specific flatten class's mro. + top_ast_info (TopClassNodeInfo): Contatins the initial information + of the ``top_ast_tree``. + + Returns: + extra_super_ast_info (ExtraSuperAstInfo): Contatins the extra + information of the ``super_ast_tree`` which needed to be consider. + """ + extra_super_ast_info = ExtraSuperAstInfo() + + # super_ast_tree scope + for node in super_ast_tree.body: + + # ast.Module -> ast.ImportFrom + if isinstance(node, ast.ImportFrom): + + # HARD CODE: if ast.alias has asname, we consider it only contains + # one module + + # Examples: + # >>> # common style + # >>> for abc import a as A + # >>> # not recommonded style + # >>> for abc import a as A, B + if node.names[0].asname is not None: + if node.module in top_ast_info.importfrom_asname_dict: + top_importfrom_node = \ + top_ast_info.importfrom_asname_dict.get(node.module) + + if ast.dump(top_importfrom_node # type: ignore[arg-type] + ) != ast.dump(node): + # yapf: disable + alias_names = [alias.name + if alias.asname is None + else alias.asname + for alias in node.names] + extra_super_ast_info.extra_importfrom[node.module] = \ + (node, alias_names) + # yapf: enable + + # only name + else: + # the ast.alias import from the same module will be merge into + # one ast.ImportFrom + if node.module in top_ast_info.importfrom_dict: + (top_importfrom_node, last_names) = \ + top_ast_info.importfrom_dict.get(node.module) # type: ignore # noqa: E501 + + current_names = [alias.name for alias in node.names + ] # type: ignore[misc] # noqa: E501 + last_names = list(set(last_names + current_names)) + top_importfrom_node.names = [ + ast.alias(name=name) for name in last_names + ] # type: ignore[attr-defined] + + # NOTE: update the information of top_ast_tree + top_ast_info.importfrom_dict[node.module] = ( + top_importfrom_node, last_names) + + # those don't exist ast.ImportFrom will be later added + elif node.module is not None: + # yapf: disable + alias_names = [alias.name + if alias.asname is None + else alias.asname + for alias in node.names] + extra_super_ast_info.extra_importfrom[node.module] = \ + (node, alias_names) + # yapf: enable + + # ast.Module -> ast.Import + elif isinstance(node, ast.Import): + for alias in node.names: + if alias.asname is not None: + if alias.asname not in top_ast_info.import_list: + extra_super_ast_info.extra_import.append(alias) + else: + if alias.name not in top_ast_info.import_list: + extra_super_ast_info.extra_import.append(alias) + + # ast.Module -> ast.Try / ast.Assign / ast.If + elif (isinstance(node, ast.Try) or isinstance(node, ast.Assign) or + isinstance(node, ast.If)) and \ + not is_in_top_ast_tree(node, top_ast_info=top_ast_info): + record_used_node(node, extra_super_ast_info.used_module) + + # ast.Module -> ast.ClassDef + elif isinstance( + node, ast.ClassDef + ) and node.name == top_ast_info.class_dict.super_cls_name: + + # get the specific flattened class node in the top_ast_tree + top_cls_node = top_ast_info.class_dict.cls_node + + # process super, including below circumstances: + # class A(B) and class B(C) + # 1. super().xxx(): directly replace to self.B_xxx() + # 2. super(A, self).xxx(): directly replace to self.B_xxx() + # 3. super(B, self).xxx(): waiting the level_cls_name=B, then + # replace to self.C_xxx() + + # HARD CODE: if B doesn't exist self.xxx(), it will not deal with + # super(A, self).xxx() until the :func:`postprocess_super()` will + # remove all the args in ``super(args)``, then change to + # ``super()``. In another word, if super doesn't replace in the + # correct level, it will be turn to use the root super + # class' method. + super_func = [] + for sub_node in ast.walk(top_cls_node): # type: ignore[arg-type] + + if isinstance(sub_node, ast.Attribute) \ + and hasattr(sub_node, 'value') \ + and isinstance(sub_node.value, ast.Call) \ + and isinstance(sub_node.value.func, ast.Name) \ + and sub_node.value.func.id == 'super': # noqa: E501 + """ + Examples: super().__init__() + >>> Expr( + >>> value=Call( + >>> func=Attribute( + >>> value=Call( + >>> func=Name(id='super', + ctx=Load()), + >>> args=[], + >>> keywords=[]), + >>> attr='__init__', + >>> ctx=Load()), + >>> args=[], + >>> keywords=[]))], + """ + # Only flatten super syntax: + # 1. super().func_call + # 2. super(top_cls_name, self).func_call + if len( + sub_node.value.args + ) != 0 and sub_node.value.args[ # type: ignore[attr-defined] # noqa: E501 + 0].id != top_ast_info.class_dict.level_cls_name: + continue + + # search and justify if the .xxx() function in the + # super node + for super_cls_sub_node in node.body: + if isinstance( + super_cls_sub_node, ast.FunctionDef + ) and sub_node.attr == \ + super_cls_sub_node.name: + super_func.append(sub_node.attr) + sub_node.value = \ + sub_node.value.func + sub_node.value.id = 'self' + sub_node.value.args = [ # type: ignore[attr-defined] # noqa: E501 + ] + sub_node.attr = node.name + \ + '_' + sub_node.attr + break + + # record all the needed ast.ClassDef -> ast.FunctionDef + # and ast.ClassDef -> ast.Assign + func_need_append = [] + assign_need_append = [] + for super_cls_sub_node in node.body: + + # ast.Module -> ast.ClassDef -> ast.FunctionDef + if isinstance(super_cls_sub_node, ast.FunctionDef): + + # the function call as super().xxx() should be rename to + # super_cls_name_xxx() + if super_cls_sub_node.name in super_func: + super_cls_sub_node.name = node.name + '_' + \ + super_cls_sub_node.name + func_need_append.append(super_cls_sub_node) + + # NOTE: update the information of top_ast_tree + top_ast_info.class_dict.sub_func[ + super_cls_sub_node.name] = super_cls_sub_node + record_used_node(super_cls_sub_node, + extra_super_ast_info.used_module) + + # the function don't exist in top class node will be + # directly imported + elif super_cls_sub_node.name not in \ + top_ast_info.class_dict.sub_func: + func_need_append.append(super_cls_sub_node) + # if super_cls_sub_node.name == "_init_cls_convs": + # NOTE: update the information of top_ast_tree + top_ast_info.class_dict.sub_func[ + super_cls_sub_node.name] = super_cls_sub_node + record_used_node(super_cls_sub_node, + extra_super_ast_info.used_module) + + # ast.Module -> ast.ClassDef -> ast.Assign + elif isinstance(super_cls_sub_node, ast.Assign): + add_flag = True + + for name in top_ast_info.class_dict.sub_assign.keys(): + if name == super_cls_sub_node.targets[ + 0].id: # type: ignore[attr-defined] + add_flag = False + + if add_flag: + assign_need_append.append(super_cls_sub_node) + + # NOTE: update the information of top_ast_tree + top_ast_info.class_dict.end_idx += 1 # type: ignore + top_ast_info.class_dict.sub_assign[ + super_cls_sub_node. + targets[0]. # type: ignore[attr-defined] + id] = super_cls_sub_node + record_used_node(super_cls_sub_node, + extra_super_ast_info.used_module) + + # add all the needed ast.ClassDef -> ast.FunctionDef and + # ast.ClassDef -> ast.Assign to top_cls_node + if len(assign_need_append) != 0: + insert_idx = ignore_ast_docstring( + top_cls_node) # type: ignore[arg-type] # noqa: E501 + + assign_need_append.reverse() + for assign in assign_need_append: + top_cls_node.body.insert( + insert_idx, + assign) # type: ignore[arg-type] # noqa: E501 + + func_name = [func.name for func in func_need_append] + print_log( + f'Add function {func_name}.', + logger='export', + level=logging.DEBUG) + top_cls_node.body.extend( + func_need_append) # type: ignore[arg-type] # noqa: E501 + + # complete this level flatten, change the super class of + # top_cls_node + top_cls_node.bases = node.bases # type: ignore[attr-defined] + + # NOTE: update the information of top_ast_tree + top_ast_info.class_dict.level_cls_name = node.name + # HARD CODE: useless, only for preventing error when ``nn.xxx`` + # as the last super class + top_ast_info.class_dict.super_cls_name = node.bases[ # type: ignore # noqa: E501 + 0].id \ + if isinstance(node.bases[0], ast.Name) else node.bases[0] + + return extra_super_ast_info + + +def postprocess_top_ast_tree( + super_ast_tree: ast.Module, + top_ast_tree: ast.Module, + extra_super_ast_info: ExtraSuperAstInfo, + top_ast_info: TopAstInfo, +): + """Postprocess ``top_ast_tree`` with the information collected by + traversing super_ast_tree. + + This function finishes: + 1. get all the nodes needed by ``top_ast_tree`` and + exist in super_ast_tree + 2. add as local import for the new add function from super_ast_tree + preventing from covering by the same name modules on the top. + 3. add extra Import/ImportFrom of super_ast_tree to + the top of ``top_ast_tree`` + + Args: + super_ast_tree (ast.Module): The super ast tree including the super + class in the specific flatten class's mro. + top_ast_tree (ast.Module): The top ast tree contains the classes + directly called, which is continuelly updated. + extra_super_ast_info (ExtraSuperAstInfo): Contatins the extra + information of the ``super_ast_tree`` which needed to be consider. + top_ast_info (TopClassNodeInfo): Contatins the initial information + of the ``top_ast_tree``. + """ + + # record all the imported module + imported_module_name_upper = set() + for importfrom_node, alias_list in top_ast_info.importfrom_dict.values(): + for alias in alias_list: + imported_module_name_upper.add(alias) + + for name in top_ast_info.import_list: + imported_module_name_upper.add(name) + + # HARD CODE: there will be a situation that the super class and the sub + # class exist in the same file, the super class should + imported_module_name_upper.discard( + top_ast_info.class_dict.super_cls_name) # type: ignore # noqa: E501 + + # find the needed ast.ClassDef or ast.FunctionDef in super_ast_tree + need_append_node_name: Set[str] = set() + if get_len(extra_super_ast_info.used_module) != 0: + + while True: + origin_len = get_len(extra_super_ast_info.used_module) + + # super_ast_tree scope + for node in super_ast_tree.body: + if (isinstance(node, ast.ClassDef) + or isinstance(node, ast.FunctionDef)) \ + and not if_need_remove(node, + extra_super_ast_info.used_module) \ + and node.name not in imported_module_name_upper: + + need_append_node_name.add(node.name) + record_used_node(node, extra_super_ast_info.used_module) + + # if there is no longer extra new module, then search break + if get_len(extra_super_ast_info.used_module) == origin_len: + break + + # record insert_idx and classes and functions' name in top_ast_tree + insert_idx = 0 + top_cls_func_node_name_list = [] + for top_node in top_ast_tree.body: + + if isinstance(top_node, ast.Import) or isinstance( + top_node, ast.ImportFrom): + insert_idx += 1 + elif isinstance(top_node, ast.FunctionDef) or isinstance( + top_node, ast.ClassDef): + top_cls_func_node_name_list.append(top_node.name) + # super_ast_tree scope + for node in super_ast_tree.body: + + # ast.Module -> ast.Try / ast.Assign / ast.If + if (isinstance(node, ast.Try) or isinstance(node, ast.Assign) + or isinstance(node, ast.If)) \ + and not is_in_top_ast_tree(node, + top_ast_info, + top_cls_func_node_name_list): + + # NOTE: postprocess top_ast_tree + top_ast_tree.body.insert(insert_idx, node) + insert_idx += 1 + + # NOTE: update the information of top_ast_tree + if isinstance(node, ast.Try): + top_ast_info.try_list.append(node) + elif isinstance(node, ast.Assign): + top_ast_info.assign_list.append(node) + elif isinstance(node, ast.If): + top_ast_info.if_list.append(node) + + elif not if_need_remove(node, extra_super_ast_info.used_module) \ + and not is_in_top_ast_tree(node, + top_ast_info, + top_cls_func_node_name_list) \ + and node.name in need_append_node_name: # type: ignore[attr-defined] # noqa: E501 + + # ast.Module -> ast.FunctionDef + if isinstance(node, ast.FunctionDef): + + need_node_info = find_local_import( + node=node, extra_super_ast_info=extra_super_ast_info) + + add_local_import_to_func( + node=node, need_node_info=need_node_info) + + # NOTE: postprocess top_ast_tree + top_ast_tree.body.insert(insert_idx, node) + insert_idx += 1 + + # ast.Module -> ast.ClassDef + elif isinstance(node, ast.ClassDef): + add_local_import_to_class( + cls_node=node, extra_super_ast_info=extra_super_ast_info) + + # NOTE: postprocess top_ast_tree + top_ast_tree.body.insert(insert_idx, node) + insert_idx += 1 + + # the newly add functions in top_cls_node also should add local import + top_cls_node = top_ast_info.class_dict.cls_node + add_local_import_to_class( + cls_node=top_cls_node, # type: ignore[arg-type] + extra_super_ast_info=extra_super_ast_info, + new_node_begin_index=top_ast_info.class_dict.end_idx) + + # update the end_idx for next time postprocess + top_ast_info.class_dict.end_idx = len( + top_cls_node.body) # type: ignore[attr-defined] + + # postprocess global import + # all the extra import will be inserted to the top of the top_ast_tree + need_node_info = NeededNodeInfo() + + for module_name, ( + sub_node, + name_list) in extra_super_ast_info.extra_importfrom.items(): + need_node_info.need_importfrom_nodes.add(sub_node) + top_ast_info.importfrom_dict[module_name] = (sub_node, name_list) + + for alias in extra_super_ast_info.extra_import: # type: ignore + if alias.asname is not None: # type: ignore + need_node_info.need_import_alias_asname.add(alias) + top_ast_info.import_list.append(alias.asname) # type: ignore + else: + need_node_info.need_import_alias.add(alias) + top_ast_info.import_list.append(alias.name) # type: ignore + + add_local_import_to_func(node=top_ast_tree, need_node_info=need_node_info) + + +def postprocess_super_call(ast_tree: ast.Module): + """Postprocess those don't successfully process ``super()`` call. + + This is a hard code. + All the ``super(args)`` with args will be remove and turn to ``super()``. + + Args: + ast_tree (ast.Module) + """ + for node in ast_tree.body: + if isinstance(node, ast.ClassDef): + for sub_node in ast.walk(node): + if isinstance(sub_node, ast.Attribute) \ + and hasattr(sub_node, 'value') \ + and isinstance(sub_node.value, ast.Call) \ + and isinstance(sub_node.value.func, ast.Name) \ + and sub_node.value.func.id == 'super': + sub_node.value.args = [] + + +def flatten_inheritance_chain(top_ast_tree: ast.Module, obj_cls: type): + """Flatten the module. (Key Interface) + + The logic of the ``flatten_module`` are as below. + First, get the inheritance_chain by ``class.mro()`` and prune it. + Second, get the file of chosen top class and parse it to + be ``top_ast_tree``. + Third, call ``init_prepare()`` to collect the information of + ``top_ast_tree``. + + Last, for each super class in the inheritance_chain, we will do: + 1. parse the super class file as ``super_ast_tree`` and + do preprocess. + 2. call ``flatten_model()`` to visit necessary node + in ``super_ast_tree`` to change needed flattened class node and + record the information for flatten. + 3. call ``postprocess_ast_tree()`` with the information got from + ``flatten_model()`` to change the ``top_ast_tree``. + + In summary, ``top_ast_tree`` is the most important ast tree maintained and + updated from the begin to the end. + + Args: + top_ast_tree (ast.Module): The top ast tree contains the classes + directly called, which is continually updated. + obj_cls (object): The chosen top class to be flattened. + """ + print_log( + f'------------- Starting flatten model [{obj_cls.__name__}] ' + f'-------------\n' + f'\n *[mro]: {obj_cls.mro()}\n', + logger='export', + level=logging.INFO) + + # get inheritance_chain + inheritance_chain = [] + for cls in obj_cls.mro()[1:]: + if cls in [ + BaseModule, BaseModel, BaseDataPreprocessor, + ImgDataPreprocessor + ] or 'torch' in cls.__module__: + break + inheritance_chain.append(cls) + + # collect the init information of ``top_ast_tree`` + top_ast_info = init_prepare(top_ast_tree, obj_cls.__name__) + + # iteratively deal with the super class + for cls in inheritance_chain: + + modul_pth = inspect.getfile(cls) + with open(modul_pth) as f: + super_ast_tree = ast.parse(f.read()) + + ImportResolverTransformer(cls.__module__).visit(super_ast_tree) + # collect the difference between ``top_ast_tree`` and ``super_ast_tree`` # noqa: E501 + extra_super_ast_info = collect_needed_node_from_super( + super_ast_tree=super_ast_tree, top_ast_info=top_ast_info) + + # update ``top_ast_tree`` + postprocess_top_ast_tree( + super_ast_tree, + top_ast_tree, + extra_super_ast_info=extra_super_ast_info, + top_ast_info=top_ast_info, + ) + + print_log( + f'------------- Ending flatten model [{obj_cls.__name__}] ' + f'-------------\n', + logger='export', + level=logging.INFO) + + +class RegisterModuleTransformer(ast.NodeTransformer): + """Deal with repeatedly registering same module. + + Add "force=True" to register_module(force=True) for covering registered + modules. + """ + + def visit_Call(self, node): + if isinstance(node.func, ast.Attribute): + if isinstance(node.func.value, ast.Name): + if node.func.attr == 'register_module': + new_keyword = ast.keyword( + arg='force', value=ast.NameConstant(value=True)) + if node.keywords is None: + node.keywords = [new_keyword] + else: + node.keywords.append(new_keyword) + return node + + +class ImportResolverTransformer(ast.NodeTransformer): + """Deal with the relative import problem. + + Args: + import_prefix (str): The import prefix for the visit ast code + + Examples: + >>> # file_path = '/home/username/miniconda3/envs/env_name/lib' \ + >>> '/python3.9/site-packages/mmdet/models/detectors' \ + >>> '/dino.py' + >>> import_prefix = mmdet.models.detector + """ + + def __init__(self, import_prefix: str): + super().__init__() + self.import_prefix = import_prefix + + def visit_ImportFrom(self, node): + matched = self._match_alias_registry(node) + if matched is not None: + # In an ideal scenario, the `ImportResolverTransformer` would + # modify the import sources of all `Registry` from downstream + # algorithm libraries (`mmdet`) to `pack`, for example, convert + # `from mmdet.models import DETECTORS` to + # `from pack.models import DETECTORS`. + + # However, some algorithm libraries, such as `mmpose`, provide + # aliases for `MODELS`, `TASK_UTILS`, and other registries, + # as seen here: https://github.com/open-mmlab/mmpose/blob/537bd8e543ab463fb55120d5caaa1ae22d6aaf06/mmpose/models/builder.py#L13. # noqa: E501 + + # For these registries with aliases, we cannot directly import from + # `pack.registry` because `pack.registry` is copied from + # `mmpose.registry` and does not contain these aliases. + + # Therefore, we gather all registries with aliases under + # `mim._internal.export.patch_utils` and hardcode the redirection + # of import sources. + if matched == 'MODELS': + node.module = 'mim._internal.export.patch_utils.patch_model' + elif matched == 'TASK_UTILS': + node.module = 'mim._internal.export.patch_utils.patch_task' + node.level = 0 + return node + + else: + # deal with relative import + if node.level != 0: + import_prefix = '.'.join( + self.import_prefix.split('.')[:-node.level]) + if node.module is not None: + node.module = import_prefix + '.' + node.module + else: + # from . import xxx + node.module = import_prefix + node.level = 0 + + if 'registry' in node.module \ + and not node.module.startswith('mmengine'): + node.module = 'pack.registry' + + return node + + def _match_alias_registry(self, node) -> Optional[str]: + match_patch_key = None + for key, list_value in OBJECTS_TO_BE_PATCHED.items(): + for alias in node.names: + if alias.name in list_value: + match_patch_key = key + break + + if match_patch_key is not None: + break + return match_patch_key diff --git a/mim/_internal/export/pack_cfg.py b/mim/_internal/export/pack_cfg.py new file mode 100644 index 0000000..4a0ee15 --- /dev/null +++ b/mim/_internal/export/pack_cfg.py @@ -0,0 +1,312 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +import os +import os.path as osp +import shutil +import signal +import sys +import tempfile +import traceback +from datetime import datetime +from typing import Optional, Union + +from mmengine import MMLogger +from mmengine.config import Config, ConfigDict +from mmengine.hub import get_config +from mmengine.logging import print_log +from mmengine.registry import init_default_scope +from mmengine.runner import Runner +from mmengine.utils import get_installed_path, mkdir_or_exist + +from mim.utils import echo_error +from .common import _import_pack_str, _init_str +from .utils import ( + _get_all_files, + _postprocess_importfrom_module_to_pack, + _postprocess_registry_locations, + _replace_config_scope_to_pack, + _wrapper_all_registries_build_func, +) + + +def export_from_cfg(cfg: Union[str, ConfigDict], + export_root_dir: str, + model_only: Optional[bool] = False, + save_log: Optional[bool] = False): + """A function to pack the minimum available package according to config + file. + + Args: + cfg (:obj:`ConfigDict` or str): Config file for packing the + minimum package. + export_root_dir (str, optional): The pack directory to save the + packed package. + fast_test (bool, optional): Turn to fast testing mode. + Defaults to False. + """ + # generate temp dir for export + export_root_tempdir = tempfile.TemporaryDirectory() + export_log_tempdir = tempfile.TemporaryDirectory() + + # delete the incomplete export package when keyboard interrupt + signal.signal(signal.SIGINT, lambda sig, frame: keyboardinterupt_handler( + sig, frame, export_root_tempdir, export_log_tempdir) + ) # type: ignore[arg-type] + + # get config + if isinstance(cfg, str): + if '::' in cfg: + cfg = get_config(cfg) + else: + cfg = Config.fromfile(cfg) + + default_scope = cfg.get('default_scope', 'mmengine') + + # automatically generate ``export_root_dir`` + if export_root_dir is None: + export_root_dir = f'pack_from_{default_scope}_' + \ + f"{datetime.now().strftime(r'%Y%m%d_%H%M%S')}" + + # generate ``export_log_dir`` + if osp.sep in export_root_dir: + export_path = osp.dirname(export_root_dir) + else: + export_path = os.getcwd() + export_log_dir = osp.join(export_path, 'export_log') + + export_logger = MMLogger.get_instance( # noqa: F841 + 'export', + log_file=osp.join(export_log_tempdir.name, 'export.log')) + + export_module_tempdir_name = osp.join(export_root_tempdir.name, 'pack') + + # export config + if '.mim' in cfg.filename: + cfg_path = osp.join(export_module_tempdir_name, + cfg.filename[cfg.filename.find('configs'):]) + else: + cfg_path = osp.join( + osp.join(export_module_tempdir_name, 'configs'), + osp.basename(cfg.filename)) + mkdir_or_exist(osp.dirname(cfg_path)) + + # transform to default_scope + init_default_scope(default_scope) + + # wrap ``Registry.build()`` for exporting modules + _wrapper_all_registries_build_func( + export_module_dir=export_module_tempdir_name, scope=default_scope) + + print_log( + f'[ Export Package Name ]: {export_root_dir}\n' + f' package from config: {cfg.filename}\n' + f" from downstream package: '{default_scope}'\n", + logger='export', + level=logging.INFO) + + # creat temp work_dirs for export + cfg['work_dir'] = export_log_tempdir.name + + # use runner to export all needed modules + runner = Runner.from_cfg(cfg) + + # HARD CODE: In order to deal with some module will build in + # ``before_run`` or ``after_run``, we can call them without need + # to call "runner.train())". + + # Example: + # >>> @HOOKS.register_module() + # >>> class EMAHook(Hook): + # >>> ... + # >>> def before_run(self, runner) -> None: + # >>> """Create an ema copy of the model. + # >>> Args: + # >>> runner (Runner): The runner of the training process. + # >>> """ + # >>> model = runner.model + # >>> if is_model_wrapper(model): + # >>> model = model.module + # >>> self.src_model = model + # >>> self.ema_model = MODELS.build( + # >>> self.ema_cfg, default_args=dict(model=self.src_model)) + + # It need to build ``self.ema_model`` in ``before_run``. + + for hook in runner.hooks: + hook.before_run(runner) + hook.after_run(runner) + + def dump(): + cfg['work_dir'] = 'work_dirs' # recover to default. + _replace_config_scope_to_pack(cfg) + cfg.dump(cfg_path) + + # copy temp log to export log + if save_log: + shutil.copytree( + export_log_tempdir.name, export_log_dir, dirs_exist_ok=True) + + export_log_tempdir.cleanup() + + # copy temp_package_dir to export_package_dir + shutil.copytree( + export_root_tempdir.name, export_root_dir, dirs_exist_ok=True) + export_root_tempdir.cleanup() + + print_log( + f'[ Export Package Name ]: ' + f'{osp.join(os.getcwd(), export_root_dir)}\n', + logger='export', + level=logging.INFO) + + if model_only: + dump() + return 0 + + try: + runner.build_train_loop(cfg.train_cfg) + except FileNotFoundError: + error_postprocess(export_log_dir, default_scope, + export_root_tempdir, export_log_tempdir, + osp.basename(cfg_path), 'train_dataloader') + + try: + if 'val_cfg' in cfg and cfg.val_cfg is not None: + runner.build_val_loop(cfg.val_cfg) + except FileNotFoundError: + error_postprocess(export_log_dir, default_scope, + export_root_tempdir, export_log_tempdir, + osp.basename(cfg_path), 'val_dataloader') + + try: + if 'test_cfg' in cfg and cfg.test_cfg is not None: + runner.build_test_loop(cfg.test_cfg) + except FileNotFoundError: + error_postprocess(export_log_dir, default_scope, + export_root_tempdir, export_log_tempdir, + osp.basename(cfg_path), 'test_dataloader') + + if 'optim_wrapper' in cfg and cfg.optim_wrapper is not None: + runner.optim_wrapper = runner.build_optim_wrapper(cfg.optim_wrapper) + if 'param_scheduler' in cfg and cfg.param_scheduler is not None: + runner.build_param_scheduler(cfg.param_scheduler) + + # add ``__init__.py`` to all dirs, for transferring directories + # to be modules + for directory, _, _ in os.walk(export_module_tempdir_name): + if not osp.exists(osp.join(directory, '__init__.py')) \ + and 'configs' not in directory: + with open(osp.join(directory, '__init__.py'), 'w') as f: + f.write(_init_str) + + # postprocess for ``pack/registry.py`` + _postprocess_registry_locations(export_root_tempdir.name) + + # postprocess for ImportFrom Node, turn to import from export path + all_export_files = _get_all_files(export_module_tempdir_name) + for file in all_export_files: + _postprocess_importfrom_module_to_pack(file) + + # get tools from web + tools_dir = osp.join(export_root_tempdir.name, 'tools') + mkdir_or_exist(tools_dir) + + for tool_name in [ + 'train.py', 'test.py', 'dist_train.sh', 'dist_test.sh', + 'slurm_train.sh', 'slurm_test.sh' + ]: + pack_tools( + tool_name=tool_name, + scope=default_scope, + tool_dir=tools_dir, + auto_import=True) + + # TODO: get demo.py + + dump() + return 0 + + +def keyboardinterupt_handler( + sig: int, + frame, + export_root_tempdir: tempfile.TemporaryDirectory, + export_log_tempdir: tempfile.TemporaryDirectory, +): + """Clear uncompleted exported package by interrupting with keyboard.""" + + export_log_tempdir.cleanup() + export_root_tempdir.cleanup() + + sys.exit(-1) + + +def error_postprocess(export_log_dir: str, scope: str, + export_root_dir_tempfile: tempfile.TemporaryDirectory, + export_log_dir_tempfile: tempfile.TemporaryDirectory, + cfg_name: str, error_key: str): + """Print Debug message when package can't successfully export for missing + datasets. + + Args: + export_root_dir (str): _description_ + absolute_cfg_path (str): _description_ + origin_cfg (ConfigDict): _description_ + error_key (str): _description_ + logger (_type_): _description_ + """ + shutil.copytree( + export_log_dir_tempfile.name, export_log_dir, dirs_exist_ok=True) + export_root_dir_tempfile.cleanup() + export_log_dir_tempfile.cleanup() + + traceback.print_exc() + + error_msg = f"{'=' * 20} Debug Message {'=' * 20}"\ + f"\nThe data root of '{error_key}' is not found. You can"\ + ' use the below two method to deal with.\n\n'\ + " >>> Method 1: Please modify the 'data_root' in"\ + f" duplicate config '{export_log_dir}/{cfg_name}'.\n"\ + " >>> Method 2: Use '--model_only' to export model only.\n\n"\ + "After finishing one of the above steps, you can use 'mim export"\ + f" {scope} {export_log_dir}/{cfg_name} [--model-only]' to export"\ + ' again.' + + echo_error(error_msg) + + sys.exit(-1) + + +def pack_tools(tool_name: str, + scope: str, + tool_dir: str, + auto_import: Optional[bool] = False): + """pack tools from installed repo. + + Args: + tool_name (str): Tool name in repos' tool dir. + scope (str): The scope of repo. + path (str): Path to save tool. + auto_import (bool, optional): Automatically add "import pack" to the + tool file. Defaults to "False" + """ + pkg_root = get_installed_path(scope) + path = osp.join(tool_dir, tool_name) + + if os.path.exists(path): + os.remove(path) + + # tools will be put in package/.mim in PR #68 + tool_script = osp.join(pkg_root, '.mim', 'tools', tool_name) + if not osp.exists(tool_script): + tool_script = osp.join(pkg_root, 'tools', tool_name) + + shutil.copy(tool_script, path) + + # automatically import the pack modules + if auto_import: + with open(path, 'r+') as f: + lines = f.readlines() + code = ''.join(lines[:1] + [_import_pack_str] + lines[1:]) + f.seek(0) + f.write(code) diff --git a/mim/_internal/export/patch_utils/README.md b/mim/_internal/export/patch_utils/README.md new file mode 100644 index 0000000..d584c3e --- /dev/null +++ b/mim/_internal/export/patch_utils/README.md @@ -0,0 +1,75 @@ +# Patch Utils + +## Problem + +This patch is mainly to solve the problem that the module cannot be properly registered due to the renaming of the registry in the downstream repo, such as an example of the `mmsegmentation`: + +```python +# "mmsegmentation/mmseg/structures/sampler/builder.py" + +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmseg.registry import TASK_UTILS + +PIXEL_SAMPLERS = TASK_UTILS + + +def build_pixel_sampler(cfg, **default_args): + """Build pixel sampler for segmentation map.""" + warnings.warn( + '``build_pixel_sampler`` would be deprecated soon, please use ' + '``mmseg.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) +``` + +Some modules may use the renamed registry, which makes it difficult for `mim export` to find the original name of the renamed modules. + +```python +# "mmsegmentation/mmseg/structures/sampler/ohem_pixel_sampler.py" + +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_pixel_sampler import BasePixelSampler +from .builder import PIXEL_SAMPLERS + + +@PIXEL_SAMPLERS.register_module() +class OHEMPixelSampler(BasePixelSampler): + ... +``` + +## Solution + +Therefore, we have currently migrated the necessary modules in `mmpose/mmdetection/mmseg/mmpretrain` listed below, directly to `patch_utils.patch_model` and `patch_utils.patch_task`. In order to build a patch containing renamed registry and special module constructor functions. + +```python +"mmdetection/mmdet/models/task_modules/builder.py" +"mmdetection/build/lib/mmdet/models/task_modules/builder.py" + +"mmsegmentation/mmseg/models/builder.py" +"mmsegmentation/mmseg/structures/sampler/builder.py" +"mmsegmentation/build/lib/mmseg/models/builder.py" +"mmsegmentation/build/lib/mmseg/structures/sampler/builder.py" + +"mmpretrain/mmpretrain/datasets/builder.py" +"mmpretrain/mmpretrain/models/builder.py" +"mmpretrain/build/lib/mmpretrain/datasets/builder.py" +"mmpretrain/build/lib/mmpretrain/models/builder.py" + +"mmpose/mmpose/datasets/builder.py" +"mmpose/mmpose/models/builder.py" +"mmpose/build/lib/mmpose/datasets/builder.py" +"mmpose/build/lib/mmpose/models/builder.py" + +"mmengine/mmengine/evaluator/builder.py" +"mmengine/mmengine/model/builder.py" +"mmengine/mmengine/optim/optimizer/builder.py" +"mmengine/mmengine/visualization/builder.py" +"mmengine/build/lib/mmengine/evaluator/builder.py" +"mmengine/build/lib/mmengine/model/builder.py" +"mmengine/build/lib/mmengine/optim/optimizer/builder.py" +``` diff --git a/mim/_internal/export/patch_utils/README_zh-CN.md b/mim/_internal/export/patch_utils/README_zh-CN.md new file mode 100644 index 0000000..e58fcac --- /dev/null +++ b/mim/_internal/export/patch_utils/README_zh-CN.md @@ -0,0 +1,75 @@ +# Patch Utils + +## 问题 + +该补丁主要是为了解决下游 repo 中存在对注册器进行重命名导致模块无法被正确注册的问题,如 `mmsegmentation` 中的一个例子: + +```python +# "mmsegmentation/mmseg/structures/sampler/builder.py" + +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmseg.registry import TASK_UTILS + +PIXEL_SAMPLERS = TASK_UTILS + + +def build_pixel_sampler(cfg, **default_args): + """Build pixel sampler for segmentation map.""" + warnings.warn( + '``build_pixel_sampler`` would be deprecated soon, please use ' + '``mmseg.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) +``` + +然后在某些模块中可能会使用改名后的注册器,这对于导出后处理很难找到重命名模块原来的名字 + +```python +# "mmsegmentation/mmseg/structures/sampler/ohem_pixel_sampler.py" + +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_pixel_sampler import BasePixelSampler +from .builder import PIXEL_SAMPLERS + + +@PIXEL_SAMPLERS.register_module() +class OHEMPixelSampler(BasePixelSampler): + ... +``` + +## 解决方案 + +因此我们目前已经将如下 `mmpose / mmdetection / mmseg / mmpretrain` 中必要的模块直接迁移到 `patch_utils.patch_model` 和 `patch_utils.patch_task` 中构建一个包含注册器重命名和特殊模块构造函数的补丁。 + +```python +"mmdetection/mmdet/models/task_modules/builder.py" +"mmdetection/build/lib/mmdet/models/task_modules/builder.py" + +"mmsegmentation/mmseg/models/builder.py" +"mmsegmentation/mmseg/structures/sampler/builder.py" +"mmsegmentation/build/lib/mmseg/models/builder.py" +"mmsegmentation/build/lib/mmseg/structures/sampler/builder.py" + +"mmpretrain/mmpretrain/datasets/builder.py" +"mmpretrain/mmpretrain/models/builder.py" +"mmpretrain/build/lib/mmpretrain/datasets/builder.py" +"mmpretrain/build/lib/mmpretrain/models/builder.py" + +"mmpose/mmpose/datasets/builder.py" +"mmpose/mmpose/models/builder.py" +"mmpose/build/lib/mmpose/datasets/builder.py" +"mmpose/build/lib/mmpose/models/builder.py" + +"mmengine/mmengine/evaluator/builder.py" +"mmengine/mmengine/model/builder.py" +"mmengine/mmengine/optim/optimizer/builder.py" +"mmengine/mmengine/visualization/builder.py" +"mmengine/build/lib/mmengine/evaluator/builder.py" +"mmengine/build/lib/mmengine/model/builder.py" +"mmengine/build/lib/mmengine/optim/optimizer/builder.py" +``` diff --git a/mim/_internal/export/patch_utils/__init__.py b/mim/_internal/export/patch_utils/__init__.py new file mode 100644 index 0000000..c5a04e3 --- /dev/null +++ b/mim/_internal/export/patch_utils/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .patch_model import * # noqa: F401, F403 +from .patch_task import * # noqa: F401, F403 diff --git a/mim/_internal/export/patch_utils/patch_model.py b/mim/_internal/export/patch_utils/patch_model.py new file mode 100644 index 0000000..b0a35af --- /dev/null +++ b/mim/_internal/export/patch_utils/patch_model.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmengine.registry import MODELS + +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +SEGMENTORS = MODELS + + +def build_backbone(cfg): + """Build backbone.""" + warnings.warn('``build_backbone`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') + return BACKBONES.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + warnings.warn('``build_neck`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') + return NECKS.build(cfg) + + +def build_head(cfg): + """Build head.""" + warnings.warn('``build_head`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') + return HEADS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + warnings.warn('``build_loss`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') + return LOSSES.build(cfg) + + +def build_segmentor(cfg, train_cfg=None, test_cfg=None): + """Build segmentor.""" + if train_cfg is not None or test_cfg is not None: + warnings.warn( + 'train_cfg and test_cfg is deprecated, ' + 'please specify them in model', UserWarning) + assert cfg.get('train_cfg') is None or train_cfg is None, \ + 'train_cfg specified in both outer field and model field ' + assert cfg.get('test_cfg') is None or test_cfg is None, \ + 'test_cfg specified in both outer field and model field ' + return SEGMENTORS.build( + cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) + + +CLASSIFIERS = MODELS +RETRIEVER = MODELS + + +def build_classifier(cfg): + """Build classifier.""" + return CLASSIFIERS.build(cfg) + + +def build_retriever(cfg): + """Build retriever.""" + return RETRIEVER.build(cfg) + + +POSE_ESTIMATORS = MODELS + + +def build_pose_estimator(cfg): + """Build pose estimator.""" + return POSE_ESTIMATORS.build(cfg) + + +def build_posenet(cfg): + """Build posenet.""" + warnings.warn( + '``build_posenet`` will be deprecated soon, ' + 'please use ``build_pose_estimator`` instead.', DeprecationWarning) + return build_pose_estimator(cfg) diff --git a/mim/_internal/export/patch_utils/patch_task.py b/mim/_internal/export/patch_utils/patch_task.py new file mode 100644 index 0000000..14c3150 --- /dev/null +++ b/mim/_internal/export/patch_utils/patch_task.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmengine.registry import TASK_UTILS + +PRIOR_GENERATORS = TASK_UTILS +ANCHOR_GENERATORS = TASK_UTILS +BBOX_ASSIGNERS = TASK_UTILS +BBOX_SAMPLERS = TASK_UTILS +BBOX_CODERS = TASK_UTILS +MATCH_COSTS = TASK_UTILS +IOU_CALCULATORS = TASK_UTILS + + +def build_bbox_coder(cfg, **default_args): + """Builder of box coder.""" + warnings.warn('``build_sampler`` would be deprecated soon, please use ' + '``mmdet.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) + + +def build_iou_calculator(cfg, default_args=None): + """Builder of IoU calculator.""" + warnings.warn( + '``build_iou_calculator`` would be deprecated soon, please use ' + '``mmdet.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) + + +def build_match_cost(cfg, default_args=None): + """Builder of IoU calculator.""" + warnings.warn('``build_match_cost`` would be deprecated soon, please use ' + '``mmdet.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) + + +def build_assigner(cfg, **default_args): + """Builder of box assigner.""" + warnings.warn('``build_assigner`` would be deprecated soon, please use ' + '``mmdet.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) + + +def build_sampler(cfg, **default_args): + """Builder of box sampler.""" + warnings.warn('``build_sampler`` would be deprecated soon, please use ' + '``mmdet.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) + + +def build_prior_generator(cfg, default_args=None): + warnings.warn( + '``build_prior_generator`` would be deprecated soon, please use ' + '``mmdet.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) + + +def build_anchor_generator(cfg, default_args=None): + warnings.warn( + '``build_anchor_generator`` would be deprecated soon, please use ' + '``mmdet.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) + + +PIXEL_SAMPLERS = TASK_UTILS + + +def build_pixel_sampler(cfg, **default_args): + """Build pixel sampler for segmentation map.""" + warnings.warn( + '``build_pixel_sampler`` would be deprecated soon, please use ' + '``mmseg.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) diff --git a/mim/_internal/export/utils.py b/mim/_internal/export/utils.py new file mode 100644 index 0000000..477fd02 --- /dev/null +++ b/mim/_internal/export/utils.py @@ -0,0 +1,555 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import ast +import copy +import importlib +import inspect +import logging +import os +import os.path as osp +import shutil +from functools import wraps +from typing import Callable + +import torch.nn +from mmengine import mkdir_or_exist +from mmengine.config import ConfigDict +from mmengine.logging import print_log +from mmengine.model import ( + BaseDataPreprocessor, + BaseModel, + BaseModule, + ImgDataPreprocessor, +) +from mmengine.registry import Registry +from yapf.yapflib.yapf_api import FormatCode + +from mim.utils import OFFICIAL_MODULES +from .common import REGISTRY_TYPES +from .flatten_func import * # noqa: F403, F401 +from .flatten_func import ( + ImportResolverTransformer, + RegisterModuleTransformer, + flatten_inheritance_chain, + ignore_ast_docstring, + postprocess_super_call, +) + + +def format_code(code_text: str): + """Format the code text with yapf.""" + yapf_style = dict( + based_on_style='pep8', + blank_line_before_nested_class_or_def=True, + split_before_expression_after_opening_paren=True) + try: + code_text, _ = FormatCode(code_text, style_config=yapf_style) + except: # noqa: E722 + raise SyntaxError('Failed to format the config file, please ' + f'check the syntax of: \n{code_text}') + + return code_text + + +def _postprocess_registry_locations(export_root_dir: str): + """Remove the Registry.locations if it doesn't exist. + + Check the location path for Registry to load modules if the path hasn't + been exported, then need to be removed. Finally will use the root Registry + to find module until it actually doesn't exist. + """ + export_module_dir = osp.join(export_root_dir, 'pack') + + with open( + osp.join(export_module_dir, 'registry.py'), encoding='utf-8') as f: + ast_tree = ast.parse(f.read()) + + for node in ast.walk(ast_tree): + """node structure. + + Assign( targets=[ Name(id='EVALUATORS', ctx=Store())], + value=Call( func=Name(id='Registry', ctx=Load()), args=[ + Constant(value='evaluator')], keywords=[ keyword( arg='parent', + value=Name(id='MMENGINE_EVALUATOR', ctx=Load())), keyword( + arg='locations', value=List( elts=[ + Constant(value='pack.evaluation')], ctx=Load()))])), + """ + if isinstance(node, ast.Call): + need_to_be_remove = None + + for keyword in node.keywords: + if keyword.arg == 'locations': + for sub_node in ast.walk(keyword): + + # the locations of Registry already transfer to `pack` + # scope before. if the location path is exist, then + # turn to pack scope + if isinstance( + sub_node, + ast.Constant) and 'pack' in sub_node.value: + + path = sub_node.value + if not osp.exists( + osp.join(export_root_dir, path).replace( + '.', osp.sep)): + print_log( + '[ Pass ] Remove Registry.locations ' + f"'{osp.join(export_root_dir, path).replace('.',osp.sep)}', " # noqa: E501 + 'which is no need to export.', + logger='export', + level=logging.DEBUG) + need_to_be_remove = keyword + break + + if need_to_be_remove is not None: + break + + if need_to_be_remove is not None: + node.keywords.remove(need_to_be_remove) + + with open( + osp.join(export_module_dir, 'registry.py'), 'w', + encoding='utf-8') as f: + f.write(format_code(ast.unparse(ast_tree))) + + +def _get_all_files(directory: str): + """Get all files of the directory. + + Args: + directory (str): The directory path. + + Returns: + List: Return the a list containing all the files in the directory. + """ + file_paths = [] + for root, dirs, files in os.walk(directory): + for file in files: + if '__init__' not in file and 'registry.py' not in file: + file_paths.append(os.path.join(root, file)) + + return file_paths + + +def _postprocess_importfrom_module_to_pack(file_path: str): + """Transfer the importfrom path from "downstream repo" to export module. + + Args: + file_path (str): The path of file needed to be transfer. + + Examples: + >>> from mmdet.models.detectors.two_stage import TwoStageDetector + >>> # transfer to below, if "TwoStageDetector" had been exported + >>> from pack.models.detectors.two_stage import TwoStageDetector + """ + from mmengine import Registry + + # _module_path_dict is a class attribute, + # already record all the exported module and their path before + _module_path_dict = Registry._module_path_dict + + with open(file_path, encoding='utf-8') as f: + ast_tree = ast.parse(f.read()) + + # if the import module have the same name with the object in these file, + # they import path won't be change + can_not_change_module = [] + for node in ast_tree.body: + if isinstance(node, ast.FunctionDef) or isinstance(node, ast.ClassDef): + can_not_change_module.append(node.name) + + def check_change_importfrom_node(node: ast.ImportFrom): + """Check if the ImportFrom node should be changed. + + If the modules in node had already been exported, they will be + separated and compose a new ast.ImportFrom node with the export + path as the module path. + + Args: + node (ast.ImportFrom): ImportFrom node. + + Returns: + ast.ImportFrom | None: Return a new ast.ImportFrom node + if one of the module in node had been export else ``None``. + """ + export_module_path = None + needed_change_alias = [] + + for alias in node.names: + # if the import module's name is equal to the class or function + # name, it can not be transfer for avoiding circular import. + if alias.name in _module_path_dict.keys( + ) and alias.name not in can_not_change_module: + + if export_module_path is None: + export_module_path = _module_path_dict[alias.name] + else: + assert _module_path_dict[alias.name] == \ + export_module_path,\ + 'There are two module from the same downstream repo,'\ + " but can't change to the same export path." + + needed_change_alias.append(alias) + + if len(needed_change_alias) != 0: + for alias in needed_change_alias: + node.names.remove(alias) + + return ast.ImportFrom( + module=export_module_path, names=needed_change_alias, level=0) + + return None + + # Naming rules for searching ast syntax tree + # - node: node of ast.Module + # - func_sub_node: sub_node of ast.FunctionDef + # - class_sub_node: sub_node of ast.ClassDef + # - func_sub_class_sub_node: sub_node ast.FunctionDef in ast.ClassDef + + # record the insert_idx and node needed to be insert for later insert. + insert_idx_and_node = {} + + insert_idx = 0 + for idx, node in enumerate(ast_tree.body): + + # search ast.ImportFrom in ast.Module scope + # ast.Module -> ast.ImportFrom + if isinstance(node, ast.ImportFrom): + insert_idx += 1 + temp_node = check_change_importfrom_node(node) + if temp_node is not None: + if len(node.names) == 0: + ast_tree.body[idx] = temp_node + else: + insert_idx_and_node[insert_idx] = temp_node + insert_idx += 1 + + elif isinstance(node, ast.Import): + insert_idx += 1 + + else: + # search ast.ImportFrom in ast.FunctionDef scope + # ast.Module -> ast.FunctionDef -> ast.ImportFrom + if isinstance(node, ast.FunctionDef): + temp_func_insert_idx = ignore_ast_docstring(node) + func_need_to_be_removed_nodes = [] + + for func_sub_node in node.body: + if isinstance(func_sub_node, ast.ImportFrom): + temp_node = check_change_importfrom_node( + func_sub_node) # noqa: E501 + if temp_node is not None: + node.body.insert(temp_func_insert_idx, temp_node) + + # if importfrom module is empty, the node should be remove # noqa: E501 + if len(func_sub_node.names) == 0: + func_need_to_be_removed_nodes.append( + func_sub_node) # noqa: E501 + + for need_to_be_removed_node in func_need_to_be_removed_nodes: + node.body.remove(need_to_be_removed_node) + + # search ast.ImportFrom in ast.ClassDef scope + # ast.Module -> ast.ClassDef -> ast.ImportFrom + # -> ast.FunctionDef -> ast.ImportFrom + elif isinstance(node, ast.ClassDef): + temp_class_insert_idx = ignore_ast_docstring(node) + class_need_to_be_removed_nodes = [] + + for class_sub_node in node.body: + + # ast.Module -> ast.ClassDef -> ast.ImportFrom + if isinstance(class_sub_node, ast.ImportFrom): + temp_node = check_change_importfrom_node( + class_sub_node) + if temp_node is not None: + node.body.insert(temp_class_insert_idx, temp_node) + if len(class_sub_node.names) == 0: + class_need_to_be_removed_nodes.append( + class_sub_node) + + # ast.Module -> ast.ClassDef -> ast.FunctionDef -> ast.ImportFrom # noqa: E501 + elif isinstance(class_sub_node, ast.FunctionDef): + temp_class_sub_insert_idx = ignore_ast_docstring(node) + func_need_to_be_removed_nodes = [] + + for func_sub_class_sub_node in class_sub_node.body: + if isinstance(func_sub_class_sub_node, + ast.ImportFrom): + temp_node = check_change_importfrom_node( + func_sub_class_sub_node) + if temp_node is not None: + node.body.insert(temp_class_sub_insert_idx, + temp_node) + if len(func_sub_class_sub_node.names) == 0: + func_need_to_be_removed_nodes.append( + func_sub_class_sub_node) + + for need_to_be_removed_node in func_need_to_be_removed_nodes: # noqa: E501 + class_sub_node.body.remove(need_to_be_removed_node) + + for class_need_to_be_removed_node in class_need_to_be_removed_nodes: # noqa: E501 + node.body.remove(class_need_to_be_removed_node) + + # lazy add new ast.ImportFrom node to ast.Module + for insert_idx, temp_node in insert_idx_and_node.items(): + ast_tree.body.insert(insert_idx, temp_node) + + with open(file_path, 'w', encoding='utf-8') as f: + f.write(format_code(ast.unparse(ast_tree))) + + +def _replace_config_scope_to_pack(cfg: ConfigDict): + """Replace the config scope from "mmxxx" to "pack". + + Args: + cfg (ConfigDict): The config dict to be replaced. + """ + + for key, value in cfg.items(): + if key == '_scope_' or key == 'default_scope': + cfg[key] = 'pack' + elif isinstance(value, dict): + _replace_config_scope_to_pack(value) + + +def _wrapper_all_registries_build_func(export_module_dir: str, scope: str): + """A function to wrap all registries' build_func. + + Args: + pack_module_dir (str): The root dir for packing modules. + scope (str): The default scope of the config. + """ + # copy the downstream repo.registry to pack.registry + # and change all the registry.locations + repo_registries = importlib.import_module('.registry', scope) + origin_file = inspect.getfile(repo_registries) + registry_path = osp.join(export_module_dir, 'registry.py') + shutil.copy(origin_file, registry_path) + + # replace 'repo' name in Registry.locations to 'pack' + with open( + osp.join(export_module_dir, 'registry.py'), encoding='utf-8') as f: + ast_tree = ast.parse(f.read()) + for node in ast.walk(ast_tree): + if isinstance(node, ast.Constant): + if scope in node.value: + node.value = node.value.replace(scope, 'pack') + + with open(osp.join(export_module_dir, 'registry.py'), 'w') as f: + f.write(format_code(ast.unparse(ast_tree))) + + # prevent circular registration + Registry._extra_module_set = set() + + # record the exported module for postprocessing the importfrom path + Registry._module_path_dict = {} + + # prevent circular wrapper + if Registry.build.__name__ == 'wrapper': + Registry.build = _wrap_build(Registry.init_build_func, + export_module_dir) + Registry.get = _wrap_get(Registry.init_get_func, export_module_dir) + else: + Registry.init_build_func = copy.deepcopy(Registry.build) + Registry.init_get_func = copy.deepcopy(Registry.get) + Registry.build = _wrap_build(Registry.build, export_module_dir) + Registry.get = _wrap_get(Registry.get, export_module_dir) + + +def ignore_self_cache(func): + """Ignore the ``@lru_cache`` for function. + + Args: + func (Callable): The function to be ignored. + + Returns: + Callable: The function without ``@lru_cache``. + """ + cache = {} + + @wraps(func) + def wrapper(self, *args, **kwargs): + key = args + if key not in cache: + cache[key] = 1 + func(self, *args, **kwargs) + else: + return + + return wrapper + + +@ignore_self_cache +def _export_module(self, obj_cls: type, pack_module_dir, obj_type: str): + """Export module. + + This function will get the object's file and export to + ``pack_module_dir``. + + If the object is built by ``MODELS`` registry, all the objects + as the top classes in this file, will be iteratively flattened. + Else will be directly exported. + + The flatten logic is: + 1. get the origin file of object, which built + by ``MODELS.build()`` + 2. get all the classes in the origin file + 3. flatten all the classes but not only the object + 4. call ``flatten_module()`` to finish flatten + according to ``class.mro()`` + + Args: + obj (object): The object to be flatten. + """ + # find the file by obj class + file_path = inspect.getfile(obj_cls) + + if osp.exists(file_path): + print_log( + f'building class: ' + f'{obj_cls.__name__} from file: {file_path}.', + logger='export', + level=logging.DEBUG) + else: + raise FileExistsError(f"file [{file_path}] doesn't exist.") + + # local origin module + module = obj_cls.__module__ + parent = module.split('.')[0] + new_module = module.replace(parent, 'pack') + + # Not necessary to export module implemented in `mmcv` and `mmengine` + if parent in set(OFFICIAL_MODULES) - {'mmcv', 'mmengine'}: + + with open(file_path, encoding='utf-8') as f: + top_ast_tree = ast.parse(f.read()) + + # deal with relative import + ImportResolverTransformer(module).visit(top_ast_tree) + + # NOTE: ``MODELS.build()`` means to flatten model module + if self.name == 'model': + + # record all the class needed to be flattened + need_to_be_flattened_class_names = [] + for node in top_ast_tree.body: + if isinstance(node, ast.ClassDef): + need_to_be_flattened_class_names.append(node.name) + + imported_module = importlib.import_module(obj_cls.__module__) + for cls_name in need_to_be_flattened_class_names: + + # record the exported module for postprocessing the importfrom path # noqa: E501 + self._module_path_dict[cls_name] = new_module + + cls = getattr(imported_module, cls_name) + + for super_cls in cls.__bases__: + + # the class only will be flattened when: + # 1. super class doesn't exist in this file + # 2. and super class is not base class + # 3. and super class is not torch module + if super_cls.__name__\ + not in need_to_be_flattened_class_names \ + and (super_cls not in [BaseModule, + BaseModel, + BaseDataPreprocessor, + ImgDataPreprocessor]) \ + and 'torch' not in super_cls.__module__: # noqa: E501 + + print_log( + f'need_flatten: {cls_name}\n', + logger='export', + level=logging.INFO) + + flatten_inheritance_chain(top_ast_tree, cls) + break + postprocess_super_call(top_ast_tree) + + else: + self._module_path_dict[obj_cls.__name__] = new_module + + # add ``register_module(force=True)`` to cover the registered modules # noqa: E501 + RegisterModuleTransformer().visit(top_ast_tree) + + # unparse ast tree and save reformat code + new_file_path = new_module.split('.', 1)[1].replace('.', + osp.sep) + '.py' + new_file_path = osp.join(pack_module_dir, new_file_path) + new_dir = osp.dirname(new_file_path) + mkdir_or_exist(new_dir) + + with open(new_file_path, mode='w') as f: + f.write(format_code(ast.unparse(top_ast_tree))) + + # Downstream repo could register torch module into Registry, such as + # registering `torch.nn.Linear` into `MODELS`. We need to reserve these + # codes in the exported module. + elif 'torch' in module.split('.')[0]: + + # get the root registry, because it can get all the modules + # had been registered. + root_registry = self if self.parent is None else self.parent + if (obj_type not in self._extra_module_set) and ( + root_registry.init_get_func(obj_type) is None): + self._extra_module_set.add(obj_type) + with open(osp.join(pack_module_dir, 'registry.py'), 'a') as f: + + # TODO: When the downstream repo registries' name are + # different with mmengine, the module may not be registried + # to the right register. + # For example: `EVALUATOR` in mmengine, `EVALUATORS` in mmdet. + f.write('\n') + f.write(f'from {module} import {obj_cls.__name__}\n') + f.write( + f"{REGISTRY_TYPES[self.name]}.register_module('{obj_type}', module={obj_cls.__name__}, force=True)" # noqa: E501 + ) + + +def _wrap_build(build_func: Callable, pack_module_dir: str): + """wrap Registry.build() + + Args: + build_func (Callable): ``Registry.build()``, which will be wrapped. + pack_module_dir (str): Modules export path. + """ + + def wrapper(self, cfg: dict, *args, **kwargs): + + # obj is class instanace + obj = build_func(self, cfg, *args, **kwargs) + args = cfg.copy() # type: ignore + obj_type = args.pop('type') # type: ignore + obj_type = obj_type if isinstance(obj_type, str) else obj_type.__name__ + + # modules in ``torch.nn.Sequential`` should be respectively exported + if isinstance(obj, torch.nn.Sequential): + for children in obj.children(): + _export_module(self, children.__class__, pack_module_dir, + obj_type) + else: + _export_module(self, obj.__class__, pack_module_dir, obj_type) + + return obj + + return wrapper + + +def _wrap_get(get_func: Callable, pack_module_dir: str): + """wrap Registry.get() + + Args: + get_func (Callable): ``Registry.get()``, which will be wrapped. + pack_module_dir (str): Modules export path. + """ + + def wrapper(self, key: str): + + obj_cls = get_func(self, key) + + _export_module(self, obj_cls, pack_module_dir, key) + + return obj_cls + + return wrapper diff --git a/mim/click/autocompletion.py b/mim/click/autocompletion.py index e87e3f4..63e5366 100644 --- a/mim/click/autocompletion.py +++ b/mim/click/autocompletion.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mim.commands.list import list_package +from mim.utils.default import OFFICIAL_MODULES def get_installed_package(ctx, args, incomplete): @@ -19,19 +20,4 @@ def get_downstream_package(ctx, args, incomplete): def get_official_package(ctx=None, args=None, incomplete=None): - return [ - 'mmcls', - 'mmdet', - 'mmdet3d', - 'mmseg', - 'mmaction2', - 'mmtrack', - 'mmpose', - 'mmedit', - 'mmocr', - 'mmgen', - 'mmselfsup' - 'mmrotate', - 'mmflow', - 'mmyolo', - ] + return OFFICIAL_MODULES diff --git a/mim/commands/export.py b/mim/commands/export.py new file mode 100644 index 0000000..3f388c3 --- /dev/null +++ b/mim/commands/export.py @@ -0,0 +1,124 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import sys + +import click +from mmengine.config import Config +from mmengine.hub import get_config + +from mim._internal.export.pack_cfg import export_from_cfg +from mim.click import CustomCommand + +PYTHON = sys.executable + + +@click.command( + name='export', + context_settings=dict(ignore_unknown_options=True), + cls=CustomCommand) +@click.argument('package', type=str) +@click.argument('config', type=str) +@click.argument('export_dir', type=str, default=None, required=False) +@click.option( + '-f', + '--fast-test', + is_flag=True, + help='The fast_test mode. In order to quickly test if' + ' there is any error in the export package,' + ' it only use the first two data of your datasets' + ' which only be used to train 2 iters/epoches.') +@click.option( + '--save-log', + is_flag=True, + help='The flag to keep the export log of the process. The log of export' + " process will be save to directory 'export_log'. Default will" + ' automatically delete the log after export.') +def cli(config: str, + package: str, + export_dir: str, + fast_test: bool = False, + save_log: bool = False) -> None: + """Export package from config file. + + Example: + + \b + >>> # Export package from downstream config file. + >>> mim export mmdet dab_detr/dab-detr_r50_8xb2-50e_coco.py \\ + ... dab_detr + >>> + >>> # Export package from specified config file. + >>> mim export mmdet mmdetection/configs/dab_detr/dab-detr_r50_8xb2-50e_coco.py dab_detr # noqa: E501 + >>> + >>> # It can auto generate a export dir when not specified. + >>> # like:'pack_from_mmdet_20231026_052704. + >>> mim export mmdet dab_detr/dab-detr_r50_8xb2-50e_coco.py + >>> + >>> # Only export the model of config file. + >>> mim export mmdet dab_detr/dab-detr_r50_8xb2-50e_coco.py \\ + ... mask_rcnn_package --model-only + >>> + >>> # Keep the export log of the process. + >>> mim export mmdet dab_detr/dab-detr_r50_8xb2-50e_coco.py \\ + ... mask_rcnn_package --save-log + >>> + >>> # Print the help information of export command. + >>> mim export -h + """ + + # get config + if osp.exists(config): + config = Config.fromfile(config) # from local + else: + try: + config = get_config(package + '::' + config) # from downstream + except Exception: + raise FileNotFoundError( + f"Config file '{config}' or '{package + '::' + config}'.") + + fast_test_mode(config, fast_test) + + export_from_cfg(config, export_root_dir=export_dir, save_log=save_log) + + +def fast_test_mode(cfg, fast_test: bool = False): + """Use less data for faster testing. + + Args: + cfg (Config): Config of export package. + fast_test (bool, optional): Fast testing mode. Defaults to False. + """ + if fast_test: + # for batch_norm using at least 2 data + if 'dataset' in cfg.train_dataloader.dataset: + cfg.train_dataloader.dataset.dataset.indices = [0, 1] + else: + cfg.train_dataloader.dataset.indices = [0, 1] + cfg.train_dataloader.batch_size = 2 + + if cfg.get('test_dataloader') is not None: + cfg.test_dataloader.dataset.indices = [0, 1] + cfg.test_dataloader.batch_size = 2 + + if cfg.get('val_dataloader') is not None: + cfg.val_dataloader.dataset.indices = [0, 1] + cfg.val_dataloader.batch_size = 2 + + if (cfg.train_cfg.get('type') == 'IterBasedTrainLoop') \ + or (cfg.train_cfg.get('by_epoch') is None + and cfg.train_cfg.get('type') != 'EpochBasedTrainLoop'): + cfg.train_cfg.max_iters = 2 + else: + cfg.train_cfg.max_epochs = 2 + + cfg.train_cfg.val_interval = 1 + cfg.default_hooks.logger.interval = 1 + + if 'param_scheduler' in cfg and cfg.param_scheduler is not None: + if isinstance(cfg.param_scheduler, list): + for lr_sc in cfg.param_scheduler: + lr_sc.begin = 0 + lr_sc.end = 2 + else: + cfg.param_scheduler.begin = 0 + cfg.param_scheduler.end = 2 diff --git a/mim/utils/__init__.py b/mim/utils/__init__.py index b7ca41a..6f2fe62 100644 --- a/mim/utils/__init__.py +++ b/mim/utils/__init__.py @@ -4,6 +4,7 @@ from .default import ( DEFAULT_MMCV_BASE_URL, DEFAULT_URL, MODULE2PKG, + OFFICIAL_MODULES, PKG2MODULE, PKG2PROJECT, RAW_GITHUB_URL, @@ -90,4 +91,5 @@ __all__ = [ 'parse_home_page', 'ensure_installation', 'rich_progress_bar', + 'OFFICIAL_MODULES', ] diff --git a/mim/utils/default.py b/mim/utils/default.py index 35db69e..aa75096 100644 --- a/mim/utils/default.py +++ b/mim/utils/default.py @@ -13,6 +13,13 @@ WHEEL_URL = { '{torch_version}/index.html', } RAW_GITHUB_URL = 'https://raw.githubusercontent.com/{owner}/{repo}/{branch}' + +OFFICIAL_MODULES = [ + 'mmcls', 'mmdet', 'mmdet3d', 'mmseg', 'mmaction2', 'mmtrack', 'mmpose', + 'mmedit', 'mmocr', 'mmgen', 'mmselfsup', 'mmrotate', 'mmflow', 'mmyolo', + 'mmpretrain', 'mmagic' +] + PKG2PROJECT = { 'mmcv-full': 'mmcv', 'mmcls': 'mmclassification',