diff --git a/tools/static/train.py b/tools/static/train.py index a73a160a6..cd6aaefaa 100644 --- a/tools/static/train.py +++ b/tools/static/train.py @@ -124,7 +124,7 @@ def main(args): # load pretrained models or checkpoints init_model(config, train_prog, exe) - if not config.get("is_distributed", True): + if not config.get("is_distributed", True) and not use_xpu: compiled_train_prog = program.compile( config, train_prog, loss_name=train_fetchs["loss"][0].name) else: