[Refactor] Refactor code structure (#395)
* Rename data to structure * adjust the way to import module * adjust the way to import module * rename Structure to Data Structures in docs api * rename structure to structures * support using some modules of mmengine without torch * fix circleci config * fix circleci config * fix registry ut * minor fix * move init method from model/utils to model/weight_init.py * move init method from model/utils to model/weight_init.py * move sync_bn to model * move functions depending on torch to dl_utils * format import * fix logging ut * add weight init in model/__init__.py * move get_config and get_model to mmengine/hub * move log_processor.py to mmengine/runner * fix ut * Add TimeCounter in dl_utils/__init__.pypull/471/head
parent
486d8cda56
commit
7e1d7af2d9
|
@ -17,6 +17,34 @@ jobs:
|
|||
pip install interrogate
|
||||
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 80 mmengine
|
||||
|
||||
build_without_torch:
|
||||
parameters:
|
||||
# The python version must match available image tags in
|
||||
# https://circleci.com/developer/images/image/cimg/python
|
||||
python:
|
||||
type: string
|
||||
default: "3.7.4"
|
||||
docker:
|
||||
- image: cimg/python:<< parameters.python >>
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Upgrade pip
|
||||
command: |
|
||||
python -V
|
||||
python -m pip install pip --upgrade
|
||||
python -m pip --version
|
||||
- run:
|
||||
name: Install mmengine dependencies
|
||||
command: python -m pip install -r requirements.txt
|
||||
- run:
|
||||
name: Build and install
|
||||
command: python -m pip install -e .
|
||||
- run:
|
||||
name: Run unit tests
|
||||
command: python -m pytest tests/test_config tests/test_registry tests/test_fileio tests/test_logging tests/test_utils --ignore=tests/test_utils/test_dl_utils
|
||||
|
||||
build_cpu:
|
||||
parameters:
|
||||
# The python version must match available image tags in
|
||||
|
@ -101,12 +129,16 @@ workflows:
|
|||
unit_tests:
|
||||
jobs:
|
||||
- lint
|
||||
- build_without_torch:
|
||||
requires:
|
||||
- lint
|
||||
- build_cpu:
|
||||
name: build_cpu_th1.8_py3.7
|
||||
torch: 1.8.0
|
||||
torchvision: 0.9.0
|
||||
requires:
|
||||
- lint
|
||||
- build_without_torch
|
||||
- hold:
|
||||
type: approval # <<< This key-value pair will set your workflow to a status of "On Hold"
|
||||
requires:
|
||||
|
|
|
@ -23,13 +23,13 @@ Optimizer
|
|||
.. automodule:: mmengine.optim
|
||||
:members:
|
||||
|
||||
Data
|
||||
--------
|
||||
.. automodule:: mmengine.data
|
||||
Data Structures
|
||||
----------------
|
||||
.. automodule:: mmengine.structures
|
||||
:members:
|
||||
|
||||
Dataset
|
||||
--------
|
||||
------------
|
||||
.. automodule:: mmengine.dataset
|
||||
:members:
|
||||
|
||||
|
|
|
@ -23,9 +23,14 @@ Optimizer
|
|||
.. automodule:: mmengine.optim
|
||||
:members:
|
||||
|
||||
Data
|
||||
--------
|
||||
.. automodule:: mmengine.data
|
||||
Data Structures
|
||||
----------------
|
||||
.. automodule:: mmengine.structures
|
||||
:members:
|
||||
|
||||
Dataset
|
||||
------------
|
||||
.. automodule:: mmengine.dataset
|
||||
:members:
|
||||
|
||||
Distributed
|
||||
|
@ -42,3 +47,13 @@ Model
|
|||
--------
|
||||
.. automodule:: mmengine.model
|
||||
:members:
|
||||
|
||||
Visualization
|
||||
--------
|
||||
.. automodule:: mmengine.visualization
|
||||
:members:
|
||||
|
||||
Utils
|
||||
--------
|
||||
.. automodule:: mmengine.utils
|
||||
:members:
|
||||
|
|
|
@ -1,14 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# flake8: noqa
|
||||
from .config import *
|
||||
from .data import *
|
||||
from .dataset import *
|
||||
from .device import *
|
||||
from .fileio import *
|
||||
from .hooks import *
|
||||
from .logging import *
|
||||
from .registry import *
|
||||
from .runner import *
|
||||
from .utils import *
|
||||
from .version import __version__, version_info
|
||||
from .visualization import *
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .config import Config, ConfigDict, DictAction
|
||||
from .get_config_model import get_config, get_model
|
||||
|
||||
__all__ = ['Config', 'ConfigDict', 'DictAction', 'get_config', 'get_model']
|
||||
__all__ = ['Config', 'ConfigDict', 'DictAction']
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_data_element import BaseDataElement
|
||||
from .instance_data import InstanceData
|
||||
from .label_data import LabelData
|
||||
from .pixel_data import PixelData
|
||||
from .sampler import DefaultSampler, InfiniteSampler
|
||||
from .utils import pseudo_collate, worker_init_fn
|
||||
|
||||
__all__ = [
|
||||
'BaseDataElement', 'DefaultSampler', 'InfiniteSampler', 'worker_init_fn',
|
||||
'pseudo_collate', 'InstanceData', 'LabelData', 'PixelData'
|
||||
]
|
|
@ -1,4 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# flake8: noqa
|
||||
from .base_dataset import BaseDataset, Compose, force_full_init
|
||||
from .dataset_wrapper import ClassBalancedDataset, ConcatDataset, RepeatDataset
|
||||
from .sampler import DefaultSampler, InfiniteSampler
|
||||
from .utils import pseudo_collate, worker_init_fn
|
||||
|
||||
__all__ = [
|
||||
'BaseDataset', 'Compose', 'force_full_init', 'ClassBalancedDataset',
|
||||
'ConcatDataset', 'RepeatDataset', 'DefaultSampler', 'InfiniteSampler',
|
||||
'worker_init_fn', 'pseudo_collate'
|
||||
]
|
||||
|
|
|
@ -18,7 +18,7 @@ from .utils import (get_world_size, get_rank, get_backend, get_dist_info,
|
|||
get_default_group, barrier, get_data_device,
|
||||
get_comm_device, cast_data_device)
|
||||
from mmengine.utils.version_utils import digit_version
|
||||
from mmengine.utils.parrots_wrapper import TORCH_VERSION
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
||||
|
||||
def _get_reduce_op(name: str) -> torch_dist.ReduceOp:
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Iterator, List, Optional, Sequence, Union
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.registry import EVALUATOR, METRICS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from .metric import BaseMetric
|
||||
|
||||
|
||||
|
|
|
@ -5,12 +5,12 @@ from typing import Any, List, Optional, Sequence, Union
|
|||
|
||||
from torch import Tensor
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.dist import (broadcast_object_list, collect_results,
|
||||
is_main_process)
|
||||
from mmengine.fileio import dump
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import METRICS
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
|
||||
class BaseMetric(metaclass=ABCMeta):
|
||||
|
|
|
@ -3,8 +3,8 @@ from typing import Optional, Sequence, Union
|
|||
|
||||
import torch
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
import time
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
|
|
@ -4,10 +4,10 @@ import os.path as osp
|
|||
from pathlib import Path
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.fileio import FileClient, dump
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from mmengine.utils import is_tuple_of, scandir
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
|
|
@ -5,10 +5,10 @@ from typing import Optional, Sequence, Tuple
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.utils.misc import tensor2imgs
|
||||
from mmengine.structures import BaseDataElement
|
||||
from mmengine.utils.dl_utils import tensor2imgs
|
||||
|
||||
|
||||
# TODO: Due to interface changes, the current class
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .hub import get_config, get_model
|
||||
|
||||
__all__ = ['get_config', 'get_model']
|
|
@ -2,13 +2,13 @@
|
|||
import importlib
|
||||
import os.path as osp
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.config import Config
|
||||
from mmengine.config.utils import (_get_cfg_metainfo,
|
||||
_get_external_cfg_base_path,
|
||||
_get_package_and_cfg_path)
|
||||
from mmengine.registry import MODELS, DefaultScope
|
||||
from mmengine.runner import load_checkpoint
|
||||
from mmengine.utils import check_install_package, get_installed_path
|
||||
from .config import Config
|
||||
from .utils import (_get_cfg_metainfo, _get_external_cfg_base_path,
|
||||
_get_package_and_cfg_path)
|
||||
|
||||
|
||||
def get_config(cfg_path: str, pretrained: bool = False) -> Config:
|
||||
|
@ -56,7 +56,7 @@ def get_config(cfg_path: str, pretrained: bool = False) -> Config:
|
|||
return cfg
|
||||
|
||||
|
||||
def get_model(cfg_path: str, pretrained: bool = False, **kwargs) -> nn.Module:
|
||||
def get_model(cfg_path: str, pretrained: bool = False, **kwargs):
|
||||
"""Get built model from external package.
|
||||
|
||||
Args:
|
||||
|
@ -68,7 +68,6 @@ def get_model(cfg_path: str, pretrained: bool = False, **kwargs) -> nn.Module:
|
|||
Returns:
|
||||
nn.Module: Built model.
|
||||
"""
|
||||
import mmengine.runner
|
||||
package = cfg_path.split('::')[0]
|
||||
with DefaultScope.overwrite_default_scope(package): # type: ignore
|
||||
cfg = get_config(cfg_path, pretrained)
|
||||
|
@ -76,5 +75,5 @@ def get_model(cfg_path: str, pretrained: bool = False, **kwargs) -> nn.Module:
|
|||
models_module.register_all_modules() # type: ignore
|
||||
model = MODELS.build(cfg.model, default_args=kwargs)
|
||||
if pretrained:
|
||||
mmengine.runner.load_checkpoint(model, cfg.model_path)
|
||||
load_checkpoint(model, cfg.model_path)
|
||||
return model
|
|
@ -1,9 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .history_buffer import HistoryBuffer
|
||||
from .log_processor import LogProcessor
|
||||
from .logger import MMLogger, print_log
|
||||
from .message_hub import MessageHub
|
||||
|
||||
__all__ = [
|
||||
'HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log', 'LogProcessor'
|
||||
]
|
||||
__all__ = ['HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log']
|
||||
|
|
|
@ -7,7 +7,6 @@ from typing import Optional, Union
|
|||
|
||||
from termcolor import colored
|
||||
|
||||
from mmengine.dist import get_rank
|
||||
from mmengine.utils import ManagerMixin
|
||||
from mmengine.utils.manager import _accquire_lock, _release_lock
|
||||
|
||||
|
@ -152,7 +151,8 @@ class MMLogger(Logger, ManagerMixin):
|
|||
Logger.__init__(self, logger_name)
|
||||
ManagerMixin.__init__(self, name)
|
||||
# Get rank in DDP mode.
|
||||
rank = get_rank()
|
||||
|
||||
rank = _get_rank()
|
||||
|
||||
# Config stream_handler. If `rank != 0`. stream_handler can only
|
||||
# export ERROR logs.
|
||||
|
@ -289,3 +289,14 @@ def print_log(msg,
|
|||
raise TypeError(
|
||||
'`logger` should be either a logging.Logger object, str, '
|
||||
f'"silent", "current" or None, but got {type(logger)}')
|
||||
|
||||
|
||||
def _get_rank():
|
||||
"""Support using logging module without torch."""
|
||||
try:
|
||||
# requires torch
|
||||
from mmengine.dist import get_rank
|
||||
except ImportError:
|
||||
return 0
|
||||
else:
|
||||
return get_rank()
|
||||
|
|
|
@ -2,15 +2,17 @@
|
|||
import copy
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmengine.utils import ManagerMixin
|
||||
from .history_buffer import HistoryBuffer
|
||||
from .logger import print_log
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
||||
class MessageHub(ManagerMixin):
|
||||
"""Message hub for component interaction. MessageHub is created and
|
||||
|
@ -99,7 +101,7 @@ class MessageHub(ManagerMixin):
|
|||
|
||||
def update_scalar(self,
|
||||
key: str,
|
||||
value: Union[int, float, np.ndarray, torch.Tensor],
|
||||
value: Union[int, float, np.ndarray, 'torch.Tensor'],
|
||||
count: int = 1,
|
||||
resumed: bool = True) -> None:
|
||||
"""Update :attr:_log_scalars.
|
||||
|
@ -315,7 +317,7 @@ class MessageHub(ManagerMixin):
|
|||
return self._runtime_info[key]
|
||||
|
||||
def _get_valid_value(
|
||||
self, value: Union[torch.Tensor, np.ndarray, int, float]) \
|
||||
self, value: Union['torch.Tensor', np.ndarray, int, float]) \
|
||||
-> Union[int, float]:
|
||||
"""Convert value to python built-in type.
|
||||
|
||||
|
@ -328,11 +330,13 @@ class MessageHub(ManagerMixin):
|
|||
if isinstance(value, np.ndarray):
|
||||
assert value.size == 1
|
||||
value = value.item()
|
||||
elif isinstance(value, torch.Tensor):
|
||||
assert value.numel() == 1
|
||||
value = value.item()
|
||||
elif isinstance(value, (int, float)):
|
||||
value = value
|
||||
else:
|
||||
assert isinstance(value, (int, float))
|
||||
# check whether value is torch.Tensor but don't want
|
||||
# to import torch in this file
|
||||
assert hasattr(value, 'numel') and value.numel() == 1
|
||||
value = value.item()
|
||||
return value # type: ignore
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
|
|
|
@ -1,11 +1,19 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.utils.parrots_wrapper import TORCH_VERSION
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
from mmengine.utils.version_utils import digit_version
|
||||
from .averaged_model import (BaseAveragedModel, ExponentialMovingAverage,
|
||||
MomentumAnnealingEMA, StochasticWeightAverage)
|
||||
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
||||
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
|
||||
from .utils import detect_anomalous_params, merge_dict, stack_batch
|
||||
from .utils import (detect_anomalous_params, merge_dict, revert_sync_batchnorm,
|
||||
stack_batch)
|
||||
from .weight_init import (BaseInit, Caffe2XavierInit, ConstantInit,
|
||||
KaimingInit, NormalInit, PretrainedInit,
|
||||
TruncNormalInit, UniformInit, XavierInit,
|
||||
bias_init_with_prob, caffe2_xavier_init,
|
||||
constant_init, initialize, kaiming_init, normal_init,
|
||||
trunc_normal_init, uniform_init, update_init_info,
|
||||
xavier_init)
|
||||
from .wrappers import (MMDistributedDataParallel,
|
||||
MMSeparateDistributedDataParallel, is_model_wrapper)
|
||||
|
||||
|
@ -15,7 +23,12 @@ __all__ = [
|
|||
'MomentumAnnealingEMA', 'BaseModel', 'BaseDataPreprocessor',
|
||||
'ImgDataPreprocessor', 'MMSeparateDistributedDataParallel', 'BaseModule',
|
||||
'stack_batch', 'merge_dict', 'detect_anomalous_params', 'ModuleList',
|
||||
'ModuleDict', 'Sequential'
|
||||
'ModuleDict', 'Sequential', 'revert_sync_batchnorm', 'update_init_info',
|
||||
'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
|
||||
'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
|
||||
'bias_init_with_prob', 'BaseInit', 'ConstantInit', 'XavierInit',
|
||||
'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit',
|
||||
'Caffe2XavierInit', 'PretrainedInit', 'initialize'
|
||||
]
|
||||
|
||||
if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
|
||||
|
|
|
@ -6,9 +6,9 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.registry import MODELS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from mmengine.utils import is_list_of
|
||||
from ..base_module import BaseModule
|
||||
from .data_preprocessor import BaseDataPreprocessor
|
||||
|
|
|
@ -4,8 +4,8 @@ from typing import List, Optional, Sequence, Tuple, Union
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.registry import MODELS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from ..utils import stack_batch
|
||||
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import torch.nn as nn
|
|||
|
||||
from mmengine.dist import master_only
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from .weight_init import initialize, update_init_info
|
||||
|
||||
|
||||
class BaseModule(nn.Module, metaclass=ABCMeta):
|
||||
|
@ -92,7 +93,6 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
|
|||
logger = MMLogger.get_current_instance()
|
||||
logger_name = logger.instance_name
|
||||
|
||||
from .utils import initialize, update_init_info
|
||||
module_name = self.__class__.__name__
|
||||
if not self._is_init:
|
||||
if self.init_cfg:
|
||||
|
|
|
@ -1,676 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from mmengine.logging.logger import MMLogger, print_log
|
||||
from mmengine.registry import WEIGHT_INITIALIZERS, build_from_cfg
|
||||
|
||||
|
||||
def update_init_info(module, init_info):
|
||||
"""Update the `_params_init_info` in the module if the value of parameters
|
||||
are changed.
|
||||
|
||||
Args:
|
||||
module (obj:`nn.Module`): The module of PyTorch with a user-defined
|
||||
attribute `_params_init_info` which records the initialization
|
||||
information.
|
||||
init_info (str): The string that describes the initialization.
|
||||
"""
|
||||
assert hasattr(
|
||||
module,
|
||||
'_params_init_info'), f'Can not find `_params_init_info` in {module}'
|
||||
for name, param in module.named_parameters():
|
||||
|
||||
assert param in module._params_init_info, (
|
||||
f'Find a new :obj:`Parameter` '
|
||||
f'named `{name}` during executing the '
|
||||
f'`init_weights` of '
|
||||
f'`{module.__class__.__name__}`. '
|
||||
f'Please do not add or '
|
||||
f'replace parameters during executing '
|
||||
f'the `init_weights`. ')
|
||||
|
||||
# The parameter has been changed during executing the
|
||||
# `init_weights` of module
|
||||
mean_value = param.data.mean().cpu()
|
||||
if module._params_init_info[param]['tmp_mean_value'] != mean_value:
|
||||
module._params_init_info[param]['init_info'] = init_info
|
||||
module._params_init_info[param]['tmp_mean_value'] = mean_value
|
||||
|
||||
|
||||
def constant_init(module, val, bias=0):
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
nn.init.constant_(module.weight, val)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def xavier_init(module, gain=1, bias=0, distribution='normal'):
|
||||
assert distribution in ['uniform', 'normal']
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
if distribution == 'uniform':
|
||||
nn.init.xavier_uniform_(module.weight, gain=gain)
|
||||
else:
|
||||
nn.init.xavier_normal_(module.weight, gain=gain)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def normal_init(module, mean=0, std=1, bias=0):
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
nn.init.normal_(module.weight, mean, std)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def trunc_normal_init(module: nn.Module,
|
||||
mean: float = 0,
|
||||
std: float = 1,
|
||||
a: float = -2,
|
||||
b: float = 2,
|
||||
bias: float = 0) -> None:
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
trunc_normal_(module.weight, mean, std, a, b) # type: ignore
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias) # type: ignore
|
||||
|
||||
|
||||
def uniform_init(module, a=0, b=1, bias=0):
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
nn.init.uniform_(module.weight, a, b)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def kaiming_init(module,
|
||||
a=0,
|
||||
mode='fan_out',
|
||||
nonlinearity='relu',
|
||||
bias=0,
|
||||
distribution='normal'):
|
||||
assert distribution in ['uniform', 'normal']
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
if distribution == 'uniform':
|
||||
nn.init.kaiming_uniform_(
|
||||
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
||||
else:
|
||||
nn.init.kaiming_normal_(
|
||||
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def caffe2_xavier_init(module, bias=0):
|
||||
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
|
||||
# Acknowledgment to FAIR's internal code
|
||||
kaiming_init(
|
||||
module,
|
||||
a=1,
|
||||
mode='fan_in',
|
||||
nonlinearity='leaky_relu',
|
||||
bias=bias,
|
||||
distribution='uniform')
|
||||
|
||||
|
||||
def bias_init_with_prob(prior_prob):
|
||||
"""initialize conv/fc bias value according to a given probability value."""
|
||||
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
|
||||
return bias_init
|
||||
|
||||
|
||||
def _get_bases_name(m):
|
||||
return [b.__name__ for b in m.__class__.__bases__]
|
||||
|
||||
|
||||
class BaseInit:
|
||||
|
||||
def __init__(self, *, bias=0, bias_prob=None, layer=None):
|
||||
self.wholemodule = False
|
||||
if not isinstance(bias, (int, float)):
|
||||
raise TypeError(f'bias must be a number, but got a {type(bias)}')
|
||||
|
||||
if bias_prob is not None:
|
||||
if not isinstance(bias_prob, float):
|
||||
raise TypeError(f'bias_prob type must be float, \
|
||||
but got {type(bias_prob)}')
|
||||
|
||||
if layer is not None:
|
||||
if not isinstance(layer, (str, list)):
|
||||
raise TypeError(f'layer must be a str or a list of str, \
|
||||
but got a {type(layer)}')
|
||||
else:
|
||||
layer = []
|
||||
|
||||
if bias_prob is not None:
|
||||
self.bias = bias_init_with_prob(bias_prob)
|
||||
else:
|
||||
self.bias = bias
|
||||
self.layer = [layer] if isinstance(layer, str) else layer
|
||||
|
||||
def _get_init_info(self):
|
||||
info = f'{self.__class__.__name__}, bias={self.bias}'
|
||||
return info
|
||||
|
||||
|
||||
@WEIGHT_INITIALIZERS.register_module(name='Constant')
|
||||
class ConstantInit(BaseInit):
|
||||
"""Initialize module parameters with constant values.
|
||||
|
||||
Args:
|
||||
val (int | float): the value to fill the weights in the module with
|
||||
bias (int | float): the value to fill the bias. Defaults to 0.
|
||||
bias_prob (float, optional): the probability for bias initialization.
|
||||
Defaults to None.
|
||||
layer (str | list[str], optional): the layer will be initialized.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, val, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.val = val
|
||||
|
||||
def __call__(self, module):
|
||||
|
||||
def init(m):
|
||||
if self.wholemodule:
|
||||
constant_init(m, self.val, self.bias)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
constant_init(m, self.val, self.bias)
|
||||
|
||||
module.apply(init)
|
||||
if hasattr(module, '_params_init_info'):
|
||||
update_init_info(module, init_info=self._get_init_info())
|
||||
|
||||
def _get_init_info(self):
|
||||
info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
|
||||
return info
|
||||
|
||||
|
||||
@WEIGHT_INITIALIZERS.register_module(name='Xavier')
|
||||
class XavierInit(BaseInit):
|
||||
r"""Initialize module parameters with values according to the method
|
||||
described in `Understanding the difficulty of training deep feedforward
|
||||
neural networks - Glorot, X. & Bengio, Y. (2010).
|
||||
<http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_
|
||||
Args:
|
||||
gain (int | float): an optional scaling factor. Defaults to 1.
|
||||
bias (int | float): the value to fill the bias. Defaults to 0.
|
||||
bias_prob (float, optional): the probability for bias initialization.
|
||||
Defaults to None.
|
||||
distribution (str): distribution either be ``'normal'``
|
||||
or ``'uniform'``. Defaults to ``'normal'``.
|
||||
layer (str | list[str], optional): the layer will be initialized.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, gain=1, distribution='normal', **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.gain = gain
|
||||
self.distribution = distribution
|
||||
|
||||
def __call__(self, module):
|
||||
|
||||
def init(m):
|
||||
if self.wholemodule:
|
||||
xavier_init(m, self.gain, self.bias, self.distribution)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
xavier_init(m, self.gain, self.bias, self.distribution)
|
||||
|
||||
module.apply(init)
|
||||
if hasattr(module, '_params_init_info'):
|
||||
update_init_info(module, init_info=self._get_init_info())
|
||||
|
||||
def _get_init_info(self):
|
||||
info = f'{self.__class__.__name__}: gain={self.gain}, ' \
|
||||
f'distribution={self.distribution}, bias={self.bias}'
|
||||
return info
|
||||
|
||||
|
||||
@WEIGHT_INITIALIZERS.register_module(name='Normal')
|
||||
class NormalInit(BaseInit):
|
||||
r"""Initialize module parameters with the values drawn from the normal
|
||||
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
|
||||
Args:
|
||||
mean (int | float):the mean of the normal distribution. Defaults to 0.
|
||||
std (int | float): the standard deviation of the normal distribution.
|
||||
Defaults to 1.
|
||||
bias (int | float): the value to fill the bias. Defaults to 0.
|
||||
bias_prob (float, optional): the probability for bias initialization.
|
||||
Defaults to None.
|
||||
layer (str | list[str], optional): the layer will be initialized.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, mean=0, std=1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, module):
|
||||
|
||||
def init(m):
|
||||
if self.wholemodule:
|
||||
normal_init(m, self.mean, self.std, self.bias)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
normal_init(m, self.mean, self.std, self.bias)
|
||||
|
||||
module.apply(init)
|
||||
if hasattr(module, '_params_init_info'):
|
||||
update_init_info(module, init_info=self._get_init_info())
|
||||
|
||||
def _get_init_info(self):
|
||||
info = f'{self.__class__.__name__}: mean={self.mean},' \
|
||||
f' std={self.std}, bias={self.bias}'
|
||||
return info
|
||||
|
||||
|
||||
@WEIGHT_INITIALIZERS.register_module(name='TruncNormal')
|
||||
class TruncNormalInit(BaseInit):
|
||||
r"""Initialize module parameters with the values drawn from the normal
|
||||
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
|
||||
outside :math:`[a, b]`.
|
||||
Args:
|
||||
mean (float): the mean of the normal distribution. Defaults to 0.
|
||||
std (float): the standard deviation of the normal distribution.
|
||||
Defaults to 1.
|
||||
a (float): The minimum cutoff value.
|
||||
b ( float): The maximum cutoff value.
|
||||
bias (float): the value to fill the bias. Defaults to 0.
|
||||
bias_prob (float, optional): the probability for bias initialization.
|
||||
Defaults to None.
|
||||
layer (str | list[str], optional): the layer will be initialized.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mean: float = 0,
|
||||
std: float = 1,
|
||||
a: float = -2,
|
||||
b: float = 2,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
def __call__(self, module: nn.Module) -> None:
|
||||
|
||||
def init(m):
|
||||
if self.wholemodule:
|
||||
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
|
||||
self.bias)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
|
||||
self.bias)
|
||||
|
||||
module.apply(init)
|
||||
if hasattr(module, '_params_init_info'):
|
||||
update_init_info(module, init_info=self._get_init_info())
|
||||
|
||||
def _get_init_info(self):
|
||||
info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
|
||||
f' mean={self.mean}, std={self.std}, bias={self.bias}'
|
||||
return info
|
||||
|
||||
|
||||
@WEIGHT_INITIALIZERS.register_module(name='Uniform')
|
||||
class UniformInit(BaseInit):
|
||||
r"""Initialize module parameters with values drawn from the uniform
|
||||
distribution :math:`\mathcal{U}(a, b)`.
|
||||
Args:
|
||||
a (int | float): the lower bound of the uniform distribution.
|
||||
Defaults to 0.
|
||||
b (int | float): the upper bound of the uniform distribution.
|
||||
Defaults to 1.
|
||||
bias (int | float): the value to fill the bias. Defaults to 0.
|
||||
bias_prob (float, optional): the probability for bias initialization.
|
||||
Defaults to None.
|
||||
layer (str | list[str], optional): the layer will be initialized.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, a=0, b=1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
def __call__(self, module):
|
||||
|
||||
def init(m):
|
||||
if self.wholemodule:
|
||||
uniform_init(m, self.a, self.b, self.bias)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
uniform_init(m, self.a, self.b, self.bias)
|
||||
|
||||
module.apply(init)
|
||||
if hasattr(module, '_params_init_info'):
|
||||
update_init_info(module, init_info=self._get_init_info())
|
||||
|
||||
def _get_init_info(self):
|
||||
info = f'{self.__class__.__name__}: a={self.a},' \
|
||||
f' b={self.b}, bias={self.bias}'
|
||||
return info
|
||||
|
||||
|
||||
@WEIGHT_INITIALIZERS.register_module(name='Kaiming')
|
||||
class KaimingInit(BaseInit):
|
||||
r"""Initialize module parameters with the values according to the method
|
||||
described in `Delving deep into rectifiers: Surpassing human-level
|
||||
performance on ImageNet classification - He, K. et al. (2015).
|
||||
<https://www.cv-foundation.org/openaccess/content_iccv_2015/
|
||||
papers/He_Delving_Deep_into_ICCV_2015_paper.pdf>`_
|
||||
Args:
|
||||
a (int | float): the negative slope of the rectifier used after this
|
||||
layer (only used with ``'leaky_relu'``). Defaults to 0.
|
||||
mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing
|
||||
``'fan_in'`` preserves the magnitude of the variance of the weights
|
||||
in the forward pass. Choosing ``'fan_out'`` preserves the
|
||||
magnitudes in the backwards pass. Defaults to ``'fan_out'``.
|
||||
nonlinearity (str): the non-linear function (`nn.functional` name),
|
||||
recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
|
||||
Defaults to 'relu'.
|
||||
bias (int | float): the value to fill the bias. Defaults to 0.
|
||||
bias_prob (float, optional): the probability for bias initialization.
|
||||
Defaults to None.
|
||||
distribution (str): distribution either be ``'normal'`` or
|
||||
``'uniform'``. Defaults to ``'normal'``.
|
||||
layer (str | list[str], optional): the layer will be initialized.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
a=0,
|
||||
mode='fan_out',
|
||||
nonlinearity='relu',
|
||||
distribution='normal',
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.a = a
|
||||
self.mode = mode
|
||||
self.nonlinearity = nonlinearity
|
||||
self.distribution = distribution
|
||||
|
||||
def __call__(self, module):
|
||||
|
||||
def init(m):
|
||||
if self.wholemodule:
|
||||
kaiming_init(m, self.a, self.mode, self.nonlinearity,
|
||||
self.bias, self.distribution)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
kaiming_init(m, self.a, self.mode, self.nonlinearity,
|
||||
self.bias, self.distribution)
|
||||
|
||||
module.apply(init)
|
||||
if hasattr(module, '_params_init_info'):
|
||||
update_init_info(module, init_info=self._get_init_info())
|
||||
|
||||
def _get_init_info(self):
|
||||
info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
|
||||
f'nonlinearity={self.nonlinearity}, ' \
|
||||
f'distribution ={self.distribution}, bias={self.bias}'
|
||||
return info
|
||||
|
||||
|
||||
@WEIGHT_INITIALIZERS.register_module(name='Caffe2Xavier')
|
||||
class Caffe2XavierInit(KaimingInit):
|
||||
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
|
||||
# Acknowledgment to FAIR's internal code
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
a=1,
|
||||
mode='fan_in',
|
||||
nonlinearity='leaky_relu',
|
||||
distribution='uniform',
|
||||
**kwargs)
|
||||
|
||||
def __call__(self, module):
|
||||
super().__call__(module)
|
||||
|
||||
|
||||
@WEIGHT_INITIALIZERS.register_module(name='Pretrained')
|
||||
class PretrainedInit:
|
||||
"""Initialize module by loading a pretrained model.
|
||||
|
||||
Args:
|
||||
checkpoint (str): the checkpoint file of the pretrained model should
|
||||
be load.
|
||||
prefix (str, optional): the prefix of a sub-module in the pretrained
|
||||
model. it is for loading a part of the pretrained model to
|
||||
initialize. For example, if we would like to only load the
|
||||
backbone of a detector model, we can set ``prefix='backbone.'``.
|
||||
Defaults to None.
|
||||
map_location (str): map tensors into proper locations.
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint, prefix=None, map_location=None):
|
||||
self.checkpoint = checkpoint
|
||||
self.prefix = prefix
|
||||
self.map_location = map_location
|
||||
|
||||
def __call__(self, module):
|
||||
from mmengine.runner.checkpoint import (_load_checkpoint_with_prefix,
|
||||
load_checkpoint,
|
||||
load_state_dict)
|
||||
logger = MMLogger.get_instance('mmengine')
|
||||
if self.prefix is None:
|
||||
print_log(f'load model from: {self.checkpoint}', logger=logger)
|
||||
load_checkpoint(
|
||||
module,
|
||||
self.checkpoint,
|
||||
map_location=self.map_location,
|
||||
strict=False,
|
||||
logger=logger)
|
||||
else:
|
||||
print_log(
|
||||
f'load {self.prefix} in model from: {self.checkpoint}',
|
||||
logger=logger)
|
||||
state_dict = _load_checkpoint_with_prefix(
|
||||
self.prefix, self.checkpoint, map_location=self.map_location)
|
||||
load_state_dict(module, state_dict, strict=False, logger=logger)
|
||||
|
||||
if hasattr(module, '_params_init_info'):
|
||||
update_init_info(module, init_info=self._get_init_info())
|
||||
|
||||
def _get_init_info(self):
|
||||
info = f'{self.__class__.__name__}: load from {self.checkpoint}'
|
||||
return info
|
||||
|
||||
|
||||
def _initialize(module, cfg, wholemodule=False):
|
||||
func = build_from_cfg(cfg, WEIGHT_INITIALIZERS)
|
||||
# wholemodule flag is for override mode, there is no layer key in override
|
||||
# and initializer will give init values for the whole module with the name
|
||||
# in override.
|
||||
func.wholemodule = wholemodule
|
||||
func(module)
|
||||
|
||||
|
||||
def _initialize_override(module, override, cfg):
|
||||
if not isinstance(override, (dict, list)):
|
||||
raise TypeError(f'override must be a dict or a list of dict, \
|
||||
but got {type(override)}')
|
||||
|
||||
override = [override] if isinstance(override, dict) else override
|
||||
|
||||
for override_ in override:
|
||||
|
||||
cp_override = copy.deepcopy(override_)
|
||||
name = cp_override.pop('name', None)
|
||||
if name is None:
|
||||
raise ValueError('`override` must contain the key "name",'
|
||||
f'but got {cp_override}')
|
||||
# if override only has name key, it means use args in init_cfg
|
||||
if not cp_override:
|
||||
cp_override.update(cfg)
|
||||
# if override has name key and other args except type key, it will
|
||||
# raise error
|
||||
elif 'type' not in cp_override.keys():
|
||||
raise ValueError(
|
||||
f'`override` need "type" key, but got {cp_override}')
|
||||
|
||||
if hasattr(module, name):
|
||||
_initialize(getattr(module, name), cp_override, wholemodule=True)
|
||||
else:
|
||||
raise RuntimeError(f'module did not have attribute {name}, '
|
||||
f'but init_cfg is {cp_override}.')
|
||||
|
||||
|
||||
def initialize(module, init_cfg):
|
||||
r"""Initialize a module.
|
||||
Args:
|
||||
module (``torch.nn.Module``): the module will be initialized.
|
||||
init_cfg (dict | list[dict]): initialization configuration dict to
|
||||
define initializer. OpenMMLab has implemented 6 initializers
|
||||
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
|
||||
``Kaiming``, and ``Pretrained``.
|
||||
Example:
|
||||
>>> module = nn.Linear(2, 3, bias=True)
|
||||
>>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
|
||||
>>> initialize(module, init_cfg)
|
||||
>>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
|
||||
>>> # define key ``'layer'`` for initializing layer with different
|
||||
>>> # configuration
|
||||
>>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
|
||||
dict(type='Constant', layer='Linear', val=2)]
|
||||
>>> initialize(module, init_cfg)
|
||||
>>> # define key``'override'`` to initialize some specific part in
|
||||
>>> # module
|
||||
>>> class FooNet(nn.Module):
|
||||
>>> def __init__(self):
|
||||
>>> super().__init__()
|
||||
>>> self.feat = nn.Conv2d(3, 16, 3)
|
||||
>>> self.reg = nn.Conv2d(16, 10, 3)
|
||||
>>> self.cls = nn.Conv2d(16, 5, 3)
|
||||
>>> model = FooNet()
|
||||
>>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
|
||||
>>> override=dict(type='Constant', name='reg', val=3, bias=4))
|
||||
>>> initialize(model, init_cfg)
|
||||
>>> model = ResNet(depth=50)
|
||||
>>> # Initialize weights with the pretrained model.
|
||||
>>> init_cfg = dict(type='Pretrained',
|
||||
checkpoint='torchvision://resnet50')
|
||||
>>> initialize(model, init_cfg)
|
||||
>>> # Initialize weights of a sub-module with the specific part of
|
||||
>>> # a pretrained model by using "prefix".
|
||||
>>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
|
||||
>>> 'retinanet_r50_fpn_1x_coco/'\
|
||||
>>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
|
||||
>>> init_cfg = dict(type='Pretrained',
|
||||
checkpoint=url, prefix='backbone.')
|
||||
"""
|
||||
if not isinstance(init_cfg, (dict, list)):
|
||||
raise TypeError(f'init_cfg must be a dict or a list of dict, \
|
||||
but got {type(init_cfg)}')
|
||||
|
||||
if isinstance(init_cfg, dict):
|
||||
init_cfg = [init_cfg]
|
||||
|
||||
for cfg in init_cfg:
|
||||
# should deeply copy the original config because cfg may be used by
|
||||
# other modules, e.g., one init_cfg shared by multiple bottleneck
|
||||
# blocks, the expected cfg will be changed after pop and will change
|
||||
# the initialization behavior of other modules
|
||||
cp_cfg = copy.deepcopy(cfg)
|
||||
override = cp_cfg.pop('override', None)
|
||||
_initialize(module, cp_cfg)
|
||||
|
||||
if override is not None:
|
||||
cp_cfg.pop('layer', None)
|
||||
_initialize_override(module, override, cp_cfg)
|
||||
else:
|
||||
# All attributes in module have same initialization.
|
||||
pass
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
|
||||
b: float) -> Tensor:
|
||||
# Method based on
|
||||
# https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
# Modified from
|
||||
# https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
warnings.warn(
|
||||
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
||||
'The distribution of values may be incorrect.',
|
||||
stacklevel=2)
|
||||
|
||||
with torch.no_grad():
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
lower = norm_cdf((a - mean) / std)
|
||||
upper = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [lower, upper], then translate
|
||||
# to [2lower-1, 2upper-1].
|
||||
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
tensor.erfinv_()
|
||||
|
||||
# Transform to proper mean, std
|
||||
tensor.mul_(std * math.sqrt(2.))
|
||||
tensor.add_(mean)
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
tensor.clamp_(min=a, max=b)
|
||||
return tensor
|
||||
|
||||
|
||||
def trunc_normal_(tensor: Tensor,
|
||||
mean: float = 0.,
|
||||
std: float = 1.,
|
||||
a: float = -2.,
|
||||
b: float = 2.) -> Tensor:
|
||||
r"""Fills the input Tensor with values drawn from a truncated
|
||||
normal distribution. The values are effectively drawn from the
|
||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
Modified from
|
||||
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
||||
Args:
|
||||
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
|
||||
mean (float): the mean of the normal distribution.
|
||||
std (float): the standard deviation of the normal distribution.
|
||||
a (float): the minimum cutoff value.
|
||||
b (float): the maximum cutoff value.
|
||||
"""
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
from mmengine.utils.dl_utils import mmcv_full_available
|
||||
|
||||
|
||||
def stack_batch(tensor_list: List[torch.Tensor],
|
||||
|
@ -746,7 +83,7 @@ def detect_anomalous_params(loss: torch.Tensor, model) -> None:
|
|||
traverse(grad_fn)
|
||||
|
||||
traverse(loss.grad_fn)
|
||||
from mmengine import MMLogger
|
||||
from mmengine.logging import MMLogger
|
||||
logger = MMLogger.get_current_instance()
|
||||
for n, p in model.named_parameters():
|
||||
if p not in parameters_in_graph and p.requires_grad:
|
||||
|
@ -799,3 +136,62 @@ try:
|
|||
except ImportError:
|
||||
warnings.warn('Cannot import torch.fx, `merge_dict` is a simple function '
|
||||
'to merge multiple dicts')
|
||||
|
||||
|
||||
class _BatchNormXd(nn.modules.batchnorm._BatchNorm):
|
||||
"""A general BatchNorm layer without input dimension check.
|
||||
|
||||
Reproduced from @kapily's work:
|
||||
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
|
||||
The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
|
||||
is `_check_input_dim` that is designed for tensor sanity checks.
|
||||
The check has been bypassed in this class for the convenience of converting
|
||||
SyncBatchNorm.
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input: torch.Tensor):
|
||||
return
|
||||
|
||||
|
||||
def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
|
||||
"""Helper function to convert all `SyncBatchNorm` (SyncBN) and
|
||||
`mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
|
||||
`BatchNormXd` layers.
|
||||
|
||||
Adapted from @kapily's work:
|
||||
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
|
||||
|
||||
Args:
|
||||
module (nn.Module): The module containing `SyncBatchNorm` layers.
|
||||
|
||||
Returns:
|
||||
module_output: The converted module with `BatchNormXd` layers.
|
||||
"""
|
||||
module_output = module
|
||||
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
|
||||
|
||||
if mmcv_full_available():
|
||||
from mmcv.ops import SyncBatchNorm
|
||||
module_checklist.append(SyncBatchNorm)
|
||||
|
||||
if isinstance(module, tuple(module_checklist)):
|
||||
module_output = _BatchNormXd(module.num_features, module.eps,
|
||||
module.momentum, module.affine,
|
||||
module.track_running_stats)
|
||||
if module.affine:
|
||||
# no_grad() may not be needed here but
|
||||
# just to be consistent with `convert_sync_batchnorm()`
|
||||
with torch.no_grad():
|
||||
module_output.weight = module.weight
|
||||
module_output.bias = module.bias
|
||||
module_output.running_mean = module.running_mean
|
||||
module_output.running_var = module.running_var
|
||||
module_output.num_batches_tracked = module.num_batches_tracked
|
||||
module_output.training = module.training
|
||||
# qconfig exists in quantized models
|
||||
if hasattr(module, 'qconfig'):
|
||||
module_output.qconfig = module.qconfig
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(name, revert_sync_batchnorm(child))
|
||||
del module
|
||||
return module_output
|
||||
|
|
|
@ -0,0 +1,679 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from mmengine.registry import WEIGHT_INITIALIZERS, build_from_cfg
|
||||
|
||||
|
||||
def update_init_info(module, init_info):
|
||||
"""Update the `_params_init_info` in the module if the value of parameters
|
||||
are changed.
|
||||
|
||||
Args:
|
||||
module (obj:`nn.Module`): The module of PyTorch with a user-defined
|
||||
attribute `_params_init_info` which records the initialization
|
||||
information.
|
||||
init_info (str): The string that describes the initialization.
|
||||
"""
|
||||
assert hasattr(
|
||||
module,
|
||||
'_params_init_info'), f'Can not find `_params_init_info` in {module}'
|
||||
for name, param in module.named_parameters():
|
||||
|
||||
assert param in module._params_init_info, (
|
||||
f'Find a new :obj:`Parameter` '
|
||||
f'named `{name}` during executing the '
|
||||
f'`init_weights` of '
|
||||
f'`{module.__class__.__name__}`. '
|
||||
f'Please do not add or '
|
||||
f'replace parameters during executing '
|
||||
f'the `init_weights`. ')
|
||||
|
||||
# The parameter has been changed during executing the
|
||||
# `init_weights` of module
|
||||
mean_value = param.data.mean().cpu()
|
||||
if module._params_init_info[param]['tmp_mean_value'] != mean_value:
|
||||
module._params_init_info[param]['init_info'] = init_info
|
||||
module._params_init_info[param]['tmp_mean_value'] = mean_value
|
||||
|
||||
|
||||
def constant_init(module, val, bias=0):
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
nn.init.constant_(module.weight, val)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def xavier_init(module, gain=1, bias=0, distribution='normal'):
|
||||
assert distribution in ['uniform', 'normal']
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
if distribution == 'uniform':
|
||||
nn.init.xavier_uniform_(module.weight, gain=gain)
|
||||
else:
|
||||
nn.init.xavier_normal_(module.weight, gain=gain)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def normal_init(module, mean=0, std=1, bias=0):
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
nn.init.normal_(module.weight, mean, std)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def trunc_normal_init(module: nn.Module,
|
||||
mean: float = 0,
|
||||
std: float = 1,
|
||||
a: float = -2,
|
||||
b: float = 2,
|
||||
bias: float = 0) -> None:
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
trunc_normal_(module.weight, mean, std, a, b) # type: ignore
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias) # type: ignore
|
||||
|
||||
|
||||
def uniform_init(module, a=0, b=1, bias=0):
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
nn.init.uniform_(module.weight, a, b)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def kaiming_init(module,
|
||||
a=0,
|
||||
mode='fan_out',
|
||||
nonlinearity='relu',
|
||||
bias=0,
|
||||
distribution='normal'):
|
||||
assert distribution in ['uniform', 'normal']
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
if distribution == 'uniform':
|
||||
nn.init.kaiming_uniform_(
|
||||
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
||||
else:
|
||||
nn.init.kaiming_normal_(
|
||||
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def caffe2_xavier_init(module, bias=0):
|
||||
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
|
||||
# Acknowledgment to FAIR's internal code
|
||||
kaiming_init(
|
||||
module,
|
||||
a=1,
|
||||
mode='fan_in',
|
||||
nonlinearity='leaky_relu',
|
||||
bias=bias,
|
||||
distribution='uniform')
|
||||
|
||||
|
||||
def bias_init_with_prob(prior_prob):
|
||||
"""initialize conv/fc bias value according to a given probability value."""
|
||||
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
|
||||
return bias_init
|
||||
|
||||
|
||||
def _get_bases_name(m):
|
||||
return [b.__name__ for b in m.__class__.__bases__]
|
||||
|
||||
|
||||
class BaseInit:
|
||||
|
||||
def __init__(self, *, bias=0, bias_prob=None, layer=None):
|
||||
self.wholemodule = False
|
||||
if not isinstance(bias, (int, float)):
|
||||
raise TypeError(f'bias must be a number, but got a {type(bias)}')
|
||||
|
||||
if bias_prob is not None:
|
||||
if not isinstance(bias_prob, float):
|
||||
raise TypeError(f'bias_prob type must be float, \
|
||||
but got {type(bias_prob)}')
|
||||
|
||||
if layer is not None:
|
||||
if not isinstance(layer, (str, list)):
|
||||
raise TypeError(f'layer must be a str or a list of str, \
|
||||
but got a {type(layer)}')
|
||||
else:
|
||||
layer = []
|
||||
|
||||
if bias_prob is not None:
|
||||
self.bias = bias_init_with_prob(bias_prob)
|
||||
else:
|
||||
self.bias = bias
|
||||
self.layer = [layer] if isinstance(layer, str) else layer
|
||||
|
||||
def _get_init_info(self):
|
||||
info = f'{self.__class__.__name__}, bias={self.bias}'
|
||||
return info
|
||||
|
||||
|
||||
@WEIGHT_INITIALIZERS.register_module(name='Constant')
|
||||
class ConstantInit(BaseInit):
|
||||
"""Initialize module parameters with constant values.
|
||||
|
||||
Args:
|
||||
val (int | float): the value to fill the weights in the module with
|
||||
bias (int | float): the value to fill the bias. Defaults to 0.
|
||||
bias_prob (float, optional): the probability for bias initialization.
|
||||
Defaults to None.
|
||||
layer (str | list[str], optional): the layer will be initialized.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, val, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.val = val
|
||||
|
||||
def __call__(self, module):
|
||||
|
||||
def init(m):
|
||||
if self.wholemodule:
|
||||
constant_init(m, self.val, self.bias)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
constant_init(m, self.val, self.bias)
|
||||
|
||||
module.apply(init)
|
||||
if hasattr(module, '_params_init_info'):
|
||||
update_init_info(module, init_info=self._get_init_info())
|
||||
|
||||
def _get_init_info(self):
|
||||
info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
|
||||
return info
|
||||
|
||||
|
||||
@WEIGHT_INITIALIZERS.register_module(name='Xavier')
|
||||
class XavierInit(BaseInit):
|
||||
r"""Initialize module parameters with values according to the method
|
||||
described in `Understanding the difficulty of training deep feedforward
|
||||
neural networks - Glorot, X. & Bengio, Y. (2010).
|
||||
<http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_
|
||||
|
||||
Args:
|
||||
gain (int | float): an optional scaling factor. Defaults to 1.
|
||||
bias (int | float): the value to fill the bias. Defaults to 0.
|
||||
bias_prob (float, optional): the probability for bias initialization.
|
||||
Defaults to None.
|
||||
distribution (str): distribution either be ``'normal'``
|
||||
or ``'uniform'``. Defaults to ``'normal'``.
|
||||
layer (str | list[str], optional): the layer will be initialized.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, gain=1, distribution='normal', **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.gain = gain
|
||||
self.distribution = distribution
|
||||
|
||||
def __call__(self, module):
|
||||
|
||||
def init(m):
|
||||
if self.wholemodule:
|
||||
xavier_init(m, self.gain, self.bias, self.distribution)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
xavier_init(m, self.gain, self.bias, self.distribution)
|
||||
|
||||
module.apply(init)
|
||||
if hasattr(module, '_params_init_info'):
|
||||
update_init_info(module, init_info=self._get_init_info())
|
||||
|
||||
def _get_init_info(self):
|
||||
info = f'{self.__class__.__name__}: gain={self.gain}, ' \
|
||||
f'distribution={self.distribution}, bias={self.bias}'
|
||||
return info
|
||||
|
||||
|
||||
@WEIGHT_INITIALIZERS.register_module(name='Normal')
|
||||
class NormalInit(BaseInit):
|
||||
r"""Initialize module parameters with the values drawn from the normal
|
||||
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
|
||||
|
||||
Args:
|
||||
mean (int | float):the mean of the normal distribution. Defaults to 0.
|
||||
std (int | float): the standard deviation of the normal distribution.
|
||||
Defaults to 1.
|
||||
bias (int | float): the value to fill the bias. Defaults to 0.
|
||||
bias_prob (float, optional): the probability for bias initialization.
|
||||
Defaults to None.
|
||||
layer (str | list[str], optional): the layer will be initialized.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, mean=0, std=1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, module):
|
||||
|
||||
def init(m):
|
||||
if self.wholemodule:
|
||||
normal_init(m, self.mean, self.std, self.bias)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
normal_init(m, self.mean, self.std, self.bias)
|
||||
|
||||
module.apply(init)
|
||||
if hasattr(module, '_params_init_info'):
|
||||
update_init_info(module, init_info=self._get_init_info())
|
||||
|
||||
def _get_init_info(self):
|
||||
info = f'{self.__class__.__name__}: mean={self.mean},' \
|
||||
f' std={self.std}, bias={self.bias}'
|
||||
return info
|
||||
|
||||
|
||||
@WEIGHT_INITIALIZERS.register_module(name='TruncNormal')
|
||||
class TruncNormalInit(BaseInit):
|
||||
r"""Initialize module parameters with the values drawn from the normal
|
||||
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
|
||||
outside :math:`[a, b]`.
|
||||
|
||||
Args:
|
||||
mean (float): the mean of the normal distribution. Defaults to 0.
|
||||
std (float): the standard deviation of the normal distribution.
|
||||
Defaults to 1.
|
||||
a (float): The minimum cutoff value.
|
||||
b ( float): The maximum cutoff value.
|
||||
bias (float): the value to fill the bias. Defaults to 0.
|
||||
bias_prob (float, optional): the probability for bias initialization.
|
||||
Defaults to None.
|
||||
layer (str | list[str], optional): the layer will be initialized.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mean: float = 0,
|
||||
std: float = 1,
|
||||
a: float = -2,
|
||||
b: float = 2,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
def __call__(self, module: nn.Module) -> None:
|
||||
|
||||
def init(m):
|
||||
if self.wholemodule:
|
||||
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
|
||||
self.bias)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
|
||||
self.bias)
|
||||
|
||||
module.apply(init)
|
||||
if hasattr(module, '_params_init_info'):
|
||||
update_init_info(module, init_info=self._get_init_info())
|
||||
|
||||
def _get_init_info(self):
|
||||
info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
|
||||
f' mean={self.mean}, std={self.std}, bias={self.bias}'
|
||||
return info
|
||||
|
||||
|
||||
@WEIGHT_INITIALIZERS.register_module(name='Uniform')
|
||||
class UniformInit(BaseInit):
|
||||
r"""Initialize module parameters with values drawn from the uniform
|
||||
distribution :math:`\mathcal{U}(a, b)`.
|
||||
|
||||
Args:
|
||||
a (int | float): the lower bound of the uniform distribution.
|
||||
Defaults to 0.
|
||||
b (int | float): the upper bound of the uniform distribution.
|
||||
Defaults to 1.
|
||||
bias (int | float): the value to fill the bias. Defaults to 0.
|
||||
bias_prob (float, optional): the probability for bias initialization.
|
||||
Defaults to None.
|
||||
layer (str | list[str], optional): the layer will be initialized.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, a=0, b=1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
def __call__(self, module):
|
||||
|
||||
def init(m):
|
||||
if self.wholemodule:
|
||||
uniform_init(m, self.a, self.b, self.bias)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
uniform_init(m, self.a, self.b, self.bias)
|
||||
|
||||
module.apply(init)
|
||||
if hasattr(module, '_params_init_info'):
|
||||
update_init_info(module, init_info=self._get_init_info())
|
||||
|
||||
def _get_init_info(self):
|
||||
info = f'{self.__class__.__name__}: a={self.a},' \
|
||||
f' b={self.b}, bias={self.bias}'
|
||||
return info
|
||||
|
||||
|
||||
@WEIGHT_INITIALIZERS.register_module(name='Kaiming')
|
||||
class KaimingInit(BaseInit):
|
||||
r"""Initialize module parameters with the values according to the method
|
||||
described in `Delving deep into rectifiers: Surpassing human-level
|
||||
performance on ImageNet classification - He, K. et al. (2015).
|
||||
<https://www.cv-foundation.org/openaccess/content_iccv_2015/
|
||||
papers/He_Delving_Deep_into_ICCV_2015_paper.pdf>`_
|
||||
|
||||
Args:
|
||||
a (int | float): the negative slope of the rectifier used after this
|
||||
layer (only used with ``'leaky_relu'``). Defaults to 0.
|
||||
mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing
|
||||
``'fan_in'`` preserves the magnitude of the variance of the weights
|
||||
in the forward pass. Choosing ``'fan_out'`` preserves the
|
||||
magnitudes in the backwards pass. Defaults to ``'fan_out'``.
|
||||
nonlinearity (str): the non-linear function (`nn.functional` name),
|
||||
recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
|
||||
Defaults to 'relu'.
|
||||
bias (int | float): the value to fill the bias. Defaults to 0.
|
||||
bias_prob (float, optional): the probability for bias initialization.
|
||||
Defaults to None.
|
||||
distribution (str): distribution either be ``'normal'`` or
|
||||
``'uniform'``. Defaults to ``'normal'``.
|
||||
layer (str | list[str], optional): the layer will be initialized.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
a=0,
|
||||
mode='fan_out',
|
||||
nonlinearity='relu',
|
||||
distribution='normal',
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.a = a
|
||||
self.mode = mode
|
||||
self.nonlinearity = nonlinearity
|
||||
self.distribution = distribution
|
||||
|
||||
def __call__(self, module):
|
||||
|
||||
def init(m):
|
||||
if self.wholemodule:
|
||||
kaiming_init(m, self.a, self.mode, self.nonlinearity,
|
||||
self.bias, self.distribution)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
kaiming_init(m, self.a, self.mode, self.nonlinearity,
|
||||
self.bias, self.distribution)
|
||||
|
||||
module.apply(init)
|
||||
if hasattr(module, '_params_init_info'):
|
||||
update_init_info(module, init_info=self._get_init_info())
|
||||
|
||||
def _get_init_info(self):
|
||||
info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
|
||||
f'nonlinearity={self.nonlinearity}, ' \
|
||||
f'distribution ={self.distribution}, bias={self.bias}'
|
||||
return info
|
||||
|
||||
|
||||
@WEIGHT_INITIALIZERS.register_module(name='Caffe2Xavier')
|
||||
class Caffe2XavierInit(KaimingInit):
|
||||
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
|
||||
# Acknowledgment to FAIR's internal code
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
a=1,
|
||||
mode='fan_in',
|
||||
nonlinearity='leaky_relu',
|
||||
distribution='uniform',
|
||||
**kwargs)
|
||||
|
||||
def __call__(self, module):
|
||||
super().__call__(module)
|
||||
|
||||
|
||||
@WEIGHT_INITIALIZERS.register_module(name='Pretrained')
|
||||
class PretrainedInit:
|
||||
"""Initialize module by loading a pretrained model.
|
||||
|
||||
Args:
|
||||
checkpoint (str): the checkpoint file of the pretrained model should
|
||||
be load.
|
||||
prefix (str, optional): the prefix of a sub-module in the pretrained
|
||||
model. it is for loading a part of the pretrained model to
|
||||
initialize. For example, if we would like to only load the
|
||||
backbone of a detector model, we can set ``prefix='backbone.'``.
|
||||
Defaults to None.
|
||||
map_location (str): map tensors into proper locations.
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint, prefix=None, map_location=None):
|
||||
self.checkpoint = checkpoint
|
||||
self.prefix = prefix
|
||||
self.map_location = map_location
|
||||
|
||||
def __call__(self, module):
|
||||
from mmengine.runner.checkpoint import (_load_checkpoint_with_prefix,
|
||||
load_checkpoint,
|
||||
load_state_dict)
|
||||
logger = MMLogger.get_instance('mmengine')
|
||||
if self.prefix is None:
|
||||
print_log(f'load model from: {self.checkpoint}', logger=logger)
|
||||
load_checkpoint(
|
||||
module,
|
||||
self.checkpoint,
|
||||
map_location=self.map_location,
|
||||
strict=False,
|
||||
logger=logger)
|
||||
else:
|
||||
print_log(
|
||||
f'load {self.prefix} in model from: {self.checkpoint}',
|
||||
logger=logger)
|
||||
state_dict = _load_checkpoint_with_prefix(
|
||||
self.prefix, self.checkpoint, map_location=self.map_location)
|
||||
load_state_dict(module, state_dict, strict=False, logger=logger)
|
||||
|
||||
if hasattr(module, '_params_init_info'):
|
||||
update_init_info(module, init_info=self._get_init_info())
|
||||
|
||||
def _get_init_info(self):
|
||||
info = f'{self.__class__.__name__}: load from {self.checkpoint}'
|
||||
return info
|
||||
|
||||
|
||||
def _initialize(module, cfg, wholemodule=False):
|
||||
func = build_from_cfg(cfg, WEIGHT_INITIALIZERS)
|
||||
# wholemodule flag is for override mode, there is no layer key in override
|
||||
# and initializer will give init values for the whole module with the name
|
||||
# in override.
|
||||
func.wholemodule = wholemodule
|
||||
func(module)
|
||||
|
||||
|
||||
def _initialize_override(module, override, cfg):
|
||||
if not isinstance(override, (dict, list)):
|
||||
raise TypeError(f'override must be a dict or a list of dict, \
|
||||
but got {type(override)}')
|
||||
|
||||
override = [override] if isinstance(override, dict) else override
|
||||
|
||||
for override_ in override:
|
||||
|
||||
cp_override = copy.deepcopy(override_)
|
||||
name = cp_override.pop('name', None)
|
||||
if name is None:
|
||||
raise ValueError('`override` must contain the key "name",'
|
||||
f'but got {cp_override}')
|
||||
# if override only has name key, it means use args in init_cfg
|
||||
if not cp_override:
|
||||
cp_override.update(cfg)
|
||||
# if override has name key and other args except type key, it will
|
||||
# raise error
|
||||
elif 'type' not in cp_override.keys():
|
||||
raise ValueError(
|
||||
f'`override` need "type" key, but got {cp_override}')
|
||||
|
||||
if hasattr(module, name):
|
||||
_initialize(getattr(module, name), cp_override, wholemodule=True)
|
||||
else:
|
||||
raise RuntimeError(f'module did not have attribute {name}, '
|
||||
f'but init_cfg is {cp_override}.')
|
||||
|
||||
|
||||
def initialize(module, init_cfg):
|
||||
r"""Initialize a module.
|
||||
|
||||
Args:
|
||||
module (``torch.nn.Module``): the module will be initialized.
|
||||
init_cfg (dict | list[dict]): initialization configuration dict to
|
||||
define initializer. OpenMMLab has implemented 6 initializers
|
||||
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
|
||||
``Kaiming``, and ``Pretrained``.
|
||||
|
||||
Example:
|
||||
>>> module = nn.Linear(2, 3, bias=True)
|
||||
>>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
|
||||
>>> initialize(module, init_cfg)
|
||||
>>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
|
||||
>>> # define key ``'layer'`` for initializing layer with different
|
||||
>>> # configuration
|
||||
>>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
|
||||
dict(type='Constant', layer='Linear', val=2)]
|
||||
>>> initialize(module, init_cfg)
|
||||
>>> # define key``'override'`` to initialize some specific part in
|
||||
>>> # module
|
||||
>>> class FooNet(nn.Module):
|
||||
>>> def __init__(self):
|
||||
>>> super().__init__()
|
||||
>>> self.feat = nn.Conv2d(3, 16, 3)
|
||||
>>> self.reg = nn.Conv2d(16, 10, 3)
|
||||
>>> self.cls = nn.Conv2d(16, 5, 3)
|
||||
>>> model = FooNet()
|
||||
>>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
|
||||
>>> override=dict(type='Constant', name='reg', val=3, bias=4))
|
||||
>>> initialize(model, init_cfg)
|
||||
>>> model = ResNet(depth=50)
|
||||
>>> # Initialize weights with the pretrained model.
|
||||
>>> init_cfg = dict(type='Pretrained',
|
||||
checkpoint='torchvision://resnet50')
|
||||
>>> initialize(model, init_cfg)
|
||||
>>> # Initialize weights of a sub-module with the specific part of
|
||||
>>> # a pretrained model by using "prefix".
|
||||
>>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
|
||||
>>> 'retinanet_r50_fpn_1x_coco/'\
|
||||
>>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
|
||||
>>> init_cfg = dict(type='Pretrained',
|
||||
checkpoint=url, prefix='backbone.')
|
||||
"""
|
||||
if not isinstance(init_cfg, (dict, list)):
|
||||
raise TypeError(f'init_cfg must be a dict or a list of dict, \
|
||||
but got {type(init_cfg)}')
|
||||
|
||||
if isinstance(init_cfg, dict):
|
||||
init_cfg = [init_cfg]
|
||||
|
||||
for cfg in init_cfg:
|
||||
# should deeply copy the original config because cfg may be used by
|
||||
# other modules, e.g., one init_cfg shared by multiple bottleneck
|
||||
# blocks, the expected cfg will be changed after pop and will change
|
||||
# the initialization behavior of other modules
|
||||
cp_cfg = copy.deepcopy(cfg)
|
||||
override = cp_cfg.pop('override', None)
|
||||
_initialize(module, cp_cfg)
|
||||
|
||||
if override is not None:
|
||||
cp_cfg.pop('layer', None)
|
||||
_initialize_override(module, override, cp_cfg)
|
||||
else:
|
||||
# All attributes in module have same initialization.
|
||||
pass
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
|
||||
b: float) -> Tensor:
|
||||
# Method based on
|
||||
# https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
# Modified from
|
||||
# https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
warnings.warn(
|
||||
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
||||
'The distribution of values may be incorrect.',
|
||||
stacklevel=2)
|
||||
|
||||
with torch.no_grad():
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
lower = norm_cdf((a - mean) / std)
|
||||
upper = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [lower, upper], then translate
|
||||
# to [2lower-1, 2upper-1].
|
||||
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
tensor.erfinv_()
|
||||
|
||||
# Transform to proper mean, std
|
||||
tensor.mul_(std * math.sqrt(2.))
|
||||
tensor.add_(mean)
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
tensor.clamp_(min=a, max=b)
|
||||
return tensor
|
||||
|
||||
|
||||
def trunc_normal_(tensor: Tensor,
|
||||
mean: float = 0.,
|
||||
std: float = 1.,
|
||||
a: float = -2.,
|
||||
b: float = 2.) -> Tensor:
|
||||
r"""Fills the input Tensor with values drawn from a truncated
|
||||
normal distribution. The values are effectively drawn from the
|
||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
|
||||
Modified from
|
||||
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
||||
|
||||
Args:
|
||||
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
|
||||
mean (float): the mean of the normal distribution.
|
||||
std (float): the standard deviation of the normal distribution.
|
||||
a (float): the minimum cutoff value.
|
||||
b (float): the maximum cutoff value.
|
||||
"""
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.utils.parrots_wrapper import TORCH_VERSION
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
from mmengine.utils.version_utils import digit_version
|
||||
from .distributed import MMDistributedDataParallel
|
||||
from .seperate_distributed import MMSeparateDistributedDataParallel
|
||||
|
|
|
@ -4,9 +4,9 @@ from typing import Dict, List
|
|||
import torch
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.registry import MODEL_WRAPPERS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from ..utils import detect_anomalous_params
|
||||
|
||||
|
||||
|
|
|
@ -7,9 +7,9 @@ from torch.distributed import ProcessGroup
|
|||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
BackwardPrefetch, CPUOffload, FullyShardedDataParallel)
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.registry import MODEL_WRAPPERS, Registry
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
# support customize fsdp policy
|
||||
FSDP_WRAP_POLICYS = Registry('fsdp wrap policy')
|
||||
|
|
|
@ -6,10 +6,10 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.device import get_device
|
||||
from mmengine.optim import OptimWrapperDict
|
||||
from mmengine.registry import MODEL_WRAPPERS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from .distributed import MMDistributedDataParallel
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,8 @@ import torch.nn as nn
|
|||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from mmengine.registry import OPTIM_WRAPPERS
|
||||
from mmengine.utils import TORCH_VERSION, digit_version
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
from .optimizer_wrapper import OptimWrapper
|
||||
|
||||
|
||||
|
|
|
@ -9,8 +9,9 @@ from torch.nn import GroupNorm, LayerNorm
|
|||
from mmengine.logging import print_log
|
||||
from mmengine.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
|
||||
OPTIMIZERS)
|
||||
from mmengine.utils import is_list_of, mmcv_full_available
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm, _InstanceNorm
|
||||
from mmengine.utils import is_list_of
|
||||
from mmengine.utils.dl_utils import mmcv_full_available
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
|
||||
from .optimizer_wrapper import OptimWrapper
|
||||
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from torch.optim import Optimizer
|
|||
|
||||
from mmengine.logging import MessageHub, print_log
|
||||
from mmengine.registry import OPTIM_WRAPPERS
|
||||
from mmengine.utils import has_batch_norm
|
||||
from mmengine.utils.dl_utils import has_batch_norm
|
||||
|
||||
|
||||
@OPTIM_WRAPPERS.register_module()
|
||||
|
|
|
@ -3,15 +3,15 @@ import inspect
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..config import Config, ConfigDict
|
||||
from ..utils import ManagerMixin
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.utils import ManagerMixin
|
||||
from .registry import Registry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..optim.scheduler import _ParamScheduler
|
||||
from ..runner import Runner
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.optim.scheduler import _ParamScheduler
|
||||
from mmengine.runner import Runner
|
||||
|
||||
|
||||
def build_from_cfg(
|
||||
|
@ -158,6 +158,7 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
|
|||
Returns:
|
||||
object: The constructed runner object.
|
||||
"""
|
||||
from ..config import Config, ConfigDict
|
||||
from ..logging import print_log
|
||||
|
||||
assert isinstance(
|
||||
|
@ -210,10 +211,10 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
|
|||
|
||||
|
||||
def build_model_from_cfg(
|
||||
cfg: Union[dict, ConfigDict, Config],
|
||||
registry: Registry,
|
||||
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> \
|
||||
nn.Module:
|
||||
cfg: Union[dict, ConfigDict, Config],
|
||||
registry: Registry,
|
||||
default_args: Optional[Union[dict, 'ConfigDict', 'Config']] = None
|
||||
) -> 'nn.Module':
|
||||
"""Build a PyTorch model from config dict(s). Different from
|
||||
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
|
||||
|
||||
|
@ -239,10 +240,10 @@ def build_model_from_cfg(
|
|||
|
||||
|
||||
def build_scheduler_from_cfg(
|
||||
cfg: Union[dict, ConfigDict, Config],
|
||||
registry: Registry,
|
||||
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> \
|
||||
'_ParamScheduler':
|
||||
cfg: Union[dict, ConfigDict, Config],
|
||||
registry: Registry,
|
||||
default_args: Optional[Union[dict, ConfigDict, Config]] = None
|
||||
) -> '_ParamScheduler':
|
||||
"""Builds a ``ParamScheduler`` instance from config.
|
||||
|
||||
``ParamScheduler`` supports building instance by its constructor or
|
||||
|
|
|
@ -23,7 +23,7 @@ class DefaultScope(ManagerMixin):
|
|||
scope_name (str): Scope of current task.
|
||||
|
||||
Examples:
|
||||
>>> from mmengine import MODELS
|
||||
>>> from mmengine.model import MODELS
|
||||
>>> # Define default scope in runner.
|
||||
>>> DefaultScope.get_instance('task', scope_name='mmdet')
|
||||
>>> # Get default scope globally.
|
||||
|
|
|
@ -7,8 +7,8 @@ from contextlib import contextmanager
|
|||
from importlib import import_module
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union
|
||||
|
||||
from ..config.utils import PKG2PROJECT
|
||||
from ..utils import is_seq_of
|
||||
from mmengine.config.utils import PKG2PROJECT
|
||||
from mmengine.utils import is_seq_of
|
||||
from .default_scope import DefaultScope
|
||||
|
||||
|
||||
|
@ -190,8 +190,7 @@ class Registry:
|
|||
scope (str): The target scope.
|
||||
|
||||
Examples:
|
||||
>>> from mmengine import Registry, DefaultScope
|
||||
>>> from mmengine.registry import MODELS
|
||||
>>> from mmengine.registry import Registry, DefaultScope, MODELS
|
||||
>>> import time
|
||||
>>> # External Registry
|
||||
>>> MMDET_MODELS = Registry('mmdet_model', scope='mmdet',
|
||||
|
|
|
@ -6,6 +6,7 @@ from .checkpoint import (CheckpointLoader, find_latest_checkpoint,
|
|||
get_mmcls_models, get_state_dict,
|
||||
get_torchvision_models, load_checkpoint,
|
||||
load_state_dict, save_checkpoint, weights_to_cpu)
|
||||
from .log_processor import LogProcessor
|
||||
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
|
||||
from .priority import Priority, get_priority
|
||||
from .runner import Runner
|
||||
|
@ -16,5 +17,5 @@ __all__ = [
|
|||
'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict',
|
||||
'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop',
|
||||
'TestLoop', 'Runner', 'get_priority', 'Priority', 'find_latest_checkpoint',
|
||||
'autocast'
|
||||
'autocast', 'LogProcessor'
|
||||
]
|
||||
|
|
|
@ -7,7 +7,8 @@ import torch
|
|||
|
||||
from mmengine.device import get_device
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.utils import TORCH_VERSION, digit_version
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
|
@ -19,7 +19,8 @@ from mmengine.fileio import FileClient
|
|||
from mmengine.fileio import load as load_file
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.model import is_model_wrapper
|
||||
from mmengine.utils import load_url, mkdir_or_exist
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
from mmengine.utils.dl_utils import load_url
|
||||
|
||||
# `MMENGINE_HOME` is the highest priority directory to save checkpoints
|
||||
# downloaded from Internet. If it is not set, as a workaround, using
|
||||
|
@ -251,9 +252,8 @@ class CheckpointLoader:
|
|||
|
||||
checkpoint_loader = cls._get_checkpoint_loader(filename)
|
||||
class_name = checkpoint_loader.__name__
|
||||
mmengine.print_log(
|
||||
f'{class_name[10:]} loads checkpoint from path: {filename}',
|
||||
logger)
|
||||
print_log(f'{class_name[10:]} loads checkpoint from path: {filename}',
|
||||
logger)
|
||||
return checkpoint_loader(filename, map_location)
|
||||
|
||||
|
||||
|
|
|
@ -139,8 +139,8 @@ class _InfiniteDataloaderIterator:
|
|||
It resets the dataloader to continue iterating when the iterator has
|
||||
iterated over all the data. However, this approach is not efficient, as the
|
||||
workers need to be restarted every time the dataloader is reset. It is
|
||||
recommended to use `mmengine.data.InfiniteSampler` to enable the dataloader
|
||||
to iterate infinitely.
|
||||
recommended to use `mmengine.dataset.InfiniteSampler` to enable the
|
||||
dataloader to iterate infinitely.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader: DataLoader) -> None:
|
||||
|
@ -157,8 +157,9 @@ class _InfiniteDataloaderIterator:
|
|||
except StopIteration:
|
||||
warnings.warn('Reach the end of the dataloader, it will be '
|
||||
'restarted and continue to iterate. It is '
|
||||
'recommended to use `mmengine.data.InfiniteSampler` '
|
||||
'to enable the dataloader to iterate infinitely.')
|
||||
'recommended to use '
|
||||
'`mmengine.dataset.InfiniteSampler` to enable the '
|
||||
'dataloader to iterate infinitely.')
|
||||
self._epoch += 1
|
||||
if hasattr(self._dataloader, 'sampler') and hasattr(
|
||||
self._dataloader.sampler, 'set_epoch'):
|
||||
|
|
|
@ -20,16 +20,16 @@ from torch.utils.data import DataLoader
|
|||
|
||||
import mmengine
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.data import pseudo_collate, worker_init_fn
|
||||
from mmengine.dataset import pseudo_collate, worker_init_fn
|
||||
from mmengine.device import get_device
|
||||
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
|
||||
is_distributed, master_only, sync_random_seed)
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.fileio import FileClient
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.logging import LogProcessor, MessageHub, MMLogger, print_log
|
||||
from mmengine.logging import MessageHub, MMLogger, print_log
|
||||
from mmengine.model import (BaseModel, MMDistributedDataParallel,
|
||||
is_model_wrapper)
|
||||
is_model_wrapper, revert_sync_batchnorm)
|
||||
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
|
||||
build_optim_wrapper)
|
||||
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
|
||||
|
@ -37,14 +37,15 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
|
|||
OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS,
|
||||
VISUALIZERS, DefaultScope,
|
||||
count_registered_modules)
|
||||
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
|
||||
get_git_hash, is_seq_of, revert_sync_batchnorm,
|
||||
set_multi_processing)
|
||||
from mmengine.utils import digit_version, get_git_hash, is_seq_of
|
||||
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
|
||||
set_multi_processing)
|
||||
from mmengine.visualization import Visualizer
|
||||
from .base_loop import BaseLoop
|
||||
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
||||
find_latest_checkpoint, get_state_dict,
|
||||
save_checkpoint, weights_to_cpu)
|
||||
from .log_processor import LogProcessor
|
||||
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
|
||||
from .priority import Priority, get_priority
|
||||
|
||||
|
@ -181,7 +182,7 @@ class Runner:
|
|||
Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> from mmengine import Runner
|
||||
>>> from mmengine.runner import Runner
|
||||
>>> cfg = dict(
|
||||
>>> model=dict(type='ToyModel'),
|
||||
>>> work_dir='path/of/work_dir',
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_data_element import BaseDataElement
|
||||
from .instance_data import InstanceData
|
||||
from .label_data import LabelData
|
||||
from .pixel_data import PixelData
|
||||
|
||||
__all__ = ['BaseDataElement', 'InstanceData', 'LabelData', 'PixelData']
|
|
@ -71,7 +71,7 @@ class BaseDataElement:
|
|||
model predictions. Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> from mmengine.data import BaseDataElement
|
||||
>>> from mmengine.structures import BaseDataElement
|
||||
>>> gt_instances = BaseDataElement()
|
||||
>>> bboxes = torch.rand((5, 4))
|
||||
>>> scores = torch.rand((5,))
|
|
@ -52,7 +52,7 @@ class InstanceData(BaseDataElement):
|
|||
... return new_data
|
||||
... def __repr__(self):
|
||||
... return str(self.tmp)
|
||||
>>> from mmengine.data import InstanceData
|
||||
>>> from mmengine.structures import InstanceData
|
||||
>>> import numpy as np
|
||||
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
|
||||
>>> instance_data = InstanceData(metainfo=img_meta)
|
|
@ -3,7 +3,8 @@ from typing import Any, Callable, Optional, Union
|
|||
|
||||
from torch.testing import assert_allclose as _assert_allclose
|
||||
|
||||
from mmengine.utils import TORCH_VERSION, digit_version
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
||||
|
||||
def assert_allclose(
|
||||
|
|
|
@ -1,31 +1,20 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .collect_env import collect_env
|
||||
from .hub import load_url
|
||||
from .manager import ManagerMeta, ManagerMixin
|
||||
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
|
||||
has_batch_norm, has_method, import_modules_from_strings,
|
||||
is_list_of, is_method_overridden, is_seq_of, is_str,
|
||||
is_tuple_of, iter_cast, list_cast, mmcv_full_available,
|
||||
requires_executable, requires_package, slice_list,
|
||||
to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple,
|
||||
tuple_cast)
|
||||
has_method, import_modules_from_strings, is_list_of,
|
||||
is_method_overridden, is_seq_of, is_str, is_tuple_of,
|
||||
iter_cast, list_cast, requires_executable, requires_package,
|
||||
slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
|
||||
to_ntuple, tuple_cast)
|
||||
from .package_utils import (call_command, check_install_package,
|
||||
get_installed_path, is_installed)
|
||||
from .parrots_wrapper import TORCH_VERSION
|
||||
from .path import (check_file_exist, fopen, is_abs, is_filepath,
|
||||
mkdir_or_exist, scandir, symlink)
|
||||
from .progressbar import (ProgressBar, track_iter_progress,
|
||||
track_parallel_progress, track_progress)
|
||||
from .setup_env import set_multi_processing
|
||||
from .sync_bn import revert_sync_batchnorm
|
||||
from .timer import Timer, TimerError, check_time
|
||||
from .torch_ops import torch_meshgrid
|
||||
from .trace import is_jit_tracing
|
||||
from .version_utils import digit_version, get_git_hash
|
||||
|
||||
# TODO: creates intractable circular import issues
|
||||
# from .time_counter import TimeCounter
|
||||
|
||||
__all__ = [
|
||||
'is_str', 'iter_cast', 'list_cast', 'tuple_cast', 'is_seq_of',
|
||||
'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',
|
||||
|
@ -33,12 +22,9 @@ __all__ = [
|
|||
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', 'symlink',
|
||||
'scandir', 'deprecated_api_warning', 'import_modules_from_strings',
|
||||
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
|
||||
'is_method_overridden', 'has_method', 'mmcv_full_available',
|
||||
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
|
||||
'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm',
|
||||
'is_abs', 'is_installed', 'call_command', 'get_installed_path',
|
||||
'check_install_package', 'is_abs', 'revert_sync_batchnorm', 'collect_env',
|
||||
'Timer', 'check_time', 'TimerError', 'ProgressBar', 'track_iter_progress',
|
||||
'track_parallel_progress', 'track_progress', 'torch_meshgrid',
|
||||
'is_jit_tracing'
|
||||
'is_installed', 'call_command', 'get_installed_path',
|
||||
'check_install_package', 'is_abs', 'is_method_overridden', 'has_method',
|
||||
'digit_version', 'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer',
|
||||
'check_time', 'TimerError', 'ProgressBar', 'track_iter_progress',
|
||||
'track_parallel_progress', 'track_progress'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from .collect_env import collect_env
|
||||
from .hub import load_url
|
||||
from .misc import has_batch_norm, is_norm, mmcv_full_available, tensor2imgs
|
||||
from .parrots_wrapper import TORCH_VERSION
|
||||
from .setup_env import set_multi_processing
|
||||
from .time_counter import TimeCounter
|
||||
from .torch_ops import torch_meshgrid
|
||||
from .trace import is_jit_tracing
|
||||
|
||||
__all__ = [
|
||||
'load_url', 'TORCH_VERSION', 'set_multi_processing', 'has_batch_norm',
|
||||
'is_norm', 'tensor2imgs', 'mmcv_full_available', 'collect_env',
|
||||
'torch_meshgrid', 'is_jit_tracing', 'TimeCounter'
|
||||
]
|
|
@ -4,9 +4,9 @@
|
|||
# torch >= 1.6.0 but loaded in torch < 1.7.0.
|
||||
# More details at https://github.com/open-mmlab/mmpose/issues/904
|
||||
|
||||
from ..path import mkdir_or_exist
|
||||
from ..version_utils import digit_version
|
||||
from .parrots_wrapper import TORCH_VERSION
|
||||
from .path import mkdir_or_exist
|
||||
from .version_utils import digit_version
|
||||
|
||||
if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
|
||||
'1.7.0'):
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pkgutil
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..misc import is_tuple_of
|
||||
from .parrots_wrapper import _BatchNorm, _InstanceNorm
|
||||
|
||||
|
||||
def is_norm(layer: nn.Module,
|
||||
exclude: Optional[Union[type, Tuple[type]]] = None) -> bool:
|
||||
"""Check if a layer is a normalization layer.
|
||||
|
||||
Args:
|
||||
layer (nn.Module): The layer to be checked.
|
||||
exclude (type, tuple[type], optional): Types to be excluded.
|
||||
|
||||
Returns:
|
||||
bool: Whether the layer is a norm layer.
|
||||
"""
|
||||
if exclude is not None:
|
||||
if not isinstance(exclude, tuple):
|
||||
exclude = (exclude, )
|
||||
if not is_tuple_of(exclude, type):
|
||||
raise TypeError(
|
||||
f'"exclude" must be either None or type or a tuple of types, '
|
||||
f'but got {type(exclude)}: {exclude}')
|
||||
|
||||
if exclude and isinstance(layer, exclude):
|
||||
return False
|
||||
|
||||
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
|
||||
return isinstance(layer, all_norm_bases)
|
||||
|
||||
|
||||
def tensor2imgs(tensor: torch.Tensor,
|
||||
mean: Optional[Tuple[float, float, float]] = None,
|
||||
std: Optional[Tuple[float, float, float]] = None,
|
||||
to_bgr: bool = True):
|
||||
"""Convert tensor to 3-channel images or 1-channel gray images.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Tensor that contains multiple images, shape (
|
||||
N, C, H, W). :math:`C` can be either 3 or 1. If C is 3, the format
|
||||
should be RGB.
|
||||
mean (tuple[float], optional): Mean of images. If None,
|
||||
(0, 0, 0) will be used for tensor with 3-channel,
|
||||
while (0, ) for tensor with 1-channel. Defaults to None.
|
||||
std (tuple[float], optional): Standard deviation of images. If None,
|
||||
(1, 1, 1) will be used for tensor with 3-channel,
|
||||
while (1, ) for tensor with 1-channel. Defaults to None.
|
||||
to_bgr (bool): For the tensor with 3 channel, convert its format to
|
||||
BGR. For the tensor with 1 channel, it must be False. Defaults to
|
||||
True.
|
||||
|
||||
Returns:
|
||||
list[np.ndarray]: A list that contains multiple images.
|
||||
"""
|
||||
|
||||
assert torch.is_tensor(tensor) and tensor.ndim == 4
|
||||
channels = tensor.size(1)
|
||||
assert channels in [1, 3]
|
||||
if mean is None:
|
||||
mean = (0, ) * channels
|
||||
if std is None:
|
||||
std = (1, ) * channels
|
||||
assert (channels == len(mean) == len(std) == 3) or \
|
||||
(channels == len(mean) == len(std) == 1 and not to_bgr)
|
||||
mean = tensor.new_tensor(mean).view(1, -1)
|
||||
std = tensor.new_tensor(std).view(1, -1)
|
||||
tensor = tensor.permute(0, 2, 3, 1) * std + mean
|
||||
imgs = tensor.detach().cpu().numpy()
|
||||
if to_bgr and channels == 3:
|
||||
imgs = imgs[:, :, :, (2, 1, 0)] # RGB2BGR
|
||||
imgs = [np.ascontiguousarray(img) for img in imgs]
|
||||
return imgs
|
||||
|
||||
|
||||
def has_batch_norm(model: nn.Module) -> bool:
|
||||
"""Detect whether model has a BatchNormalization layer.
|
||||
|
||||
Args:
|
||||
model (nn.Module): training model.
|
||||
|
||||
Returns:
|
||||
bool: whether model has a BatchNormalization layer
|
||||
"""
|
||||
if isinstance(model, _BatchNorm):
|
||||
return True
|
||||
for m in model.children():
|
||||
if has_batch_norm(m):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def mmcv_full_available() -> bool:
|
||||
"""Check whether mmcv-full is installed.
|
||||
|
||||
Returns:
|
||||
bool: True if mmcv-full is installed else False.
|
||||
"""
|
||||
try:
|
||||
import mmcv # noqa: F401
|
||||
except ImportError:
|
||||
return False
|
||||
ext_loader = pkgutil.find_loader('mmcv._ext')
|
||||
return ext_loader is not None
|
|
@ -24,7 +24,7 @@ class TimeCounter:
|
|||
|
||||
Examples:
|
||||
>>> import time
|
||||
>>> from mmengine.utils import TimeCounter
|
||||
>>> from mmengine.utils.dl_utils import TimeCounter
|
||||
>>> @TimeCounter()
|
||||
... def fun1():
|
||||
... time.sleep(0.1)
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from ..version_utils import digit_version
|
||||
from .parrots_wrapper import TORCH_VERSION
|
||||
from .version_utils import digit_version
|
||||
|
||||
_torch_version_meshgrid_indexing = (
|
||||
'parrots' not in TORCH_VERSION
|
|
@ -3,7 +3,7 @@ import warnings
|
|||
|
||||
import torch
|
||||
|
||||
from .version_utils import digit_version
|
||||
from ..version_utils import digit_version
|
||||
|
||||
|
||||
def is_jit_tracing() -> bool:
|
|
@ -2,20 +2,13 @@
|
|||
import collections.abc
|
||||
import functools
|
||||
import itertools
|
||||
import pkgutil
|
||||
import subprocess
|
||||
import warnings
|
||||
from collections import abc
|
||||
from importlib import import_module
|
||||
from inspect import getfullargspec
|
||||
from itertools import repeat
|
||||
from typing import Any, Callable, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .parrots_wrapper import _BatchNorm, _InstanceNorm
|
||||
from typing import Any, Callable, Optional, Type, Union
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
|
@ -394,103 +387,3 @@ def has_method(obj: object, method: str) -> bool:
|
|||
bool: True if the object has the method else False.
|
||||
"""
|
||||
return hasattr(obj, method) and callable(getattr(obj, method))
|
||||
|
||||
|
||||
def mmcv_full_available() -> bool:
|
||||
"""Check whether mmcv-full is installed.
|
||||
|
||||
Returns:
|
||||
bool: True if mmcv-full is installed else False.
|
||||
"""
|
||||
try:
|
||||
import mmcv # noqa: F401
|
||||
except ImportError:
|
||||
return False
|
||||
ext_loader = pkgutil.find_loader('mmcv._ext')
|
||||
return ext_loader is not None
|
||||
|
||||
|
||||
def is_norm(layer: nn.Module,
|
||||
exclude: Optional[Union[type, Tuple[type]]] = None) -> bool:
|
||||
"""Check if a layer is a normalization layer.
|
||||
|
||||
Args:
|
||||
layer (nn.Module): The layer to be checked.
|
||||
exclude (type, tuple[type], optional): Types to be excluded.
|
||||
|
||||
Returns:
|
||||
bool: Whether the layer is a norm layer.
|
||||
"""
|
||||
if exclude is not None:
|
||||
if not isinstance(exclude, tuple):
|
||||
exclude = (exclude, )
|
||||
if not is_tuple_of(exclude, type):
|
||||
raise TypeError(
|
||||
f'"exclude" must be either None or type or a tuple of types, '
|
||||
f'but got {type(exclude)}: {exclude}')
|
||||
|
||||
if exclude and isinstance(layer, exclude):
|
||||
return False
|
||||
|
||||
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
|
||||
return isinstance(layer, all_norm_bases)
|
||||
|
||||
|
||||
def tensor2imgs(tensor: torch.Tensor,
|
||||
mean: Optional[Tuple[float, float, float]] = None,
|
||||
std: Optional[Tuple[float, float, float]] = None,
|
||||
to_bgr: bool = True):
|
||||
"""Convert tensor to 3-channel images or 1-channel gray images.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Tensor that contains multiple images, shape (
|
||||
N, C, H, W). :math:`C` can be either 3 or 1. If C is 3, the format
|
||||
should be RGB.
|
||||
mean (tuple[float], optional): Mean of images. If None,
|
||||
(0, 0, 0) will be used for tensor with 3-channel,
|
||||
while (0, ) for tensor with 1-channel. Defaults to None.
|
||||
std (tuple[float], optional): Standard deviation of images. If None,
|
||||
(1, 1, 1) will be used for tensor with 3-channel,
|
||||
while (1, ) for tensor with 1-channel. Defaults to None.
|
||||
to_bgr (bool): For the tensor with 3 channel, convert its format to
|
||||
BGR. For the tensor with 1 channel, it must be False. Defaults to
|
||||
True.
|
||||
|
||||
Returns:
|
||||
list[np.ndarray]: A list that contains multiple images.
|
||||
"""
|
||||
|
||||
assert torch.is_tensor(tensor) and tensor.ndim == 4
|
||||
channels = tensor.size(1)
|
||||
assert channels in [1, 3]
|
||||
if mean is None:
|
||||
mean = (0, ) * channels
|
||||
if std is None:
|
||||
std = (1, ) * channels
|
||||
assert (channels == len(mean) == len(std) == 3) or \
|
||||
(channels == len(mean) == len(std) == 1 and not to_bgr)
|
||||
mean = tensor.new_tensor(mean).view(1, -1)
|
||||
std = tensor.new_tensor(std).view(1, -1)
|
||||
tensor = tensor.permute(0, 2, 3, 1) * std + mean
|
||||
imgs = tensor.detach().cpu().numpy()
|
||||
if to_bgr and channels == 3:
|
||||
imgs = imgs[:, :, :, (2, 1, 0)] # RGB2BGR
|
||||
imgs = [np.ascontiguousarray(img) for img in imgs]
|
||||
return imgs
|
||||
|
||||
|
||||
def has_batch_norm(model: nn.Module) -> bool:
|
||||
"""Detect whether model has a BatchNormalization layer.
|
||||
|
||||
Args:
|
||||
model (nn.Module): training model.
|
||||
|
||||
Returns:
|
||||
bool: whether model has a BatchNormalization layer
|
||||
"""
|
||||
if isinstance(model, _BatchNorm):
|
||||
return True
|
||||
for m in model.children():
|
||||
if has_batch_norm(m):
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -1,66 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class _BatchNormXd(nn.modules.batchnorm._BatchNorm):
|
||||
"""A general BatchNorm layer without input dimension check.
|
||||
|
||||
Reproduced from @kapily's work:
|
||||
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
|
||||
The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
|
||||
is `_check_input_dim` that is designed for tensor sanity checks.
|
||||
The check has been bypassed in this class for the convenience of converting
|
||||
SyncBatchNorm.
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input: torch.Tensor):
|
||||
return
|
||||
|
||||
|
||||
def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
|
||||
"""Helper function to convert all `SyncBatchNorm` (SyncBN) and
|
||||
`mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
|
||||
`BatchNormXd` layers.
|
||||
|
||||
Adapted from @kapily's work:
|
||||
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
|
||||
|
||||
Args:
|
||||
module (nn.Module): The module containing `SyncBatchNorm` layers.
|
||||
|
||||
Returns:
|
||||
module_output: The converted module with `BatchNormXd` layers.
|
||||
"""
|
||||
module_output = module
|
||||
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
|
||||
|
||||
try:
|
||||
import mmcv
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
if hasattr(mmcv, 'ops'):
|
||||
module_checklist.append(mmcv.ops.SyncBatchNorm)
|
||||
|
||||
if isinstance(module, tuple(module_checklist)):
|
||||
module_output = _BatchNormXd(module.num_features, module.eps,
|
||||
module.momentum, module.affine,
|
||||
module.track_running_stats)
|
||||
if module.affine:
|
||||
# no_grad() may not be needed here but
|
||||
# just to be consistent with `convert_sync_batchnorm()`
|
||||
with torch.no_grad():
|
||||
module_output.weight = module.weight
|
||||
module_output.bias = module.bias
|
||||
module_output.running_mean = module.running_mean
|
||||
module_output.running_var = module.running_var
|
||||
module_output.num_batches_tracked = module.num_batches_tracked
|
||||
module_output.training = module.training
|
||||
# qconfig exists in quantized models
|
||||
if hasattr(module, 'qconfig'):
|
||||
module_output.qconfig = module.qconfig
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(name, revert_sync_batchnorm(child))
|
||||
del module
|
||||
return module_output
|
|
@ -15,7 +15,7 @@ from mmengine.config import Config
|
|||
from mmengine.fileio import dump
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.registry import VISBACKENDS
|
||||
from mmengine.utils import TORCH_VERSION
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
||||
|
||||
def force_init_env(old_func: Callable) -> Any:
|
||||
|
|
|
@ -16,9 +16,9 @@ from matplotlib.patches import Circle
|
|||
from matplotlib.pyplot import new_figure_manager
|
||||
|
||||
from mmengine.config import Config
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.dist import master_only
|
||||
from mmengine.registry import VISBACKENDS, VISUALIZERS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from mmengine.utils import ManagerMixin
|
||||
from mmengine.visualization.utils import (check_type, check_type_and_length,
|
||||
color_str2rgb, color_val_matplotlib,
|
||||
|
|
|
@ -6,7 +6,7 @@ from unittest.mock import patch
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmengine.data import DefaultSampler, InfiniteSampler
|
||||
from mmengine.dataset import DefaultSampler, InfiniteSampler
|
||||
|
||||
|
||||
class TestDefaultSampler(TestCase):
|
||||
|
@ -15,7 +15,7 @@ class TestDefaultSampler(TestCase):
|
|||
self.data_length = 100
|
||||
self.dataset = list(range(self.data_length))
|
||||
|
||||
@patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1))
|
||||
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(0, 1))
|
||||
def test_non_dist(self, mock):
|
||||
sampler = DefaultSampler(self.dataset)
|
||||
self.assertEqual(sampler.world_size, 1)
|
||||
|
@ -33,7 +33,7 @@ class TestDefaultSampler(TestCase):
|
|||
self.assertEqual(sampler.num_samples, self.data_length)
|
||||
self.assertEqual(list(sampler), list(range(self.data_length)))
|
||||
|
||||
@patch('mmengine.data.sampler.get_dist_info', return_value=(2, 3))
|
||||
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(2, 3))
|
||||
def test_dist(self, mock):
|
||||
sampler = DefaultSampler(self.dataset)
|
||||
self.assertEqual(sampler.world_size, 3)
|
||||
|
@ -56,8 +56,8 @@ class TestDefaultSampler(TestCase):
|
|||
self.assertEqual(len(sampler), sampler.num_samples)
|
||||
self.assertEqual(list(sampler), list(range(self.data_length))[2::3])
|
||||
|
||||
@patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1))
|
||||
@patch('mmengine.data.sampler.sync_random_seed', return_value=7)
|
||||
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(0, 1))
|
||||
@patch('mmengine.dataset.sampler.sync_random_seed', return_value=7)
|
||||
def test_shuffle(self, mock1, mock2):
|
||||
# test seed=None
|
||||
sampler = DefaultSampler(self.dataset, seed=None)
|
||||
|
@ -87,7 +87,7 @@ class TestInfiniteSampler(TestCase):
|
|||
self.data_length = 100
|
||||
self.dataset = list(range(self.data_length))
|
||||
|
||||
@patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1))
|
||||
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(0, 1))
|
||||
def test_non_dist(self, mock):
|
||||
sampler = InfiniteSampler(self.dataset)
|
||||
self.assertEqual(sampler.world_size, 1)
|
||||
|
@ -101,7 +101,7 @@ class TestInfiniteSampler(TestCase):
|
|||
items = [next(sampler_iter) for _ in range(self.data_length * 2)]
|
||||
self.assertEqual(items, list(range(self.data_length)) * 2)
|
||||
|
||||
@patch('mmengine.data.sampler.get_dist_info', return_value=(2, 3))
|
||||
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(2, 3))
|
||||
def test_dist(self, mock):
|
||||
sampler = InfiniteSampler(self.dataset)
|
||||
self.assertEqual(sampler.world_size, 3)
|
||||
|
@ -117,8 +117,8 @@ class TestInfiniteSampler(TestCase):
|
|||
print(samples)
|
||||
self.assertEqual(samples, targets)
|
||||
|
||||
@patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1))
|
||||
@patch('mmengine.data.sampler.sync_random_seed', return_value=7)
|
||||
@patch('mmengine.dataset.sampler.get_dist_info', return_value=(0, 1))
|
||||
@patch('mmengine.dataset.sampler.sync_random_seed', return_value=7)
|
||||
def test_shuffle(self, mock1, mock2):
|
||||
# test seed=None
|
||||
sampler = InfiniteSampler(self.dataset, seed=None)
|
|
@ -13,7 +13,8 @@ import torch.distributed as torch_dist
|
|||
import mmengine.dist as dist
|
||||
from mmengine.dist.dist import sync_random_seed
|
||||
from mmengine.testing._internal import MultiProcessTestCase
|
||||
from mmengine.utils import TORCH_VERSION, digit_version
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
||||
|
||||
class TestDist(TestCase):
|
||||
|
|
|
@ -7,9 +7,9 @@ from unittest import TestCase
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.evaluator import BaseMetric, Evaluator, get_metric_value
|
||||
from mmengine.registry import METRICS
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
|
|
|
@ -3,8 +3,8 @@ from unittest.mock import Mock
|
|||
|
||||
import torch
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.hooks import NaiveVisualizationHook
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
|
||||
class TestNaiveVisualizationHook:
|
||||
|
|
|
@ -3,7 +3,8 @@ import os.path as osp
|
|||
|
||||
import pytest
|
||||
|
||||
from mmengine import Config, DefaultScope, get_config, get_model
|
||||
from mmengine import Config, DefaultScope
|
||||
from mmengine.hub import get_config, get_model
|
||||
from mmengine.utils import get_installed_path, is_installed
|
||||
|
||||
data_path = osp.join(osp.dirname(osp.dirname(__file__)), 'data/')
|
|
@ -1,9 +1,16 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine import HistoryBuffer
|
||||
from mmengine.logging import HistoryBuffer
|
||||
|
||||
array_method = [np.array, lambda x: x]
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
array_method.append(torch.tensor)
|
||||
|
||||
|
||||
class TestLoggerBuffer:
|
||||
|
@ -30,8 +37,7 @@ class TestLoggerBuffer:
|
|||
with pytest.raises(AssertionError):
|
||||
HistoryBuffer([1, 2], [1])
|
||||
|
||||
@pytest.mark.parametrize('array_method',
|
||||
[torch.tensor, np.array, lambda x: x])
|
||||
@pytest.mark.parametrize('array_method', array_method)
|
||||
def test_update(self, array_method):
|
||||
# test `update` method
|
||||
log_buffer = HistoryBuffer()
|
||||
|
|
|
@ -15,9 +15,7 @@ class TestLogger:
|
|||
stream_handler_regex_time = r'\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
|
||||
file_handler_regex_time = r'\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
|
||||
|
||||
# Since `get_rank` has been imported in logger.py, it needs to mock
|
||||
# `logger.get_rank`
|
||||
@patch('mmengine.logging.logger.get_rank', lambda: 0)
|
||||
@patch('mmengine.logging.logger._get_rank', lambda: 0)
|
||||
def test_init_rank0(self, tmp_path):
|
||||
logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO')
|
||||
assert logger.name == 'mmengine'
|
||||
|
@ -47,7 +45,7 @@ class TestLogger:
|
|||
assert logger.instance_name == 'rank0.pkg3'
|
||||
logging.shutdown()
|
||||
|
||||
@patch('mmengine.logging.logger.get_rank', lambda: 1)
|
||||
@patch('mmengine.logging.logger._get_rank', lambda: 1)
|
||||
def test_init_rank1(self, tmp_path):
|
||||
# If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`.
|
||||
tmp_file = tmp_path / 'tmp_file.log'
|
||||
|
|
|
@ -4,9 +4,9 @@ from collections import OrderedDict
|
|||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.logging import HistoryBuffer, MessageHub
|
||||
from mmengine.utils import is_installed
|
||||
|
||||
|
||||
class NoDeepCopy:
|
||||
|
@ -84,7 +84,9 @@ class TestMessageHub:
|
|||
message_hub.update_info('test_value', recorded_dict)
|
||||
assert message_hub.get_info('test_value') == recorded_dict
|
||||
|
||||
@pytest.mark.skipif(not is_installed('torch'), reason='requires torch')
|
||||
def test_get_scalars(self):
|
||||
import torch
|
||||
message_hub = MessageHub.get_instance('mmengine')
|
||||
log_dict = dict(
|
||||
loss=1,
|
||||
|
|
|
@ -4,8 +4,8 @@ from unittest import TestCase
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmengine import InstanceData
|
||||
from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor
|
||||
from mmengine.structures import InstanceData
|
||||
from mmengine.testing import assert_allclose
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import pytest
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.utils import revert_sync_batchnorm
|
||||
from mmengine.model import revert_sync_batchnorm
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
|
@ -15,7 +15,7 @@ from mmengine.model.averaged_model import ExponentialMovingAverage
|
|||
from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict
|
||||
from mmengine.testing import assert_allclose
|
||||
from mmengine.testing._internal import MultiProcessTestCase
|
||||
from mmengine.utils.parrots_wrapper import TORCH_VERSION
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
from mmengine.utils.version_utils import digit_version
|
||||
|
||||
if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
|
||||
|
|
|
@ -11,7 +11,7 @@ from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
|||
build_optim_wrapper)
|
||||
from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS
|
||||
from mmengine.registry import build_from_cfg
|
||||
from mmengine.utils import mmcv_full_available
|
||||
from mmengine.utils.dl_utils import mmcv_full_available
|
||||
|
||||
MMCV_FULL_AVAILABLE = mmcv_full_available()
|
||||
if not MMCV_FULL_AVAILABLE:
|
||||
|
|
|
@ -16,7 +16,8 @@ from mmengine.logging import MessageHub, MMLogger
|
|||
from mmengine.optim import AmpOptimWrapper, OptimWrapper
|
||||
from mmengine.testing import assert_allclose
|
||||
from mmengine.testing._internal import MultiProcessTestCase
|
||||
from mmengine.utils import TORCH_VERSION, digit_version
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
from torch.optim import SGD
|
||||
|
||||
from mmengine import (PARAM_SCHEDULERS, Config, ConfigDict, ManagerMixin,
|
||||
Registry, build_from_cfg, build_model_from_cfg)
|
||||
from mmengine.utils import is_installed
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
|
||||
|
@ -128,7 +127,10 @@ def test_build_from_cfg(cfg_type):
|
|||
Visualizer.get_current_instance()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_installed('torch'), reason='tests requires torch')
|
||||
def test_build_model_from_cfg():
|
||||
import torch.nn as nn
|
||||
|
||||
BACKBONES = Registry('backbone', build_func=build_model_from_cfg)
|
||||
|
||||
@BACKBONES.register_module()
|
||||
|
@ -186,7 +188,10 @@ def test_build_model_from_cfg():
|
|||
assert NEW_MODELS.build_func is pseudo_build
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_installed('torch'), reason='tests requires torch')
|
||||
def test_build_sheduler_from_cfg():
|
||||
import torch.nn as nn
|
||||
from torch.optim import SGD
|
||||
model = nn.Conv2d(1, 1, 1)
|
||||
optimizer = SGD(model.parameters(), lr=0.1)
|
||||
cfg = dict(
|
||||
|
|
|
@ -4,7 +4,9 @@ import time
|
|||
import pytest
|
||||
|
||||
from mmengine.config import Config, ConfigDict # type: ignore
|
||||
from mmengine.registry import DefaultScope, Registry, build_from_cfg
|
||||
from mmengine.registry import (DefaultScope, Registry, build_from_cfg,
|
||||
build_model_from_cfg)
|
||||
from mmengine.utils import ManagerMixin
|
||||
|
||||
|
||||
class TestRegistry:
|
||||
|
@ -473,3 +475,187 @@ class TestRegistry:
|
|||
"<locals>.Munchkin'>")
|
||||
repr_str += '})'
|
||||
assert repr(CATS) == repr_str
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
|
||||
def test_build_from_cfg(cfg_type):
|
||||
BACKBONES = Registry('backbone')
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNet:
|
||||
|
||||
def __init__(self, depth, stages=4):
|
||||
self.depth = depth
|
||||
self.stages = stages
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNeXt:
|
||||
|
||||
def __init__(self, depth, stages=4):
|
||||
self.depth = depth
|
||||
self.stages = stages
|
||||
|
||||
# test `cfg` parameter
|
||||
# `cfg` should be a dict, ConfigDict or Config object
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=('cfg should be a dict, ConfigDict or Config, but got '
|
||||
"<class 'str'>")):
|
||||
cfg = 'ResNet'
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# `cfg` is a dict, ConfigDict or Config object
|
||||
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# `cfg` is a dict but it does not contain the key "type"
|
||||
with pytest.raises(KeyError, match='must contain the key "type"'):
|
||||
cfg = dict(depth=50, stages=4)
|
||||
cfg = cfg_type(cfg)
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# cfg['type'] should be a str or class
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="type must be a str or valid type, but got <class 'int'>"):
|
||||
cfg = dict(type=1000)
|
||||
cfg = cfg_type(cfg)
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNeXt)
|
||||
assert model.depth == 50 and model.stages == 3
|
||||
|
||||
cfg = cfg_type(dict(type=ResNet, depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# non-registered class
|
||||
with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
|
||||
cfg = cfg_type(dict(type='VGG'))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# `cfg` contains unexpected arguments
|
||||
with pytest.raises(TypeError):
|
||||
cfg = cfg_type(dict(type='ResNet', non_existing_arg=50))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# test `default_args` parameter
|
||||
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3)))
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 3
|
||||
|
||||
# default_args must be a dict or None
|
||||
with pytest.raises(TypeError):
|
||||
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES, default_args=1)
|
||||
|
||||
# cfg or default_args should contain the key "type"
|
||||
with pytest.raises(KeyError, match='must contain the key "type"'):
|
||||
cfg = cfg_type(dict(depth=50))
|
||||
model = build_from_cfg(
|
||||
cfg, BACKBONES, default_args=cfg_type(dict(stages=4)))
|
||||
|
||||
# "type" defined using default_args
|
||||
cfg = cfg_type(dict(depth=50))
|
||||
model = build_from_cfg(
|
||||
cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet')))
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
cfg = cfg_type(dict(depth=50))
|
||||
model = build_from_cfg(
|
||||
cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet)))
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# test `registry` parameter
|
||||
# incorrect registry type
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=('registry must be a mmengine.Registry object, but got '
|
||||
"<class 'str'>")):
|
||||
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, 'BACKBONES')
|
||||
|
||||
VISUALIZER = Registry('visualizer')
|
||||
|
||||
@VISUALIZER.register_module()
|
||||
class Visualizer(ManagerMixin):
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
Visualizer.get_current_instance()
|
||||
cfg = dict(type='Visualizer', name='visualizer')
|
||||
build_from_cfg(cfg, VISUALIZER)
|
||||
Visualizer.get_current_instance()
|
||||
|
||||
|
||||
def test_build_model_from_cfg():
|
||||
try:
|
||||
import torch.nn as nn
|
||||
except ImportError:
|
||||
pytest.skip('require torch')
|
||||
|
||||
BACKBONES = Registry('backbone', build_func=build_model_from_cfg)
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, depth, stages=4):
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.stages = stages
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNeXt(nn.Module):
|
||||
|
||||
def __init__(self, depth, stages=4):
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.stages = stages
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
cfg = dict(type='ResNet', depth=50)
|
||||
model = BACKBONES.build(cfg)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
cfg = dict(type='ResNeXt', depth=50, stages=3)
|
||||
model = BACKBONES.build(cfg)
|
||||
assert isinstance(model, ResNeXt)
|
||||
assert model.depth == 50 and model.stages == 3
|
||||
|
||||
cfg = [
|
||||
dict(type='ResNet', depth=50),
|
||||
dict(type='ResNeXt', depth=50, stages=3)
|
||||
]
|
||||
model = BACKBONES.build(cfg)
|
||||
assert isinstance(model, nn.Sequential)
|
||||
assert isinstance(model[0], ResNet)
|
||||
assert model[0].depth == 50 and model[0].stages == 4
|
||||
assert isinstance(model[1], ResNeXt)
|
||||
assert model[1].depth == 50 and model[1].stages == 3
|
||||
|
||||
# test inherit `build_func` from parent
|
||||
NEW_MODELS = Registry('models', parent=BACKBONES, scope='new')
|
||||
assert NEW_MODELS.build_func is build_model_from_cfg
|
||||
|
||||
# test specify `build_func`
|
||||
def pseudo_build(cfg):
|
||||
return cfg
|
||||
|
||||
NEW_MODELS = Registry('models', parent=BACKBONES, build_func=pseudo_build)
|
||||
assert NEW_MODELS.build_func is pseudo_build
|
||||
|
|
|
@ -5,7 +5,8 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from mmengine.runner import autocast
|
||||
from mmengine.utils import TORCH_VERSION, digit_version
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
||||
|
||||
class TestAmp(unittest.TestCase):
|
||||
|
|
|
@ -6,7 +6,8 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.logging import HistoryBuffer, LogProcessor, MessageHub, MMLogger
|
||||
from mmengine.logging import HistoryBuffer, MessageHub, MMLogger
|
||||
from mmengine.runner import LogProcessor
|
||||
|
||||
|
||||
class TestLogProcessor:
|
|
@ -14,12 +14,12 @@ from torch.optim import SGD, Adam
|
|||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmengine.config import Config
|
||||
from mmengine.data import DefaultSampler
|
||||
from mmengine.dataset import DefaultSampler
|
||||
from mmengine.evaluator import BaseMetric, Evaluator
|
||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
|
||||
IterTimerHook, LoggerHook, ParamSchedulerHook,
|
||||
RuntimeInfoHook)
|
||||
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
||||
from mmengine.logging import MessageHub, MMLogger
|
||||
from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
||||
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
|
||||
OptimWrapper, OptimWrapperDict, StepLR)
|
||||
|
@ -28,10 +28,11 @@ from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
|
|||
OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
|
||||
PARAM_SCHEDULERS, RUNNERS, Registry)
|
||||
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
|
||||
Runner, TestLoop, ValLoop)
|
||||
LogProcessor, Runner, TestLoop, ValLoop)
|
||||
from mmengine.runner.loops import _InfiniteDataloaderIterator
|
||||
from mmengine.runner.priority import Priority, get_priority
|
||||
from mmengine.utils import TORCH_VERSION, digit_version, is_list_of
|
||||
from mmengine.utils import digit_version, is_list_of
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
from mmengine.visualization import Visualizer
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
|
||||
class TestBaseDataElement(TestCase):
|
|
@ -7,7 +7,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.data import BaseDataElement, InstanceData
|
||||
from mmengine.structures import BaseDataElement, InstanceData
|
||||
|
||||
|
||||
class TmpObject:
|
|
@ -4,7 +4,7 @@ from unittest import TestCase
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.data import LabelData
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
|
||||
class TestLabelData(TestCase):
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.data import PixelData
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
|
||||
class TestPixelData(TestCase):
|
|
@ -5,7 +5,8 @@ from unittest import TestCase
|
|||
import torch.cuda
|
||||
|
||||
import mmengine
|
||||
from mmengine.utils.collect_env import _get_cuda_home, collect_env
|
||||
from mmengine.utils.dl_utils import collect_env
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _get_cuda_home
|
||||
|
||||
|
||||
class TestCollectEnv(TestCase):
|
|
@ -5,7 +5,7 @@ import platform
|
|||
|
||||
import cv2
|
||||
|
||||
from mmengine.utils import set_multi_processing
|
||||
from mmengine.utils.dl_utils import set_multi_processing
|
||||
|
||||
|
||||
def test_setup_multi_processes():
|
|
@ -2,7 +2,7 @@
|
|||
import time
|
||||
import unittest
|
||||
|
||||
from mmengine.utils.time_counter import TimeCounter
|
||||
from mmengine.utils.dl_utils.time_counter import TimeCounter
|
||||
|
||||
|
||||
class TestTimeCounter(unittest.TestCase):
|
|
@ -2,7 +2,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.utils import torch_meshgrid
|
||||
from mmengine.utils.dl_utils import torch_meshgrid
|
||||
|
||||
|
||||
def test_torch_meshgrid():
|
|
@ -2,7 +2,8 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.utils import digit_version, is_jit_tracing
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import is_jit_tracing
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
Loading…
Reference in New Issue