mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Enhance compatibility of revert_sync_batchnorm
(#695)
* [Enhance] Enhance revert_sync_batchnorm and convert_sync_batchnorm * [Enhance] Enhance revert_sync_batchnorm and convert_sync_batchnorm * Fix unit test * Add coments * Refine comments * clean the code * revert convert_sync_batchnorm * revert convert_sync_batchnorm * refine comment * fix CI * fix CI
This commit is contained in:
parent
9b4dbb3131
commit
ded73f3a56
@ -7,6 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.utils.dl_utils import mmcv_full_available
|
||||
|
||||
|
||||
@ -192,7 +193,17 @@ def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
|
||||
if hasattr(module, 'qconfig'):
|
||||
module_output.qconfig = module.qconfig
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(name, revert_sync_batchnorm(child))
|
||||
# Some custom modules or 3rd party implemented modules may raise an
|
||||
# error when calling `add_module`. Therefore, try to catch the error
|
||||
# and do not raise it. See https://github.com/open-mmlab/mmengine/issues/638 # noqa: E501
|
||||
# for more details.
|
||||
try:
|
||||
module_output.add_module(name, revert_sync_batchnorm(child))
|
||||
except Exception:
|
||||
print_log(
|
||||
F'Failed to convert {child} from SyncBN to BN!',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
del module
|
||||
return module_output
|
||||
|
||||
|
@ -15,6 +15,16 @@ from mmengine.registry import MODEL_WRAPPERS, Registry
|
||||
from mmengine.utils import is_installed
|
||||
|
||||
|
||||
class ToyModule(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer1 = nn.Linear(1, 1)
|
||||
|
||||
def add_module(self, name, module):
|
||||
raise ValueError()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
torch.__version__ == 'parrots', reason='not supported in parrots now')
|
||||
def test_revert_syncbn():
|
||||
@ -28,6 +38,12 @@ def test_revert_syncbn():
|
||||
y = conv(x)
|
||||
assert y.shape == (1, 8, 9, 9)
|
||||
|
||||
# TODO, capsys provided by `pytest` cannot capture the error log produced
|
||||
# by MMLogger. Test the error log after refactoring the unit test with
|
||||
# `unittest`
|
||||
conv = nn.Sequential(ToyModule(), nn.SyncBatchNorm(8))
|
||||
revert_sync_batchnorm(conv)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
torch.__version__ == 'parrots', reason='not supported in parrots now')
|
||||
@ -41,10 +57,12 @@ def test_convert_syncbn():
|
||||
# Test convert to mmcv SyncBatchNorm
|
||||
if is_installed('mmcv'):
|
||||
# MMCV SyncBatchNorm is only supported on distributed training.
|
||||
# torch 1.6 will throw an AssertionError, and higher version will
|
||||
# throw an RuntimeError
|
||||
with pytest.raises((RuntimeError, AssertionError)):
|
||||
convert_sync_batchnorm(conv, implementation='mmcv')
|
||||
|
||||
# Test convert to Pytorch SyncBatchNorm
|
||||
# Test convert BN to Pytorch SyncBatchNorm
|
||||
# Expect a ValueError prompting that SyncBN is not supported on CPU
|
||||
converted_conv = convert_sync_batchnorm(conv)
|
||||
assert isinstance(converted_conv[1], torch.nn.SyncBatchNorm)
|
||||
|
Loading…
x
Reference in New Issue
Block a user