Make use of wandb configurable
parent
8e6fb861e4
commit
00c8e0b8bd
21
train.py
21
train.py
|
@ -273,6 +273,10 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa
|
|||
help='use the multi-epochs-loader to save time at the beginning of every epoch')
|
||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='convert model torchscript for inference')
|
||||
parser.add_argument('--use-wandb', action='store_true', default=False,
|
||||
help='use wandb for training and validation logs')
|
||||
parser.add_argument('--wandb-project-name', type=str, default=None,
|
||||
help='wandb project name to be used')
|
||||
|
||||
|
||||
def _parse_args():
|
||||
|
@ -295,8 +299,13 @@ def _parse_args():
|
|||
def main():
|
||||
setup_default_logging()
|
||||
args, args_text = _parse_args()
|
||||
wandb.init(project='efficientnet_v2', config=args)
|
||||
wandb.run.name = args.model
|
||||
|
||||
if args.use_wandb:
|
||||
if not args.wandb_project_name:
|
||||
args.wandb_project_name = args.model
|
||||
_logger.warning(f"Wandb project name not provided, defaulting to {args.model}")
|
||||
wandb.init(project=args.wandb_project_name, config=args)
|
||||
|
||||
args.prefetcher = not args.no_prefetcher
|
||||
args.distributed = False
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
|
@ -575,14 +584,18 @@ def main():
|
|||
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
||||
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
||||
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
|
||||
wandb.log(train_metrics)
|
||||
|
||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||
if args.local_rank == 0:
|
||||
_logger.info("Distributing BatchNorm running means and vars")
|
||||
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
||||
|
||||
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
|
||||
wandb.log(eval_metrics)
|
||||
|
||||
if args.use_wandb:
|
||||
wandb.log(train_metrics)
|
||||
wandb.log(eval_metrics)
|
||||
|
||||
if model_ema is not None and not args.model_ema_force_cpu:
|
||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
||||
|
|
Loading…
Reference in New Issue