fix dist eval (#364)

pull/1021/head
littletomatodonkey 2020-11-04 11:06:23 +08:00 committed by GitHub
parent 081ca857ce
commit 82e7a90bfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 2 deletions

View File

@ -59,9 +59,14 @@ def main(args, return_dict={}):
paddle.disable_static(place) paddle.disable_static(place)
strategy = paddle.distributed.init_parallel_env() use_data_parallel = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) != 1
config["use_data_parallel"] = use_data_parallel
net = program.create_model(config.ARCHITECTURE, config.classes_num) net = program.create_model(config.ARCHITECTURE, config.classes_num)
net = paddle.DataParallel(net, strategy) if config["use_data_parallel"]:
strategy = paddle.distributed.init_parallel_env()
net = paddle.DataParallel(net, strategy)
init_model(config, net, optimizer=None) init_model(config, net, optimizer=None)
valid_dataloader = Reader(config, 'valid', places=place)() valid_dataloader = Reader(config, 'valid', places=place)()
net.eval() net.eval()