From c3d401b7eae1caffa63860b07763cd3b82ce296b Mon Sep 17 00:00:00 2001 From: liuyuhui <1029880267@qq.com> Date: Wed, 14 Apr 2021 22:31:36 +0800 Subject: [PATCH] add multi xpu support for PaddleClas (#678) --- tools/static/train.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tools/static/train.py b/tools/static/train.py index 973b29d26..8560540fc 100644 --- a/tools/static/train.py +++ b/tools/static/train.py @@ -119,11 +119,12 @@ def main(args): init_model(config, train_prog, exe) if 'AMP' in config and config.AMP.get("use_pure_fp16", False): - optimizer.amp_init(place, - scope=paddle.static.global_scope(), - test_program=valid_prog if config.validate else None) + optimizer.amp_init( + place, + scope=paddle.static.global_scope(), + test_program=valid_prog if config.validate else None) - if not config.get("is_distributed", True) and not use_xpu: + if not config.get("is_distributed", True): compiled_train_prog = program.compile( config, train_prog, loss_name=train_fetchs["loss"][0].name) else: @@ -133,10 +134,7 @@ def main(args): train_dataloader = Reader(config, 'train', places=place)() if config.validate and paddle.distributed.get_rank() == 0: valid_dataloader = Reader(config, 'valid', places=place)() - if use_xpu: - compiled_valid_prog = valid_prog - else: - compiled_valid_prog = program.compile(config, valid_prog) + compiled_valid_prog = program.compile(config, valid_prog) else: assert use_gpu is True, "DALI only support gpu, please set use_gpu to True!" import dali