commit
5736d85b96
|
@ -206,6 +206,8 @@ def mp_reader(params):
|
|||
check_params(params)
|
||||
|
||||
full_lines = get_file_list(params)
|
||||
if params["mode"] == "train":
|
||||
full_lines = shuffle_lines(full_lines, seed=None)
|
||||
|
||||
part_num = 1 if 'num_workers' not in params else params['num_workers']
|
||||
|
||||
|
@ -254,11 +256,10 @@ class Reader:
|
|||
self.batch_ops = create_operators(self.params['mix'])
|
||||
|
||||
def __call__(self):
|
||||
reader = mp_reader(self.params)
|
||||
|
||||
batch_size = int(self.params['batch_size']) // trainers_num
|
||||
|
||||
def wrapper():
|
||||
reader = mp_reader(self.params)
|
||||
batch = []
|
||||
for idx, sample in enumerate(reader()):
|
||||
img, label = sample
|
||||
|
|
|
@ -106,22 +106,20 @@ def load_params(exe, prog, path, ignore_params=[]):
|
|||
fluid.io.set_program_state(prog, state)
|
||||
|
||||
|
||||
def init_model(config, program, exe, prefix=""):
|
||||
def init_model(config, program, exe):
|
||||
"""
|
||||
load model from checkpoint or pretrained_model
|
||||
"""
|
||||
checkpoints = config.get('checkpoints')
|
||||
if checkpoints:
|
||||
path = os.path.join(checkpoints, prefix)
|
||||
fluid.load(program, path, exe)
|
||||
logger.info("Finish initing model from {}".format(path))
|
||||
fluid.load(program, checkpoints, exe)
|
||||
logger.info("Finish initing model from {}".format(checkpoints))
|
||||
return
|
||||
|
||||
pretrained_model = config.get('pretrained_model')
|
||||
if pretrained_model:
|
||||
path = os.path.join(pretrained_model, prefix)
|
||||
load_params(exe, program, path)
|
||||
logger.info("Finish initing model from {}".format(path))
|
||||
load_params(exe, program, pretrained_model)
|
||||
logger.info("Finish initing model from {}".format(pretrained_model))
|
||||
|
||||
|
||||
def save_model(program, model_path, epoch_id, prefix='ppcls'):
|
||||
|
|
Loading…
Reference in New Issue