fix dist eval ()

pull/1021/head
littletomatodonkey 2020-11-04 11:06:23 +08:00 committed by GitHub
parent 081ca857ce
commit 82e7a90bfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 2 deletions
tools

View File

@ -59,9 +59,14 @@ def main(args, return_dict={}):
paddle.disable_static(place)
strategy = paddle.distributed.init_parallel_env()
use_data_parallel = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) != 1
config["use_data_parallel"] = use_data_parallel
net = program.create_model(config.ARCHITECTURE, config.classes_num)
net = paddle.DataParallel(net, strategy)
if config["use_data_parallel"]:
strategy = paddle.distributed.init_parallel_env()
net = paddle.DataParallel(net, strategy)
init_model(config, net, optimizer=None)
valid_dataloader = Reader(config, 'valid', places=place)()
net.eval()