fix tools/train.py
parent
70f1782339
commit
a9bce0409d
|
@ -109,47 +109,21 @@ def main(args):
|
|||
program.run(train_dataloader, config, dp_net, optimizer,
|
||||
lr_scheduler, epoch_id, 'train', vdl_writer)
|
||||
|
||||
if use_xpu:
|
||||
if paddle.distributed.get_rank() == 0:
|
||||
# 2. validate with validate dataset
|
||||
if config.validate and epoch_id % config.valid_interval == 0:
|
||||
net.eval()
|
||||
top1_acc = program.run(valid_dataloader, config, net,
|
||||
None, None, epoch_id, 'valid')
|
||||
if top1_acc > best_top1_acc:
|
||||
best_top1_acc = top1_acc
|
||||
best_top1_epoch = epoch_id
|
||||
if epoch_id % config.save_interval == 0:
|
||||
model_path = os.path.join(
|
||||
config.model_save_dir,
|
||||
config.ARCHITECTURE["name"])
|
||||
save_model(net, optimizer, model_path,
|
||||
"best_model")
|
||||
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
|
||||
best_top1_acc, best_top1_epoch)
|
||||
logger.info("{:s}".format(
|
||||
logger.coloring(message, "RED")))
|
||||
|
||||
else:
|
||||
# 2. validate with validate dataset
|
||||
if paddle.distributed.get_rank() == 0:
|
||||
if config.validate and epoch_id % config.valid_interval == 0:
|
||||
net.eval()
|
||||
with paddle.no_grad():
|
||||
top1_acc = program.run(valid_dataloader, config,
|
||||
net, None, None, epoch_id,
|
||||
'valid', vdl_writer)
|
||||
if top1_acc > best_top1_acc:
|
||||
best_top1_acc = top1_acc
|
||||
best_top1_epoch = epoch_id
|
||||
model_path = os.path.join(
|
||||
config.model_save_dir,
|
||||
config.ARCHITECTURE["name"])
|
||||
save_model(net, optimizer, model_path,
|
||||
"best_model")
|
||||
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
|
||||
best_top1_acc, best_top1_epoch)
|
||||
logger.info(message)
|
||||
# 2. validate with validate dataset
|
||||
if config.validate and epoch_id % config.valid_interval == 0:
|
||||
net.eval()
|
||||
with paddle.no_grad():
|
||||
top1_acc = program.run(valid_dataloader, config, net, None,
|
||||
None, epoch_id, 'valid', vdl_writer)
|
||||
if top1_acc > best_top1_acc:
|
||||
best_top1_acc = top1_acc
|
||||
best_top1_epoch = epoch_id
|
||||
model_path = os.path.join(config.model_save_dir,
|
||||
config.ARCHITECTURE["name"])
|
||||
save_model(net, optimizer, model_path, "best_model")
|
||||
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
|
||||
best_top1_acc, best_top1_epoch)
|
||||
logger.info(message)
|
||||
|
||||
# 3. save the persistable model
|
||||
if epoch_id % config.save_interval == 0:
|
||||
|
|
Loading…
Reference in New Issue