fix static train (#478)
parent
29b305d228
commit
8fd56a4503
tools/static
|
@ -86,16 +86,21 @@ def create_model(architecture, image, classes_num, config, is_train):
|
|||
use_pure_fp16 = config.get("use_pure_fp16", False)
|
||||
name = architecture["name"]
|
||||
params = architecture.get("params", {})
|
||||
data_format = config.get("data_format", "NCHW")
|
||||
|
||||
data_format = "NCHW"
|
||||
if "data_format" in config:
|
||||
params["data_format"] = config["data_format"]
|
||||
data_format = config["data_format"]
|
||||
input_image_channel = config.get('image_shape', [3, 224, 224])[0]
|
||||
if input_image_channel != 3:
|
||||
logger.warning(
|
||||
"Input image channel is changed to {}, maybe for better speed-up".
|
||||
format(input_image_channel))
|
||||
params["input_image_channel"] = input_image_channel
|
||||
if "is_test" in params:
|
||||
params['is_test'] = not is_train
|
||||
model = architectures.__dict__[name](
|
||||
class_dim=classes_num,
|
||||
input_image_channel=input_image_channel,
|
||||
data_format=data_format,
|
||||
**params)
|
||||
|
||||
model = architectures.__dict__[name](class_dim=classes_num, **params)
|
||||
|
||||
if use_pure_fp16 and not config.get("use_dali", False):
|
||||
image = image.astype('float16')
|
||||
if data_format == "NHWC":
|
||||
|
@ -352,7 +357,10 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
|
|||
and config.get("use_dali", False):
|
||||
image_dtype = "float16"
|
||||
feeds = create_feeds(
|
||||
config.image_shape, use_mix=use_mix, use_dali=use_dali, dtype = image_dtype)
|
||||
config.image_shape,
|
||||
use_mix=use_mix,
|
||||
use_dali=use_dali,
|
||||
dtype=image_dtype)
|
||||
if use_dali and use_mix:
|
||||
import dali
|
||||
feeds = dali.mix(feeds, config, is_train)
|
||||
|
@ -395,9 +403,11 @@ def compile(config, program, loss_name=None, share_prog=None):
|
|||
exec_strategy = paddle.static.ExecutionStrategy()
|
||||
|
||||
exec_strategy.num_threads = 1
|
||||
exec_strategy.num_iteration_per_drop_scope = 10000 if config.get('use_pure_fp16', False) else 10
|
||||
exec_strategy.num_iteration_per_drop_scope = 10000 if config.get(
|
||||
'use_pure_fp16', False) else 10
|
||||
|
||||
fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16', False)
|
||||
fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16',
|
||||
False)
|
||||
fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
|
||||
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
|
||||
fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
|
||||
|
|
|
@ -65,10 +65,7 @@ def main(args):
|
|||
if config.get("is_distributed", True):
|
||||
fleet.init(is_collective=True)
|
||||
# assign the place
|
||||
use_gpu = config.get("use_gpu", False)
|
||||
assert use_gpu is True, "gpu must be true in static mode!"
|
||||
place = paddle.set_device("gpu")
|
||||
|
||||
use_gpu = config.get("use_gpu", True)
|
||||
# amp related config
|
||||
use_amp = config.get('use_amp', False)
|
||||
use_pure_fp16 = config.get('use_pure_fp16', False)
|
||||
|
@ -122,7 +119,7 @@ def main(args):
|
|||
exe = paddle.static.Executor(place)
|
||||
# Parameter initialization
|
||||
exe.run(startup_prog)
|
||||
if config.get("use_pure_fp16", False):
|
||||
if config.get("use_pure_fp16", False):
|
||||
cast_parameters_to_fp16(place, train_prog, fluid.global_scope())
|
||||
# load pretrained models or checkpoints
|
||||
init_model(config, train_prog, exe)
|
||||
|
|
Loading…
Reference in New Issue