mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhancement] Improve revert_sync_batchnorm to support mmcv SyncBN (#448)
This commit is contained in:
parent
e907931fb8
commit
e8ee1926b8
@ -34,6 +34,15 @@ def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
|
|||||||
"""
|
"""
|
||||||
module_output = module
|
module_output = module
|
||||||
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
|
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
|
||||||
|
|
||||||
|
try:
|
||||||
|
import mmcv
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if hasattr(mmcv, 'ops'):
|
||||||
|
module_checklist.append(mmcv.ops.SyncBatchNorm)
|
||||||
|
|
||||||
if isinstance(module, tuple(module_checklist)):
|
if isinstance(module, tuple(module_checklist)):
|
||||||
module_output = _BatchNormXd(module.num_features, module.eps,
|
module_output = _BatchNormXd(module.num_features, module.eps,
|
||||||
module.momentum, module.affine,
|
module.momentum, module.affine,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user