diff --git a/train.py b/train.py index 7d8e9898..5e9ecc9c 100755 --- a/train.py +++ b/train.py @@ -390,6 +390,8 @@ group.add_argument('--use-multi-epochs-loader', action='store_true', default=Fal help='use the multi-epochs-loader to save time at the beginning of every epoch') group.add_argument('--log-wandb', action='store_true', default=False, help='log training and validation metrics to wandb') +group.add_argument('--wandb-project', default=None, type=str, + help='wandb project name') group.add_argument('--wandb-tags', default=[], type=str, nargs='+', help='wandb tags') group.add_argument('--wandb-resume-id', default='', type=str, metavar='ID', @@ -832,20 +834,21 @@ def main(): with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) - if utils.is_primary(args) and args.log_wandb: - if has_wandb: - assert not args.wandb_resume_id or args.resume - wandb.init( - project=args.experiment, - config=args, - tags=args.wandb_tags, - resume='must' if args.wandb_resume_id else None, - id=args.wandb_resume_id if args.wandb_resume_id else None, - ) - else: - _logger.warning( - "You've requested to log metrics to wandb but package not found. " - "Metrics not being logged to wandb, try `pip install wandb`") + if args.log_wandb: + if has_wandb: + assert not args.wandb_resume_id or args.resume + wandb.init( + project=args.wandb_project, + name=exp_name, + config=args, + tags=args.wandb_tags, + resume="must" if args.wandb_resume_id else None, + id=args.wandb_resume_id if args.wandb_resume_id else None, + ) + else: + _logger.warning( + "You've requested to log metrics to wandb but package not found. " + "Metrics not being logged to wandb, try `pip install wandb`") # setup learning rate schedule and starting epoch updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps