fix dygpaph training speed
parent
6cae5aafa1
commit
48afe86f68
|
@ -21,6 +21,7 @@ import time
|
|||
|
||||
from collections import OrderedDict
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
|
||||
from ppcls.optimizer import LearningRateBuilder
|
||||
|
@ -280,7 +281,7 @@ def mixed_precision_optimizer(config, optimizer):
|
|||
|
||||
|
||||
def create_feeds(batch, use_mix):
|
||||
image = to_variable(batch[0].numpy().astype("float32"))
|
||||
image = batch[0]
|
||||
if use_mix:
|
||||
y_a = to_variable(batch[1].numpy().astype("int64").reshape(-1, 1))
|
||||
y_b = to_variable(batch[2].numpy().astype("int64").reshape(-1, 1))
|
||||
|
|
|
@ -57,13 +57,14 @@ def main(args):
|
|||
|
||||
with fluid.dygraph.guard(place):
|
||||
net = program.create_model(config.ARCHITECTURE, config.classes_num)
|
||||
if config["use_data_parallel"]:
|
||||
strategy = fluid.dygraph.parallel.prepare_context()
|
||||
net = fluid.dygraph.parallel.DataParallel(net, strategy)
|
||||
|
||||
optimizer = program.create_optimizer(
|
||||
config, parameter_list=net.parameters())
|
||||
|
||||
if config["use_data_parallel"]:
|
||||
strategy = fluid.dygraph.parallel.prepare_context()
|
||||
net = fluid.dygraph.parallel.DataParallel(net, strategy)
|
||||
|
||||
# load model from checkpoint or pretrained model
|
||||
init_model(config, net, optimizer)
|
||||
|
||||
|
|
Loading…
Reference in New Issue