Add missing training flag to convert_sync_batchnorm
parent
cb4cea561a
commit
84631cb5c6
|
@ -176,6 +176,7 @@ def convert_sync_batchnorm(module, process_group=None):
|
||||||
module_output.running_mean = module.running_mean
|
module_output.running_mean = module.running_mean
|
||||||
module_output.running_var = module.running_var
|
module_output.running_var = module.running_var
|
||||||
module_output.num_batches_tracked = module.num_batches_tracked
|
module_output.num_batches_tracked = module.num_batches_tracked
|
||||||
|
module_output.training = module.training
|
||||||
if hasattr(module, "qconfig"):
|
if hasattr(module, "qconfig"):
|
||||||
module_output.qconfig = module.qconfig
|
module_output.qconfig = module.qconfig
|
||||||
for name, child in module.named_children():
|
for name, child in module.named_children():
|
||||||
|
|
Loading…
Reference in New Issue