diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 637540d6..fec759be 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -37,7 +37,8 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, count_registered_modules) from mmengine.registry.root import LOG_PROCESSORS from mmengine.utils import (TORCH_VERSION, digit_version, get_git_hash, - is_list_of, set_multi_processing) + is_list_of, revert_sync_batchnorm, + set_multi_processing) from mmengine.visualization import Visualizer from .base_loop import BaseLoop from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, @@ -830,6 +831,11 @@ class Runner: model = model.to(get_device()) if not self.distributed: + self.logger.info( + 'Distributed training is not used, all SyncBatchNorm (SyncBN) ' + 'layers in the model will be automatically reverted to ' + 'BatchNormXd layers if they are used.') + model = revert_sync_batchnorm(model) return model if model_wrapper_cfg is None: diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py index 0375578b..df080378 100644 --- a/mmengine/utils/__init__.py +++ b/mmengine/utils/__init__.py @@ -12,6 +12,7 @@ from .parrots_wrapper import TORCH_VERSION from .path import (check_file_exist, fopen, is_abs, is_filepath, mkdir_or_exist, scandir, symlink) from .setup_env import set_multi_processing +from .sync_bn import revert_sync_batchnorm from .version_utils import digit_version, get_git_hash # TODO: creates intractable circular import issues @@ -27,5 +28,5 @@ __all__ = [ 'is_method_overridden', 'has_method', 'mmcv_full_available', 'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url', 'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm', - 'is_abs' + 'is_abs', 'revert_sync_batchnorm' ] diff --git a/mmengine/utils/sync_bn.py b/mmengine/utils/sync_bn.py new file mode 100644 index 00000000..bcbe0165 --- /dev/null +++ b/mmengine/utils/sync_bn.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + + +class _BatchNormXd(nn.modules.batchnorm._BatchNorm): + """A general BatchNorm layer without input dimension check. + + Reproduced from @kapily's work: + (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) + The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc + is `_check_input_dim` that is designed for tensor sanity checks. + The check has been bypassed in this class for the convenience of converting + SyncBatchNorm. + """ + + def _check_input_dim(self, input: torch.Tensor): + return + + +def revert_sync_batchnorm(module: nn.Module) -> nn.Module: + """Helper function to convert all `SyncBatchNorm` (SyncBN) and + `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to + `BatchNormXd` layers. + + Adapted from @kapily's work: + (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) + + Args: + module (nn.Module): The module containing `SyncBatchNorm` layers. + + Returns: + module_output: The converted module with `BatchNormXd` layers. + """ + module_output = module + module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm] + if isinstance(module, tuple(module_checklist)): + module_output = _BatchNormXd(module.num_features, module.eps, + module.momentum, module.affine, + module.track_running_stats) + if module.affine: + # no_grad() may not be needed here but + # just to be consistent with `convert_sync_batchnorm()` + with torch.no_grad(): + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + module_output.training = module.training + # qconfig exists in quantized models + if hasattr(module, 'qconfig'): + module_output.qconfig = module.qconfig + for name, child in module.named_children(): + module_output.add_module(name, revert_sync_batchnorm(child)) + del module + return module_output diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 43fed5b3..39454511 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -65,6 +65,30 @@ class ToyModel1(ToyModel): super().__init__() +@MODELS.register_module() +class ToySyncBNModel(BaseModel): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 8, 2) + self.bn = nn.SyncBatchNorm(8) + + def forward(self, batch_inputs, labels, mode='tensor'): + labels = torch.stack(labels) + outputs = self.conv(batch_inputs) + outputs = self.bn(outputs) + + if mode == 'tensor': + return outputs + elif mode == 'loss': + loss = (labels - outputs).sum() + outputs = dict(loss=loss) + return outputs + elif mode == 'predict': + outputs = dict(log_vars=dict(a=1, b=0.5)) + return outputs + + @MODELS.register_module() class TopGANModel(BaseModel): @@ -683,6 +707,14 @@ class TestRunner(TestCase): self.assertFalse(model.initiailzed) def test_wrap_model(self): + # revert sync batchnorm + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.experiment_name = 'test_revert_syncbn' + cfg.model = dict(type='ToySyncBNModel') + runner = Runner.from_cfg(cfg) + self.assertIsInstance(runner.model, BaseModel) + assert not isinstance(runner.model.bn, nn.SyncBatchNorm) + # custom model wrapper cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_wrap_model' diff --git a/tests/test_utils/test_revert_syncbn.py b/tests/test_utils/test_revert_syncbn.py new file mode 100644 index 00000000..6b7e4846 --- /dev/null +++ b/tests/test_utils/test_revert_syncbn.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch +import torch.nn as nn + +from mmengine.utils import revert_sync_batchnorm + + +@pytest.mark.skipif( + torch.__version__ == 'parrots', reason='not supported in parrots now') +def test_revert_syncbn(): + # conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')) + conv = nn.Sequential(nn.Conv2d(3, 8, 2), nn.SyncBatchNorm(8)) + x = torch.randn(1, 3, 10, 10) + # Expect a ValueError prompting that SyncBN is not supported on CPU + with pytest.raises(ValueError): + y = conv(x) + conv = revert_sync_batchnorm(conv) + y = conv(x) + assert y.shape == (1, 8, 9, 9)