fix dist eval (#364)
parent
081ca857ce
commit
82e7a90bfe
|
@ -59,9 +59,14 @@ def main(args, return_dict={}):
|
|||
|
||||
paddle.disable_static(place)
|
||||
|
||||
strategy = paddle.distributed.init_parallel_env()
|
||||
use_data_parallel = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) != 1
|
||||
config["use_data_parallel"] = use_data_parallel
|
||||
|
||||
net = program.create_model(config.ARCHITECTURE, config.classes_num)
|
||||
if config["use_data_parallel"]:
|
||||
strategy = paddle.distributed.init_parallel_env()
|
||||
net = paddle.DataParallel(net, strategy)
|
||||
|
||||
init_model(config, net, optimizer=None)
|
||||
valid_dataloader = Reader(config, 'valid', places=place)()
|
||||
net.eval()
|
||||
|
|
Loading…
Reference in New Issue