commit
5736d85b96
|
@ -206,6 +206,8 @@ def mp_reader(params):
|
||||||
check_params(params)
|
check_params(params)
|
||||||
|
|
||||||
full_lines = get_file_list(params)
|
full_lines = get_file_list(params)
|
||||||
|
if params["mode"] == "train":
|
||||||
|
full_lines = shuffle_lines(full_lines, seed=None)
|
||||||
|
|
||||||
part_num = 1 if 'num_workers' not in params else params['num_workers']
|
part_num = 1 if 'num_workers' not in params else params['num_workers']
|
||||||
|
|
||||||
|
@ -254,11 +256,10 @@ class Reader:
|
||||||
self.batch_ops = create_operators(self.params['mix'])
|
self.batch_ops = create_operators(self.params['mix'])
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
reader = mp_reader(self.params)
|
|
||||||
|
|
||||||
batch_size = int(self.params['batch_size']) // trainers_num
|
batch_size = int(self.params['batch_size']) // trainers_num
|
||||||
|
|
||||||
def wrapper():
|
def wrapper():
|
||||||
|
reader = mp_reader(self.params)
|
||||||
batch = []
|
batch = []
|
||||||
for idx, sample in enumerate(reader()):
|
for idx, sample in enumerate(reader()):
|
||||||
img, label = sample
|
img, label = sample
|
||||||
|
|
|
@ -106,22 +106,20 @@ def load_params(exe, prog, path, ignore_params=[]):
|
||||||
fluid.io.set_program_state(prog, state)
|
fluid.io.set_program_state(prog, state)
|
||||||
|
|
||||||
|
|
||||||
def init_model(config, program, exe, prefix=""):
|
def init_model(config, program, exe):
|
||||||
"""
|
"""
|
||||||
load model from checkpoint or pretrained_model
|
load model from checkpoint or pretrained_model
|
||||||
"""
|
"""
|
||||||
checkpoints = config.get('checkpoints')
|
checkpoints = config.get('checkpoints')
|
||||||
if checkpoints:
|
if checkpoints:
|
||||||
path = os.path.join(checkpoints, prefix)
|
fluid.load(program, checkpoints, exe)
|
||||||
fluid.load(program, path, exe)
|
logger.info("Finish initing model from {}".format(checkpoints))
|
||||||
logger.info("Finish initing model from {}".format(path))
|
|
||||||
return
|
return
|
||||||
|
|
||||||
pretrained_model = config.get('pretrained_model')
|
pretrained_model = config.get('pretrained_model')
|
||||||
if pretrained_model:
|
if pretrained_model:
|
||||||
path = os.path.join(pretrained_model, prefix)
|
load_params(exe, program, pretrained_model)
|
||||||
load_params(exe, program, path)
|
logger.info("Finish initing model from {}".format(pretrained_model))
|
||||||
logger.info("Finish initing model from {}".format(path))
|
|
||||||
|
|
||||||
|
|
||||||
def save_model(program, model_path, epoch_id, prefix='ppcls'):
|
def save_model(program, model_path, epoch_id, prefix='ppcls'):
|
||||||
|
|
Loading…
Reference in New Issue