save best model
parent
0779613671
commit
2cf8014d8d
|
@ -383,17 +383,21 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
|
|||
tic = time.time()
|
||||
for i, m in enumerate(metrics):
|
||||
metric_list[i].update(m[0], len(batch[0]))
|
||||
fetchs_str = ''.join([str(m.value)+' '
|
||||
for m in metric_list]+ [batch_time.value])
|
||||
fetchs_str = ''.join([str(m.value) + ' '
|
||||
for m in metric_list] + [batch_time.value])
|
||||
if epoch != -1:
|
||||
logger.info("epoch:{:<3d} {:s} step:{:<4d} {:s}s".format(
|
||||
epoch, mode, idx, fetchs_str))
|
||||
epoch, mode, idx, fetchs_str))
|
||||
else:
|
||||
logger.info("{:s} step:{:<4d} {:s}s".format(
|
||||
mode, idx, fetchs_str))
|
||||
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
|
||||
|
||||
end_str = ''.join([str(m.mean)+' ' for m in metric_list] + [batch_time.total])
|
||||
if epoch!= -1:
|
||||
end_str = ''.join([str(m.mean) + ' '
|
||||
for m in metric_list] + [batch_time.total])
|
||||
if epoch != -1:
|
||||
logger.info("END epoch:{:<3d} {:s} {:s}s".format(epoch, mode, end_str))
|
||||
else:
|
||||
logger.info("END {:s} {:s}s".format(mode, end_str))
|
||||
|
||||
# save the best model
|
||||
top1_acc = fetchs["top1"][1].avg
|
||||
return top1_acc
|
||||
|
|
|
@ -26,6 +26,7 @@ from paddle.fluid.incubate.fleet.collective import fleet
|
|||
from ppcls.data import Reader
|
||||
from ppcls.utils.config import get_config
|
||||
from ppcls.utils.save_load import init_model, save_model
|
||||
from ppcls.utils import logger
|
||||
import program
|
||||
|
||||
|
||||
|
@ -61,6 +62,10 @@ def main(args):
|
|||
startup_prog = fluid.Program()
|
||||
train_prog = fluid.Program()
|
||||
|
||||
# best_top1_acc_list[0]: top1 acc
|
||||
# best_top1_acc_list[1]: epoch id
|
||||
best_top1_acc_list = [0.0, 0]
|
||||
|
||||
train_dataloader, train_fetchs = program.build(
|
||||
config, train_prog, startup_prog, is_train=True)
|
||||
|
||||
|
@ -94,8 +99,16 @@ def main(args):
|
|||
epoch_id, 'train')
|
||||
# 2. validate with validate dataset
|
||||
if config.validate and epoch_id % config.valid_interval == 0:
|
||||
program.run(valid_dataloader, exe, compiled_valid_prog,
|
||||
valid_fetchs, epoch_id, 'valid')
|
||||
top1_acc = program.run(valid_dataloader, exe, compiled_valid_prog,
|
||||
valid_fetchs, epoch_id, 'valid')
|
||||
if top1_acc > best_top1_acc_list[0]:
|
||||
best_top1_acc_list[0] = top1_acc
|
||||
best_top1_acc_list[1] = epoch_id
|
||||
logger.info("Best top1 acc: {}, in epoch: {}".format(
|
||||
best_top1_acc_list[0], best_top1_acc_list[1]))
|
||||
model_path = os.path.join(config.model_save_dir,
|
||||
config.ARCHITECTURE["name"])
|
||||
save_model(train_prog, model_path, "best_model")
|
||||
|
||||
# 3. save the persistable model
|
||||
if epoch_id % config.save_interval == 0:
|
||||
|
|
Loading…
Reference in New Issue