mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Merge branch 'main' into adapt
This commit is contained in:
commit
a39d959eeb
122
.circleci/config.yml
Normal file
122
.circleci/config.yml
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
version: 2.1
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
docker:
|
||||||
|
- image: cimg/python:3.7.4
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Install pre-commit hook
|
||||||
|
command: |
|
||||||
|
sudo apt-add-repository ppa:brightbox/ruby-ng -y
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y ruby2.7
|
||||||
|
pip install pre-commit
|
||||||
|
pre-commit install
|
||||||
|
- run:
|
||||||
|
name: Linting
|
||||||
|
command: pre-commit run --all-files
|
||||||
|
- run:
|
||||||
|
name: Check docstring coverage
|
||||||
|
command: |
|
||||||
|
pip install interrogate
|
||||||
|
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 80 mmengine
|
||||||
|
|
||||||
|
build_cpu:
|
||||||
|
parameters:
|
||||||
|
# The python version must match available image tags in
|
||||||
|
# https://circleci.com/developer/images/image/cimg/python
|
||||||
|
python:
|
||||||
|
type: string
|
||||||
|
default: "3.7.4"
|
||||||
|
torch:
|
||||||
|
type: string
|
||||||
|
torchvision:
|
||||||
|
type: string
|
||||||
|
docker:
|
||||||
|
- image: cimg/python:<< parameters.python >>
|
||||||
|
resource_class: large
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Upgrade pip
|
||||||
|
command: |
|
||||||
|
python -V
|
||||||
|
python -m pip install pip --upgrade
|
||||||
|
python -m pip --version
|
||||||
|
- run:
|
||||||
|
name: Install PyTorch
|
||||||
|
command: python -m pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
- run:
|
||||||
|
name: Install mmcv-full
|
||||||
|
command: python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.8.0/index.html
|
||||||
|
- run:
|
||||||
|
name: Install mmengine dependencies
|
||||||
|
command: python -m pip install -r requirements.txt
|
||||||
|
- run:
|
||||||
|
name: Build and install
|
||||||
|
command: python -m pip install -e .
|
||||||
|
- run:
|
||||||
|
name: Run unit tests
|
||||||
|
command: python -m pytest tests/
|
||||||
|
|
||||||
|
build_cu102:
|
||||||
|
machine:
|
||||||
|
image: ubuntu-1604-cuda-10.1:201909-23 # the actual version of cuda is 10.2
|
||||||
|
resource_class: gpu.nvidia.small
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
# https://github.com/pytorch/vision/issues/2921
|
||||||
|
name: Install dependency of torchvision when using pyenv
|
||||||
|
command: sudo apt-get install -y liblzma-dev
|
||||||
|
- run:
|
||||||
|
# python3.7 should be re-installed due to the issue https://github.com/pytorch/vision/issues/2921
|
||||||
|
name: Select python3.7
|
||||||
|
command: |
|
||||||
|
pyenv uninstall -f 3.7.0
|
||||||
|
pyenv install 3.7.0
|
||||||
|
pyenv global 3.7.0
|
||||||
|
- run:
|
||||||
|
name: Upgrade pip
|
||||||
|
command: |
|
||||||
|
python -V
|
||||||
|
python -m pip install pip --upgrade
|
||||||
|
python -m pip --version
|
||||||
|
- run:
|
||||||
|
name: Install PyTorch
|
||||||
|
command: python -m pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
- run:
|
||||||
|
name: Install mmcv-full
|
||||||
|
command: python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.8.0/index.html
|
||||||
|
- run:
|
||||||
|
name: Install mmengine dependencies
|
||||||
|
command: python -m pip install -r requirements.txt
|
||||||
|
- run:
|
||||||
|
name: Build and install
|
||||||
|
command: python -m pip install -e .
|
||||||
|
- run:
|
||||||
|
name: Run unit tests
|
||||||
|
command: |
|
||||||
|
python -m coverage run --branch --source mmengine -m pytest tests/
|
||||||
|
python -m coverage xml
|
||||||
|
python -m coverage report -m
|
||||||
|
|
||||||
|
workflows:
|
||||||
|
unit_tests:
|
||||||
|
jobs:
|
||||||
|
- lint
|
||||||
|
- build_cpu:
|
||||||
|
name: build_cpu_th1.8_py3.7
|
||||||
|
torch: 1.8.0
|
||||||
|
torchvision: 0.9.0
|
||||||
|
requires:
|
||||||
|
- lint
|
||||||
|
- hold:
|
||||||
|
type: approval # <<< This key-value pair will set your workflow to a status of "On Hold"
|
||||||
|
requires:
|
||||||
|
- build_cpu_th1.8_py3.7
|
||||||
|
- build_cu102:
|
||||||
|
requires:
|
||||||
|
- hold
|
@ -311,11 +311,15 @@ from mmcls.models import MODELS
|
|||||||
model = MODELS.build(cfg=dict(type='mmdet.RetinaNet'))
|
model = MODELS.build(cfg=dict(type='mmdet.RetinaNet'))
|
||||||
```
|
```
|
||||||
|
|
||||||
调用兄弟节点的模块需要指定在 `type` 中指定 `scope` 前缀,如果不想指定,我们可以将 `build` 方法中的 `default_scope` 参数设置为 'mmdet',它会将 `default_scope` 对应的 `registry` 作为当前 `Registry` 并调用 `build` 方法。
|
调用非本节点的模块需要指定在 `type` 中指定 `scope` 前缀,如果不想指定,我们可以创建一个全局变量 `default_scope` 并将 `scope_name` 设置为 'mmdet',`Registry` 会将 `scope_name` 对应的 `registry` 作为当前 `Registry` 并调用 `build` 方法。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmcls.models import MODELS
|
from mmengine.registry import DefaultScope, MODELS
|
||||||
model = MODELS.build(cfg=dict(type='RetinaNet'), default_scope='mmdet')
|
|
||||||
|
# 调用注册在 mmdet 中的 RetinaNet
|
||||||
|
default_scope = DefaultScope.get_instance(
|
||||||
|
'my_experiment', scope_name='mmdet')
|
||||||
|
model = MODELS.build(cfg=dict(type='RetinaNet'))
|
||||||
```
|
```
|
||||||
|
|
||||||
注册器除了支持两层结构,三层甚至更多层结构也是支持的。
|
注册器除了支持两层结构,三层甚至更多层结构也是支持的。
|
||||||
@ -325,7 +329,7 @@ model = MODELS.build(cfg=dict(type='RetinaNet'), default_scope='mmdet')
|
|||||||
`DetPlus` 中定义了模块 `MetaNet`,
|
`DetPlus` 中定义了模块 `MetaNet`,
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmengine.model import Registry
|
from mmengine.registry import Registry
|
||||||
from mmdet.model import MODELS as MMDET_MODELS
|
from mmdet.model import MODELS as MMDET_MODELS
|
||||||
MODELS = Registry('model', parent=MMDET_MODELS, scope='det_plus')
|
MODELS = Registry('model', parent=MMDET_MODELS, scope='det_plus')
|
||||||
|
|
||||||
@ -354,6 +358,10 @@ model = MODELS.build(cfg=dict(type='mmcls.ResNet'))
|
|||||||
from mmcls.models import MODELS
|
from mmcls.models import MODELS
|
||||||
# 需要注意前缀的顺序,'detplus.mmdet.ResNet' 是不正确的
|
# 需要注意前缀的顺序,'detplus.mmdet.ResNet' 是不正确的
|
||||||
model = MODELS.build(cfg=dict(type='mmdet.detplus.MetaNet'))
|
model = MODELS.build(cfg=dict(type='mmdet.detplus.MetaNet'))
|
||||||
# 当然,更简单的方法是直接设置 default_scope
|
|
||||||
|
# 如果希望默认从 detplus 构建模型,设置可以 default_scope
|
||||||
|
from mmengine.registry import DefaultScope
|
||||||
|
default_scope = DefaultScope.get_instance(
|
||||||
|
'my_experiment', scope_name='detplus')
|
||||||
model = MODELS.build(cfg=dict(type='MetaNet', default_scope='detplus'))
|
model = MODELS.build(cfg=dict(type='MetaNet', default_scope='detplus'))
|
||||||
```
|
```
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Optional, Union
|
from typing import Union
|
||||||
|
|
||||||
from ..registry import EVALUATORS
|
from ..registry import EVALUATORS
|
||||||
from .base import BaseEvaluator
|
from .base import BaseEvaluator
|
||||||
@ -7,9 +7,7 @@ from .composed_evaluator import ComposedEvaluator
|
|||||||
|
|
||||||
|
|
||||||
def build_evaluator(
|
def build_evaluator(
|
||||||
cfg: Union[dict, list],
|
cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]:
|
||||||
default_scope: Optional[str] = None
|
|
||||||
) -> Union[BaseEvaluator, ComposedEvaluator]:
|
|
||||||
"""Build function of evaluator.
|
"""Build function of evaluator.
|
||||||
|
|
||||||
When the evaluator config is a list, it will automatically build composed
|
When the evaluator config is a list, it will automatically build composed
|
||||||
@ -18,16 +16,12 @@ def build_evaluator(
|
|||||||
Args:
|
Args:
|
||||||
cfg (dict | list): Config of evaluator. When the config is a list, it
|
cfg (dict | list): Config of evaluator. When the config is a list, it
|
||||||
will automatically build composed evaluators.
|
will automatically build composed evaluators.
|
||||||
default_scope (str, optional): The ``default_scope`` is used to
|
|
||||||
reset the current registry. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
BaseEvaluator or ComposedEvaluator: The built evaluator.
|
BaseEvaluator or ComposedEvaluator: The built evaluator.
|
||||||
"""
|
"""
|
||||||
if isinstance(cfg, list):
|
if isinstance(cfg, list):
|
||||||
evaluators = [
|
evaluators = [EVALUATORS.build(_cfg) for _cfg in cfg]
|
||||||
EVALUATORS.build(_cfg, default_scope=default_scope) for _cfg in cfg
|
|
||||||
]
|
|
||||||
return ComposedEvaluator(evaluators=evaluators)
|
return ComposedEvaluator(evaluators=evaluators)
|
||||||
else:
|
else:
|
||||||
return EVALUATORS.build(cfg, default_scope=default_scope)
|
return EVALUATORS.build(cfg)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -30,10 +30,7 @@ def register_torch_optimizers() -> List[str]:
|
|||||||
TORCH_OPTIMIZERS = register_torch_optimizers()
|
TORCH_OPTIMIZERS = register_torch_optimizers()
|
||||||
|
|
||||||
|
|
||||||
def build_optimizer(
|
def build_optimizer(model: nn.Module, cfg: dict) -> torch.optim.Optimizer:
|
||||||
model: nn.Module,
|
|
||||||
cfg: dict,
|
|
||||||
default_scope: Optional[str] = None) -> torch.optim.Optimizer:
|
|
||||||
"""Build function of optimizer.
|
"""Build function of optimizer.
|
||||||
|
|
||||||
If ``constructor`` is set in the ``cfg``, this method will build an
|
If ``constructor`` is set in the ``cfg``, this method will build an
|
||||||
@ -58,7 +55,6 @@ def build_optimizer(
|
|||||||
dict(
|
dict(
|
||||||
type=constructor_type,
|
type=constructor_type,
|
||||||
optimizer_cfg=optimizer_cfg,
|
optimizer_cfg=optimizer_cfg,
|
||||||
paramwise_cfg=paramwise_cfg),
|
paramwise_cfg=paramwise_cfg))
|
||||||
default_scope=default_scope)
|
optimizer = optim_constructor(model)
|
||||||
optimizer = optim_constructor(model, default_scope=default_scope)
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
@ -241,9 +241,7 @@ class DefaultOptimizerConstructor:
|
|||||||
prefix=child_prefix,
|
prefix=child_prefix,
|
||||||
is_dcn_module=is_dcn_module)
|
is_dcn_module=is_dcn_module)
|
||||||
|
|
||||||
def __call__(self,
|
def __call__(self, model: nn.Module) -> torch.optim.Optimizer:
|
||||||
model: nn.Module,
|
|
||||||
default_scope: Optional[str] = None) -> torch.optim.Optimizer:
|
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model = model.module
|
model = model.module
|
||||||
|
|
||||||
@ -251,11 +249,11 @@ class DefaultOptimizerConstructor:
|
|||||||
# if no paramwise option is specified, just use the global setting
|
# if no paramwise option is specified, just use the global setting
|
||||||
if not self.paramwise_cfg:
|
if not self.paramwise_cfg:
|
||||||
optimizer_cfg['params'] = model.parameters()
|
optimizer_cfg['params'] = model.parameters()
|
||||||
return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope)
|
return OPTIMIZERS.build(optimizer_cfg)
|
||||||
|
|
||||||
# set param-wise lr and weight decay recursively
|
# set param-wise lr and weight decay recursively
|
||||||
params: List = []
|
params: List = []
|
||||||
self.add_params(params, model)
|
self.add_params(params, model)
|
||||||
optimizer_cfg['params'] = params
|
optimizer_cfg['params'] = params
|
||||||
|
|
||||||
return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope)
|
return OPTIMIZERS.build(optimizer_cfg)
|
||||||
|
@ -25,8 +25,6 @@ class DefaultScope(ManagerMixin):
|
|||||||
>>> DefaultScope.get_instance('task', scope_name='mmdet')
|
>>> DefaultScope.get_instance('task', scope_name='mmdet')
|
||||||
>>> # Get default scope globally.
|
>>> # Get default scope globally.
|
||||||
>>> scope_name = DefaultScope.get_instance('task').scope_name
|
>>> scope_name = DefaultScope.get_instance('task').scope_name
|
||||||
>>> # build model from cfg.
|
|
||||||
>>> model = MODELS.build(model_cfg, default_scope=scope_name)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name: str, scope_name: str):
|
def __init__(self, name: str, scope_name: str):
|
||||||
|
@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|||||||
|
|
||||||
from ..config import Config, ConfigDict
|
from ..config import Config, ConfigDict
|
||||||
from ..utils import is_seq_of
|
from ..utils import is_seq_of
|
||||||
|
from .default_scope import DefaultScope
|
||||||
|
|
||||||
|
|
||||||
def build_from_cfg(
|
def build_from_cfg(
|
||||||
@ -354,19 +355,13 @@ class Registry:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def build(self,
|
def build(self, *args, **kwargs) -> Any:
|
||||||
*args,
|
|
||||||
default_scope: Optional[str] = None,
|
|
||||||
**kwargs) -> Any:
|
|
||||||
"""Build an instance.
|
"""Build an instance.
|
||||||
|
|
||||||
Build an instance by calling :attr:`build_func`. If
|
Build an instance by calling :attr:`build_func`. If the global
|
||||||
:attr:`default_scope` is given, :meth:`build` will firstly get the
|
variable default scope (:obj:`DefaultScope`) exists ,
|
||||||
responding registry and then call its own :meth:`build`.
|
:meth:`build` will firstly get the responding registry and then call
|
||||||
|
its own :meth:`build`.
|
||||||
Args:
|
|
||||||
default_scope (str, optional): The ``default_scope`` is used to
|
|
||||||
reset the current registry. Defaults to None.
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from mmengine import Registry
|
>>> from mmengine import Registry
|
||||||
@ -379,9 +374,11 @@ class Registry:
|
|||||||
>>> cfg = dict(type='ResNet', depth=50)
|
>>> cfg = dict(type='ResNet', depth=50)
|
||||||
>>> model = MODELS.build(cfg)
|
>>> model = MODELS.build(cfg)
|
||||||
"""
|
"""
|
||||||
|
# get the global default scope
|
||||||
|
default_scope = DefaultScope.get_current_instance()
|
||||||
if default_scope is not None:
|
if default_scope is not None:
|
||||||
root = self._get_root_registry()
|
root = self._get_root_registry()
|
||||||
registry = root._search_child(default_scope)
|
registry = root._search_child(default_scope.scope_name)
|
||||||
if registry is None:
|
if registry is None:
|
||||||
# if `default_scope` can not be found, fallback to use self
|
# if `default_scope` can not be found, fallback to use self
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -675,8 +675,7 @@ class Runner:
|
|||||||
if isinstance(model, nn.Module):
|
if isinstance(model, nn.Module):
|
||||||
return model
|
return model
|
||||||
elif isinstance(model, dict):
|
elif isinstance(model, dict):
|
||||||
return MODELS.build(
|
return MODELS.build(model)
|
||||||
model, default_scope=self.default_scope.scope_name)
|
|
||||||
else:
|
else:
|
||||||
raise TypeError('model should be a nn.Module object or dict, '
|
raise TypeError('model should be a nn.Module object or dict, '
|
||||||
f'but got {model}')
|
f'but got {model}')
|
||||||
@ -726,9 +725,7 @@ class Runner:
|
|||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
else:
|
else:
|
||||||
model = MODEL_WRAPPERS.build(
|
model = MODEL_WRAPPERS.build(
|
||||||
model_wrapper_cfg,
|
model_wrapper_cfg, default_args=dict(model=self.model))
|
||||||
default_scope=self.default_scope.scope_name,
|
|
||||||
default_args=dict(model=self.model))
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -750,10 +747,7 @@ class Runner:
|
|||||||
if isinstance(optimizer, Optimizer):
|
if isinstance(optimizer, Optimizer):
|
||||||
return optimizer
|
return optimizer
|
||||||
elif isinstance(optimizer, dict):
|
elif isinstance(optimizer, dict):
|
||||||
optimizer = build_optimizer(
|
optimizer = build_optimizer(self.model, optimizer)
|
||||||
self.model,
|
|
||||||
optimizer,
|
|
||||||
default_scope=self.default_scope.scope_name)
|
|
||||||
return optimizer
|
return optimizer
|
||||||
else:
|
else:
|
||||||
raise TypeError('optimizer should be an Optimizer object or dict, '
|
raise TypeError('optimizer should be an Optimizer object or dict, '
|
||||||
@ -801,7 +795,6 @@ class Runner:
|
|||||||
param_schedulers.append(
|
param_schedulers.append(
|
||||||
PARAM_SCHEDULERS.build(
|
PARAM_SCHEDULERS.build(
|
||||||
_scheduler,
|
_scheduler,
|
||||||
default_scope=self.default_scope.scope_name,
|
|
||||||
default_args=dict(optimizer=self.optimizer)))
|
default_args=dict(optimizer=self.optimizer)))
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@ -837,9 +830,7 @@ class Runner:
|
|||||||
if isinstance(evaluator, (BaseEvaluator, ComposedEvaluator)):
|
if isinstance(evaluator, (BaseEvaluator, ComposedEvaluator)):
|
||||||
return evaluator
|
return evaluator
|
||||||
elif isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
elif isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||||
return build_evaluator(
|
return build_evaluator(evaluator) # type: ignore
|
||||||
evaluator,
|
|
||||||
default_scope=self.default_scope.scope_name) # type: ignore
|
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
'evaluator should be one of dict, list of dict, BaseEvaluator '
|
'evaluator should be one of dict, list of dict, BaseEvaluator '
|
||||||
@ -880,8 +871,7 @@ class Runner:
|
|||||||
# build dataset
|
# build dataset
|
||||||
dataset_cfg = dataloader_cfg.pop('dataset')
|
dataset_cfg = dataloader_cfg.pop('dataset')
|
||||||
if isinstance(dataset_cfg, dict):
|
if isinstance(dataset_cfg, dict):
|
||||||
dataset = DATASETS.build(
|
dataset = DATASETS.build(dataset_cfg)
|
||||||
dataset_cfg, default_scope=self.default_scope.scope_name)
|
|
||||||
else:
|
else:
|
||||||
# fallback to raise error in dataloader
|
# fallback to raise error in dataloader
|
||||||
# if `dataset_cfg` is not a valid type
|
# if `dataset_cfg` is not a valid type
|
||||||
@ -891,9 +881,7 @@ class Runner:
|
|||||||
sampler_cfg = dataloader_cfg.pop('sampler')
|
sampler_cfg = dataloader_cfg.pop('sampler')
|
||||||
if isinstance(sampler_cfg, dict):
|
if isinstance(sampler_cfg, dict):
|
||||||
sampler = DATA_SAMPLERS.build(
|
sampler = DATA_SAMPLERS.build(
|
||||||
sampler_cfg,
|
sampler_cfg, default_args=dict(dataset=dataset))
|
||||||
default_scope=self.default_scope.scope_name,
|
|
||||||
default_args=dict(dataset=dataset))
|
|
||||||
else:
|
else:
|
||||||
# fallback to raise error in dataloader
|
# fallback to raise error in dataloader
|
||||||
# if `sampler_cfg` is not a valid type
|
# if `sampler_cfg` is not a valid type
|
||||||
@ -961,7 +949,6 @@ class Runner:
|
|||||||
if 'type' in loop_cfg:
|
if 'type' in loop_cfg:
|
||||||
loop = LOOPS.build(
|
loop = LOOPS.build(
|
||||||
loop_cfg,
|
loop_cfg,
|
||||||
default_scope=self.default_scope.scope_name,
|
|
||||||
default_args=dict(
|
default_args=dict(
|
||||||
runner=self, dataloader=self.train_dataloader))
|
runner=self, dataloader=self.train_dataloader))
|
||||||
else:
|
else:
|
||||||
@ -1012,7 +999,6 @@ class Runner:
|
|||||||
if 'type' in loop_cfg:
|
if 'type' in loop_cfg:
|
||||||
loop = LOOPS.build(
|
loop = LOOPS.build(
|
||||||
loop_cfg,
|
loop_cfg,
|
||||||
default_scope=self.default_scope.scope_name,
|
|
||||||
default_args=dict(
|
default_args=dict(
|
||||||
runner=self,
|
runner=self,
|
||||||
dataloader=self.val_dataloader,
|
dataloader=self.val_dataloader,
|
||||||
@ -1059,7 +1045,6 @@ class Runner:
|
|||||||
if 'type' in loop_cfg:
|
if 'type' in loop_cfg:
|
||||||
loop = LOOPS.build(
|
loop = LOOPS.build(
|
||||||
loop_cfg,
|
loop_cfg,
|
||||||
default_scope=self.default_scope.scope_name,
|
|
||||||
default_args=dict(
|
default_args=dict(
|
||||||
runner=self,
|
runner=self,
|
||||||
dataloader=self.test_dataloader,
|
dataloader=self.test_dataloader,
|
||||||
|
@ -6,6 +6,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as torch_dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
import mmengine.dist as dist
|
import mmengine.dist as dist
|
||||||
@ -108,9 +109,16 @@ def init_process(rank, world_size, functions, backend='gloo'):
|
|||||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
os.environ['MASTER_PORT'] = '29505'
|
os.environ['MASTER_PORT'] = '29505'
|
||||||
os.environ['RANK'] = str(rank)
|
os.environ['RANK'] = str(rank)
|
||||||
dist.init_dist('pytorch', backend, rank=rank, world_size=world_size)
|
|
||||||
|
|
||||||
device = 'cpu' if backend == 'gloo' else 'cuda'
|
if backend == 'nccl':
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
torch.cuda.set_device(rank % num_gpus)
|
||||||
|
device = 'cuda'
|
||||||
|
else:
|
||||||
|
device = 'cpu'
|
||||||
|
|
||||||
|
torch_dist.init_process_group(
|
||||||
|
backend=backend, rank=rank, world_size=world_size)
|
||||||
|
|
||||||
for func in functions:
|
for func in functions:
|
||||||
func(device)
|
func(device)
|
||||||
|
@ -55,7 +55,13 @@ def init_process(rank, world_size, functions, backend='gloo'):
|
|||||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
os.environ['MASTER_PORT'] = '29501'
|
os.environ['MASTER_PORT'] = '29501'
|
||||||
os.environ['RANK'] = str(rank)
|
os.environ['RANK'] = str(rank)
|
||||||
dist.init_dist('pytorch', backend, rank=rank, world_size=world_size)
|
|
||||||
|
if backend == 'nccl':
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
torch.cuda.set_device(rank % num_gpus)
|
||||||
|
|
||||||
|
torch_dist.init_process_group(
|
||||||
|
backend=backend, rank=rank, world_size=world_size)
|
||||||
dist.init_local_group(0, world_size)
|
dist.init_local_group(0, world_size)
|
||||||
|
|
||||||
for func in functions:
|
for func in functions:
|
||||||
|
@ -79,10 +79,9 @@ def generate_test_results(size, batch_size, pred, label):
|
|||||||
bs_residual = size % batch_size
|
bs_residual = size % batch_size
|
||||||
for i in range(num_batch):
|
for i in range(num_batch):
|
||||||
bs = bs_residual if i == num_batch - 1 else batch_size
|
bs = bs_residual if i == num_batch - 1 else batch_size
|
||||||
data_batch = [(np.zeros(
|
data_batch = [(np.zeros((3, 10, 10)), BaseDataElement(label=label))
|
||||||
(3, 10, 10)), BaseDataElement(data={'label': label}))
|
|
||||||
for _ in range(bs)]
|
for _ in range(bs)]
|
||||||
predictions = [BaseDataElement(data={'pred': pred}) for _ in range(bs)]
|
predictions = [BaseDataElement(pred=pred) for _ in range(bs)]
|
||||||
yield (data_batch, predictions)
|
yield (data_batch, predictions)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import time
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mmengine.config import Config, ConfigDict # type: ignore
|
from mmengine.config import Config, ConfigDict # type: ignore
|
||||||
from mmengine.registry import Registry, build_from_cfg
|
from mmengine.registry import DefaultScope, Registry, build_from_cfg
|
||||||
|
|
||||||
|
|
||||||
class TestRegistry:
|
class TestRegistry:
|
||||||
@ -342,11 +344,15 @@ class TestRegistry:
|
|||||||
|
|
||||||
# test `default_scope`
|
# test `default_scope`
|
||||||
# switch the current registry to another registry
|
# switch the current registry to another registry
|
||||||
dog = LITTLE_HOUNDS.build(b_cfg, default_scope='mid_hound')
|
DefaultScope.get_instance(
|
||||||
|
f'test-{time.time()}', scope_name='mid_hound')
|
||||||
|
dog = LITTLE_HOUNDS.build(b_cfg)
|
||||||
assert isinstance(dog, Beagle)
|
assert isinstance(dog, Beagle)
|
||||||
|
|
||||||
# `default_scope` can not be found
|
# `default_scope` can not be found
|
||||||
dog = MID_HOUNDS.build(b_cfg, default_scope='scope-not-found')
|
DefaultScope.get_instance(
|
||||||
|
f'test2-{time.time()}', scope_name='scope-not-found')
|
||||||
|
dog = MID_HOUNDS.build(b_cfg)
|
||||||
assert isinstance(dog, Beagle)
|
assert isinstance(dog, Beagle)
|
||||||
|
|
||||||
def test_repr(self):
|
def test_repr(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user