fix error that sync bn should not be used on cpu

pull/2492/head
gaotingquan 2022-11-15 03:28:34 +00:00 committed by Wei Shengyu
parent 2e41b5de5e
commit 89696c7bac
1 changed files with 5 additions and 1 deletions

View File

@ -39,7 +39,11 @@ def build_model(config, mode="train"):
mod = importlib.import_module(__name__)
arch = getattr(mod, model_type)(**arch_config)
if use_sync_bn:
arch = nn.SyncBatchNorm.convert_sync_batchnorm(arch)
if config["Global"]["device"] == "gpu":
arch = nn.SyncBatchNorm.convert_sync_batchnorm(arch)
else:
msg = "SyncBatchNorm can only be used on GPU device. The releated setting has been ignored."
logger.warning(msg)
if isinstance(arch, TheseusLayer):
prune_model(config, arch)