add multi xpu support for PaddleClas (#678)
parent
a7aa14525c
commit
c3d401b7ea
tools/static
|
@ -119,11 +119,12 @@ def main(args):
|
|||
init_model(config, train_prog, exe)
|
||||
|
||||
if 'AMP' in config and config.AMP.get("use_pure_fp16", False):
|
||||
optimizer.amp_init(place,
|
||||
scope=paddle.static.global_scope(),
|
||||
test_program=valid_prog if config.validate else None)
|
||||
optimizer.amp_init(
|
||||
place,
|
||||
scope=paddle.static.global_scope(),
|
||||
test_program=valid_prog if config.validate else None)
|
||||
|
||||
if not config.get("is_distributed", True) and not use_xpu:
|
||||
if not config.get("is_distributed", True):
|
||||
compiled_train_prog = program.compile(
|
||||
config, train_prog, loss_name=train_fetchs["loss"][0].name)
|
||||
else:
|
||||
|
@ -133,10 +134,7 @@ def main(args):
|
|||
train_dataloader = Reader(config, 'train', places=place)()
|
||||
if config.validate and paddle.distributed.get_rank() == 0:
|
||||
valid_dataloader = Reader(config, 'valid', places=place)()
|
||||
if use_xpu:
|
||||
compiled_valid_prog = valid_prog
|
||||
else:
|
||||
compiled_valid_prog = program.compile(config, valid_prog)
|
||||
compiled_valid_prog = program.compile(config, valid_prog)
|
||||
else:
|
||||
assert use_gpu is True, "DALI only support gpu, please set use_gpu to True!"
|
||||
import dali
|
||||
|
|
Loading…
Reference in New Issue