imporve msg info
parent
4d01af9e87
commit
a0b125d99a
|
@ -58,7 +58,6 @@ def main(args):
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
gpu_id = ParallelEnv().dev_id
|
gpu_id = ParallelEnv().dev_id
|
||||||
place = paddle.CUDAPlace(gpu_id)
|
place = paddle.CUDAPlace(gpu_id)
|
||||||
print("[gry debug]gpu_id: ", gpu_id)
|
|
||||||
else:
|
else:
|
||||||
place = paddle.CPUPlace()
|
place = paddle.CPUPlace()
|
||||||
|
|
||||||
|
@ -85,6 +84,7 @@ def main(args):
|
||||||
valid_dataloader = Reader(config, 'valid', places=place)()
|
valid_dataloader = Reader(config, 'valid', places=place)()
|
||||||
|
|
||||||
best_top1_acc = 0.0 # best top1 acc record
|
best_top1_acc = 0.0 # best top1 acc record
|
||||||
|
best_top1_epoch = 0
|
||||||
for epoch_id in range(config.epochs):
|
for epoch_id in range(config.epochs):
|
||||||
net.train()
|
net.train()
|
||||||
# 1. train with train dataset
|
# 1. train with train dataset
|
||||||
|
@ -99,14 +99,14 @@ def main(args):
|
||||||
None, epoch_id, 'valid')
|
None, epoch_id, 'valid')
|
||||||
if top1_acc > best_top1_acc:
|
if top1_acc > best_top1_acc:
|
||||||
best_top1_acc = top1_acc
|
best_top1_acc = top1_acc
|
||||||
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
|
best_top1_epoch = epoch_id
|
||||||
best_top1_acc, epoch_id)
|
|
||||||
logger.info("{:s}".format(logger.coloring(message, "RED")))
|
|
||||||
if epoch_id % config.save_interval == 0:
|
if epoch_id % config.save_interval == 0:
|
||||||
|
|
||||||
model_path = os.path.join(config.model_save_dir,
|
model_path = os.path.join(config.model_save_dir,
|
||||||
config.ARCHITECTURE["name"])
|
config.ARCHITECTURE["name"])
|
||||||
save_model(net, optimizer, model_path, "best_model")
|
save_model(net, optimizer, model_path, "best_model")
|
||||||
|
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
|
||||||
|
best_top1_acc, epoch_id)
|
||||||
|
logger.info("{:s}".format(logger.coloring(message, "RED")))
|
||||||
|
|
||||||
# 3. save the persistable model
|
# 3. save the persistable model
|
||||||
if epoch_id % config.save_interval == 0:
|
if epoch_id % config.save_interval == 0:
|
||||||
|
|
Loading…
Reference in New Issue