diff --git a/tools/train.py b/tools/train.py index 0da14bf0d..4c10136c1 100644 --- a/tools/train.py +++ b/tools/train.py @@ -70,7 +70,9 @@ def main(args): config, parameter_list=net.parameters()) if config["use_data_parallel"]: - net = paddle.DataParallel(net) + find_unused_parameters = config.get("find_unused_parameters", False) + net = paddle.DataParallel( + net, find_unused_parameters=find_unused_parameters) # load model from checkpoint or pretrained model init_model(config, net, optimizer)