added tensorboard

pull/511/head
Veronikkkka 2025-03-10 17:43:16 +00:00
parent e68f8e290a
commit 693b581c24
1 changed files with 8 additions and 1 deletions

View File

@ -131,7 +131,10 @@ def do_test(cfg, model, iteration):
torch.save({"teacher": new_state_dict}, teacher_ckp_path)
from torch.utils.tensorboard import SummaryWriter
def do_train(cfg, model, resume=False):
writer = SummaryWriter(log_dir=cfg.train.output_dir)
model.train()
inputs_dtype = torch.half
fp16_scaler = model.fp16_scaler # for mixed precision training
@ -281,7 +284,10 @@ def do_train(cfg, model, resume=False):
metric_logger.update(last_layer_lr=last_layer_lr)
metric_logger.update(current_batch_size=current_batch_size)
metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced)
writer.add_scalar('Loss/Total_Loss', losses_reduced, iteration)
writer.add_scalar('Learning_Rate', lr, iteration)
writer.add_scalar('Weight_Decay', wd, iteration)
writer.add_scalar('Momentum', mom, iteration)
# checkpointing and testing
if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0:
@ -291,6 +297,7 @@ def do_train(cfg, model, resume=False):
iteration = iteration + 1
metric_logger.synchronize_between_processes()
writer.close()
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
import re