mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feat] Support revert syncbn (#326)
* [Feat] Support revert syncbn * use logger.info but not warning * fix info string
This commit is contained in:
parent
312f264ecd
commit
e18832f046
@ -37,7 +37,8 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
|
|||||||
count_registered_modules)
|
count_registered_modules)
|
||||||
from mmengine.registry.root import LOG_PROCESSORS
|
from mmengine.registry.root import LOG_PROCESSORS
|
||||||
from mmengine.utils import (TORCH_VERSION, digit_version, get_git_hash,
|
from mmengine.utils import (TORCH_VERSION, digit_version, get_git_hash,
|
||||||
is_list_of, set_multi_processing)
|
is_list_of, revert_sync_batchnorm,
|
||||||
|
set_multi_processing)
|
||||||
from mmengine.visualization import Visualizer
|
from mmengine.visualization import Visualizer
|
||||||
from .base_loop import BaseLoop
|
from .base_loop import BaseLoop
|
||||||
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
||||||
@ -830,6 +831,11 @@ class Runner:
|
|||||||
model = model.to(get_device())
|
model = model.to(get_device())
|
||||||
|
|
||||||
if not self.distributed:
|
if not self.distributed:
|
||||||
|
self.logger.info(
|
||||||
|
'Distributed training is not used, all SyncBatchNorm (SyncBN) '
|
||||||
|
'layers in the model will be automatically reverted to '
|
||||||
|
'BatchNormXd layers if they are used.')
|
||||||
|
model = revert_sync_batchnorm(model)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
if model_wrapper_cfg is None:
|
if model_wrapper_cfg is None:
|
||||||
|
@ -12,6 +12,7 @@ from .parrots_wrapper import TORCH_VERSION
|
|||||||
from .path import (check_file_exist, fopen, is_abs, is_filepath,
|
from .path import (check_file_exist, fopen, is_abs, is_filepath,
|
||||||
mkdir_or_exist, scandir, symlink)
|
mkdir_or_exist, scandir, symlink)
|
||||||
from .setup_env import set_multi_processing
|
from .setup_env import set_multi_processing
|
||||||
|
from .sync_bn import revert_sync_batchnorm
|
||||||
from .version_utils import digit_version, get_git_hash
|
from .version_utils import digit_version, get_git_hash
|
||||||
|
|
||||||
# TODO: creates intractable circular import issues
|
# TODO: creates intractable circular import issues
|
||||||
@ -27,5 +28,5 @@ __all__ = [
|
|||||||
'is_method_overridden', 'has_method', 'mmcv_full_available',
|
'is_method_overridden', 'has_method', 'mmcv_full_available',
|
||||||
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
|
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
|
||||||
'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm',
|
'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm',
|
||||||
'is_abs'
|
'is_abs', 'revert_sync_batchnorm'
|
||||||
]
|
]
|
||||||
|
57
mmengine/utils/sync_bn.py
Normal file
57
mmengine/utils/sync_bn.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class _BatchNormXd(nn.modules.batchnorm._BatchNorm):
|
||||||
|
"""A general BatchNorm layer without input dimension check.
|
||||||
|
|
||||||
|
Reproduced from @kapily's work:
|
||||||
|
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
|
||||||
|
The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
|
||||||
|
is `_check_input_dim` that is designed for tensor sanity checks.
|
||||||
|
The check has been bypassed in this class for the convenience of converting
|
||||||
|
SyncBatchNorm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _check_input_dim(self, input: torch.Tensor):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
|
||||||
|
"""Helper function to convert all `SyncBatchNorm` (SyncBN) and
|
||||||
|
`mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
|
||||||
|
`BatchNormXd` layers.
|
||||||
|
|
||||||
|
Adapted from @kapily's work:
|
||||||
|
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): The module containing `SyncBatchNorm` layers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
module_output: The converted module with `BatchNormXd` layers.
|
||||||
|
"""
|
||||||
|
module_output = module
|
||||||
|
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
|
||||||
|
if isinstance(module, tuple(module_checklist)):
|
||||||
|
module_output = _BatchNormXd(module.num_features, module.eps,
|
||||||
|
module.momentum, module.affine,
|
||||||
|
module.track_running_stats)
|
||||||
|
if module.affine:
|
||||||
|
# no_grad() may not be needed here but
|
||||||
|
# just to be consistent with `convert_sync_batchnorm()`
|
||||||
|
with torch.no_grad():
|
||||||
|
module_output.weight = module.weight
|
||||||
|
module_output.bias = module.bias
|
||||||
|
module_output.running_mean = module.running_mean
|
||||||
|
module_output.running_var = module.running_var
|
||||||
|
module_output.num_batches_tracked = module.num_batches_tracked
|
||||||
|
module_output.training = module.training
|
||||||
|
# qconfig exists in quantized models
|
||||||
|
if hasattr(module, 'qconfig'):
|
||||||
|
module_output.qconfig = module.qconfig
|
||||||
|
for name, child in module.named_children():
|
||||||
|
module_output.add_module(name, revert_sync_batchnorm(child))
|
||||||
|
del module
|
||||||
|
return module_output
|
@ -65,6 +65,30 @@ class ToyModel1(ToyModel):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class ToySyncBNModel(BaseModel):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(3, 8, 2)
|
||||||
|
self.bn = nn.SyncBatchNorm(8)
|
||||||
|
|
||||||
|
def forward(self, batch_inputs, labels, mode='tensor'):
|
||||||
|
labels = torch.stack(labels)
|
||||||
|
outputs = self.conv(batch_inputs)
|
||||||
|
outputs = self.bn(outputs)
|
||||||
|
|
||||||
|
if mode == 'tensor':
|
||||||
|
return outputs
|
||||||
|
elif mode == 'loss':
|
||||||
|
loss = (labels - outputs).sum()
|
||||||
|
outputs = dict(loss=loss)
|
||||||
|
return outputs
|
||||||
|
elif mode == 'predict':
|
||||||
|
outputs = dict(log_vars=dict(a=1, b=0.5))
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class TopGANModel(BaseModel):
|
class TopGANModel(BaseModel):
|
||||||
|
|
||||||
@ -683,6 +707,14 @@ class TestRunner(TestCase):
|
|||||||
self.assertFalse(model.initiailzed)
|
self.assertFalse(model.initiailzed)
|
||||||
|
|
||||||
def test_wrap_model(self):
|
def test_wrap_model(self):
|
||||||
|
# revert sync batchnorm
|
||||||
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_revert_syncbn'
|
||||||
|
cfg.model = dict(type='ToySyncBNModel')
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
self.assertIsInstance(runner.model, BaseModel)
|
||||||
|
assert not isinstance(runner.model.bn, nn.SyncBatchNorm)
|
||||||
|
|
||||||
# custom model wrapper
|
# custom model wrapper
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
cfg.experiment_name = 'test_wrap_model'
|
cfg.experiment_name = 'test_wrap_model'
|
||||||
|
20
tests/test_utils/test_revert_syncbn.py
Normal file
20
tests/test_utils/test_revert_syncbn.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from mmengine.utils import revert_sync_batchnorm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
torch.__version__ == 'parrots', reason='not supported in parrots now')
|
||||||
|
def test_revert_syncbn():
|
||||||
|
# conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN'))
|
||||||
|
conv = nn.Sequential(nn.Conv2d(3, 8, 2), nn.SyncBatchNorm(8))
|
||||||
|
x = torch.randn(1, 3, 10, 10)
|
||||||
|
# Expect a ValueError prompting that SyncBN is not supported on CPU
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
y = conv(x)
|
||||||
|
conv = revert_sync_batchnorm(conv)
|
||||||
|
y = conv(x)
|
||||||
|
assert y.shape == (1, 8, 9, 9)
|
Loading…
x
Reference in New Issue
Block a user