mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Fix] fix inconsisent training/eval state after SyncBN->BN (#453)
This commit is contained in:
parent
961fbb6ca5
commit
7c1bf45c63
@ -21,7 +21,7 @@ def revert_sync_batchnorm(module):
|
||||
"""Helper function to convert all `SyncBatchNorm` layers in the model to
|
||||
`BatchNormXd` layers.
|
||||
|
||||
Reproduced from @kapily's work:
|
||||
Adapted from @kapily's work:
|
||||
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
|
||||
|
||||
Args:
|
||||
@ -42,6 +42,7 @@ def revert_sync_batchnorm(module):
|
||||
module_output.running_mean = module.running_mean
|
||||
module_output.running_var = module.running_var
|
||||
module_output.num_batches_tracked = module.num_batches_tracked
|
||||
module_output.training = module.training
|
||||
if hasattr(module, 'qconfig'):
|
||||
module_output.qconfig = module.qconfig
|
||||
for name, child in module.named_children():
|
||||
|
@ -7,10 +7,16 @@ from mmocr.utils import revert_sync_batchnorm
|
||||
|
||||
|
||||
def test_revert_sync_batchnorm():
|
||||
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN'))
|
||||
conv_syncbn = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')).to('cpu')
|
||||
conv_syncbn.train()
|
||||
x = torch.randn(1, 3, 10, 10)
|
||||
# Will raise an ValueError saying SyncBN does not run on CPU
|
||||
with pytest.raises(ValueError):
|
||||
y = conv(x)
|
||||
conv = revert_sync_batchnorm(conv)
|
||||
y = conv(x)
|
||||
y = conv_syncbn(x)
|
||||
conv_bn = revert_sync_batchnorm(conv_syncbn)
|
||||
y = conv_bn(x)
|
||||
assert y.shape == (1, 8, 9, 9)
|
||||
assert conv_bn.training == conv_syncbn.training
|
||||
conv_syncbn.eval()
|
||||
conv_bn = revert_sync_batchnorm(conv_syncbn)
|
||||
assert conv_bn.training == conv_syncbn.training
|
||||
|
Loading…
x
Reference in New Issue
Block a user