log to wandb only if using using wandb
parent
00c8e0b8bd
commit
624c9b6949
timm/utils
|
@ -23,10 +23,12 @@ def get_outdir(path, *paths, inc=False):
|
||||||
return outdir
|
return outdir
|
||||||
|
|
||||||
|
|
||||||
def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False):
|
def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False):
|
||||||
rowd = OrderedDict(epoch=epoch)
|
rowd = OrderedDict(epoch=epoch)
|
||||||
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
|
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
|
||||||
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
|
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
|
||||||
|
if log_wandb:
|
||||||
|
wandb.log(rowd)
|
||||||
with open(filename, mode='a') as cf:
|
with open(filename, mode='a') as cf:
|
||||||
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
|
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
|
||||||
if write_header: # first iteration (epoch == 1 can't be used)
|
if write_header: # first iteration (epoch == 1 can't be used)
|
||||||
|
|
6
train.py
6
train.py
|
@ -592,10 +592,6 @@ def main():
|
||||||
|
|
||||||
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
|
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
|
||||||
|
|
||||||
if args.use_wandb:
|
|
||||||
wandb.log(train_metrics)
|
|
||||||
wandb.log(eval_metrics)
|
|
||||||
|
|
||||||
if model_ema is not None and not args.model_ema_force_cpu:
|
if model_ema is not None and not args.model_ema_force_cpu:
|
||||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||||
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
||||||
|
@ -609,7 +605,7 @@ def main():
|
||||||
|
|
||||||
update_summary(
|
update_summary(
|
||||||
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
||||||
write_header=best_metric is None)
|
write_header=best_metric is None, log_wandb=args.use_wandb)
|
||||||
|
|
||||||
if saver is not None:
|
if saver is not None:
|
||||||
# save proper checkpoint with eval metric
|
# save proper checkpoint with eval metric
|
||||||
|
|
Loading…
Reference in New Issue