mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Docs] Refine registry docs (#443)
* [Docs] Refine registry docs * explain how to use _scope_ * refine
This commit is contained in:
parent
dcab0f5055
commit
2f09342663
@ -1,242 +1,184 @@
|
|||||||
# 注册器(Registry)
|
# 注册器(Registry)
|
||||||
|
|
||||||
OpenMMLab 的算法库支持了丰富的算法和数据集,因此实现了很多功能相近的模块。例如 ResNet 和 SE-ResNet 的算法实现分别基于 `ResNet` 和 `SEResNet` 类,这些类有相似的功能和接口,都属于算法库中的模型组件。
|
OpenMMLab 的算法库支持了丰富的算法和数据集,因此实现了很多功能相近的模块。例如 ResNet 和 SE-ResNet 的算法实现分别基于 `ResNet` 和 `SEResNet` 类,这些类有相似的功能和接口,都属于算法库中的模型组件。
|
||||||
为了管理这些功能相似的模块,MMEngine 实现了 [注册器](https://mmengine.readthedocs.io/zh_CN/latest/api.html#mmengine.registry.Registry)。
|
为了管理这些功能相似的模块,MMEngine 实现了 [注册器](mmengine.registry.Registry)。
|
||||||
OpenMMLab 大多数算法库均使用注册器来管理他们的代码模块,包括 [MMDetection](https://github.com/open-mmlab/mmdetection), [MMDetection3D](https://github.com/open-mmlab/mmdetection3d),[MMClassification](https://github.com/open-mmlab/mmclassification) 和 [MMEditing](https://github.com/open-mmlab/mmediting) 等。
|
OpenMMLab 大多数算法库均使用注册器来管理它们的代码模块,包括 [MMDetection](https://github.com/open-mmlab/mmdetection), [MMDetection3D](https://github.com/open-mmlab/mmdetection3d),[MMClassification](https://github.com/open-mmlab/mmclassification) 和 [MMEditing](https://github.com/open-mmlab/mmediting) 等。
|
||||||
|
|
||||||
## 什么是注册器
|
## 什么是注册器
|
||||||
|
|
||||||
MMEngine 实现的注册器可以看作一个映射表和模块构建方法(build function)的组合。映射表维护了一个字符串到类或者函数的映射,使得用户可以借助字符串查找到相应的类或函数,例如维护字符串 `"ResNet"` 到 `ResNet` 类或函数的映射,使得用户可以通过 `"ResNet"` 找到 `ResNet` 类或函数;
|
MMEngine 实现的[注册器](mmengine.registry.Registry)可以看作一个映射表和模块构建方法(build function)的组合。映射表维护了一个字符串到**类或者函数的映射**,使得用户可以借助字符串查找到相应的类或函数,例如维护字符串 `"ResNet"` 到 `ResNet` 类或函数的映射,使得用户可以通过 `"ResNet"` 找到 `ResNet` 类;
|
||||||
而模块构建方法则定义了如何根据字符串查找到对应的类或函数,并定义了如何实例化这个类或调用这个函数,例如根据规则通过字符串 `"bn"` 找到 `nn.BatchNorm2d`,并且实例化 `BatchNorm2d` 模块。又或者根据规则通过字符串 `"bn"` 找到 `build_batchnorm2d`,并且调用函数获得 `BatchNorm2d` 模块。
|
而模块构建方法则定义了如何根据字符串查找到对应的类或函数以及如何实例化这个类或者调用这个函数,例如,通过字符串 `"bn"` 找到 `nn.BatchNorm2d` 并实例化 `BatchNorm2d` 模块;又或者通过字符串 `"build_batchnorm2d"` 找到 `build_batchnorm2d` 函数并返回该函数的调用结果。
|
||||||
MMEngine 中的注册器默认使用 [build_from_cfg 函数](https://mmengine.readthedocs.io/zh_CN/latest/api.html#mmengine.registry.build_from_cfg) 来查找并实例化字符串对应的类。
|
MMEngine 中的注册器默认使用 [build_from_cfg](mmengine.registry.build_from_cfg) 函数来查找并实例化字符串对应的类或者函数。
|
||||||
|
|
||||||
一个注册器管理的类或函数通常有相似的接口和功能,因此该注册器可以被视作这些类或函数的抽象。例如注册器 `Classifier` 可以被视作所有分类网络的抽象,管理了 `ResNet`, `SEResNet` 和 `RegNetX` 等分类网络的类以及 `build_ResNet`, `build_SEResNet` 和 `build_RegNetX` 等分类网络的构建函数。
|
一个注册器管理的类或函数通常有相似的接口和功能,因此该注册器可以被视作这些类或函数的抽象。例如注册器 `MODELS` 可以被视作所有模型的抽象,管理了 `ResNet`, `SEResNet` 和 `RegNetX` 等分类网络的类以及 `build_ResNet`, `build_SEResNet` 和 `build_RegNetX` 等分类网络的构建函数。
|
||||||
使用注册器管理功能相似的模块可以显著提高代码的扩展性和灵活性。用户可以跳至`使用注册器提高代码的扩展性`章节了解注册器是如何提高代码拓展性的。
|
|
||||||
|
|
||||||
## 入门用法
|
## 入门用法
|
||||||
|
|
||||||
使用注册器管理代码库中的模块,需要以下三个步骤。
|
使用注册器管理代码库中的模块,需要以下三个步骤。
|
||||||
|
|
||||||
1. 创建注册器
|
1. 创建注册器
|
||||||
2. 创建一个用于实例化类的构建方法(可选,在大多数情况下您可以只使用默认方法)
|
2. 创建一个用于实例化类的构建方法(可选,在大多数情况下可以只使用默认方法)
|
||||||
3. 将模块加入注册器中
|
3. 将模块加入注册器中
|
||||||
|
|
||||||
假设我们要实现一系列数据集转换器(Dataset Converter),将不同格式的数据转换为标准数据格式。我们希望可以实现仅修改配置就能够使用不同的转换器而无需修改代码。
|
假设我们要实现一系列激活模块并且希望仅修改配置就能够使用不同的激活模块而无需修改代码。
|
||||||
|
|
||||||
我们先创建一个名为 `converters` 的目录作为包,在包中我们创建一个文件来实现构建器(builder),
|
首先创建注册器,
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# model/builder.py
|
|
||||||
from mmengine import Registry
|
from mmengine import Registry
|
||||||
# 创建转换器的注册器
|
# scope 表示注册器的作用域,如果不设置,默认为包名,例如在 mmdetection 中,它的 scope 为 mmdet
|
||||||
CONVERTERS = Registry('converter')
|
ACTIVATION = Registry('activation', scope='mmengine')
|
||||||
```
|
```
|
||||||
|
|
||||||
然后我们可以实现不同的转换器。例如,在 `converters/converter_cls.py` 中实现 `Converter1` 和 `Converter2`,在 `converters/converter_func.py` 中实现 `converter3`。
|
然后我们可以实现不同的激活模块,例如 `Sigmoid`,`ReLU` 和 `Softmax`。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# converters/converter_cls.py
|
import torch.nn as nn
|
||||||
from .builder import CONVERTERS
|
|
||||||
|
|
||||||
# 使用注册器管理模块
|
# 使用注册器管理模块
|
||||||
@CONVERTERS.register_module()
|
@ACTIVATION.register_module()
|
||||||
class Converter1(object):
|
class Sigmoid(nn.Module):
|
||||||
def __init__(self, a, b):
|
def __init__(self):
|
||||||
self.a = a
|
super().__init__()
|
||||||
self.b = b
|
|
||||||
|
|
||||||
@CONVERTERS.register_module()
|
def forward(self, x):
|
||||||
class Converter2(object):
|
print('call Sigmoid.forward')
|
||||||
def __init__(self, a, b, c):
|
return x
|
||||||
self.a = a
|
|
||||||
self.b = b
|
@ACTIVATION.register_module()
|
||||||
self.c = c
|
class ReLU(nn.Module):
|
||||||
|
def __init__(self, inplace=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
print('call ReLU.forward')
|
||||||
|
return x
|
||||||
|
|
||||||
|
@ACTIVATION.register_module()
|
||||||
|
class Softmax(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
print('call Softmax.forward')
|
||||||
|
return x
|
||||||
```
|
```
|
||||||
|
|
||||||
```python
|
使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 `ACTIVATION` 中。通过 `@ACTIVATION.register_module()` 装饰所实现的模块,字符串和类或函数之间的映射就可以由 `ACTIVATION` 构建和维护,我们也可以通过 `ACTIVATION.register_module(module=ReLU)` 实现同样的功能。
|
||||||
# converters/converter_func.py
|
|
||||||
from .builder import CONVERTERS
|
|
||||||
from .converter_cls import Converter1
|
|
||||||
@CONVERTERS.register_module()
|
|
||||||
def converter3(a, b)
|
|
||||||
return Converter1(a, b)
|
|
||||||
```
|
|
||||||
|
|
||||||
使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 `CONVERTERS` 中。通过 `@CONVERTERS.register_module()` 装饰所实现的模块,字符串和类或函数之间的映射就可以由 `CONVERTERS` 构建和维护,我们也可以通过 `CONVERTERS.register_module(module=Converter1)` 实现同样的功能。
|
通过注册,我们就可以通过 `ACTIVATION` 建立字符串与类或函数之间的映射,
|
||||||
|
|
||||||
通过注册,我们就可以通过 `CONVERTERS` 建立字符串与类或函数之间的映射,
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
'Converter1' -> <class 'Converter1'>
|
print(ACTIVATION.module_dict)
|
||||||
'Converter2' -> <class 'Converter2'>
|
# {
|
||||||
'Converter3' -> <function 'Converter3'>
|
# 'Sigmoid': __main__.Sigmoid,
|
||||||
|
# 'ReLU': __main__.ReLU,
|
||||||
|
# 'Softmax': __main__.Softmax
|
||||||
|
# }
|
||||||
```
|
```
|
||||||
|
|
||||||
```{note}
|
```{note}
|
||||||
只有模块所在的文件被导入时,注册机制才会被触发,所以我们需要在某处导入该文件或者使用 `custom_imports` 字段动态导入该模块进而触发注册机制,详情见 [导入自定义 Python 模块](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/config.html#python).
|
只有模块所在的文件被导入时,注册机制才会被触发,所以我们需要在某处导入该文件或者使用 `custom_imports` 字段动态导入该模块进而触发注册机制,详情见[导入自定义 Python 模块](config.md)。
|
||||||
```
|
```
|
||||||
|
|
||||||
模块成功注册后,我们可以通过配置文件使用这个转换器。
|
模块成功注册后,我们可以通过配置文件使用这个激活模块。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# main.py
|
import torch
|
||||||
# 注意,converter_cfg 可以通过解析配置文件得到
|
input = torch.randn(2)
|
||||||
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
|
|
||||||
converter = CONVERTERS.build(converter_cfg)
|
act_cfg = dict(type='Sigmoid')
|
||||||
converter3_cfg = dict(type='converter3', a=a_value, b=b_value)
|
activation = ACTIVATION.build(act_cfg)
|
||||||
# returns the calling result
|
output = activation(input)
|
||||||
converter3 = CONVERTERS.build(converter3_cfg)
|
# call Sigmoid.forward
|
||||||
|
print(output)
|
||||||
|
# tensor([0.0159, 0.0815])
|
||||||
```
|
```
|
||||||
|
|
||||||
如果我们想使用 `Converter2`,仅需修改配置。
|
如果我们想使用 `ReLU`,仅需修改配置。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
converter_cfg = dict(type='Converter2', a=a_value, b=b_value, c=c_value)
|
act_cfg = dict(type='ReLU', inplace=True)
|
||||||
converter = CONVERTERS.build(converter_cfg)
|
activation = ACTIVATION.build(act_cfg)
|
||||||
|
output = activation(input)
|
||||||
|
# call Sigmoid.forward
|
||||||
|
print(output)
|
||||||
|
# tensor([0.0159, 0.0815])
|
||||||
```
|
```
|
||||||
|
|
||||||
假如我们想在创建实例前检查输入参数的类型(或者任何其他操作),我们可以实现一个构建方法并将其传递给注册器从而实现自定义构建流程。
|
如果我们希望在创建实例前检查输入参数的类型(或者任何其他操作),我们可以实现一个构建方法并将其传递给注册器从而实现自定义构建流程。
|
||||||
|
|
||||||
|
创建一个构建方法,
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmengine import Registry
|
|
||||||
|
|
||||||
# 创建一个构建方法
|
def build_activation(cfg, registry, *args, **kwargs):
|
||||||
def build_converter(cfg, registry, *args, **kwargs):
|
|
||||||
cfg_ = cfg.copy()
|
cfg_ = cfg.copy()
|
||||||
converter_type = cfg_.pop('type')
|
act_type = cfg_.pop('type')
|
||||||
if converter_type not in registry:
|
print(f'build activation: {act_type}')
|
||||||
raise KeyError(f'Unrecognized converter type {converter_type}')
|
act_cls = registry.get(act_type)
|
||||||
else:
|
act = act_cls(*args, **kwargs, **cfg_)
|
||||||
converter_cls = registry.get(converter_type)
|
return act
|
||||||
|
```
|
||||||
|
|
||||||
converter = converter_cls(*args, **kwargs, **cfg_)
|
并将 `build_activation` 传递给 `build_func` 参数
|
||||||
return converter
|
|
||||||
|
|
||||||
# 创建一个用于转换器的注册器,并将 `build_converter` 传递给 `build_func` 参数
|
```python
|
||||||
CONVERTERS = Registry('converter', build_func=build_converter)
|
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine')
|
||||||
|
|
||||||
|
@ACTIVATION.register_module()
|
||||||
|
class Tanh(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
print('call Tanh.forward')
|
||||||
|
return x
|
||||||
|
|
||||||
|
act_cfg = dict(type='Tanh')
|
||||||
|
activation = ACTIVATION.build(act_cfg)
|
||||||
|
output = activation(input)
|
||||||
|
# build activation: Tanh
|
||||||
|
# call Tanh.forward
|
||||||
|
print(output)
|
||||||
|
# tensor([0.0159, 0.0815])
|
||||||
```
|
```
|
||||||
|
|
||||||
```{note}
|
```{note}
|
||||||
在这个例子中,我们演示了如何使用参数:`build_func` 自定义构建类的实例的方法。
|
在这个例子中,我们演示了如何使用参数 `build_func` 自定义构建类的实例的方法。
|
||||||
该功能类似于默认的 `build_from_cfg` 方法。在大多数情况下,使用默认的方法就可以了。
|
该功能类似于默认的 `build_from_cfg` 方法。在大多数情况下,使用默认的方法就可以了。
|
||||||
```
|
```
|
||||||
|
|
||||||
## 使用注册器提高代码的扩展性
|
MMEngine 的注册器除了可以注册类,也可以注册函数。
|
||||||
|
|
||||||
使用注册器管理功能相似的模块可以便利模块的自由组合与灵活拓展。下面通过例子介绍注册器的两个优点。
|
|
||||||
|
|
||||||
### 模块的自由组合
|
|
||||||
|
|
||||||
假设用户实现了一个模块 `ConvBlock`,`ConvBlock` 中定义了一个卷积层和一个激活层。
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch.nn as nn
|
FUNCTION = Registry('function', scope='mmengine')
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
@FUNCTION.register_module()
|
||||||
|
def print_args(**kwargs):
|
||||||
|
print(kwargs)
|
||||||
|
|
||||||
def __init__(self):
|
func_cfg = dict(type='print_args', a=1, b=2)
|
||||||
self.conv = nn.Conv2d()
|
func_res = FUNCTION.build(func_cfg)
|
||||||
self.act = nn.ReLU()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv(x)
|
|
||||||
x = self.act(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
conv_blcok = ConvBlock()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
可以发现,此时 ConvBlock 只支持 `nn.Conv2d` 和 `nn.ReLU` 的组合。如果我们想要让 `ConvBlock` 更加通用,例如让它可以使用其他类型的激活层,在不使用注册器的情况下,需要做如下改动
|
## 进阶用法
|
||||||
|
|
||||||
```python
|
MMEngine 的注册器支持层级注册,利用该功能可实现跨项目调用,即可以在一个项目中使用另一个项目的模块。虽然跨项目调用也有其他方法的可以实现,但 MMEngine 注册器提供了更为简便的方法。
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, act_type):
|
|
||||||
self.conv = nn.Conv2d()
|
|
||||||
if act_type == 'relu':
|
|
||||||
self.act = nn.ReLU()
|
|
||||||
elif act_type == 'gelu':
|
|
||||||
self.act = nn.GELU()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv(x)
|
|
||||||
x = self.act(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
conv_block = ConvBlock()
|
|
||||||
```
|
|
||||||
|
|
||||||
可以发现,上述改动需要枚举模块的各种类型,无法灵活地组合各种模块。而如果使用注册器,该问题可以轻松解决,用户只需要在构建 ConvBlock 的时候设置不同的 `conv_cfg` 和 `act_cfg` 即可达到目的。
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch.nn as nn
|
|
||||||
from mmengine import MODELS
|
|
||||||
|
|
||||||
# 将卷积和激活模块注册到 MODELS
|
|
||||||
MODELS.register_module(module=nn.Conv2d)
|
|
||||||
MODELS.register_module(module=nn.ReLU)
|
|
||||||
MODELS.register_module(module=nn.GELU)
|
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, conv_cfg, act_cfg):
|
|
||||||
self.conv = MODELS.build(conv_cfg)
|
|
||||||
self.pool = MODELS.build(act_cfg)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv(x)
|
|
||||||
x = self.act(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
# 注意,conv_cfg 和 act_cfg 可以通过解析配置文件得到
|
|
||||||
conv_cfg = dict(type='Conv2d')
|
|
||||||
act_cfg = dict(type='GELU')
|
|
||||||
conv_block = ConvBlock(conv_cfg, act_cfg)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 模块的灵活拓展
|
|
||||||
|
|
||||||
如果我们自定义了一个 `DeformConv2d` 卷积模块,我们只需将该模块注册到 `MODELS`,
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch.nn as nn
|
|
||||||
from mmengine import MODELS
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
|
||||||
class DeformConv2d(nn.Module):
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
就可以通过配置使用该模块。
|
|
||||||
|
|
||||||
```python
|
|
||||||
conv_cfg = dict(type='DeformConv2d')
|
|
||||||
act_cfg = dict(type='GELU')
|
|
||||||
conv_block = ConvBlock(conv_cfg, act_cfg)
|
|
||||||
conv = MODELS.build(cfg)
|
|
||||||
```
|
|
||||||
|
|
||||||
可以看到,添加了 `DeformConv2d` 模块并不需要对 `ConvBlock` 做修改。
|
|
||||||
|
|
||||||
## 通过 Registry 实现模块的跨库调用
|
|
||||||
|
|
||||||
MMEngine 的注册器支持跨项目调用,即可以在一个项目中使用另一个项目的模块。虽然跨项目调用也有其他方法的可以实现,但 MMEngine 注册器提供了更为简便的方法。
|
|
||||||
|
|
||||||
为了方便跨库调用,MMEngine 提供了 20 个根注册器:
|
为了方便跨库调用,MMEngine 提供了 20 个根注册器:
|
||||||
|
|
||||||
- RUNNERS: Runner 的注册器
|
- RUNNERS: Runner 的注册器
|
||||||
- RUNNER_CONSTRUCTORS: Runner 的构造器
|
- RUNNER_CONSTRUCTORS: Runner 的构造器
|
||||||
- LOOPS: 管理训练、验证以及测试流程,如 `EpochBasedTrainRunner`
|
- LOOPS: 管理训练、验证以及测试流程,如 `EpochBasedTrainLoop`
|
||||||
- HOOKS: 钩子,如 `CheckpointHook`, `ProfilerHook`
|
- HOOKS: 钩子,如 `CheckpointHook`, `ParamSchedulerHook`
|
||||||
- DATASETS: 数据集
|
- DATASETS: 数据集
|
||||||
- DATA_SAMPLERS: `Dataloader` 的 `sampler`,用于采样数据
|
- DATA_SAMPLERS: `DataLoader` 的 `Sampler`,用于采样数据
|
||||||
- TRANSFORMS: 各种数据预处理,如 `Resize`, `Reshape`
|
- TRANSFORMS: 各种数据预处理,如 `Resize`, `Reshape`
|
||||||
- MODELS: 模型的各种模块
|
- MODELS: 模型的各种模块
|
||||||
- MODEL_WRAPPERS: 模型的包装器,如 `MMDistributedDataParallel`,用于对分布式数据并行
|
- MODEL_WRAPPERS: 模型的包装器,如 `MMDistributedDataParallel`,用于对分布式数据并行
|
||||||
- WEIGHT_INITIALIZERS: 权重初始化的工具
|
- WEIGHT_INITIALIZERS: 权重初始化的工具
|
||||||
- OPTIMIZERS: 注册了 PyTorch 中所有的 `optimizer` 以及自定义的 `optimizer`
|
- OPTIMIZERS: 注册了 PyTorch 中所有的 `Optimizer` 以及自定义的 `Optimizer`
|
||||||
- OPTIM_WRAPPER: 对 Optimizer 相关操作的封装,如 `OptimWrapper`,`AmpOptimWrapper`
|
- OPTIM_WRAPPER: 对 Optimizer 相关操作的封装,如 `OptimWrapper`,`AmpOptimWrapper`
|
||||||
- OPTIM_WRAPPER_CONSTRUCTORS: optimizer wrapper 的构造器
|
- OPTIM_WRAPPER_CONSTRUCTORS: optimizer wrapper 的构造器
|
||||||
- PARAM_SCHEDULERS: 各种参数调度器,如 `MultiStepLR`
|
- PARAM_SCHEDULERS: 各种参数调度器,如 `MultiStepLR`
|
||||||
@ -247,146 +189,119 @@ MMEngine 的注册器支持跨项目调用,即可以在一个项目中使用
|
|||||||
- VISBACKENDS: 存储训练日志的后端,如 `LocalVisBackend`, `TensorboardVisBackend`
|
- VISBACKENDS: 存储训练日志的后端,如 `LocalVisBackend`, `TensorboardVisBackend`
|
||||||
- LOG_PROCESSORS: 控制日志的统计窗口和统计方法,默认使用 `LogProcessor`,如有特殊需求可自定义 `LogProcessor`
|
- LOG_PROCESSORS: 控制日志的统计窗口和统计方法,默认使用 `LogProcessor`,如有特殊需求可自定义 `LogProcessor`
|
||||||
|
|
||||||
下面我们以 OpenMMLab 开源项目为例介绍如何跨项目调用模块。
|
|
||||||
|
|
||||||
### 调用父节点的模块
|
### 调用父节点的模块
|
||||||
|
|
||||||
`MMEngine` 中定义了模块 `Conv2d`,
|
`MMEngine` 中定义模块 `RReLU`,并往 `MODELS` 根注册器注册。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
import torch.nn as nn
|
||||||
from mmengine import Registry, MODELS
|
from mmengine import Registry, MODELS
|
||||||
|
|
||||||
MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class Conv2d(nn.Module):
|
class RReLU(nn.Module):
|
||||||
pass
|
def __init__(self, lower=0.125, upper=0.333, inplace=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
print('call RReLU.forward')
|
||||||
|
return x
|
||||||
```
|
```
|
||||||
|
|
||||||
`MMDetection` 中定义了模块 `RetinaNet`,
|
假设有个项目叫 `MMAlpha`,它也定义了 `MODELS`,并设置其父节点为 `MMEngine` 的 `MODELS`,这样就建立了层级结构。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmengine import Registry, MODELS as MMENGINE_MODELS
|
from mmengine import Registry, MODELS as MMENGINE_MODELS
|
||||||
# parent 参数表示当前节点的父节点,通过 parent 参数实现层级结构
|
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha')
|
||||||
# scope 参数可以理解为当前节点的标志。如果不传入该参数,则 scope 被推导为当前文件所在
|
|
||||||
# 包的包名,这里为 mmdet
|
|
||||||
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmdet')
|
|
||||||
|
|
||||||
@MMDET_MODELS.register_module()
|
|
||||||
class RetinaNet(nn.Module):
|
|
||||||
pass
|
|
||||||
```
|
```
|
||||||
|
|
||||||
下图是 `MMEngine`, `MMDetection` 两个项目的注册器层级结构。
|
下图是 `MMEngine` 和 `MMAlpha` 的注册器层级结构。
|
||||||
|
|
||||||

|
<div align="center">
|
||||||
|
<img src="https://user-images.githubusercontent.com/58739961/185307159-26dc5771-df77-4d03-9203-9c4c3197befa.png"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
我们可以在 `MMDetection` 中调用 `MMEngine` 中的模块。
|
可以调用 [count_registered_modules](mmengine.registry.count_registered_modules) 函数打印已注册到 MMEngine 的模块以及层级结构。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmdet.models import MODELS
|
from mmengine.registry import count_registered_modules
|
||||||
# 创建 RetinaNet 实例
|
count_registered_modules()
|
||||||
model = MODELS.build(cfg=dict(type='RetinaNet'))
|
|
||||||
# 也可以加 mmdet 前缀
|
|
||||||
model = MODELS.build(cfg=dict(type='mmdet.RetinaNet'))
|
|
||||||
# 创建 Conv2d 实例
|
|
||||||
model = MODELS.build(cfg=dict(type='mmengine.Conv2d'))
|
|
||||||
# 也可以不加 mmengine 前缀
|
|
||||||
model = MODELS.build(cfg=dict(type='Conv2d'))
|
|
||||||
```
|
```
|
||||||
|
|
||||||
如果不加前缀,`build` 方法首先查找当前节点是否存在该模块,如果存在则返回该模块,否则会继续向上查找父节点甚至祖先节点直到找到该模块,因此,如果当前节点和父节点存在同一模块并且希望调用父节点的模块,我们需要指定 `scope` 前缀。需要注意的是,向上查找父节点甚至祖先节点的**前提是父节点或者祖先节点的模块已通过某种方式被导入进而完成注册**。例如,在上面这个示例中,之所以没有显示导入父节点 `mmengine` 中的 `MODELS`,是因为通过 `from mmdet.models import MODELS` 间接触发 `mmengine.MODELS` 完成模块的注册。
|
在 `MMAlpha` 中定义模块 `LogSoftmax`,并往 `MMAlpha` 的 `MODELS` 注册。
|
||||||
|
|
||||||
上面展示了如何使用子节点注册器构建模块,但有时候我们希望不填加前缀也能在父节点注册器中构建子节点的模块,目的是提供通用的代码,避免下游算法库重复造轮子,该如何实现呢?
|
|
||||||
|
|
||||||
假设 MMEngine 中有一个 `build_model` 函数,该方法用于构建模型。
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmengine.registry import MODELS
|
@MODELS.register_module()
|
||||||
|
class LogSoftmax(nn.Module):
|
||||||
|
def __init__(self, dim=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def build_model(cfg):
|
def forward(self, x):
|
||||||
model = MODELS.build(cfg)
|
print('call LogSoftmax.forward')
|
||||||
|
return x
|
||||||
```
|
```
|
||||||
|
|
||||||
如果我们希望在 MMDetection 中调用该函数构建 MMDetection 注册的模块,那么我们需要先获取一个 scope_name 为 'mmdet' 的 [DefaultScope](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.registry.DefaultScope) 实例,该实例全局唯一。
|
在 `MMAlpha` 中使用配置调用 `LogSoftmax`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmengine import build_model
|
model = MODELS.build(cfg=dict(type='LogSoftmax'))
|
||||||
import mmdet.models # 通过 import 的方式将 mmdet 中的模块导入注册器进而完成注册
|
|
||||||
|
|
||||||
default_scope = DefaultScope.get_instance('my_experiment', scope_name='mmdet')
|
|
||||||
model = build_model(cfg=dict(type='RetinaNet'))
|
|
||||||
```
|
```
|
||||||
|
|
||||||
获取 `DefaultScope` 实例的目的是使 Registry 的 build 方法会将 DefaultScope 名称(mmdet)注册器节点作为注册器的起点,才能在配置中不填加 mmdet 前缀的情况下在 MMDetection 的注册器节点中找到 RetinaNet 模块,如若不然,程序会报找不到 RetinaNet 错误。
|
也可以在 `MMAlpha` 中调用父节点 `MMEngine` 的模块。
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = MODELS.build(cfg=dict(type='RReLU', lower=0.2))
|
||||||
|
# 也可以加 scope
|
||||||
|
model = MODELS.build(cfg=dict(type='mmengine.RReLU'))
|
||||||
|
```
|
||||||
|
|
||||||
|
如果不加前缀,`build` 方法首先查找当前节点是否存在该模块,如果存在则返回该模块,否则会继续向上查找父节点甚至祖先节点直到找到该模块,因此,如果当前节点和父节点存在同一模块并且希望调用父节点的模块,我们需要指定 `scope` 前缀。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
input = torch.randn(2)
|
||||||
|
output = model(input)
|
||||||
|
# call RReLU.forward
|
||||||
|
print(output)
|
||||||
|
# tensor([-1.5774, -0.5850])
|
||||||
|
```
|
||||||
|
|
||||||
### 调用兄弟节点的模块
|
### 调用兄弟节点的模块
|
||||||
|
|
||||||
除了可以调用父节点的模块,也可以调用兄弟节点的模块。
|
除了可以调用父节点的模块,也可以调用兄弟节点的模块。
|
||||||
|
|
||||||
`MMClassification` 中定义了模块 `ResNet`,
|
假设有另一个项目叫 `MMBeta`,它和 `MMAlpha` 一样,定义了 `MODELS` 以及设置其父节点为 `MMEngine` 的 `MODELS`。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmengine.registry import Registry, MODELS
|
from mmengine import Registry, MODELS as MMENGINE_MODELS
|
||||||
MODELS = Registry('model', parent=MMENGINE_MODELS)
|
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmbeta')
|
||||||
|
|
||||||
@MODELS.register_module()
|
|
||||||
class ResNet(nn.Module):
|
|
||||||
pass
|
|
||||||
```
|
```
|
||||||
|
|
||||||
下图是 `MMEngine`, `MMDetection`, `MMClassification` 三个项目的注册器层级结构。
|
下图是 MMEngine,MMAlpha 和 MMBeta 的注册器层级结构。
|
||||||
|
|
||||||

|
<div align="center">
|
||||||
|
<img src="https://user-images.githubusercontent.com/58739961/185307738-9ddbce2d-f8b5-40c4-bf8f-603830ccc0dc.png"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
我们可以在 `MMDetection` 中调用 `MMClassification` 定义的模块,
|
在 `MMBeta` 中调用兄弟节点 `MMAlpha` 的模块,
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmdet.models import MODELS
|
model = MODELS.build(cfg=dict(type='mmalpha.LogSoftmax'))
|
||||||
model = MODELS.build(cfg=dict(type='mmcls.ResNet'))
|
output = model(input)
|
||||||
|
# call LogSoftmax.forward
|
||||||
|
print(output)
|
||||||
|
# tensor([-1.5774, -0.5850])
|
||||||
```
|
```
|
||||||
|
|
||||||
也可以在 `MMClassification` 中调用 `MMDetection` 定义的模块。
|
调用兄弟节点的模块需要在 `type` 中指定 `scope` 前缀,所以上面的配置需要加前缀 `mmalpha`。
|
||||||
|
|
||||||
|
如果需要调用兄弟节点的数个模块,每个模块都加前缀,这需要做大量的修改。于是 `MMEngine` 引入了 [DefaultScope](mmengine.registry.DefaultScope),`Registry` 借助它可以很方便地支持临时切换当前节点为指定的节点。
|
||||||
|
|
||||||
|
如果需要临时切换当前节点为指定的节点,只需在 `cfg` 设置 `_scope_` 为指定节点的作用域。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmcls.models import MODELS
|
model = MODELS.build(cfg=dict(type='LogSoftmax', _scope_='mmalpha'))
|
||||||
model = MODELS.build(cfg=dict(type='mmdet.RetinaNet'))
|
output = model(input)
|
||||||
```
|
# call LogSoftmax.forward
|
||||||
|
print(output)
|
||||||
调用非本节点或父节点的模块需要在 `type` 中指定 `scope` 前缀。
|
# tensor([-1.5774, -0.5850])
|
||||||
|
|
||||||
注册器除了支持两层结构,三层甚至更多层结构也是支持的。
|
|
||||||
|
|
||||||
假设我们新建了一个项目 `DetPlus`,它的 `MODELS` 注册器继承自 `MMDetection` 的 `MODELS`,并且它会用到 `MMClassification` 中的 `ResNet` 模块。
|
|
||||||
|
|
||||||
`DetPlus` 中定义了模块 `MetaNet`,
|
|
||||||
|
|
||||||
```python
|
|
||||||
from mmengine.registry import Registry
|
|
||||||
from mmdet.model import MODELS as MMDET_MODELS
|
|
||||||
MODELS = Registry('model', parent=MMDET_MODELS, scope='det_plus')
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
|
||||||
class MetaNet(nn.Module):
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
下图是 `MMEngine`, `MMDetection`, `MMClassification` 以及 `DetPlus` 四个项目的注册器层级结构。
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
我们可以在 `DetPlus` 中调用 `MMDetection` 或者 `MMClassification` 中的模块,
|
|
||||||
|
|
||||||
```python
|
|
||||||
from detplus.model import MODELS
|
|
||||||
# 可以不提供 mmdet 前缀,如果在 detplus 找不到则会向上在 mmdet 中查找
|
|
||||||
model = MODELS.build(cfg=dict(type='mmdet.RetinaNet'))
|
|
||||||
# 调用兄弟节点的模块需提供 mmcls 前缀,但也可以设置 default_scope 参数
|
|
||||||
model = MODELS.build(cfg=dict(type='mmcls.ResNet'))
|
|
||||||
```
|
|
||||||
|
|
||||||
也可以在 `MMClassification` 中调用 `DetPlus` 的模块。
|
|
||||||
|
|
||||||
```python
|
|
||||||
from mmcls.models import MODELS
|
|
||||||
# 需要注意前缀的顺序,'detplus.mmdet.ResNet' 是不正确的
|
|
||||||
model = MODELS.build(cfg=dict(type='mmdet.detplus.MetaNet'))
|
|
||||||
```
|
```
|
||||||
|
@ -208,18 +208,18 @@ class Registry:
|
|||||||
>>> DefaultScope.get_current_instance().scope_name
|
>>> DefaultScope.get_current_instance().scope_name
|
||||||
custom
|
custom
|
||||||
>>> # Switch to mmcls scope and get `MMCLS_MODELS` registry.
|
>>> # Switch to mmcls scope and get `MMCLS_MODELS` registry.
|
||||||
>>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as registry: # noqa: E501
|
>>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as registry:
|
||||||
>>> DefaultScope.get_current_instance().scope_name
|
>>> DefaultScope.get_current_instance().scope_name
|
||||||
mmcls
|
mmcls
|
||||||
>>> registry.scope
|
>>> registry.scope
|
||||||
mmcls
|
mmcls
|
||||||
>>> # Nested switch scope
|
>>> # Nested switch scope
|
||||||
>>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmdet') as mmdet_registry: # noqa: E501
|
>>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmdet') as mmdet_registry:
|
||||||
>>> DefaultScope.get_current_instance().scope_name
|
>>> DefaultScope.get_current_instance().scope_name
|
||||||
mmdet
|
mmdet
|
||||||
>>> mmdet_registry.scope
|
>>> mmdet_registry.scope
|
||||||
mmdet
|
mmdet
|
||||||
>>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as mmcls_registry: # noqa: E501
|
>>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as mmcls_registry:
|
||||||
>>> DefaultScope.get_current_instance().scope_name
|
>>> DefaultScope.get_current_instance().scope_name
|
||||||
mmcls
|
mmcls
|
||||||
>>> mmcls_registry.scope
|
>>> mmcls_registry.scope
|
||||||
@ -228,7 +228,7 @@ class Registry:
|
|||||||
>>> # Check switch back to original scope.
|
>>> # Check switch back to original scope.
|
||||||
>>> DefaultScope.get_current_instance().scope_name
|
>>> DefaultScope.get_current_instance().scope_name
|
||||||
custom
|
custom
|
||||||
"""
|
""" # noqa: E501
|
||||||
from ..logging import print_log
|
from ..logging import print_log
|
||||||
|
|
||||||
# Switch to the given scope temporarily. If the corresponding registry
|
# Switch to the given scope temporarily. If the corresponding registry
|
||||||
|
@ -54,10 +54,20 @@ def count_registered_modules(save_path: Optional[str] = None,
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
save_path (str, optional): Path to save the json file.
|
save_path (str, optional): Path to save the json file.
|
||||||
verbose (bool): Whether to print log. Default: True
|
verbose (bool): Whether to print log. Defaults to True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Statistic results of all registered modules.
|
dict: Statistic results of all registered modules.
|
||||||
"""
|
"""
|
||||||
|
# import modules to trigger registering
|
||||||
|
import mmengine.dataset
|
||||||
|
import mmengine.evaluator
|
||||||
|
import mmengine.hooks
|
||||||
|
import mmengine.model
|
||||||
|
import mmengine.optim
|
||||||
|
import mmengine.runner
|
||||||
|
import mmengine.visualization # noqa: F401
|
||||||
|
|
||||||
registries_info = {}
|
registries_info = {}
|
||||||
# traverse all registries in MMEngine
|
# traverse all registries in MMEngine
|
||||||
for item in dir(root):
|
for item in dir(root):
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from unittest import TestCase
|
from unittest import TestCase, skipIf
|
||||||
|
|
||||||
from mmengine.registry import (Registry, count_registered_modules, root,
|
from mmengine.registry import (Registry, count_registered_modules, root,
|
||||||
traverse_registry_tree)
|
traverse_registry_tree)
|
||||||
|
from mmengine.utils import is_installed
|
||||||
|
|
||||||
|
|
||||||
class TestUtils(TestCase):
|
class TestUtils(TestCase):
|
||||||
@ -42,6 +43,7 @@ class TestUtils(TestCase):
|
|||||||
# result from any node should be the same
|
# result from any node should be the same
|
||||||
self.assertEqual(result, result_leaf)
|
self.assertEqual(result, result_leaf)
|
||||||
|
|
||||||
|
@skipIf(not is_installed('torch'), 'tests requires torch')
|
||||||
def test_count_all_registered_modules(self):
|
def test_count_all_registered_modules(self):
|
||||||
temp_dir = TemporaryDirectory()
|
temp_dir = TemporaryDirectory()
|
||||||
results = count_registered_modules(temp_dir.name, verbose=True)
|
results = count_registered_modules(temp_dir.name, verbose=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user