[Feat] Support revert syncbn (#326)

* [Feat] Support revert syncbn

* use logger.info but not warning

* fix info string
This commit is contained in:
Alex Yang 2022-06-22 19:50:54 +08:00 committed by GitHub
parent 312f264ecd
commit e18832f046
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 118 additions and 2 deletions

View File

@ -37,7 +37,8 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
count_registered_modules)
from mmengine.registry.root import LOG_PROCESSORS
from mmengine.utils import (TORCH_VERSION, digit_version, get_git_hash,
is_list_of, set_multi_processing)
is_list_of, revert_sync_batchnorm,
set_multi_processing)
from mmengine.visualization import Visualizer
from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
@ -830,6 +831,11 @@ class Runner:
model = model.to(get_device())
if not self.distributed:
self.logger.info(
'Distributed training is not used, all SyncBatchNorm (SyncBN) '
'layers in the model will be automatically reverted to '
'BatchNormXd layers if they are used.')
model = revert_sync_batchnorm(model)
return model
if model_wrapper_cfg is None:

View File

@ -12,6 +12,7 @@ from .parrots_wrapper import TORCH_VERSION
from .path import (check_file_exist, fopen, is_abs, is_filepath,
mkdir_or_exist, scandir, symlink)
from .setup_env import set_multi_processing
from .sync_bn import revert_sync_batchnorm
from .version_utils import digit_version, get_git_hash
# TODO: creates intractable circular import issues
@ -27,5 +28,5 @@ __all__ = [
'is_method_overridden', 'has_method', 'mmcv_full_available',
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm',
'is_abs'
'is_abs', 'revert_sync_batchnorm'
]

57
mmengine/utils/sync_bn.py Normal file
View File

@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
class _BatchNormXd(nn.modules.batchnorm._BatchNorm):
"""A general BatchNorm layer without input dimension check.
Reproduced from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
is `_check_input_dim` that is designed for tensor sanity checks.
The check has been bypassed in this class for the convenience of converting
SyncBatchNorm.
"""
def _check_input_dim(self, input: torch.Tensor):
return
def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
"""Helper function to convert all `SyncBatchNorm` (SyncBN) and
`mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
`BatchNormXd` layers.
Adapted from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
Args:
module (nn.Module): The module containing `SyncBatchNorm` layers.
Returns:
module_output: The converted module with `BatchNormXd` layers.
"""
module_output = module
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
if isinstance(module, tuple(module_checklist)):
module_output = _BatchNormXd(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
# no_grad() may not be needed here but
# just to be consistent with `convert_sync_batchnorm()`
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
module_output.training = module.training
# qconfig exists in quantized models
if hasattr(module, 'qconfig'):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, revert_sync_batchnorm(child))
del module
return module_output

View File

@ -65,6 +65,30 @@ class ToyModel1(ToyModel):
super().__init__()
@MODELS.register_module()
class ToySyncBNModel(BaseModel):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 8, 2)
self.bn = nn.SyncBatchNorm(8)
def forward(self, batch_inputs, labels, mode='tensor'):
labels = torch.stack(labels)
outputs = self.conv(batch_inputs)
outputs = self.bn(outputs)
if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
elif mode == 'predict':
outputs = dict(log_vars=dict(a=1, b=0.5))
return outputs
@MODELS.register_module()
class TopGANModel(BaseModel):
@ -683,6 +707,14 @@ class TestRunner(TestCase):
self.assertFalse(model.initiailzed)
def test_wrap_model(self):
# revert sync batchnorm
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_revert_syncbn'
cfg.model = dict(type='ToySyncBNModel')
runner = Runner.from_cfg(cfg)
self.assertIsInstance(runner.model, BaseModel)
assert not isinstance(runner.model.bn, nn.SyncBatchNorm)
# custom model wrapper
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_wrap_model'

View File

@ -0,0 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn as nn
from mmengine.utils import revert_sync_batchnorm
@pytest.mark.skipif(
torch.__version__ == 'parrots', reason='not supported in parrots now')
def test_revert_syncbn():
# conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN'))
conv = nn.Sequential(nn.Conv2d(3, 8, 2), nn.SyncBatchNorm(8))
x = torch.randn(1, 3, 10, 10)
# Expect a ValueError prompting that SyncBN is not supported on CPU
with pytest.raises(ValueError):
y = conv(x)
conv = revert_sync_batchnorm(conv)
y = conv(x)
assert y.shape == (1, 8, 9, 9)