add multi xpu support for PaddleClas ()

pull/683/head^2
liuyuhui 2021-04-14 22:31:36 +08:00 committed by GitHub
parent a7aa14525c
commit c3d401b7ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 8 deletions
tools/static

View File

@ -119,11 +119,12 @@ def main(args):
init_model(config, train_prog, exe)
if 'AMP' in config and config.AMP.get("use_pure_fp16", False):
optimizer.amp_init(place,
scope=paddle.static.global_scope(),
test_program=valid_prog if config.validate else None)
optimizer.amp_init(
place,
scope=paddle.static.global_scope(),
test_program=valid_prog if config.validate else None)
if not config.get("is_distributed", True) and not use_xpu:
if not config.get("is_distributed", True):
compiled_train_prog = program.compile(
config, train_prog, loss_name=train_fetchs["loss"][0].name)
else:
@ -133,10 +134,7 @@ def main(args):
train_dataloader = Reader(config, 'train', places=place)()
if config.validate and paddle.distributed.get_rank() == 0:
valid_dataloader = Reader(config, 'valid', places=place)()
if use_xpu:
compiled_valid_prog = valid_prog
else:
compiled_valid_prog = program.compile(config, valid_prog)
compiled_valid_prog = program.compile(config, valid_prog)
else:
assert use_gpu is True, "DALI only support gpu, please set use_gpu to True!"
import dali