mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
Merge pull request #223 from littletomatodonkey/fix_single_card_dyg
fix single card dygraph train process
This commit is contained in:
commit
0e7bea5183
@ -329,9 +329,13 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'):
|
||||
feeds = create_feeds(batch, use_mix)
|
||||
fetchs = create_fetchs(feeds, net, config, mode)
|
||||
if mode == 'train':
|
||||
avg_loss = net.scale_loss(fetchs['loss'])
|
||||
avg_loss.backward()
|
||||
net.apply_collective_grads()
|
||||
if config["use_data_parallel"]:
|
||||
avg_loss = net.scale_loss(fetchs['loss'])
|
||||
avg_loss.backward()
|
||||
net.apply_collective_grads()
|
||||
else:
|
||||
avg_loss = fetchs['loss']
|
||||
avg_loss.backward()
|
||||
|
||||
optimizer.minimize(avg_loss)
|
||||
net.clear_gradients()
|
||||
|
@ -52,10 +52,14 @@ def main(args):
|
||||
gpu_id = fluid.dygraph.parallel.Env().dev_id
|
||||
place = fluid.CUDAPlace(gpu_id)
|
||||
|
||||
use_data_parallel = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) != 1
|
||||
config["use_data_parallel"] = use_data_parallel
|
||||
|
||||
with fluid.dygraph.guard(place):
|
||||
strategy = fluid.dygraph.parallel.prepare_context()
|
||||
net = program.create_model(config.ARCHITECTURE, config.classes_num)
|
||||
net = fluid.dygraph.parallel.DataParallel(net, strategy)
|
||||
if config["use_data_parallel"]:
|
||||
strategy = fluid.dygraph.parallel.prepare_context()
|
||||
net = fluid.dygraph.parallel.DataParallel(net, strategy)
|
||||
|
||||
optimizer = program.create_optimizer(
|
||||
config, parameter_list=net.parameters())
|
||||
@ -79,7 +83,8 @@ def main(args):
|
||||
program.run(train_dataloader, config, net, optimizer, epoch_id,
|
||||
'train')
|
||||
|
||||
if fluid.dygraph.parallel.Env().local_rank == 0:
|
||||
if not config["use_data_parallel"] or fluid.dygraph.parallel.Env(
|
||||
).local_rank == 0:
|
||||
# 2. validate with validate dataset
|
||||
if config.validate and epoch_id % config.valid_interval == 0:
|
||||
net.eval()
|
||||
@ -108,4 +113,4 @@ def main(args):
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main(args)
|
||||
main(args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user