From e8ee1926b8d6f04a90dc9b0294a2ee81b06f5384 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Sun, 21 Aug 2022 14:54:52 +0800 Subject: [PATCH] [Enhancement] Improve revert_sync_batchnorm to support mmcv SyncBN (#448) --- mmengine/utils/sync_bn.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mmengine/utils/sync_bn.py b/mmengine/utils/sync_bn.py index bcbe0165..f92a849f 100644 --- a/mmengine/utils/sync_bn.py +++ b/mmengine/utils/sync_bn.py @@ -34,6 +34,15 @@ def revert_sync_batchnorm(module: nn.Module) -> nn.Module: """ module_output = module module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm] + + try: + import mmcv + except ImportError: + pass + else: + if hasattr(mmcv, 'ops'): + module_checklist.append(mmcv.ops.SyncBatchNorm) + if isinstance(module, tuple(module_checklist)): module_output = _BatchNormXd(module.num_features, module.eps, module.momentum, module.affine,