Merge pull request #188 from shippingwang/refine_save
add print_interval and refine overridepull/192/head^2
commit
856628d4f1
|
@ -144,9 +144,14 @@ def override(dl, ks, v):
|
|||
override(dl[k], ks[1:], v)
|
||||
else:
|
||||
if len(ks) == 1:
|
||||
assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
|
||||
#assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
|
||||
if not ks[0] in dl:
|
||||
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
|
||||
dl[ks[0]] = str2num(v)
|
||||
else:
|
||||
assert ks[0] in dl, (
|
||||
'({}) doesn\'t exist in {}, a new dict field is invalid'.
|
||||
format(ks[0], dl))
|
||||
override(dl[ks[0]], ks[1:], v)
|
||||
|
||||
|
||||
|
|
|
@ -74,7 +74,7 @@ def main(args):
|
|||
|
||||
compiled_valid_prog = program.compile(config, valid_prog)
|
||||
program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, -1,
|
||||
'eval')
|
||||
'eval', config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -410,6 +410,7 @@ def run(dataloader,
|
|||
fetchs,
|
||||
epoch=0,
|
||||
mode='train',
|
||||
config=None,
|
||||
vdl_writer=None):
|
||||
"""
|
||||
Feed data to the model and fetch the measures and loss
|
||||
|
@ -443,16 +444,28 @@ def run(dataloader,
|
|||
logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
|
||||
total_step += 1
|
||||
if mode == 'eval':
|
||||
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
|
||||
if idx % config.get('print_interval', 10) == 0:
|
||||
logger.info("{:s} step:{:<4d} {:s}".format(mode, idx,
|
||||
fetchs_str))
|
||||
else:
|
||||
epoch_str = "epoch:{:<3d}".format(epoch)
|
||||
step_str = "{:s} step:{:<4d}".format(mode, idx)
|
||||
|
||||
logger.info("{:s} {:s} {:s}".format(
|
||||
logger.coloring(epoch_str, "HEADER")
|
||||
if idx == 0 else epoch_str,
|
||||
logger.coloring(step_str, "PURPLE"),
|
||||
logger.coloring(fetchs_str, 'OKGREEN')))
|
||||
# Keep the first 10 batches statistics, They are important for develop
|
||||
if epoch == 0 and idx < 10:
|
||||
logger.info("{:s} {:s} {:s}".format(
|
||||
logger.coloring(epoch_str, "HEADER")
|
||||
if idx == 0 else epoch_str,
|
||||
logger.coloring(step_str, "PURPLE"),
|
||||
logger.coloring(fetchs_str, 'OKGREEN')))
|
||||
|
||||
else:
|
||||
if idx % config.get('print_interval', 10) == 0:
|
||||
logger.info("{:s} {:s} {:s}".format(
|
||||
logger.coloring(epoch_str, "HEADER")
|
||||
if idx == 0 else epoch_str,
|
||||
logger.coloring(step_str, "PURPLE"),
|
||||
logger.coloring(fetchs_str, 'OKGREEN')))
|
||||
|
||||
end_str = ''.join([str(m.mean) + ' '
|
||||
for m in metric_list] + [batch_time.total]) + 's'
|
||||
|
|
|
@ -5,4 +5,5 @@ export PYTHONPATH=$PWD:$PYTHONPATH
|
|||
python -m paddle.distributed.launch \
|
||||
--selected_gpus="0,1,2,3" \
|
||||
tools/train.py \
|
||||
-c ./configs/ResNet/ResNet50.yaml
|
||||
-c ./configs/ResNet/ResNet50.yaml \
|
||||
-o print_interval=10
|
||||
|
|
|
@ -110,21 +110,21 @@ def main(args):
|
|||
for epoch_id in range(config.epochs):
|
||||
# 1. train with train dataset
|
||||
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
|
||||
epoch_id, 'train', vdl_writer)
|
||||
epoch_id, 'train', config, vdl_writer)
|
||||
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
|
||||
# 2. validate with validate dataset
|
||||
if config.validate and epoch_id % config.valid_interval == 0:
|
||||
if config.get('use_ema'):
|
||||
logger.info(logger.coloring("EMA validate start..."))
|
||||
with ema.apply(exe):
|
||||
top1_acc = program.run(valid_dataloader, exe,
|
||||
compiled_valid_prog,
|
||||
valid_fetchs, epoch_id, 'valid')
|
||||
top1_acc = program.run(
|
||||
valid_dataloader, exe, compiled_valid_prog,
|
||||
valid_fetchs, epoch_id, 'valid', config)
|
||||
logger.info(logger.coloring("EMA validate over!"))
|
||||
|
||||
top1_acc = program.run(valid_dataloader, exe,
|
||||
compiled_valid_prog, valid_fetchs,
|
||||
epoch_id, 'valid')
|
||||
epoch_id, 'valid', config)
|
||||
if top1_acc > best_top1_acc:
|
||||
best_top1_acc = top1_acc
|
||||
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
|
||||
|
|
Loading…
Reference in New Issue