add shared prog
parent
0a0d5bc060
commit
1dee0622cb
|
@ -373,7 +373,7 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
|
|||
return dataloader, fetchs
|
||||
|
||||
|
||||
def compile(config, program, loss_name=None):
|
||||
def compile(config, program, loss_name=None, share_prog=None):
|
||||
"""
|
||||
Compile the program
|
||||
|
||||
|
@ -381,6 +381,7 @@ def compile(config, program, loss_name=None):
|
|||
config(dict): config
|
||||
program(): the program which is wrapped by
|
||||
loss_name(str): loss name
|
||||
share_prog(): the shared program, used for evaluation during training
|
||||
|
||||
Returns:
|
||||
compiled_program(): a compiled program
|
||||
|
@ -392,6 +393,7 @@ def compile(config, program, loss_name=None):
|
|||
exec_strategy.num_iteration_per_drop_scope = 10
|
||||
|
||||
compiled_program = fluid.CompiledProgram(program).with_data_parallel(
|
||||
share_vars_from=share_prog,
|
||||
loss_name=loss_name,
|
||||
build_strategy=build_strategy,
|
||||
exec_strategy=exec_strategy)
|
||||
|
|
|
@ -101,13 +101,14 @@ def main(args):
|
|||
train_reader = Reader(config, 'train')()
|
||||
train_dataloader.set_sample_list_generator(train_reader, places)
|
||||
|
||||
compiled_train_prog = program.compile(config, train_prog,
|
||||
train_fetchs['loss'][0].name)
|
||||
|
||||
if config.validate:
|
||||
valid_reader = Reader(config, 'valid')()
|
||||
valid_dataloader.set_sample_list_generator(valid_reader, places)
|
||||
compiled_valid_prog = program.compile(config, valid_prog)
|
||||
|
||||
compiled_train_prog = program.compile(config, train_prog,
|
||||
train_fetchs['loss'][0].name)
|
||||
compiled_valid_prog = program.compile(
|
||||
config, valid_prog, share_prog=compiled_train_prog)
|
||||
|
||||
if args.vdl_dir:
|
||||
from visualdl import LogWriter
|
||||
|
|
Loading…
Reference in New Issue