fix dist eval (#364)
parent
081ca857ce
commit
82e7a90bfe
|
@ -59,9 +59,14 @@ def main(args, return_dict={}):
|
||||||
|
|
||||||
paddle.disable_static(place)
|
paddle.disable_static(place)
|
||||||
|
|
||||||
strategy = paddle.distributed.init_parallel_env()
|
use_data_parallel = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) != 1
|
||||||
|
config["use_data_parallel"] = use_data_parallel
|
||||||
|
|
||||||
net = program.create_model(config.ARCHITECTURE, config.classes_num)
|
net = program.create_model(config.ARCHITECTURE, config.classes_num)
|
||||||
net = paddle.DataParallel(net, strategy)
|
if config["use_data_parallel"]:
|
||||||
|
strategy = paddle.distributed.init_parallel_env()
|
||||||
|
net = paddle.DataParallel(net, strategy)
|
||||||
|
|
||||||
init_model(config, net, optimizer=None)
|
init_model(config, net, optimizer=None)
|
||||||
valid_dataloader = Reader(config, 'valid', places=place)()
|
valid_dataloader = Reader(config, 'valid', places=place)()
|
||||||
net.eval()
|
net.eval()
|
||||||
|
|
Loading…
Reference in New Issue