[Refactor] Refactor code structure (#395)

* Rename data to structure

* adjust the way to import module

* adjust the way to import module

* rename Structure to Data Structures in docs api

* rename structure to structures

* support using some modules of mmengine without torch

* fix circleci config

* fix circleci config

* fix registry ut

* minor fix

* move init method from model/utils to model/weight_init.py

* move init method from model/utils to model/weight_init.py

* move sync_bn to model

* move functions depending on torch to dl_utils

* format import

* fix logging ut

* add weight init in model/__init__.py

* move get_config and get_model to mmengine/hub

* move log_processor.py to mmengine/runner

* fix ut

* Add TimeCounter in dl_utils/__init__.py
pull/471/head
Zaida Zhou 2022-08-24 19:14:07 +08:00 committed by GitHub
parent 486d8cda56
commit 7e1d7af2d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
93 changed files with 1323 additions and 1027 deletions

View File

@ -17,6 +17,34 @@ jobs:
pip install interrogate
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 80 mmengine
build_without_torch:
parameters:
# The python version must match available image tags in
# https://circleci.com/developer/images/image/cimg/python
python:
type: string
default: "3.7.4"
docker:
- image: cimg/python:<< parameters.python >>
resource_class: large
steps:
- checkout
- run:
name: Upgrade pip
command: |
python -V
python -m pip install pip --upgrade
python -m pip --version
- run:
name: Install mmengine dependencies
command: python -m pip install -r requirements.txt
- run:
name: Build and install
command: python -m pip install -e .
- run:
name: Run unit tests
command: python -m pytest tests/test_config tests/test_registry tests/test_fileio tests/test_logging tests/test_utils --ignore=tests/test_utils/test_dl_utils
build_cpu:
parameters:
# The python version must match available image tags in
@ -101,12 +129,16 @@ workflows:
unit_tests:
jobs:
- lint
- build_without_torch:
requires:
- lint
- build_cpu:
name: build_cpu_th1.8_py3.7
torch: 1.8.0
torchvision: 0.9.0
requires:
- lint
- build_without_torch
- hold:
type: approval # <<< This key-value pair will set your workflow to a status of "On Hold"
requires:

View File

@ -23,13 +23,13 @@ Optimizer
.. automodule:: mmengine.optim
:members:
Data
--------
.. automodule:: mmengine.data
Data Structures
----------------
.. automodule:: mmengine.structures
:members:
Dataset
--------
------------
.. automodule:: mmengine.dataset
:members:

View File

@ -23,9 +23,14 @@ Optimizer
.. automodule:: mmengine.optim
:members:
Data
--------
.. automodule:: mmengine.data
Data Structures
----------------
.. automodule:: mmengine.structures
:members:
Dataset
------------
.. automodule:: mmengine.dataset
:members:
Distributed
@ -42,3 +47,13 @@ Model
--------
.. automodule:: mmengine.model
:members:
Visualization
--------
.. automodule:: mmengine.visualization
:members:
Utils
--------
.. automodule:: mmengine.utils
:members:

View File

@ -1,14 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa
from .config import *
from .data import *
from .dataset import *
from .device import *
from .fileio import *
from .hooks import *
from .logging import *
from .registry import *
from .runner import *
from .utils import *
from .version import __version__, version_info
from .visualization import *

View File

@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .config import Config, ConfigDict, DictAction
from .get_config_model import get_config, get_model
__all__ = ['Config', 'ConfigDict', 'DictAction', 'get_config', 'get_model']
__all__ = ['Config', 'ConfigDict', 'DictAction']

View File

@ -1,12 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_data_element import BaseDataElement
from .instance_data import InstanceData
from .label_data import LabelData
from .pixel_data import PixelData
from .sampler import DefaultSampler, InfiniteSampler
from .utils import pseudo_collate, worker_init_fn
__all__ = [
'BaseDataElement', 'DefaultSampler', 'InfiniteSampler', 'worker_init_fn',
'pseudo_collate', 'InstanceData', 'LabelData', 'PixelData'
]

View File

@ -1,4 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa
from .base_dataset import BaseDataset, Compose, force_full_init
from .dataset_wrapper import ClassBalancedDataset, ConcatDataset, RepeatDataset
from .sampler import DefaultSampler, InfiniteSampler
from .utils import pseudo_collate, worker_init_fn
__all__ = [
'BaseDataset', 'Compose', 'force_full_init', 'ClassBalancedDataset',
'ConcatDataset', 'RepeatDataset', 'DefaultSampler', 'InfiniteSampler',
'worker_init_fn', 'pseudo_collate'
]

View File

@ -18,7 +18,7 @@ from .utils import (get_world_size, get_rank, get_backend, get_dist_info,
get_default_group, barrier, get_data_device,
get_comm_device, cast_data_device)
from mmengine.utils.version_utils import digit_version
from mmengine.utils.parrots_wrapper import TORCH_VERSION
from mmengine.utils.dl_utils import TORCH_VERSION
def _get_reduce_op(name: str) -> torch_dist.ReduceOp:

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Iterator, List, Optional, Sequence, Union
from mmengine.data import BaseDataElement
from mmengine.registry import EVALUATOR, METRICS
from mmengine.structures import BaseDataElement
from .metric import BaseMetric

View File

@ -5,12 +5,12 @@ from typing import Any, List, Optional, Sequence, Union
from torch import Tensor
from mmengine.data import BaseDataElement
from mmengine.dist import (broadcast_object_list, collect_results,
is_main_process)
from mmengine.fileio import dump
from mmengine.logging import print_log
from mmengine.registry import METRICS
from mmengine.structures import BaseDataElement
class BaseMetric(metaclass=ABCMeta):

View File

@ -3,8 +3,8 @@ from typing import Optional, Sequence, Union
import torch
from mmengine.data import BaseDataElement
from mmengine.registry import HOOKS
from mmengine.structures import BaseDataElement
from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]]

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence, Union
from mmengine.data import BaseDataElement
from mmengine.structures import BaseDataElement
DATA_BATCH = Optional[Sequence[dict]]

View File

@ -2,8 +2,8 @@
import time
from typing import Optional, Sequence, Union
from mmengine.data import BaseDataElement
from mmengine.registry import HOOKS
from mmengine.structures import BaseDataElement
from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]]

View File

@ -4,10 +4,10 @@ import os.path as osp
from pathlib import Path
from typing import Dict, Optional, Sequence, Union
from mmengine.data import BaseDataElement
from mmengine.fileio import FileClient, dump
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_tuple_of, scandir
DATA_BATCH = Optional[Sequence[dict]]

View File

@ -5,10 +5,10 @@ from typing import Optional, Sequence, Tuple
import cv2
import numpy as np
from mmengine.data import BaseDataElement
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
from mmengine.utils.misc import tensor2imgs
from mmengine.structures import BaseDataElement
from mmengine.utils.dl_utils import tensor2imgs
# TODO: Due to interface changes, the current class

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .hub import get_config, get_model
__all__ = ['get_config', 'get_model']

View File

@ -2,13 +2,13 @@
import importlib
import os.path as osp
import torch.nn as nn
from mmengine.config import Config
from mmengine.config.utils import (_get_cfg_metainfo,
_get_external_cfg_base_path,
_get_package_and_cfg_path)
from mmengine.registry import MODELS, DefaultScope
from mmengine.runner import load_checkpoint
from mmengine.utils import check_install_package, get_installed_path
from .config import Config
from .utils import (_get_cfg_metainfo, _get_external_cfg_base_path,
_get_package_and_cfg_path)
def get_config(cfg_path: str, pretrained: bool = False) -> Config:
@ -56,7 +56,7 @@ def get_config(cfg_path: str, pretrained: bool = False) -> Config:
return cfg
def get_model(cfg_path: str, pretrained: bool = False, **kwargs) -> nn.Module:
def get_model(cfg_path: str, pretrained: bool = False, **kwargs):
"""Get built model from external package.
Args:
@ -68,7 +68,6 @@ def get_model(cfg_path: str, pretrained: bool = False, **kwargs) -> nn.Module:
Returns:
nn.Module: Built model.
"""
import mmengine.runner
package = cfg_path.split('::')[0]
with DefaultScope.overwrite_default_scope(package): # type: ignore
cfg = get_config(cfg_path, pretrained)
@ -76,5 +75,5 @@ def get_model(cfg_path: str, pretrained: bool = False, **kwargs) -> nn.Module:
models_module.register_all_modules() # type: ignore
model = MODELS.build(cfg.model, default_args=kwargs)
if pretrained:
mmengine.runner.load_checkpoint(model, cfg.model_path)
load_checkpoint(model, cfg.model_path)
return model

View File

@ -1,9 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .history_buffer import HistoryBuffer
from .log_processor import LogProcessor
from .logger import MMLogger, print_log
from .message_hub import MessageHub
__all__ = [
'HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log', 'LogProcessor'
]
__all__ = ['HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log']

View File

@ -7,7 +7,6 @@ from typing import Optional, Union
from termcolor import colored
from mmengine.dist import get_rank
from mmengine.utils import ManagerMixin
from mmengine.utils.manager import _accquire_lock, _release_lock
@ -152,7 +151,8 @@ class MMLogger(Logger, ManagerMixin):
Logger.__init__(self, logger_name)
ManagerMixin.__init__(self, name)
# Get rank in DDP mode.
rank = get_rank()
rank = _get_rank()
# Config stream_handler. If `rank != 0`. stream_handler can only
# export ERROR logs.
@ -289,3 +289,14 @@ def print_log(msg,
raise TypeError(
'`logger` should be either a logging.Logger object, str, '
f'"silent", "current" or None, but got {type(logger)}')
def _get_rank():
"""Support using logging module without torch."""
try:
# requires torch
from mmengine.dist import get_rank
except ImportError:
return 0
else:
return get_rank()

View File

@ -2,15 +2,17 @@
import copy
import logging
from collections import OrderedDict
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import numpy as np
import torch
from mmengine.utils import ManagerMixin
from .history_buffer import HistoryBuffer
from .logger import print_log
if TYPE_CHECKING:
import torch
class MessageHub(ManagerMixin):
"""Message hub for component interaction. MessageHub is created and
@ -99,7 +101,7 @@ class MessageHub(ManagerMixin):
def update_scalar(self,
key: str,
value: Union[int, float, np.ndarray, torch.Tensor],
value: Union[int, float, np.ndarray, 'torch.Tensor'],
count: int = 1,
resumed: bool = True) -> None:
"""Update :attr:_log_scalars.
@ -315,7 +317,7 @@ class MessageHub(ManagerMixin):
return self._runtime_info[key]
def _get_valid_value(
self, value: Union[torch.Tensor, np.ndarray, int, float]) \
self, value: Union['torch.Tensor', np.ndarray, int, float]) \
-> Union[int, float]:
"""Convert value to python built-in type.
@ -328,11 +330,13 @@ class MessageHub(ManagerMixin):
if isinstance(value, np.ndarray):
assert value.size == 1
value = value.item()
elif isinstance(value, torch.Tensor):
assert value.numel() == 1
value = value.item()
elif isinstance(value, (int, float)):
value = value
else:
assert isinstance(value, (int, float))
# check whether value is torch.Tensor but don't want
# to import torch in this file
assert hasattr(value, 'numel') and value.numel() == 1
value = value.item()
return value # type: ignore
def state_dict(self) -> dict:

View File

@ -1,11 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.utils.parrots_wrapper import TORCH_VERSION
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.utils.version_utils import digit_version
from .averaged_model import (BaseAveragedModel, ExponentialMovingAverage,
MomentumAnnealingEMA, StochasticWeightAverage)
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
from .utils import detect_anomalous_params, merge_dict, stack_batch
from .utils import (detect_anomalous_params, merge_dict, revert_sync_batchnorm,
stack_batch)
from .weight_init import (BaseInit, Caffe2XavierInit, ConstantInit,
KaimingInit, NormalInit, PretrainedInit,
TruncNormalInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init,
constant_init, initialize, kaiming_init, normal_init,
trunc_normal_init, uniform_init, update_init_info,
xavier_init)
from .wrappers import (MMDistributedDataParallel,
MMSeparateDistributedDataParallel, is_model_wrapper)
@ -15,7 +23,12 @@ __all__ = [
'MomentumAnnealingEMA', 'BaseModel', 'BaseDataPreprocessor',
'ImgDataPreprocessor', 'MMSeparateDistributedDataParallel', 'BaseModule',
'stack_batch', 'merge_dict', 'detect_anomalous_params', 'ModuleList',
'ModuleDict', 'Sequential'
'ModuleDict', 'Sequential', 'revert_sync_batchnorm', 'update_init_info',
'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
'bias_init_with_prob', 'BaseInit', 'ConstantInit', 'XavierInit',
'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit',
'Caffe2XavierInit', 'PretrainedInit', 'initialize'
]
if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):

View File

@ -6,9 +6,9 @@ from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine.data import BaseDataElement
from mmengine.optim import OptimWrapper
from mmengine.registry import MODELS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
from ..base_module import BaseModule
from .data_preprocessor import BaseDataPreprocessor

View File

@ -4,8 +4,8 @@ from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
from mmengine.data import BaseDataElement
from mmengine.registry import MODELS
from mmengine.structures import BaseDataElement
from ..utils import stack_batch

View File

@ -11,6 +11,7 @@ import torch.nn as nn
from mmengine.dist import master_only
from mmengine.logging import MMLogger, print_log
from .weight_init import initialize, update_init_info
class BaseModule(nn.Module, metaclass=ABCMeta):
@ -92,7 +93,6 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
logger = MMLogger.get_current_instance()
logger_name = logger.instance_name
from .utils import initialize, update_init_info
module_name = self.__class__.__name__
if not self._is_init:
if self.init_cfg:

View File

@ -1,676 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
import math
import warnings
from typing import List, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from mmengine.logging.logger import MMLogger, print_log
from mmengine.registry import WEIGHT_INITIALIZERS, build_from_cfg
def update_init_info(module, init_info):
"""Update the `_params_init_info` in the module if the value of parameters
are changed.
Args:
module (obj:`nn.Module`): The module of PyTorch with a user-defined
attribute `_params_init_info` which records the initialization
information.
init_info (str): The string that describes the initialization.
"""
assert hasattr(
module,
'_params_init_info'), f'Can not find `_params_init_info` in {module}'
for name, param in module.named_parameters():
assert param in module._params_init_info, (
f'Find a new :obj:`Parameter` '
f'named `{name}` during executing the '
f'`init_weights` of '
f'`{module.__class__.__name__}`. '
f'Please do not add or '
f'replace parameters during executing '
f'the `init_weights`. ')
# The parameter has been changed during executing the
# `init_weights` of module
mean_value = param.data.mean().cpu()
if module._params_init_info[param]['tmp_mean_value'] != mean_value:
module._params_init_info[param]['init_info'] = init_info
module._params_init_info[param]['tmp_mean_value'] = mean_value
def constant_init(module, val, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def xavier_init(module, gain=1, bias=0, distribution='normal'):
assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform':
nn.init.xavier_uniform_(module.weight, gain=gain)
else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def normal_init(module, mean=0, std=1, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def trunc_normal_init(module: nn.Module,
mean: float = 0,
std: float = 1,
a: float = -2,
b: float = 2,
bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None:
trunc_normal_(module.weight, mean, std, a, b) # type: ignore
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) # type: ignore
def uniform_init(module, a=0, b=1, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.uniform_(module.weight, a, b)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def kaiming_init(module,
a=0,
mode='fan_out',
nonlinearity='relu',
bias=0,
distribution='normal'):
assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform':
nn.init.kaiming_uniform_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
else:
nn.init.kaiming_normal_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def caffe2_xavier_init(module, bias=0):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
kaiming_init(
module,
a=1,
mode='fan_in',
nonlinearity='leaky_relu',
bias=bias,
distribution='uniform')
def bias_init_with_prob(prior_prob):
"""initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init
def _get_bases_name(m):
return [b.__name__ for b in m.__class__.__bases__]
class BaseInit:
def __init__(self, *, bias=0, bias_prob=None, layer=None):
self.wholemodule = False
if not isinstance(bias, (int, float)):
raise TypeError(f'bias must be a number, but got a {type(bias)}')
if bias_prob is not None:
if not isinstance(bias_prob, float):
raise TypeError(f'bias_prob type must be float, \
but got {type(bias_prob)}')
if layer is not None:
if not isinstance(layer, (str, list)):
raise TypeError(f'layer must be a str or a list of str, \
but got a {type(layer)}')
else:
layer = []
if bias_prob is not None:
self.bias = bias_init_with_prob(bias_prob)
else:
self.bias = bias
self.layer = [layer] if isinstance(layer, str) else layer
def _get_init_info(self):
info = f'{self.__class__.__name__}, bias={self.bias}'
return info
@WEIGHT_INITIALIZERS.register_module(name='Constant')
class ConstantInit(BaseInit):
"""Initialize module parameters with constant values.
Args:
val (int | float): the value to fill the weights in the module with
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, val, **kwargs):
super().__init__(**kwargs)
self.val = val
def __call__(self, module):
def init(m):
if self.wholemodule:
constant_init(m, self.val, self.bias)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
constant_init(m, self.val, self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
return info
@WEIGHT_INITIALIZERS.register_module(name='Xavier')
class XavierInit(BaseInit):
r"""Initialize module parameters with values according to the method
described in `Understanding the difficulty of training deep feedforward
neural networks - Glorot, X. & Bengio, Y. (2010).
<http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_
Args:
gain (int | float): an optional scaling factor. Defaults to 1.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
distribution (str): distribution either be ``'normal'``
or ``'uniform'``. Defaults to ``'normal'``.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, gain=1, distribution='normal', **kwargs):
super().__init__(**kwargs)
self.gain = gain
self.distribution = distribution
def __call__(self, module):
def init(m):
if self.wholemodule:
xavier_init(m, self.gain, self.bias, self.distribution)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
xavier_init(m, self.gain, self.bias, self.distribution)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: gain={self.gain}, ' \
f'distribution={self.distribution}, bias={self.bias}'
return info
@WEIGHT_INITIALIZERS.register_module(name='Normal')
class NormalInit(BaseInit):
r"""Initialize module parameters with the values drawn from the normal
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
Args:
mean (int | float):the mean of the normal distribution. Defaults to 0.
std (int | float): the standard deviation of the normal distribution.
Defaults to 1.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, mean=0, std=1, **kwargs):
super().__init__(**kwargs)
self.mean = mean
self.std = std
def __call__(self, module):
def init(m):
if self.wholemodule:
normal_init(m, self.mean, self.std, self.bias)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
normal_init(m, self.mean, self.std, self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: mean={self.mean},' \
f' std={self.std}, bias={self.bias}'
return info
@WEIGHT_INITIALIZERS.register_module(name='TruncNormal')
class TruncNormalInit(BaseInit):
r"""Initialize module parameters with the values drawn from the normal
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
outside :math:`[a, b]`.
Args:
mean (float): the mean of the normal distribution. Defaults to 0.
std (float): the standard deviation of the normal distribution.
Defaults to 1.
a (float): The minimum cutoff value.
b ( float): The maximum cutoff value.
bias (float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self,
mean: float = 0,
std: float = 1,
a: float = -2,
b: float = 2,
**kwargs) -> None:
super().__init__(**kwargs)
self.mean = mean
self.std = std
self.a = a
self.b = b
def __call__(self, module: nn.Module) -> None:
def init(m):
if self.wholemodule:
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
self.bias)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
f' mean={self.mean}, std={self.std}, bias={self.bias}'
return info
@WEIGHT_INITIALIZERS.register_module(name='Uniform')
class UniformInit(BaseInit):
r"""Initialize module parameters with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`.
Args:
a (int | float): the lower bound of the uniform distribution.
Defaults to 0.
b (int | float): the upper bound of the uniform distribution.
Defaults to 1.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, a=0, b=1, **kwargs):
super().__init__(**kwargs)
self.a = a
self.b = b
def __call__(self, module):
def init(m):
if self.wholemodule:
uniform_init(m, self.a, self.b, self.bias)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
uniform_init(m, self.a, self.b, self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a},' \
f' b={self.b}, bias={self.bias}'
return info
@WEIGHT_INITIALIZERS.register_module(name='Kaiming')
class KaimingInit(BaseInit):
r"""Initialize module parameters with the values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification - He, K. et al. (2015).
<https://www.cv-foundation.org/openaccess/content_iccv_2015/
papers/He_Delving_Deep_into_ICCV_2015_paper.pdf>`_
Args:
a (int | float): the negative slope of the rectifier used after this
layer (only used with ``'leaky_relu'``). Defaults to 0.
mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing
``'fan_in'`` preserves the magnitude of the variance of the weights
in the forward pass. Choosing ``'fan_out'`` preserves the
magnitudes in the backwards pass. Defaults to ``'fan_out'``.
nonlinearity (str): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
Defaults to 'relu'.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
distribution (str): distribution either be ``'normal'`` or
``'uniform'``. Defaults to ``'normal'``.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self,
a=0,
mode='fan_out',
nonlinearity='relu',
distribution='normal',
**kwargs):
super().__init__(**kwargs)
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
self.distribution = distribution
def __call__(self, module):
def init(m):
if self.wholemodule:
kaiming_init(m, self.a, self.mode, self.nonlinearity,
self.bias, self.distribution)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
kaiming_init(m, self.a, self.mode, self.nonlinearity,
self.bias, self.distribution)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
f'nonlinearity={self.nonlinearity}, ' \
f'distribution ={self.distribution}, bias={self.bias}'
return info
@WEIGHT_INITIALIZERS.register_module(name='Caffe2Xavier')
class Caffe2XavierInit(KaimingInit):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
def __init__(self, **kwargs):
super().__init__(
a=1,
mode='fan_in',
nonlinearity='leaky_relu',
distribution='uniform',
**kwargs)
def __call__(self, module):
super().__call__(module)
@WEIGHT_INITIALIZERS.register_module(name='Pretrained')
class PretrainedInit:
"""Initialize module by loading a pretrained model.
Args:
checkpoint (str): the checkpoint file of the pretrained model should
be load.
prefix (str, optional): the prefix of a sub-module in the pretrained
model. it is for loading a part of the pretrained model to
initialize. For example, if we would like to only load the
backbone of a detector model, we can set ``prefix='backbone.'``.
Defaults to None.
map_location (str): map tensors into proper locations.
"""
def __init__(self, checkpoint, prefix=None, map_location=None):
self.checkpoint = checkpoint
self.prefix = prefix
self.map_location = map_location
def __call__(self, module):
from mmengine.runner.checkpoint import (_load_checkpoint_with_prefix,
load_checkpoint,
load_state_dict)
logger = MMLogger.get_instance('mmengine')
if self.prefix is None:
print_log(f'load model from: {self.checkpoint}', logger=logger)
load_checkpoint(
module,
self.checkpoint,
map_location=self.map_location,
strict=False,
logger=logger)
else:
print_log(
f'load {self.prefix} in model from: {self.checkpoint}',
logger=logger)
state_dict = _load_checkpoint_with_prefix(
self.prefix, self.checkpoint, map_location=self.map_location)
load_state_dict(module, state_dict, strict=False, logger=logger)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: load from {self.checkpoint}'
return info
def _initialize(module, cfg, wholemodule=False):
func = build_from_cfg(cfg, WEIGHT_INITIALIZERS)
# wholemodule flag is for override mode, there is no layer key in override
# and initializer will give init values for the whole module with the name
# in override.
func.wholemodule = wholemodule
func(module)
def _initialize_override(module, override, cfg):
if not isinstance(override, (dict, list)):
raise TypeError(f'override must be a dict or a list of dict, \
but got {type(override)}')
override = [override] if isinstance(override, dict) else override
for override_ in override:
cp_override = copy.deepcopy(override_)
name = cp_override.pop('name', None)
if name is None:
raise ValueError('`override` must contain the key "name",'
f'but got {cp_override}')
# if override only has name key, it means use args in init_cfg
if not cp_override:
cp_override.update(cfg)
# if override has name key and other args except type key, it will
# raise error
elif 'type' not in cp_override.keys():
raise ValueError(
f'`override` need "type" key, but got {cp_override}')
if hasattr(module, name):
_initialize(getattr(module, name), cp_override, wholemodule=True)
else:
raise RuntimeError(f'module did not have attribute {name}, '
f'but init_cfg is {cp_override}.')
def initialize(module, init_cfg):
r"""Initialize a module.
Args:
module (``torch.nn.Module``): the module will be initialized.
init_cfg (dict | list[dict]): initialization configuration dict to
define initializer. OpenMMLab has implemented 6 initializers
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
``Kaiming``, and ``Pretrained``.
Example:
>>> module = nn.Linear(2, 3, bias=True)
>>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
>>> initialize(module, init_cfg)
>>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
>>> # define key ``'layer'`` for initializing layer with different
>>> # configuration
>>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
dict(type='Constant', layer='Linear', val=2)]
>>> initialize(module, init_cfg)
>>> # define key``'override'`` to initialize some specific part in
>>> # module
>>> class FooNet(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.feat = nn.Conv2d(3, 16, 3)
>>> self.reg = nn.Conv2d(16, 10, 3)
>>> self.cls = nn.Conv2d(16, 5, 3)
>>> model = FooNet()
>>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
>>> override=dict(type='Constant', name='reg', val=3, bias=4))
>>> initialize(model, init_cfg)
>>> model = ResNet(depth=50)
>>> # Initialize weights with the pretrained model.
>>> init_cfg = dict(type='Pretrained',
checkpoint='torchvision://resnet50')
>>> initialize(model, init_cfg)
>>> # Initialize weights of a sub-module with the specific part of
>>> # a pretrained model by using "prefix".
>>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
>>> 'retinanet_r50_fpn_1x_coco/'\
>>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
>>> init_cfg = dict(type='Pretrained',
checkpoint=url, prefix='backbone.')
"""
if not isinstance(init_cfg, (dict, list)):
raise TypeError(f'init_cfg must be a dict or a list of dict, \
but got {type(init_cfg)}')
if isinstance(init_cfg, dict):
init_cfg = [init_cfg]
for cfg in init_cfg:
# should deeply copy the original config because cfg may be used by
# other modules, e.g., one init_cfg shared by multiple bottleneck
# blocks, the expected cfg will be changed after pop and will change
# the initialization behavior of other modules
cp_cfg = copy.deepcopy(cfg)
override = cp_cfg.pop('override', None)
_initialize(module, cp_cfg)
if override is not None:
cp_cfg.pop('layer', None)
_initialize_override(module, override, cp_cfg)
else:
# All attributes in module have same initialization.
pass
def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
b: float) -> Tensor:
# Method based on
# https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
# Modified from
# https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.',
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
lower = norm_cdf((a - mean) / std)
upper = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [lower, upper], then translate
# to [2lower-1, 2upper-1].
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor: Tensor,
mean: float = 0.,
std: float = 1.,
a: float = -2.,
b: float = 2.) -> Tensor:
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Modified from
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
Args:
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
mean (float): the mean of the normal distribution.
std (float): the standard deviation of the normal distribution.
a (float): the minimum cutoff value.
b (float): the maximum cutoff value.
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
from mmengine.utils.dl_utils import mmcv_full_available
def stack_batch(tensor_list: List[torch.Tensor],
@ -746,7 +83,7 @@ def detect_anomalous_params(loss: torch.Tensor, model) -> None:
traverse(grad_fn)
traverse(loss.grad_fn)
from mmengine import MMLogger
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
for n, p in model.named_parameters():
if p not in parameters_in_graph and p.requires_grad:
@ -799,3 +136,62 @@ try:
except ImportError:
warnings.warn('Cannot import torch.fx, `merge_dict` is a simple function '
'to merge multiple dicts')
class _BatchNormXd(nn.modules.batchnorm._BatchNorm):
"""A general BatchNorm layer without input dimension check.
Reproduced from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
is `_check_input_dim` that is designed for tensor sanity checks.
The check has been bypassed in this class for the convenience of converting
SyncBatchNorm.
"""
def _check_input_dim(self, input: torch.Tensor):
return
def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
"""Helper function to convert all `SyncBatchNorm` (SyncBN) and
`mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
`BatchNormXd` layers.
Adapted from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
Args:
module (nn.Module): The module containing `SyncBatchNorm` layers.
Returns:
module_output: The converted module with `BatchNormXd` layers.
"""
module_output = module
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
if mmcv_full_available():
from mmcv.ops import SyncBatchNorm
module_checklist.append(SyncBatchNorm)
if isinstance(module, tuple(module_checklist)):
module_output = _BatchNormXd(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
# no_grad() may not be needed here but
# just to be consistent with `convert_sync_batchnorm()`
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
module_output.training = module.training
# qconfig exists in quantized models
if hasattr(module, 'qconfig'):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, revert_sync_batchnorm(child))
del module
return module_output

View File

@ -0,0 +1,679 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
import warnings
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from mmengine.logging import MMLogger, print_log
from mmengine.registry import WEIGHT_INITIALIZERS, build_from_cfg
def update_init_info(module, init_info):
"""Update the `_params_init_info` in the module if the value of parameters
are changed.
Args:
module (obj:`nn.Module`): The module of PyTorch with a user-defined
attribute `_params_init_info` which records the initialization
information.
init_info (str): The string that describes the initialization.
"""
assert hasattr(
module,
'_params_init_info'), f'Can not find `_params_init_info` in {module}'
for name, param in module.named_parameters():
assert param in module._params_init_info, (
f'Find a new :obj:`Parameter` '
f'named `{name}` during executing the '
f'`init_weights` of '
f'`{module.__class__.__name__}`. '
f'Please do not add or '
f'replace parameters during executing '
f'the `init_weights`. ')
# The parameter has been changed during executing the
# `init_weights` of module
mean_value = param.data.mean().cpu()
if module._params_init_info[param]['tmp_mean_value'] != mean_value:
module._params_init_info[param]['init_info'] = init_info
module._params_init_info[param]['tmp_mean_value'] = mean_value
def constant_init(module, val, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def xavier_init(module, gain=1, bias=0, distribution='normal'):
assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform':
nn.init.xavier_uniform_(module.weight, gain=gain)
else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def normal_init(module, mean=0, std=1, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def trunc_normal_init(module: nn.Module,
mean: float = 0,
std: float = 1,
a: float = -2,
b: float = 2,
bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None:
trunc_normal_(module.weight, mean, std, a, b) # type: ignore
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) # type: ignore
def uniform_init(module, a=0, b=1, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.uniform_(module.weight, a, b)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def kaiming_init(module,
a=0,
mode='fan_out',
nonlinearity='relu',
bias=0,
distribution='normal'):
assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform':
nn.init.kaiming_uniform_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
else:
nn.init.kaiming_normal_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def caffe2_xavier_init(module, bias=0):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
kaiming_init(
module,
a=1,
mode='fan_in',
nonlinearity='leaky_relu',
bias=bias,
distribution='uniform')
def bias_init_with_prob(prior_prob):
"""initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init
def _get_bases_name(m):
return [b.__name__ for b in m.__class__.__bases__]
class BaseInit:
def __init__(self, *, bias=0, bias_prob=None, layer=None):
self.wholemodule = False
if not isinstance(bias, (int, float)):
raise TypeError(f'bias must be a number, but got a {type(bias)}')
if bias_prob is not None:
if not isinstance(bias_prob, float):
raise TypeError(f'bias_prob type must be float, \
but got {type(bias_prob)}')
if layer is not None:
if not isinstance(layer, (str, list)):
raise TypeError(f'layer must be a str or a list of str, \
but got a {type(layer)}')
else:
layer = []
if bias_prob is not None:
self.bias = bias_init_with_prob(bias_prob)
else:
self.bias = bias
self.layer = [layer] if isinstance(layer, str) else layer
def _get_init_info(self):
info = f'{self.__class__.__name__}, bias={self.bias}'
return info
@WEIGHT_INITIALIZERS.register_module(name='Constant')
class ConstantInit(BaseInit):
"""Initialize module parameters with constant values.
Args:
val (int | float): the value to fill the weights in the module with
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, val, **kwargs):
super().__init__(**kwargs)
self.val = val
def __call__(self, module):
def init(m):
if self.wholemodule:
constant_init(m, self.val, self.bias)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
constant_init(m, self.val, self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
return info
@WEIGHT_INITIALIZERS.register_module(name='Xavier')
class XavierInit(BaseInit):
r"""Initialize module parameters with values according to the method
described in `Understanding the difficulty of training deep feedforward
neural networks - Glorot, X. & Bengio, Y. (2010).
<http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_
Args:
gain (int | float): an optional scaling factor. Defaults to 1.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
distribution (str): distribution either be ``'normal'``
or ``'uniform'``. Defaults to ``'normal'``.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, gain=1, distribution='normal', **kwargs):
super().__init__(**kwargs)
self.gain = gain
self.distribution = distribution
def __call__(self, module):
def init(m):
if self.wholemodule:
xavier_init(m, self.gain, self.bias, self.distribution)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
xavier_init(m, self.gain, self.bias, self.distribution)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: gain={self.gain}, ' \
f'distribution={self.distribution}, bias={self.bias}'
return info
@WEIGHT_INITIALIZERS.register_module(name='Normal')
class NormalInit(BaseInit):
r"""Initialize module parameters with the values drawn from the normal
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
Args:
mean (int | float):the mean of the normal distribution. Defaults to 0.
std (int | float): the standard deviation of the normal distribution.
Defaults to 1.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, mean=0, std=1, **kwargs):
super().__init__(**kwargs)
self.mean = mean
self.std = std
def __call__(self, module):
def init(m):
if self.wholemodule:
normal_init(m, self.mean, self.std, self.bias)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
normal_init(m, self.mean, self.std, self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: mean={self.mean},' \
f' std={self.std}, bias={self.bias}'
return info
@WEIGHT_INITIALIZERS.register_module(name='TruncNormal')
class TruncNormalInit(BaseInit):
r"""Initialize module parameters with the values drawn from the normal
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
outside :math:`[a, b]`.
Args:
mean (float): the mean of the normal distribution. Defaults to 0.
std (float): the standard deviation of the normal distribution.
Defaults to 1.
a (float): The minimum cutoff value.
b ( float): The maximum cutoff value.
bias (float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self,
mean: float = 0,
std: float = 1,
a: float = -2,
b: float = 2,
**kwargs) -> None:
super().__init__(**kwargs)
self.mean = mean
self.std = std
self.a = a
self.b = b
def __call__(self, module: nn.Module) -> None:
def init(m):
if self.wholemodule:
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
self.bias)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
f' mean={self.mean}, std={self.std}, bias={self.bias}'
return info
@WEIGHT_INITIALIZERS.register_module(name='Uniform')
class UniformInit(BaseInit):
r"""Initialize module parameters with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`.
Args:
a (int | float): the lower bound of the uniform distribution.
Defaults to 0.
b (int | float): the upper bound of the uniform distribution.
Defaults to 1.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, a=0, b=1, **kwargs):
super().__init__(**kwargs)
self.a = a
self.b = b
def __call__(self, module):
def init(m):
if self.wholemodule:
uniform_init(m, self.a, self.b, self.bias)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
uniform_init(m, self.a, self.b, self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a},' \
f' b={self.b}, bias={self.bias}'
return info
@WEIGHT_INITIALIZERS.register_module(name='Kaiming')
class KaimingInit(BaseInit):
r"""Initialize module parameters with the values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification - He, K. et al. (2015).
<https://www.cv-foundation.org/openaccess/content_iccv_2015/
papers/He_Delving_Deep_into_ICCV_2015_paper.pdf>`_
Args:
a (int | float): the negative slope of the rectifier used after this
layer (only used with ``'leaky_relu'``). Defaults to 0.
mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing
``'fan_in'`` preserves the magnitude of the variance of the weights
in the forward pass. Choosing ``'fan_out'`` preserves the
magnitudes in the backwards pass. Defaults to ``'fan_out'``.
nonlinearity (str): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
Defaults to 'relu'.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
distribution (str): distribution either be ``'normal'`` or
``'uniform'``. Defaults to ``'normal'``.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self,
a=0,
mode='fan_out',
nonlinearity='relu',
distribution='normal',
**kwargs):
super().__init__(**kwargs)
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
self.distribution = distribution
def __call__(self, module):
def init(m):
if self.wholemodule:
kaiming_init(m, self.a, self.mode, self.nonlinearity,
self.bias, self.distribution)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
kaiming_init(m, self.a, self.mode, self.nonlinearity,
self.bias, self.distribution)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
f'nonlinearity={self.nonlinearity}, ' \
f'distribution ={self.distribution}, bias={self.bias}'
return info
@WEIGHT_INITIALIZERS.register_module(name='Caffe2Xavier')
class Caffe2XavierInit(KaimingInit):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
def __init__(self, **kwargs):
super().__init__(
a=1,
mode='fan_in',
nonlinearity='leaky_relu',
distribution='uniform',
**kwargs)
def __call__(self, module):
super().__call__(module)
@WEIGHT_INITIALIZERS.register_module(name='Pretrained')
class PretrainedInit:
"""Initialize module by loading a pretrained model.
Args:
checkpoint (str): the checkpoint file of the pretrained model should
be load.
prefix (str, optional): the prefix of a sub-module in the pretrained
model. it is for loading a part of the pretrained model to
initialize. For example, if we would like to only load the
backbone of a detector model, we can set ``prefix='backbone.'``.
Defaults to None.
map_location (str): map tensors into proper locations.
"""
def __init__(self, checkpoint, prefix=None, map_location=None):
self.checkpoint = checkpoint
self.prefix = prefix
self.map_location = map_location
def __call__(self, module):
from mmengine.runner.checkpoint import (_load_checkpoint_with_prefix,
load_checkpoint,
load_state_dict)
logger = MMLogger.get_instance('mmengine')
if self.prefix is None:
print_log(f'load model from: {self.checkpoint}', logger=logger)
load_checkpoint(
module,
self.checkpoint,
map_location=self.map_location,
strict=False,
logger=logger)
else:
print_log(
f'load {self.prefix} in model from: {self.checkpoint}',
logger=logger)
state_dict = _load_checkpoint_with_prefix(
self.prefix, self.checkpoint, map_location=self.map_location)
load_state_dict(module, state_dict, strict=False, logger=logger)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: load from {self.checkpoint}'
return info
def _initialize(module, cfg, wholemodule=False):
func = build_from_cfg(cfg, WEIGHT_INITIALIZERS)
# wholemodule flag is for override mode, there is no layer key in override
# and initializer will give init values for the whole module with the name
# in override.
func.wholemodule = wholemodule
func(module)
def _initialize_override(module, override, cfg):
if not isinstance(override, (dict, list)):
raise TypeError(f'override must be a dict or a list of dict, \
but got {type(override)}')
override = [override] if isinstance(override, dict) else override
for override_ in override:
cp_override = copy.deepcopy(override_)
name = cp_override.pop('name', None)
if name is None:
raise ValueError('`override` must contain the key "name",'
f'but got {cp_override}')
# if override only has name key, it means use args in init_cfg
if not cp_override:
cp_override.update(cfg)
# if override has name key and other args except type key, it will
# raise error
elif 'type' not in cp_override.keys():
raise ValueError(
f'`override` need "type" key, but got {cp_override}')
if hasattr(module, name):
_initialize(getattr(module, name), cp_override, wholemodule=True)
else:
raise RuntimeError(f'module did not have attribute {name}, '
f'but init_cfg is {cp_override}.')
def initialize(module, init_cfg):
r"""Initialize a module.
Args:
module (``torch.nn.Module``): the module will be initialized.
init_cfg (dict | list[dict]): initialization configuration dict to
define initializer. OpenMMLab has implemented 6 initializers
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
``Kaiming``, and ``Pretrained``.
Example:
>>> module = nn.Linear(2, 3, bias=True)
>>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
>>> initialize(module, init_cfg)
>>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
>>> # define key ``'layer'`` for initializing layer with different
>>> # configuration
>>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
dict(type='Constant', layer='Linear', val=2)]
>>> initialize(module, init_cfg)
>>> # define key``'override'`` to initialize some specific part in
>>> # module
>>> class FooNet(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.feat = nn.Conv2d(3, 16, 3)
>>> self.reg = nn.Conv2d(16, 10, 3)
>>> self.cls = nn.Conv2d(16, 5, 3)
>>> model = FooNet()
>>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
>>> override=dict(type='Constant', name='reg', val=3, bias=4))
>>> initialize(model, init_cfg)
>>> model = ResNet(depth=50)
>>> # Initialize weights with the pretrained model.
>>> init_cfg = dict(type='Pretrained',
checkpoint='torchvision://resnet50')
>>> initialize(model, init_cfg)
>>> # Initialize weights of a sub-module with the specific part of
>>> # a pretrained model by using "prefix".
>>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
>>> 'retinanet_r50_fpn_1x_coco/'\
>>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
>>> init_cfg = dict(type='Pretrained',
checkpoint=url, prefix='backbone.')
"""
if not isinstance(init_cfg, (dict, list)):
raise TypeError(f'init_cfg must be a dict or a list of dict, \
but got {type(init_cfg)}')
if isinstance(init_cfg, dict):
init_cfg = [init_cfg]
for cfg in init_cfg:
# should deeply copy the original config because cfg may be used by
# other modules, e.g., one init_cfg shared by multiple bottleneck
# blocks, the expected cfg will be changed after pop and will change
# the initialization behavior of other modules
cp_cfg = copy.deepcopy(cfg)
override = cp_cfg.pop('override', None)
_initialize(module, cp_cfg)
if override is not None:
cp_cfg.pop('layer', None)
_initialize_override(module, override, cp_cfg)
else:
# All attributes in module have same initialization.
pass
def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
b: float) -> Tensor:
# Method based on
# https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
# Modified from
# https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.',
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
lower = norm_cdf((a - mean) / std)
upper = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [lower, upper], then translate
# to [2lower-1, 2upper-1].
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor: Tensor,
mean: float = 0.,
std: float = 1.,
a: float = -2.,
b: float = 2.) -> Tensor:
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Modified from
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
Args:
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
mean (float): the mean of the normal distribution.
std (float): the standard deviation of the normal distribution.
a (float): the minimum cutoff value.
b (float): the maximum cutoff value.
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.utils.parrots_wrapper import TORCH_VERSION
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.utils.version_utils import digit_version
from .distributed import MMDistributedDataParallel
from .seperate_distributed import MMSeparateDistributedDataParallel

View File

@ -4,9 +4,9 @@ from typing import Dict, List
import torch
from torch.nn.parallel.distributed import DistributedDataParallel
from mmengine.data import BaseDataElement
from mmengine.optim import OptimWrapper
from mmengine.registry import MODEL_WRAPPERS
from mmengine.structures import BaseDataElement
from ..utils import detect_anomalous_params

View File

@ -7,9 +7,9 @@ from torch.distributed import ProcessGroup
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch, CPUOffload, FullyShardedDataParallel)
from mmengine.data import BaseDataElement
from mmengine.optim import OptimWrapper
from mmengine.registry import MODEL_WRAPPERS, Registry
from mmengine.structures import BaseDataElement
# support customize fsdp policy
FSDP_WRAP_POLICYS = Registry('fsdp wrap policy')

View File

@ -6,10 +6,10 @@ import torch
import torch.nn as nn
from torch.nn.parallel.distributed import DistributedDataParallel
from mmengine.data import BaseDataElement
from mmengine.device import get_device
from mmengine.optim import OptimWrapperDict
from mmengine.registry import MODEL_WRAPPERS
from mmengine.structures import BaseDataElement
from .distributed import MMDistributedDataParallel

View File

@ -6,7 +6,8 @@ import torch.nn as nn
from torch.cuda.amp import GradScaler
from mmengine.registry import OPTIM_WRAPPERS
from mmengine.utils import TORCH_VERSION, digit_version
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from .optimizer_wrapper import OptimWrapper

View File

@ -9,8 +9,9 @@ from torch.nn import GroupNorm, LayerNorm
from mmengine.logging import print_log
from mmengine.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
OPTIMIZERS)
from mmengine.utils import is_list_of, mmcv_full_available
from mmengine.utils.parrots_wrapper import _BatchNorm, _InstanceNorm
from mmengine.utils import is_list_of
from mmengine.utils.dl_utils import mmcv_full_available
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
from .optimizer_wrapper import OptimWrapper

View File

@ -10,7 +10,7 @@ from torch.optim import Optimizer
from mmengine.logging import MessageHub, print_log
from mmengine.registry import OPTIM_WRAPPERS
from mmengine.utils import has_batch_norm
from mmengine.utils.dl_utils import has_batch_norm
@OPTIM_WRAPPERS.register_module()

View File

@ -3,15 +3,15 @@ import inspect
import logging
from typing import TYPE_CHECKING, Any, Optional, Union
import torch.nn as nn
from ..config import Config, ConfigDict
from ..utils import ManagerMixin
from mmengine.config import Config, ConfigDict
from mmengine.utils import ManagerMixin
from .registry import Registry
if TYPE_CHECKING:
from ..optim.scheduler import _ParamScheduler
from ..runner import Runner
import torch.nn as nn
from mmengine.optim.scheduler import _ParamScheduler
from mmengine.runner import Runner
def build_from_cfg(
@ -158,6 +158,7 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
Returns:
object: The constructed runner object.
"""
from ..config import Config, ConfigDict
from ..logging import print_log
assert isinstance(
@ -210,10 +211,10 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
def build_model_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> \
nn.Module:
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, 'ConfigDict', 'Config']] = None
) -> 'nn.Module':
"""Build a PyTorch model from config dict(s). Different from
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
@ -239,10 +240,10 @@ def build_model_from_cfg(
def build_scheduler_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> \
'_ParamScheduler':
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, ConfigDict, Config]] = None
) -> '_ParamScheduler':
"""Builds a ``ParamScheduler`` instance from config.
``ParamScheduler`` supports building instance by its constructor or

View File

@ -23,7 +23,7 @@ class DefaultScope(ManagerMixin):
scope_name (str): Scope of current task.
Examples:
>>> from mmengine import MODELS
>>> from mmengine.model import MODELS
>>> # Define default scope in runner.
>>> DefaultScope.get_instance('task', scope_name='mmdet')
>>> # Get default scope globally.

View File

@ -7,8 +7,8 @@ from contextlib import contextmanager
from importlib import import_module
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union
from ..config.utils import PKG2PROJECT
from ..utils import is_seq_of
from mmengine.config.utils import PKG2PROJECT
from mmengine.utils import is_seq_of
from .default_scope import DefaultScope
@ -190,8 +190,7 @@ class Registry:
scope (str): The target scope.
Examples:
>>> from mmengine import Registry, DefaultScope
>>> from mmengine.registry import MODELS
>>> from mmengine.registry import Registry, DefaultScope, MODELS
>>> import time
>>> # External Registry
>>> MMDET_MODELS = Registry('mmdet_model', scope='mmdet',

View File

@ -6,6 +6,7 @@ from .checkpoint import (CheckpointLoader, find_latest_checkpoint,
get_mmcls_models, get_state_dict,
get_torchvision_models, load_checkpoint,
load_state_dict, save_checkpoint, weights_to_cpu)
from .log_processor import LogProcessor
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
from .priority import Priority, get_priority
from .runner import Runner
@ -16,5 +17,5 @@ __all__ = [
'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict',
'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop',
'TestLoop', 'Runner', 'get_priority', 'Priority', 'find_latest_checkpoint',
'autocast'
'autocast', 'LogProcessor'
]

View File

@ -7,7 +7,8 @@ import torch
from mmengine.device import get_device
from mmengine.logging import print_log
from mmengine.utils import TORCH_VERSION, digit_version
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
@contextmanager

View File

@ -19,7 +19,8 @@ from mmengine.fileio import FileClient
from mmengine.fileio import load as load_file
from mmengine.logging import print_log
from mmengine.model import is_model_wrapper
from mmengine.utils import load_url, mkdir_or_exist
from mmengine.utils import mkdir_or_exist
from mmengine.utils.dl_utils import load_url
# `MMENGINE_HOME` is the highest priority directory to save checkpoints
# downloaded from Internet. If it is not set, as a workaround, using
@ -251,9 +252,8 @@ class CheckpointLoader:
checkpoint_loader = cls._get_checkpoint_loader(filename)
class_name = checkpoint_loader.__name__
mmengine.print_log(
f'{class_name[10:]} loads checkpoint from path: {filename}',
logger)
print_log(f'{class_name[10:]} loads checkpoint from path: {filename}',
logger)
return checkpoint_loader(filename, map_location)

View File

@ -139,8 +139,8 @@ class _InfiniteDataloaderIterator:
It resets the dataloader to continue iterating when the iterator has
iterated over all the data. However, this approach is not efficient, as the
workers need to be restarted every time the dataloader is reset. It is
recommended to use `mmengine.data.InfiniteSampler` to enable the dataloader
to iterate infinitely.
recommended to use `mmengine.dataset.InfiniteSampler` to enable the
dataloader to iterate infinitely.
"""
def __init__(self, dataloader: DataLoader) -> None:
@ -157,8 +157,9 @@ class _InfiniteDataloaderIterator:
except StopIteration:
warnings.warn('Reach the end of the dataloader, it will be '
'restarted and continue to iterate. It is '
'recommended to use `mmengine.data.InfiniteSampler` '
'to enable the dataloader to iterate infinitely.')
'recommended to use '
'`mmengine.dataset.InfiniteSampler` to enable the '
'dataloader to iterate infinitely.')
self._epoch += 1
if hasattr(self._dataloader, 'sampler') and hasattr(
self._dataloader.sampler, 'set_epoch'):

View File

@ -20,16 +20,16 @@ from torch.utils.data import DataLoader
import mmengine
from mmengine.config import Config, ConfigDict
from mmengine.data import pseudo_collate, worker_init_fn
from mmengine.dataset import pseudo_collate, worker_init_fn
from mmengine.device import get_device
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
is_distributed, master_only, sync_random_seed)
from mmengine.evaluator import Evaluator
from mmengine.fileio import FileClient
from mmengine.hooks import Hook
from mmengine.logging import LogProcessor, MessageHub, MMLogger, print_log
from mmengine.logging import MessageHub, MMLogger, print_log
from mmengine.model import (BaseModel, MMDistributedDataParallel,
is_model_wrapper)
is_model_wrapper, revert_sync_batchnorm)
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
build_optim_wrapper)
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
@ -37,14 +37,15 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS,
VISUALIZERS, DefaultScope,
count_registered_modules)
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
get_git_hash, is_seq_of, revert_sync_batchnorm,
set_multi_processing)
from mmengine.utils import digit_version, get_git_hash, is_seq_of
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
set_multi_processing)
from mmengine.visualization import Visualizer
from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
find_latest_checkpoint, get_state_dict,
save_checkpoint, weights_to_cpu)
from .log_processor import LogProcessor
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
from .priority import Priority, get_priority
@ -181,7 +182,7 @@ class Runner:
Defaults to None.
Examples:
>>> from mmengine import Runner
>>> from mmengine.runner import Runner
>>> cfg = dict(
>>> model=dict(type='ToyModel'),
>>> work_dir='path/of/work_dir',

View File

@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_data_element import BaseDataElement
from .instance_data import InstanceData
from .label_data import LabelData
from .pixel_data import PixelData
__all__ = ['BaseDataElement', 'InstanceData', 'LabelData', 'PixelData']

View File

@ -71,7 +71,7 @@ class BaseDataElement:
model predictions. Defaults to None.
Examples:
>>> from mmengine.data import BaseDataElement
>>> from mmengine.structures import BaseDataElement
>>> gt_instances = BaseDataElement()
>>> bboxes = torch.rand((5, 4))
>>> scores = torch.rand((5,))

View File

@ -52,7 +52,7 @@ class InstanceData(BaseDataElement):
... return new_data
... def __repr__(self):
... return str(self.tmp)
>>> from mmengine.data import InstanceData
>>> from mmengine.structures import InstanceData
>>> import numpy as np
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
>>> instance_data = InstanceData(metainfo=img_meta)

View File

@ -3,7 +3,8 @@ from typing import Any, Callable, Optional, Union
from torch.testing import assert_allclose as _assert_allclose
from mmengine.utils import TORCH_VERSION, digit_version
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
def assert_allclose(

View File

@ -1,31 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .hub import load_url
from .manager import ManagerMeta, ManagerMixin
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
has_batch_norm, has_method, import_modules_from_strings,
is_list_of, is_method_overridden, is_seq_of, is_str,
is_tuple_of, iter_cast, list_cast, mmcv_full_available,
requires_executable, requires_package, slice_list,
to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple,
tuple_cast)
has_method, import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_str, is_tuple_of,
iter_cast, list_cast, requires_executable, requires_package,
slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
to_ntuple, tuple_cast)
from .package_utils import (call_command, check_install_package,
get_installed_path, is_installed)
from .parrots_wrapper import TORCH_VERSION
from .path import (check_file_exist, fopen, is_abs, is_filepath,
mkdir_or_exist, scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress,
track_parallel_progress, track_progress)
from .setup_env import set_multi_processing
from .sync_bn import revert_sync_batchnorm
from .timer import Timer, TimerError, check_time
from .torch_ops import torch_meshgrid
from .trace import is_jit_tracing
from .version_utils import digit_version, get_git_hash
# TODO: creates intractable circular import issues
# from .time_counter import TimeCounter
__all__ = [
'is_str', 'iter_cast', 'list_cast', 'tuple_cast', 'is_seq_of',
'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',
@ -33,12 +22,9 @@ __all__ = [
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', 'symlink',
'scandir', 'deprecated_api_warning', 'import_modules_from_strings',
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
'is_method_overridden', 'has_method', 'mmcv_full_available',
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm',
'is_abs', 'is_installed', 'call_command', 'get_installed_path',
'check_install_package', 'is_abs', 'revert_sync_batchnorm', 'collect_env',
'Timer', 'check_time', 'TimerError', 'ProgressBar', 'track_iter_progress',
'track_parallel_progress', 'track_progress', 'torch_meshgrid',
'is_jit_tracing'
'is_installed', 'call_command', 'get_installed_path',
'check_install_package', 'is_abs', 'is_method_overridden', 'has_method',
'digit_version', 'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer',
'check_time', 'TimerError', 'ProgressBar', 'track_iter_progress',
'track_parallel_progress', 'track_progress'
]

View File

@ -0,0 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .hub import load_url
from .misc import has_batch_norm, is_norm, mmcv_full_available, tensor2imgs
from .parrots_wrapper import TORCH_VERSION
from .setup_env import set_multi_processing
from .time_counter import TimeCounter
from .torch_ops import torch_meshgrid
from .trace import is_jit_tracing
__all__ = [
'load_url', 'TORCH_VERSION', 'set_multi_processing', 'has_batch_norm',
'is_norm', 'tensor2imgs', 'mmcv_full_available', 'collect_env',
'torch_meshgrid', 'is_jit_tracing', 'TimeCounter'
]

View File

@ -4,9 +4,9 @@
# torch >= 1.6.0 but loaded in torch < 1.7.0.
# More details at https://github.com/open-mmlab/mmpose/issues/904
from ..path import mkdir_or_exist
from ..version_utils import digit_version
from .parrots_wrapper import TORCH_VERSION
from .path import mkdir_or_exist
from .version_utils import digit_version
if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
'1.7.0'):

View File

@ -0,0 +1,110 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pkgutil
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from ..misc import is_tuple_of
from .parrots_wrapper import _BatchNorm, _InstanceNorm
def is_norm(layer: nn.Module,
exclude: Optional[Union[type, Tuple[type]]] = None) -> bool:
"""Check if a layer is a normalization layer.
Args:
layer (nn.Module): The layer to be checked.
exclude (type, tuple[type], optional): Types to be excluded.
Returns:
bool: Whether the layer is a norm layer.
"""
if exclude is not None:
if not isinstance(exclude, tuple):
exclude = (exclude, )
if not is_tuple_of(exclude, type):
raise TypeError(
f'"exclude" must be either None or type or a tuple of types, '
f'but got {type(exclude)}: {exclude}')
if exclude and isinstance(layer, exclude):
return False
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
return isinstance(layer, all_norm_bases)
def tensor2imgs(tensor: torch.Tensor,
mean: Optional[Tuple[float, float, float]] = None,
std: Optional[Tuple[float, float, float]] = None,
to_bgr: bool = True):
"""Convert tensor to 3-channel images or 1-channel gray images.
Args:
tensor (torch.Tensor): Tensor that contains multiple images, shape (
N, C, H, W). :math:`C` can be either 3 or 1. If C is 3, the format
should be RGB.
mean (tuple[float], optional): Mean of images. If None,
(0, 0, 0) will be used for tensor with 3-channel,
while (0, ) for tensor with 1-channel. Defaults to None.
std (tuple[float], optional): Standard deviation of images. If None,
(1, 1, 1) will be used for tensor with 3-channel,
while (1, ) for tensor with 1-channel. Defaults to None.
to_bgr (bool): For the tensor with 3 channel, convert its format to
BGR. For the tensor with 1 channel, it must be False. Defaults to
True.
Returns:
list[np.ndarray]: A list that contains multiple images.
"""
assert torch.is_tensor(tensor) and tensor.ndim == 4
channels = tensor.size(1)
assert channels in [1, 3]
if mean is None:
mean = (0, ) * channels
if std is None:
std = (1, ) * channels
assert (channels == len(mean) == len(std) == 3) or \
(channels == len(mean) == len(std) == 1 and not to_bgr)
mean = tensor.new_tensor(mean).view(1, -1)
std = tensor.new_tensor(std).view(1, -1)
tensor = tensor.permute(0, 2, 3, 1) * std + mean
imgs = tensor.detach().cpu().numpy()
if to_bgr and channels == 3:
imgs = imgs[:, :, :, (2, 1, 0)] # RGB2BGR
imgs = [np.ascontiguousarray(img) for img in imgs]
return imgs
def has_batch_norm(model: nn.Module) -> bool:
"""Detect whether model has a BatchNormalization layer.
Args:
model (nn.Module): training model.
Returns:
bool: whether model has a BatchNormalization layer
"""
if isinstance(model, _BatchNorm):
return True
for m in model.children():
if has_batch_norm(m):
return True
return False
def mmcv_full_available() -> bool:
"""Check whether mmcv-full is installed.
Returns:
bool: True if mmcv-full is installed else False.
"""
try:
import mmcv # noqa: F401
except ImportError:
return False
ext_loader = pkgutil.find_loader('mmcv._ext')
return ext_loader is not None

View File

@ -24,7 +24,7 @@ class TimeCounter:
Examples:
>>> import time
>>> from mmengine.utils import TimeCounter
>>> from mmengine.utils.dl_utils import TimeCounter
>>> @TimeCounter()
... def fun1():
... time.sleep(0.1)

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from ..version_utils import digit_version
from .parrots_wrapper import TORCH_VERSION
from .version_utils import digit_version
_torch_version_meshgrid_indexing = (
'parrots' not in TORCH_VERSION

View File

@ -3,7 +3,7 @@ import warnings
import torch
from .version_utils import digit_version
from ..version_utils import digit_version
def is_jit_tracing() -> bool:

View File

@ -2,20 +2,13 @@
import collections.abc
import functools
import itertools
import pkgutil
import subprocess
import warnings
from collections import abc
from importlib import import_module
from inspect import getfullargspec
from itertools import repeat
from typing import Any, Callable, Optional, Tuple, Type, Union
import numpy as np
import torch
import torch.nn as nn
from .parrots_wrapper import _BatchNorm, _InstanceNorm
from typing import Any, Callable, Optional, Type, Union
# From PyTorch internals
@ -394,103 +387,3 @@ def has_method(obj: object, method: str) -> bool:
bool: True if the object has the method else False.
"""
return hasattr(obj, method) and callable(getattr(obj, method))
def mmcv_full_available() -> bool:
"""Check whether mmcv-full is installed.
Returns:
bool: True if mmcv-full is installed else False.
"""
try:
import mmcv # noqa: F401
except ImportError:
return False
ext_loader = pkgutil.find_loader('mmcv._ext')
return ext_loader is not None
def is_norm(layer: nn.Module,
exclude: Optional[Union[type, Tuple[type]]] = None) -> bool:
"""Check if a layer is a normalization layer.
Args:
layer (nn.Module): The layer to be checked.
exclude (type, tuple[type], optional): Types to be excluded.
Returns:
bool: Whether the layer is a norm layer.
"""
if exclude is not None:
if not isinstance(exclude, tuple):
exclude = (exclude, )
if not is_tuple_of(exclude, type):
raise TypeError(
f'"exclude" must be either None or type or a tuple of types, '
f'but got {type(exclude)}: {exclude}')
if exclude and isinstance(layer, exclude):
return False
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
return isinstance(layer, all_norm_bases)
def tensor2imgs(tensor: torch.Tensor,
mean: Optional[Tuple[float, float, float]] = None,
std: Optional[Tuple[float, float, float]] = None,
to_bgr: bool = True):
"""Convert tensor to 3-channel images or 1-channel gray images.
Args:
tensor (torch.Tensor): Tensor that contains multiple images, shape (
N, C, H, W). :math:`C` can be either 3 or 1. If C is 3, the format
should be RGB.
mean (tuple[float], optional): Mean of images. If None,
(0, 0, 0) will be used for tensor with 3-channel,
while (0, ) for tensor with 1-channel. Defaults to None.
std (tuple[float], optional): Standard deviation of images. If None,
(1, 1, 1) will be used for tensor with 3-channel,
while (1, ) for tensor with 1-channel. Defaults to None.
to_bgr (bool): For the tensor with 3 channel, convert its format to
BGR. For the tensor with 1 channel, it must be False. Defaults to
True.
Returns:
list[np.ndarray]: A list that contains multiple images.
"""
assert torch.is_tensor(tensor) and tensor.ndim == 4
channels = tensor.size(1)
assert channels in [1, 3]
if mean is None:
mean = (0, ) * channels
if std is None:
std = (1, ) * channels
assert (channels == len(mean) == len(std) == 3) or \
(channels == len(mean) == len(std) == 1 and not to_bgr)
mean = tensor.new_tensor(mean).view(1, -1)
std = tensor.new_tensor(std).view(1, -1)
tensor = tensor.permute(0, 2, 3, 1) * std + mean
imgs = tensor.detach().cpu().numpy()
if to_bgr and channels == 3:
imgs = imgs[:, :, :, (2, 1, 0)] # RGB2BGR
imgs = [np.ascontiguousarray(img) for img in imgs]
return imgs
def has_batch_norm(model: nn.Module) -> bool:
"""Detect whether model has a BatchNormalization layer.
Args:
model (nn.Module): training model.
Returns:
bool: whether model has a BatchNormalization layer
"""
if isinstance(model, _BatchNorm):
return True
for m in model.children():
if has_batch_norm(m):
return True
return False

View File

@ -1,66 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
class _BatchNormXd(nn.modules.batchnorm._BatchNorm):
"""A general BatchNorm layer without input dimension check.
Reproduced from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
is `_check_input_dim` that is designed for tensor sanity checks.
The check has been bypassed in this class for the convenience of converting
SyncBatchNorm.
"""
def _check_input_dim(self, input: torch.Tensor):
return
def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
"""Helper function to convert all `SyncBatchNorm` (SyncBN) and
`mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
`BatchNormXd` layers.
Adapted from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
Args:
module (nn.Module): The module containing `SyncBatchNorm` layers.
Returns:
module_output: The converted module with `BatchNormXd` layers.
"""
module_output = module
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
try:
import mmcv
except ImportError:
pass
else:
if hasattr(mmcv, 'ops'):
module_checklist.append(mmcv.ops.SyncBatchNorm)
if isinstance(module, tuple(module_checklist)):
module_output = _BatchNormXd(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
# no_grad() may not be needed here but
# just to be consistent with `convert_sync_batchnorm()`
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
module_output.training = module.training
# qconfig exists in quantized models
if hasattr(module, 'qconfig'):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, revert_sync_batchnorm(child))
del module
return module_output

View File

@ -15,7 +15,7 @@ from mmengine.config import Config
from mmengine.fileio import dump
from mmengine.logging import MMLogger
from mmengine.registry import VISBACKENDS
from mmengine.utils import TORCH_VERSION
from mmengine.utils.dl_utils import TORCH_VERSION
def force_init_env(old_func: Callable) -> Any:

View File

@ -16,9 +16,9 @@ from matplotlib.patches import Circle
from matplotlib.pyplot import new_figure_manager
from mmengine.config import Config
from mmengine.data import BaseDataElement
from mmengine.dist import master_only
from mmengine.registry import VISBACKENDS, VISUALIZERS
from mmengine.structures import BaseDataElement
from mmengine.utils import ManagerMixin
from mmengine.visualization.utils import (check_type, check_type_and_length,
color_str2rgb, color_val_matplotlib,

View File

@ -6,7 +6,7 @@ from unittest.mock import patch
import numpy as np
import torch
from mmengine.data import DefaultSampler, InfiniteSampler
from mmengine.dataset import DefaultSampler, InfiniteSampler
class TestDefaultSampler(TestCase):
@ -15,7 +15,7 @@ class TestDefaultSampler(TestCase):
self.data_length = 100
self.dataset = list(range(self.data_length))
@patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1))
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(0, 1))
def test_non_dist(self, mock):
sampler = DefaultSampler(self.dataset)
self.assertEqual(sampler.world_size, 1)
@ -33,7 +33,7 @@ class TestDefaultSampler(TestCase):
self.assertEqual(sampler.num_samples, self.data_length)
self.assertEqual(list(sampler), list(range(self.data_length)))
@patch('mmengine.data.sampler.get_dist_info', return_value=(2, 3))
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(2, 3))
def test_dist(self, mock):
sampler = DefaultSampler(self.dataset)
self.assertEqual(sampler.world_size, 3)
@ -56,8 +56,8 @@ class TestDefaultSampler(TestCase):
self.assertEqual(len(sampler), sampler.num_samples)
self.assertEqual(list(sampler), list(range(self.data_length))[2::3])
@patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1))
@patch('mmengine.data.sampler.sync_random_seed', return_value=7)
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(0, 1))
@patch('mmengine.dataset.sampler.sync_random_seed', return_value=7)
def test_shuffle(self, mock1, mock2):
# test seed=None
sampler = DefaultSampler(self.dataset, seed=None)
@ -87,7 +87,7 @@ class TestInfiniteSampler(TestCase):
self.data_length = 100
self.dataset = list(range(self.data_length))
@patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1))
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(0, 1))
def test_non_dist(self, mock):
sampler = InfiniteSampler(self.dataset)
self.assertEqual(sampler.world_size, 1)
@ -101,7 +101,7 @@ class TestInfiniteSampler(TestCase):
items = [next(sampler_iter) for _ in range(self.data_length * 2)]
self.assertEqual(items, list(range(self.data_length)) * 2)
@patch('mmengine.data.sampler.get_dist_info', return_value=(2, 3))
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(2, 3))
def test_dist(self, mock):
sampler = InfiniteSampler(self.dataset)
self.assertEqual(sampler.world_size, 3)
@ -117,8 +117,8 @@ class TestInfiniteSampler(TestCase):
print(samples)
self.assertEqual(samples, targets)
@patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1))
@patch('mmengine.data.sampler.sync_random_seed', return_value=7)
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(0, 1))
@patch('mmengine.dataset.sampler.sync_random_seed', return_value=7)
def test_shuffle(self, mock1, mock2):
# test seed=None
sampler = InfiniteSampler(self.dataset, seed=None)

View File

@ -13,7 +13,8 @@ import torch.distributed as torch_dist
import mmengine.dist as dist
from mmengine.dist.dist import sync_random_seed
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils import TORCH_VERSION, digit_version
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
class TestDist(TestCase):

View File

@ -7,9 +7,9 @@ from unittest import TestCase
import numpy as np
import torch
from mmengine.data import BaseDataElement
from mmengine.evaluator import BaseMetric, Evaluator, get_metric_value
from mmengine.registry import METRICS
from mmengine.structures import BaseDataElement
@METRICS.register_module()

View File

@ -3,8 +3,8 @@ from unittest.mock import Mock
import torch
from mmengine.data import BaseDataElement
from mmengine.hooks import NaiveVisualizationHook
from mmengine.structures import BaseDataElement
class TestNaiveVisualizationHook:

View File

@ -3,7 +3,8 @@ import os.path as osp
import pytest
from mmengine import Config, DefaultScope, get_config, get_model
from mmengine import Config, DefaultScope
from mmengine.hub import get_config, get_model
from mmengine.utils import get_installed_path, is_installed
data_path = osp.join(osp.dirname(osp.dirname(__file__)), 'data/')

View File

@ -1,9 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmengine import HistoryBuffer
from mmengine.logging import HistoryBuffer
array_method = [np.array, lambda x: x]
try:
import torch
except ImportError:
pass
else:
array_method.append(torch.tensor)
class TestLoggerBuffer:
@ -30,8 +37,7 @@ class TestLoggerBuffer:
with pytest.raises(AssertionError):
HistoryBuffer([1, 2], [1])
@pytest.mark.parametrize('array_method',
[torch.tensor, np.array, lambda x: x])
@pytest.mark.parametrize('array_method', array_method)
def test_update(self, array_method):
# test `update` method
log_buffer = HistoryBuffer()

View File

@ -15,9 +15,7 @@ class TestLogger:
stream_handler_regex_time = r'\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
file_handler_regex_time = r'\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
# Since `get_rank` has been imported in logger.py, it needs to mock
# `logger.get_rank`
@patch('mmengine.logging.logger.get_rank', lambda: 0)
@patch('mmengine.logging.logger._get_rank', lambda: 0)
def test_init_rank0(self, tmp_path):
logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO')
assert logger.name == 'mmengine'
@ -47,7 +45,7 @@ class TestLogger:
assert logger.instance_name == 'rank0.pkg3'
logging.shutdown()
@patch('mmengine.logging.logger.get_rank', lambda: 1)
@patch('mmengine.logging.logger._get_rank', lambda: 1)
def test_init_rank1(self, tmp_path):
# If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`.
tmp_file = tmp_path / 'tmp_file.log'

View File

@ -4,9 +4,9 @@ from collections import OrderedDict
import numpy as np
import pytest
import torch
from mmengine.logging import HistoryBuffer, MessageHub
from mmengine.utils import is_installed
class NoDeepCopy:
@ -84,7 +84,9 @@ class TestMessageHub:
message_hub.update_info('test_value', recorded_dict)
assert message_hub.get_info('test_value') == recorded_dict
@pytest.mark.skipif(not is_installed('torch'), reason='requires torch')
def test_get_scalars(self):
import torch
message_hub = MessageHub.get_instance('mmengine')
log_dict = dict(
loss=1,

View File

@ -4,8 +4,8 @@ from unittest import TestCase
import torch
import torch.nn.functional as F
from mmengine import InstanceData
from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor
from mmengine.structures import InstanceData
from mmengine.testing import assert_allclose

View File

@ -3,7 +3,7 @@ import pytest
import torch
import torch.nn as nn
from mmengine.utils import revert_sync_batchnorm
from mmengine.model import revert_sync_batchnorm
@pytest.mark.skipif(

View File

@ -15,7 +15,7 @@ from mmengine.model.averaged_model import ExponentialMovingAverage
from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict
from mmengine.testing import assert_allclose
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils.parrots_wrapper import TORCH_VERSION
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.utils.version_utils import digit_version
if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):

View File

@ -11,7 +11,7 @@ from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
build_optim_wrapper)
from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS
from mmengine.registry import build_from_cfg
from mmengine.utils import mmcv_full_available
from mmengine.utils.dl_utils import mmcv_full_available
MMCV_FULL_AVAILABLE = mmcv_full_available()
if not MMCV_FULL_AVAILABLE:

View File

@ -16,7 +16,8 @@ from mmengine.logging import MessageHub, MMLogger
from mmengine.optim import AmpOptimWrapper, OptimWrapper
from mmengine.testing import assert_allclose
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils import TORCH_VERSION, digit_version
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
class ToyModel(nn.Module):

View File

@ -1,10 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch.nn as nn
from torch.optim import SGD
from mmengine import (PARAM_SCHEDULERS, Config, ConfigDict, ManagerMixin,
Registry, build_from_cfg, build_model_from_cfg)
from mmengine.utils import is_installed
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
@ -128,7 +127,10 @@ def test_build_from_cfg(cfg_type):
Visualizer.get_current_instance()
@pytest.mark.skipif(not is_installed('torch'), reason='tests requires torch')
def test_build_model_from_cfg():
import torch.nn as nn
BACKBONES = Registry('backbone', build_func=build_model_from_cfg)
@BACKBONES.register_module()
@ -186,7 +188,10 @@ def test_build_model_from_cfg():
assert NEW_MODELS.build_func is pseudo_build
@pytest.mark.skipif(not is_installed('torch'), reason='tests requires torch')
def test_build_sheduler_from_cfg():
import torch.nn as nn
from torch.optim import SGD
model = nn.Conv2d(1, 1, 1)
optimizer = SGD(model.parameters(), lr=0.1)
cfg = dict(

View File

@ -4,7 +4,9 @@ import time
import pytest
from mmengine.config import Config, ConfigDict # type: ignore
from mmengine.registry import DefaultScope, Registry, build_from_cfg
from mmengine.registry import (DefaultScope, Registry, build_from_cfg,
build_model_from_cfg)
from mmengine.utils import ManagerMixin
class TestRegistry:
@ -473,3 +475,187 @@ class TestRegistry:
"<locals>.Munchkin'>")
repr_str += '})'
assert repr(CATS) == repr_str
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
def test_build_from_cfg(cfg_type):
BACKBONES = Registry('backbone')
@BACKBONES.register_module()
class ResNet:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
@BACKBONES.register_module()
class ResNeXt:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
# test `cfg` parameter
# `cfg` should be a dict, ConfigDict or Config object
with pytest.raises(
TypeError,
match=('cfg should be a dict, ConfigDict or Config, but got '
"<class 'str'>")):
cfg = 'ResNet'
model = build_from_cfg(cfg, BACKBONES)
# `cfg` is a dict, ConfigDict or Config object
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# `cfg` is a dict but it does not contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = dict(depth=50, stages=4)
cfg = cfg_type(cfg)
model = build_from_cfg(cfg, BACKBONES)
# cfg['type'] should be a str or class
with pytest.raises(
TypeError,
match="type must be a str or valid type, but got <class 'int'>"):
cfg = dict(type=1000)
cfg = cfg_type(cfg)
model = build_from_cfg(cfg, BACKBONES)
cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3))
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
cfg = cfg_type(dict(type=ResNet, depth=50))
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# non-registered class
with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
cfg = cfg_type(dict(type='VGG'))
model = build_from_cfg(cfg, BACKBONES)
# `cfg` contains unexpected arguments
with pytest.raises(TypeError):
cfg = cfg_type(dict(type='ResNet', non_existing_arg=50))
model = build_from_cfg(cfg, BACKBONES)
# test `default_args` parameter
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3)))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 3
# default_args must be a dict or None
with pytest.raises(TypeError):
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, BACKBONES, default_args=1)
# cfg or default_args should contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = cfg_type(dict(depth=50))
model = build_from_cfg(
cfg, BACKBONES, default_args=cfg_type(dict(stages=4)))
# "type" defined using default_args
cfg = cfg_type(dict(depth=50))
model = build_from_cfg(
cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet')))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = cfg_type(dict(depth=50))
model = build_from_cfg(
cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet)))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# test `registry` parameter
# incorrect registry type
with pytest.raises(
TypeError,
match=('registry must be a mmengine.Registry object, but got '
"<class 'str'>")):
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, 'BACKBONES')
VISUALIZER = Registry('visualizer')
@VISUALIZER.register_module()
class Visualizer(ManagerMixin):
def __init__(self, name):
super().__init__(name)
with pytest.raises(RuntimeError):
Visualizer.get_current_instance()
cfg = dict(type='Visualizer', name='visualizer')
build_from_cfg(cfg, VISUALIZER)
Visualizer.get_current_instance()
def test_build_model_from_cfg():
try:
import torch.nn as nn
except ImportError:
pytest.skip('require torch')
BACKBONES = Registry('backbone', build_func=build_model_from_cfg)
@BACKBONES.register_module()
class ResNet(nn.Module):
def __init__(self, depth, stages=4):
super().__init__()
self.depth = depth
self.stages = stages
def forward(self, x):
return x
@BACKBONES.register_module()
class ResNeXt(nn.Module):
def __init__(self, depth, stages=4):
super().__init__()
self.depth = depth
self.stages = stages
def forward(self, x):
return x
cfg = dict(type='ResNet', depth=50)
model = BACKBONES.build(cfg)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = dict(type='ResNeXt', depth=50, stages=3)
model = BACKBONES.build(cfg)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
cfg = [
dict(type='ResNet', depth=50),
dict(type='ResNeXt', depth=50, stages=3)
]
model = BACKBONES.build(cfg)
assert isinstance(model, nn.Sequential)
assert isinstance(model[0], ResNet)
assert model[0].depth == 50 and model[0].stages == 4
assert isinstance(model[1], ResNeXt)
assert model[1].depth == 50 and model[1].stages == 3
# test inherit `build_func` from parent
NEW_MODELS = Registry('models', parent=BACKBONES, scope='new')
assert NEW_MODELS.build_func is build_model_from_cfg
# test specify `build_func`
def pseudo_build(cfg):
return cfg
NEW_MODELS = Registry('models', parent=BACKBONES, build_func=pseudo_build)
assert NEW_MODELS.build_func is pseudo_build

View File

@ -5,7 +5,8 @@ import torch
import torch.nn as nn
from mmengine.runner import autocast
from mmengine.utils import TORCH_VERSION, digit_version
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
class TestAmp(unittest.TestCase):

View File

@ -6,7 +6,8 @@ import numpy as np
import pytest
import torch
from mmengine.logging import HistoryBuffer, LogProcessor, MessageHub, MMLogger
from mmengine.logging import HistoryBuffer, MessageHub, MMLogger
from mmengine.runner import LogProcessor
class TestLogProcessor:

View File

@ -14,12 +14,12 @@ from torch.optim import SGD, Adam
from torch.utils.data import DataLoader, Dataset
from mmengine.config import Config
from mmengine.data import DefaultSampler
from mmengine.dataset import DefaultSampler
from mmengine.evaluator import BaseMetric, Evaluator
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
IterTimerHook, LoggerHook, ParamSchedulerHook,
RuntimeInfoHook)
from mmengine.logging import LogProcessor, MessageHub, MMLogger
from mmengine.logging import MessageHub, MMLogger
from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
OptimWrapper, OptimWrapperDict, StepLR)
@ -28,10 +28,11 @@ from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
PARAM_SCHEDULERS, RUNNERS, Registry)
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
Runner, TestLoop, ValLoop)
LogProcessor, Runner, TestLoop, ValLoop)
from mmengine.runner.loops import _InfiniteDataloaderIterator
from mmengine.runner.priority import Priority, get_priority
from mmengine.utils import TORCH_VERSION, digit_version, is_list_of
from mmengine.utils import digit_version, is_list_of
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.visualization import Visualizer

View File

@ -6,7 +6,7 @@ import numpy as np
import pytest
import torch
from mmengine.data import BaseDataElement
from mmengine.structures import BaseDataElement
class TestBaseDataElement(TestCase):

View File

@ -7,7 +7,7 @@ import numpy as np
import pytest
import torch
from mmengine.data import BaseDataElement, InstanceData
from mmengine.structures import BaseDataElement, InstanceData
class TmpObject:

View File

@ -4,7 +4,7 @@ from unittest import TestCase
import pytest
import torch
from mmengine.data import LabelData
from mmengine.structures import LabelData
class TestLabelData(TestCase):

View File

@ -6,7 +6,7 @@ import numpy as np
import pytest
import torch
from mmengine.data import PixelData
from mmengine.structures import PixelData
class TestPixelData(TestCase):

View File

@ -5,7 +5,8 @@ from unittest import TestCase
import torch.cuda
import mmengine
from mmengine.utils.collect_env import _get_cuda_home, collect_env
from mmengine.utils.dl_utils import collect_env
from mmengine.utils.dl_utils.parrots_wrapper import _get_cuda_home
class TestCollectEnv(TestCase):

View File

@ -5,7 +5,7 @@ import platform
import cv2
from mmengine.utils import set_multi_processing
from mmengine.utils.dl_utils import set_multi_processing
def test_setup_multi_processes():

View File

@ -2,7 +2,7 @@
import time
import unittest
from mmengine.utils.time_counter import TimeCounter
from mmengine.utils.dl_utils.time_counter import TimeCounter
class TestTimeCounter(unittest.TestCase):

View File

@ -2,7 +2,7 @@
import pytest
import torch
from mmengine.utils import torch_meshgrid
from mmengine.utils.dl_utils import torch_meshgrid
def test_torch_meshgrid():

View File

@ -2,7 +2,8 @@
import pytest
import torch
from mmengine.utils import digit_version, is_jit_tracing
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import is_jit_tracing
@pytest.mark.skipif(