mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhancement] Improve revert_sync_batchnorm to support mmcv SyncBN (#448)
This commit is contained in:
parent
e907931fb8
commit
e8ee1926b8
@ -34,6 +34,15 @@ def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
|
||||
"""
|
||||
module_output = module
|
||||
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
|
||||
|
||||
try:
|
||||
import mmcv
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
if hasattr(mmcv, 'ops'):
|
||||
module_checklist.append(mmcv.ops.SyncBatchNorm)
|
||||
|
||||
if isinstance(module, tuple(module_checklist)):
|
||||
module_output = _BatchNormXd(module.num_features, module.eps,
|
||||
module.momentum, module.affine,
|
||||
|
Loading…
x
Reference in New Issue
Block a user