[Fix] fix inconsisent training/eval state after SyncBN->BN (#453)

This commit is contained in:
Tong Gao 2021-08-25 13:14:03 +08:00 committed by GitHub
parent 961fbb6ca5
commit 7c1bf45c63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 5 deletions

View File

@ -21,7 +21,7 @@ def revert_sync_batchnorm(module):
"""Helper function to convert all `SyncBatchNorm` layers in the model to
`BatchNormXd` layers.
Reproduced from @kapily's work:
Adapted from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
Args:
@ -42,6 +42,7 @@ def revert_sync_batchnorm(module):
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
module_output.training = module.training
if hasattr(module, 'qconfig'):
module_output.qconfig = module.qconfig
for name, child in module.named_children():

View File

@ -7,10 +7,16 @@ from mmocr.utils import revert_sync_batchnorm
def test_revert_sync_batchnorm():
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN'))
conv_syncbn = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')).to('cpu')
conv_syncbn.train()
x = torch.randn(1, 3, 10, 10)
# Will raise an ValueError saying SyncBN does not run on CPU
with pytest.raises(ValueError):
y = conv(x)
conv = revert_sync_batchnorm(conv)
y = conv(x)
y = conv_syncbn(x)
conv_bn = revert_sync_batchnorm(conv_syncbn)
y = conv_bn(x)
assert y.shape == (1, 8, 9, 9)
assert conv_bn.training == conv_syncbn.training
conv_syncbn.eval()
conv_bn = revert_sync_batchnorm(conv_syncbn)
assert conv_bn.training == conv_syncbn.training