Merge pull request #45 from littletomatodonkey/fix_shuffle

fix reader shuffle
pull/46/head
dyning 2020-04-19 12:56:41 +08:00 committed by GitHub
commit 5736d85b96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 9 deletions

View File

@ -206,6 +206,8 @@ def mp_reader(params):
check_params(params)
full_lines = get_file_list(params)
if params["mode"] == "train":
full_lines = shuffle_lines(full_lines, seed=None)
part_num = 1 if 'num_workers' not in params else params['num_workers']
@ -254,11 +256,10 @@ class Reader:
self.batch_ops = create_operators(self.params['mix'])
def __call__(self):
reader = mp_reader(self.params)
batch_size = int(self.params['batch_size']) // trainers_num
def wrapper():
reader = mp_reader(self.params)
batch = []
for idx, sample in enumerate(reader()):
img, label = sample

View File

@ -106,22 +106,20 @@ def load_params(exe, prog, path, ignore_params=[]):
fluid.io.set_program_state(prog, state)
def init_model(config, program, exe, prefix=""):
def init_model(config, program, exe):
"""
load model from checkpoint or pretrained_model
"""
checkpoints = config.get('checkpoints')
if checkpoints:
path = os.path.join(checkpoints, prefix)
fluid.load(program, path, exe)
logger.info("Finish initing model from {}".format(path))
fluid.load(program, checkpoints, exe)
logger.info("Finish initing model from {}".format(checkpoints))
return
pretrained_model = config.get('pretrained_model')
if pretrained_model:
path = os.path.join(pretrained_model, prefix)
load_params(exe, program, path)
logger.info("Finish initing model from {}".format(path))
load_params(exe, program, pretrained_model)
logger.info("Finish initing model from {}".format(pretrained_model))
def save_model(program, model_path, epoch_id, prefix='ppcls'):