added tensorboard
parent
e68f8e290a
commit
693b581c24
|
@ -131,7 +131,10 @@ def do_test(cfg, model, iteration):
|
|||
torch.save({"teacher": new_state_dict}, teacher_ckp_path)
|
||||
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
def do_train(cfg, model, resume=False):
|
||||
writer = SummaryWriter(log_dir=cfg.train.output_dir)
|
||||
model.train()
|
||||
inputs_dtype = torch.half
|
||||
fp16_scaler = model.fp16_scaler # for mixed precision training
|
||||
|
@ -281,7 +284,10 @@ def do_train(cfg, model, resume=False):
|
|||
metric_logger.update(last_layer_lr=last_layer_lr)
|
||||
metric_logger.update(current_batch_size=current_batch_size)
|
||||
metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced)
|
||||
|
||||
writer.add_scalar('Loss/Total_Loss', losses_reduced, iteration)
|
||||
writer.add_scalar('Learning_Rate', lr, iteration)
|
||||
writer.add_scalar('Weight_Decay', wd, iteration)
|
||||
writer.add_scalar('Momentum', mom, iteration)
|
||||
# checkpointing and testing
|
||||
|
||||
if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0:
|
||||
|
@ -291,6 +297,7 @@ def do_train(cfg, model, resume=False):
|
|||
|
||||
iteration = iteration + 1
|
||||
metric_logger.synchronize_between_processes()
|
||||
writer.close()
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
|
||||
import re
|
||||
|
|
Loading…
Reference in New Issue