add multi xpu support for PaddleClas (#678)
parent
a7aa14525c
commit
c3d401b7ea
|
@ -119,11 +119,12 @@ def main(args):
|
||||||
init_model(config, train_prog, exe)
|
init_model(config, train_prog, exe)
|
||||||
|
|
||||||
if 'AMP' in config and config.AMP.get("use_pure_fp16", False):
|
if 'AMP' in config and config.AMP.get("use_pure_fp16", False):
|
||||||
optimizer.amp_init(place,
|
optimizer.amp_init(
|
||||||
|
place,
|
||||||
scope=paddle.static.global_scope(),
|
scope=paddle.static.global_scope(),
|
||||||
test_program=valid_prog if config.validate else None)
|
test_program=valid_prog if config.validate else None)
|
||||||
|
|
||||||
if not config.get("is_distributed", True) and not use_xpu:
|
if not config.get("is_distributed", True):
|
||||||
compiled_train_prog = program.compile(
|
compiled_train_prog = program.compile(
|
||||||
config, train_prog, loss_name=train_fetchs["loss"][0].name)
|
config, train_prog, loss_name=train_fetchs["loss"][0].name)
|
||||||
else:
|
else:
|
||||||
|
@ -133,9 +134,6 @@ def main(args):
|
||||||
train_dataloader = Reader(config, 'train', places=place)()
|
train_dataloader = Reader(config, 'train', places=place)()
|
||||||
if config.validate and paddle.distributed.get_rank() == 0:
|
if config.validate and paddle.distributed.get_rank() == 0:
|
||||||
valid_dataloader = Reader(config, 'valid', places=place)()
|
valid_dataloader = Reader(config, 'valid', places=place)()
|
||||||
if use_xpu:
|
|
||||||
compiled_valid_prog = valid_prog
|
|
||||||
else:
|
|
||||||
compiled_valid_prog = program.compile(config, valid_prog)
|
compiled_valid_prog = program.compile(config, valid_prog)
|
||||||
else:
|
else:
|
||||||
assert use_gpu is True, "DALI only support gpu, please set use_gpu to True!"
|
assert use_gpu is True, "DALI only support gpu, please set use_gpu to True!"
|
||||||
|
|
Loading…
Reference in New Issue