From 3a6cc4fb1753beda0af110b8ce71484d5614e868 Mon Sep 17 00:00:00 2001 From: Sina Hajimiri Date: Wed, 20 Nov 2024 20:04:34 -0500 Subject: [PATCH] Improve wandb logging --- train.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index aa9db22b..cd8d2f38 100755 --- a/train.py +++ b/train.py @@ -386,6 +386,10 @@ group.add_argument('--use-multi-epochs-loader', action='store_true', default=Fal help='use the multi-epochs-loader to save time at the beginning of every epoch') group.add_argument('--log-wandb', action='store_true', default=False, help='log training and validation metrics to wandb') +group.add_argument('--wandb-tags', default=[], type=str, nargs='+', + help='wandb tags') +group.add_argument('--wandb-resume-id', default='', type=str, metavar='ID', + help='If resuming a run, the id of the run in wandb') def _parse_args(): @@ -814,7 +818,10 @@ def main(): if utils.is_primary(args) and args.log_wandb: if has_wandb: - wandb.init(project=args.experiment, config=args) + assert not args.wandb_resume_id or args.resume + wandb.init(project=args.experiment, config=args, tags=args.wandb_tags, + resume='must' if args.wandb_resume_id else None, + id=args.wandb_resume_id if args.wandb_resume_id else None) else: _logger.warning( "You've requested to log metrics to wandb but package not found. "