From 48afe86f6829d30baaf502a33760b295436d84b3 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Sat, 29 Aug 2020 09:44:30 +0000 Subject: [PATCH] fix dygpaph training speed --- tools/program.py | 3 ++- tools/train.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tools/program.py b/tools/program.py index aaecbe0e3..05a14ec97 100644 --- a/tools/program.py +++ b/tools/program.py @@ -21,6 +21,7 @@ import time from collections import OrderedDict +import paddle import paddle.fluid as fluid from ppcls.optimizer import LearningRateBuilder @@ -280,7 +281,7 @@ def mixed_precision_optimizer(config, optimizer): def create_feeds(batch, use_mix): - image = to_variable(batch[0].numpy().astype("float32")) + image = batch[0] if use_mix: y_a = to_variable(batch[1].numpy().astype("int64").reshape(-1, 1)) y_b = to_variable(batch[2].numpy().astype("int64").reshape(-1, 1)) diff --git a/tools/train.py b/tools/train.py index 7d919ba6b..afb783548 100644 --- a/tools/train.py +++ b/tools/train.py @@ -57,13 +57,14 @@ def main(args): with fluid.dygraph.guard(place): net = program.create_model(config.ARCHITECTURE, config.classes_num) - if config["use_data_parallel"]: - strategy = fluid.dygraph.parallel.prepare_context() - net = fluid.dygraph.parallel.DataParallel(net, strategy) optimizer = program.create_optimizer( config, parameter_list=net.parameters()) + if config["use_data_parallel"]: + strategy = fluid.dygraph.parallel.prepare_context() + net = fluid.dygraph.parallel.DataParallel(net, strategy) + # load model from checkpoint or pretrained model init_model(config, net, optimizer)