mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Enhance compatibility of revert_sync_batchnorm
(#695)
* [Enhance] Enhance revert_sync_batchnorm and convert_sync_batchnorm * [Enhance] Enhance revert_sync_batchnorm and convert_sync_batchnorm * Fix unit test * Add coments * Refine comments * clean the code * revert convert_sync_batchnorm * revert convert_sync_batchnorm * refine comment * fix CI * fix CI
This commit is contained in:
parent
9b4dbb3131
commit
ded73f3a56
@ -7,6 +7,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from mmengine.logging import print_log
|
||||||
from mmengine.utils.dl_utils import mmcv_full_available
|
from mmengine.utils.dl_utils import mmcv_full_available
|
||||||
|
|
||||||
|
|
||||||
@ -192,7 +193,17 @@ def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
|
|||||||
if hasattr(module, 'qconfig'):
|
if hasattr(module, 'qconfig'):
|
||||||
module_output.qconfig = module.qconfig
|
module_output.qconfig = module.qconfig
|
||||||
for name, child in module.named_children():
|
for name, child in module.named_children():
|
||||||
module_output.add_module(name, revert_sync_batchnorm(child))
|
# Some custom modules or 3rd party implemented modules may raise an
|
||||||
|
# error when calling `add_module`. Therefore, try to catch the error
|
||||||
|
# and do not raise it. See https://github.com/open-mmlab/mmengine/issues/638 # noqa: E501
|
||||||
|
# for more details.
|
||||||
|
try:
|
||||||
|
module_output.add_module(name, revert_sync_batchnorm(child))
|
||||||
|
except Exception:
|
||||||
|
print_log(
|
||||||
|
F'Failed to convert {child} from SyncBN to BN!',
|
||||||
|
logger='current',
|
||||||
|
level=logging.WARNING)
|
||||||
del module
|
del module
|
||||||
return module_output
|
return module_output
|
||||||
|
|
||||||
|
@ -15,6 +15,16 @@ from mmengine.registry import MODEL_WRAPPERS, Registry
|
|||||||
from mmengine.utils import is_installed
|
from mmengine.utils import is_installed
|
||||||
|
|
||||||
|
|
||||||
|
class ToyModule(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.layer1 = nn.Linear(1, 1)
|
||||||
|
|
||||||
|
def add_module(self, name, module):
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
torch.__version__ == 'parrots', reason='not supported in parrots now')
|
torch.__version__ == 'parrots', reason='not supported in parrots now')
|
||||||
def test_revert_syncbn():
|
def test_revert_syncbn():
|
||||||
@ -28,6 +38,12 @@ def test_revert_syncbn():
|
|||||||
y = conv(x)
|
y = conv(x)
|
||||||
assert y.shape == (1, 8, 9, 9)
|
assert y.shape == (1, 8, 9, 9)
|
||||||
|
|
||||||
|
# TODO, capsys provided by `pytest` cannot capture the error log produced
|
||||||
|
# by MMLogger. Test the error log after refactoring the unit test with
|
||||||
|
# `unittest`
|
||||||
|
conv = nn.Sequential(ToyModule(), nn.SyncBatchNorm(8))
|
||||||
|
revert_sync_batchnorm(conv)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
torch.__version__ == 'parrots', reason='not supported in parrots now')
|
torch.__version__ == 'parrots', reason='not supported in parrots now')
|
||||||
@ -41,10 +57,12 @@ def test_convert_syncbn():
|
|||||||
# Test convert to mmcv SyncBatchNorm
|
# Test convert to mmcv SyncBatchNorm
|
||||||
if is_installed('mmcv'):
|
if is_installed('mmcv'):
|
||||||
# MMCV SyncBatchNorm is only supported on distributed training.
|
# MMCV SyncBatchNorm is only supported on distributed training.
|
||||||
|
# torch 1.6 will throw an AssertionError, and higher version will
|
||||||
|
# throw an RuntimeError
|
||||||
with pytest.raises((RuntimeError, AssertionError)):
|
with pytest.raises((RuntimeError, AssertionError)):
|
||||||
convert_sync_batchnorm(conv, implementation='mmcv')
|
convert_sync_batchnorm(conv, implementation='mmcv')
|
||||||
|
|
||||||
# Test convert to Pytorch SyncBatchNorm
|
# Test convert BN to Pytorch SyncBatchNorm
|
||||||
# Expect a ValueError prompting that SyncBN is not supported on CPU
|
# Expect a ValueError prompting that SyncBN is not supported on CPU
|
||||||
converted_conv = convert_sync_batchnorm(conv)
|
converted_conv = convert_sync_batchnorm(conv)
|
||||||
assert isinstance(converted_conv[1], torch.nn.SyncBatchNorm)
|
assert isinstance(converted_conv[1], torch.nn.SyncBatchNorm)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user