add print_interval and refine override
parent
9d3f36b75e
commit
d3bad33f45
|
@ -144,9 +144,14 @@ def override(dl, ks, v):
|
||||||
override(dl[k], ks[1:], v)
|
override(dl[k], ks[1:], v)
|
||||||
else:
|
else:
|
||||||
if len(ks) == 1:
|
if len(ks) == 1:
|
||||||
assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
|
#assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
|
||||||
|
if not ks[0] in dl:
|
||||||
|
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
|
||||||
dl[ks[0]] = str2num(v)
|
dl[ks[0]] = str2num(v)
|
||||||
else:
|
else:
|
||||||
|
assert ks[0] in dl, (
|
||||||
|
'({}) doesn\'t exist in {}, a new dict field is invalid'.
|
||||||
|
format(ks[0], dl))
|
||||||
override(dl[ks[0]], ks[1:], v)
|
override(dl[ks[0]], ks[1:], v)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -74,7 +74,7 @@ def main(args):
|
||||||
|
|
||||||
compiled_valid_prog = program.compile(config, valid_prog)
|
compiled_valid_prog = program.compile(config, valid_prog)
|
||||||
program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, -1,
|
program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, -1,
|
||||||
'eval')
|
'eval', config)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -410,6 +410,7 @@ def run(dataloader,
|
||||||
fetchs,
|
fetchs,
|
||||||
epoch=0,
|
epoch=0,
|
||||||
mode='train',
|
mode='train',
|
||||||
|
config=None,
|
||||||
vdl_writer=None):
|
vdl_writer=None):
|
||||||
"""
|
"""
|
||||||
Feed data to the model and fetch the measures and loss
|
Feed data to the model and fetch the measures and loss
|
||||||
|
@ -443,11 +444,23 @@ def run(dataloader,
|
||||||
logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
|
logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
|
||||||
total_step += 1
|
total_step += 1
|
||||||
if mode == 'eval':
|
if mode == 'eval':
|
||||||
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
|
if idx % config.get('print_interval', 1) == 0:
|
||||||
|
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx,
|
||||||
|
fetchs_str))
|
||||||
else:
|
else:
|
||||||
epoch_str = "epoch:{:<3d}".format(epoch)
|
epoch_str = "epoch:{:<3d}".format(epoch)
|
||||||
step_str = "{:s} step:{:<4d}".format(mode, idx)
|
step_str = "{:s} step:{:<4d}".format(mode, idx)
|
||||||
|
|
||||||
|
# Keep the first 10 batches statistics, They are important for develop
|
||||||
|
if epoch == 0 and idx < 10:
|
||||||
|
logger.info("{:s} {:s} {:s}".format(
|
||||||
|
logger.coloring(epoch_str, "HEADER")
|
||||||
|
if idx == 0 else epoch_str,
|
||||||
|
logger.coloring(step_str, "PURPLE"),
|
||||||
|
logger.coloring(fetchs_str, 'OKGREEN')))
|
||||||
|
|
||||||
|
else:
|
||||||
|
if idx % config.get('print_interval', 1) == 0:
|
||||||
logger.info("{:s} {:s} {:s}".format(
|
logger.info("{:s} {:s} {:s}".format(
|
||||||
logger.coloring(epoch_str, "HEADER")
|
logger.coloring(epoch_str, "HEADER")
|
||||||
if idx == 0 else epoch_str,
|
if idx == 0 else epoch_str,
|
||||||
|
|
|
@ -5,4 +5,5 @@ export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
python -m paddle.distributed.launch \
|
python -m paddle.distributed.launch \
|
||||||
--selected_gpus="0,1,2,3" \
|
--selected_gpus="0,1,2,3" \
|
||||||
tools/train.py \
|
tools/train.py \
|
||||||
-c ./configs/ResNet/ResNet50.yaml
|
-c ./configs/ResNet/ResNet50.yaml \
|
||||||
|
-o print_interval=10
|
||||||
|
|
|
@ -110,21 +110,21 @@ def main(args):
|
||||||
for epoch_id in range(config.epochs):
|
for epoch_id in range(config.epochs):
|
||||||
# 1. train with train dataset
|
# 1. train with train dataset
|
||||||
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
|
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
|
||||||
epoch_id, 'train', vdl_writer)
|
epoch_id, 'train', config, vdl_writer)
|
||||||
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
|
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
|
||||||
# 2. validate with validate dataset
|
# 2. validate with validate dataset
|
||||||
if config.validate and epoch_id % config.valid_interval == 0:
|
if config.validate and epoch_id % config.valid_interval == 0:
|
||||||
if config.get('use_ema'):
|
if config.get('use_ema'):
|
||||||
logger.info(logger.coloring("EMA validate start..."))
|
logger.info(logger.coloring("EMA validate start..."))
|
||||||
with ema.apply(exe):
|
with ema.apply(exe):
|
||||||
top1_acc = program.run(valid_dataloader, exe,
|
top1_acc = program.run(
|
||||||
compiled_valid_prog,
|
valid_dataloader, exe, compiled_valid_prog,
|
||||||
valid_fetchs, epoch_id, 'valid')
|
valid_fetchs, epoch_id, 'valid', config)
|
||||||
logger.info(logger.coloring("EMA validate over!"))
|
logger.info(logger.coloring("EMA validate over!"))
|
||||||
|
|
||||||
top1_acc = program.run(valid_dataloader, exe,
|
top1_acc = program.run(valid_dataloader, exe,
|
||||||
compiled_valid_prog, valid_fetchs,
|
compiled_valid_prog, valid_fetchs,
|
||||||
epoch_id, 'valid')
|
epoch_id, 'valid', config)
|
||||||
if top1_acc > best_top1_acc:
|
if top1_acc > best_top1_acc:
|
||||||
best_top1_acc = top1_acc
|
best_top1_acc = top1_acc
|
||||||
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
|
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
|
||||||
|
|
Loading…
Reference in New Issue