[Enhancement] Improve revert_sync_batchnorm to support mmcv SyncBN (#448)

This commit is contained in:
Zaida Zhou 2022-08-21 14:54:52 +08:00 committed by GitHub
parent e907931fb8
commit e8ee1926b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -34,6 +34,15 @@ def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
"""
module_output = module
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
try:
import mmcv
except ImportError:
pass
else:
if hasattr(mmcv, 'ops'):
module_checklist.append(mmcv.ops.SyncBatchNorm)
if isinstance(module, tuple(module_checklist)):
module_output = _BatchNormXd(module.num_features, module.eps,
module.momentum, module.affine,