diff --git a/mmengine/model/utils.py b/mmengine/model/utils.py index e51ca76c..ce17b58c 100644 --- a/mmengine/model/utils.py +++ b/mmengine/model/utils.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from mmengine.logging import print_log from mmengine.utils.dl_utils import mmcv_full_available @@ -192,7 +193,17 @@ def revert_sync_batchnorm(module: nn.Module) -> nn.Module: if hasattr(module, 'qconfig'): module_output.qconfig = module.qconfig for name, child in module.named_children(): - module_output.add_module(name, revert_sync_batchnorm(child)) + # Some custom modules or 3rd party implemented modules may raise an + # error when calling `add_module`. Therefore, try to catch the error + # and do not raise it. See https://github.com/open-mmlab/mmengine/issues/638 # noqa: E501 + # for more details. + try: + module_output.add_module(name, revert_sync_batchnorm(child)) + except Exception: + print_log( + F'Failed to convert {child} from SyncBN to BN!', + logger='current', + level=logging.WARNING) del module return module_output diff --git a/tests/test_model/test_model_utils.py b/tests/test_model/test_model_utils.py index df3044c0..0a8d0001 100644 --- a/tests/test_model/test_model_utils.py +++ b/tests/test_model/test_model_utils.py @@ -15,6 +15,16 @@ from mmengine.registry import MODEL_WRAPPERS, Registry from mmengine.utils import is_installed +class ToyModule(nn.Module): + + def __init__(self): + super().__init__() + self.layer1 = nn.Linear(1, 1) + + def add_module(self, name, module): + raise ValueError() + + @pytest.mark.skipif( torch.__version__ == 'parrots', reason='not supported in parrots now') def test_revert_syncbn(): @@ -28,6 +38,12 @@ def test_revert_syncbn(): y = conv(x) assert y.shape == (1, 8, 9, 9) + # TODO, capsys provided by `pytest` cannot capture the error log produced + # by MMLogger. Test the error log after refactoring the unit test with + # `unittest` + conv = nn.Sequential(ToyModule(), nn.SyncBatchNorm(8)) + revert_sync_batchnorm(conv) + @pytest.mark.skipif( torch.__version__ == 'parrots', reason='not supported in parrots now') @@ -41,10 +57,12 @@ def test_convert_syncbn(): # Test convert to mmcv SyncBatchNorm if is_installed('mmcv'): # MMCV SyncBatchNorm is only supported on distributed training. + # torch 1.6 will throw an AssertionError, and higher version will + # throw an RuntimeError with pytest.raises((RuntimeError, AssertionError)): convert_sync_batchnorm(conv, implementation='mmcv') - # Test convert to Pytorch SyncBatchNorm + # Test convert BN to Pytorch SyncBatchNorm # Expect a ValueError prompting that SyncBN is not supported on CPU converted_conv = convert_sync_batchnorm(conv) assert isinstance(converted_conv[1], torch.nn.SyncBatchNorm)