fix static train (#478)
parent
29b305d228
commit
8fd56a4503
|
@ -86,16 +86,21 @@ def create_model(architecture, image, classes_num, config, is_train):
|
||||||
use_pure_fp16 = config.get("use_pure_fp16", False)
|
use_pure_fp16 = config.get("use_pure_fp16", False)
|
||||||
name = architecture["name"]
|
name = architecture["name"]
|
||||||
params = architecture.get("params", {})
|
params = architecture.get("params", {})
|
||||||
data_format = config.get("data_format", "NCHW")
|
|
||||||
|
data_format = "NCHW"
|
||||||
|
if "data_format" in config:
|
||||||
|
params["data_format"] = config["data_format"]
|
||||||
|
data_format = config["data_format"]
|
||||||
input_image_channel = config.get('image_shape', [3, 224, 224])[0]
|
input_image_channel = config.get('image_shape', [3, 224, 224])[0]
|
||||||
|
if input_image_channel != 3:
|
||||||
|
logger.warning(
|
||||||
|
"Input image channel is changed to {}, maybe for better speed-up".
|
||||||
|
format(input_image_channel))
|
||||||
|
params["input_image_channel"] = input_image_channel
|
||||||
if "is_test" in params:
|
if "is_test" in params:
|
||||||
params['is_test'] = not is_train
|
params['is_test'] = not is_train
|
||||||
model = architectures.__dict__[name](
|
model = architectures.__dict__[name](class_dim=classes_num, **params)
|
||||||
class_dim=classes_num,
|
|
||||||
input_image_channel=input_image_channel,
|
|
||||||
data_format=data_format,
|
|
||||||
**params)
|
|
||||||
|
|
||||||
if use_pure_fp16 and not config.get("use_dali", False):
|
if use_pure_fp16 and not config.get("use_dali", False):
|
||||||
image = image.astype('float16')
|
image = image.astype('float16')
|
||||||
if data_format == "NHWC":
|
if data_format == "NHWC":
|
||||||
|
@ -352,7 +357,10 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
|
||||||
and config.get("use_dali", False):
|
and config.get("use_dali", False):
|
||||||
image_dtype = "float16"
|
image_dtype = "float16"
|
||||||
feeds = create_feeds(
|
feeds = create_feeds(
|
||||||
config.image_shape, use_mix=use_mix, use_dali=use_dali, dtype = image_dtype)
|
config.image_shape,
|
||||||
|
use_mix=use_mix,
|
||||||
|
use_dali=use_dali,
|
||||||
|
dtype=image_dtype)
|
||||||
if use_dali and use_mix:
|
if use_dali and use_mix:
|
||||||
import dali
|
import dali
|
||||||
feeds = dali.mix(feeds, config, is_train)
|
feeds = dali.mix(feeds, config, is_train)
|
||||||
|
@ -395,9 +403,11 @@ def compile(config, program, loss_name=None, share_prog=None):
|
||||||
exec_strategy = paddle.static.ExecutionStrategy()
|
exec_strategy = paddle.static.ExecutionStrategy()
|
||||||
|
|
||||||
exec_strategy.num_threads = 1
|
exec_strategy.num_threads = 1
|
||||||
exec_strategy.num_iteration_per_drop_scope = 10000 if config.get('use_pure_fp16', False) else 10
|
exec_strategy.num_iteration_per_drop_scope = 10000 if config.get(
|
||||||
|
'use_pure_fp16', False) else 10
|
||||||
|
|
||||||
fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16', False)
|
fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16',
|
||||||
|
False)
|
||||||
fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
|
fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
|
||||||
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
|
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
|
||||||
fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
|
fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
|
||||||
|
|
|
@ -65,10 +65,7 @@ def main(args):
|
||||||
if config.get("is_distributed", True):
|
if config.get("is_distributed", True):
|
||||||
fleet.init(is_collective=True)
|
fleet.init(is_collective=True)
|
||||||
# assign the place
|
# assign the place
|
||||||
use_gpu = config.get("use_gpu", False)
|
use_gpu = config.get("use_gpu", True)
|
||||||
assert use_gpu is True, "gpu must be true in static mode!"
|
|
||||||
place = paddle.set_device("gpu")
|
|
||||||
|
|
||||||
# amp related config
|
# amp related config
|
||||||
use_amp = config.get('use_amp', False)
|
use_amp = config.get('use_amp', False)
|
||||||
use_pure_fp16 = config.get('use_pure_fp16', False)
|
use_pure_fp16 = config.get('use_pure_fp16', False)
|
||||||
|
@ -122,7 +119,7 @@ def main(args):
|
||||||
exe = paddle.static.Executor(place)
|
exe = paddle.static.Executor(place)
|
||||||
# Parameter initialization
|
# Parameter initialization
|
||||||
exe.run(startup_prog)
|
exe.run(startup_prog)
|
||||||
if config.get("use_pure_fp16", False):
|
if config.get("use_pure_fp16", False):
|
||||||
cast_parameters_to_fp16(place, train_prog, fluid.global_scope())
|
cast_parameters_to_fp16(place, train_prog, fluid.global_scope())
|
||||||
# load pretrained models or checkpoints
|
# load pretrained models or checkpoints
|
||||||
init_model(config, train_prog, exe)
|
init_model(config, train_prog, exe)
|
||||||
|
|
Loading…
Reference in New Issue