diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index 6cb506a5..caf68fad 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -1,7 +1,7 @@ """ Scheduler Factory Hacked together by / Copyright 2021 Ross Wightman """ -from typing import List, Union +from typing import List, Optional, Union from torch.optim import Optimizer @@ -13,12 +13,15 @@ from .step_lr import StepLRScheduler from .tanh_lr import TanhLRScheduler -def scheduler_kwargs(cfg): +def scheduler_kwargs(cfg, decreasing_metric: Optional[bool] = None): """ cfg/argparse to kwargs helper Convert scheduler args in argparse args or cfg (.dot) like object to keyword args. """ eval_metric = getattr(cfg, 'eval_metric', 'top1') - plateau_mode = 'min' if 'loss' in eval_metric else 'max' + if decreasing_metric is not None: + plateau_mode = 'min' if decreasing_metric else 'max' + else: + plateau_mode = 'min' if 'loss' in eval_metric else 'max' kwargs = dict( sched=cfg.sched, num_epochs=getattr(cfg, 'epochs', 100), diff --git a/timm/utils/summary.py b/timm/utils/summary.py index c377a75f..eccbb941 100644 --- a/timm/utils/summary.py +++ b/timm/utils/summary.py @@ -38,7 +38,8 @@ def update_summary( ): rowd = OrderedDict(epoch=epoch) rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) - rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) + if eval_metrics: + rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) if lr is not None: rowd['lr'] = lr if log_wandb: diff --git a/train.py b/train.py index fceba87e..077d0316 100755 --- a/train.py +++ b/train.py @@ -614,6 +614,7 @@ def main(): input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L' else: input_img_mode = args.input_img_mode + dataset_train = create_dataset( args.dataset, root=args.data_dir, @@ -630,19 +631,20 @@ def main(): num_samples=args.train_num_samples, ) - dataset_eval = create_dataset( - args.dataset, - root=args.data_dir, - split=args.val_split, - is_training=False, - class_map=args.class_map, - download=args.dataset_download, - batch_size=args.batch_size, - input_img_mode=input_img_mode, - input_key=args.input_key, - target_key=args.target_key, - num_samples=args.val_num_samples, - ) + if args.val_split: + dataset_eval = create_dataset( + args.dataset, + root=args.data_dir, + split=args.val_split, + is_training=False, + class_map=args.class_map, + download=args.dataset_download, + batch_size=args.batch_size, + input_img_mode=input_img_mode, + input_key=args.input_key, + target_key=args.target_key, + num_samples=args.val_num_samples, + ) # setup mixup / cutmix collate_fn = None @@ -707,25 +709,27 @@ def main(): worker_seeding=args.worker_seeding, ) - eval_workers = args.workers - if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset): - # FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training - eval_workers = min(2, args.workers) - loader_eval = create_loader( - dataset_eval, - input_size=data_config['input_size'], - batch_size=args.validation_batch_size or args.batch_size, - is_training=False, - interpolation=data_config['interpolation'], - mean=data_config['mean'], - std=data_config['std'], - num_workers=eval_workers, - distributed=args.distributed, - crop_pct=data_config['crop_pct'], - pin_memory=args.pin_mem, - device=device, - use_prefetcher=args.prefetcher, - ) + loader_eval = None + if args.val_split: + eval_workers = args.workers + if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset): + # FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training + eval_workers = min(2, args.workers) + loader_eval = create_loader( + dataset_eval, + input_size=data_config['input_size'], + batch_size=args.validation_batch_size or args.batch_size, + is_training=False, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=eval_workers, + distributed=args.distributed, + crop_pct=data_config['crop_pct'], + pin_memory=args.pin_mem, + device=device, + use_prefetcher=args.prefetcher, + ) # setup loss function if args.jsd_loss: @@ -757,7 +761,8 @@ def main(): validate_loss_fn = nn.CrossEntropyLoss().to(device=device) # setup checkpoint saver and eval metric tracking - eval_metric = args.eval_metric + eval_metric = args.eval_metric if loader_eval is not None else 'loss' + decreasing_metric = eval_metric == 'loss' best_metric = None best_epoch = None saver = None @@ -772,7 +777,6 @@ def main(): str(data_config['input_size'][-1]) ]) output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name) - decreasing = True if eval_metric == 'loss' else False saver = utils.CheckpointSaver( model=model, optimizer=optimizer, @@ -781,7 +785,7 @@ def main(): amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, - decreasing=decreasing, + decreasing=decreasing_metric, max_history=args.checkpoint_hist ) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: @@ -799,7 +803,7 @@ def main(): updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps lr_scheduler, num_epochs = create_scheduler_v2( optimizer, - **scheduler_kwargs(args), + **scheduler_kwargs(args, decreasing_metric=decreasing_metric), updates_per_epoch=updates_per_epoch, ) start_epoch = 0 @@ -847,27 +851,30 @@ def main(): _logger.info("Distributing BatchNorm running means and vars") utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce') - eval_metrics = validate( - model, - loader_eval, - validate_loss_fn, - args, - amp_autocast=amp_autocast, - ) - - if model_ema is not None and not args.model_ema_force_cpu: - if args.distributed and args.dist_bn in ('broadcast', 'reduce'): - utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') - - ema_eval_metrics = validate( - model_ema.module, + if loader_eval is not None: + eval_metrics = validate( + model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, - log_suffix=' (EMA)', ) - eval_metrics = ema_eval_metrics + + if model_ema is not None and not args.model_ema_force_cpu: + if args.distributed and args.dist_bn in ('broadcast', 'reduce'): + utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') + + ema_eval_metrics = validate( + model_ema.module, + loader_eval, + validate_loss_fn, + args, + amp_autocast=amp_autocast, + log_suffix=' (EMA)', + ) + eval_metrics = ema_eval_metrics + else: + eval_metrics = None if output_dir is not None: lrs = [param_group['lr'] for param_group in optimizer.param_groups] @@ -881,14 +888,18 @@ def main(): log_wandb=args.log_wandb and has_wandb, ) + if eval_metrics is not None: + latest_metric = eval_metrics[eval_metric] + else: + latest_metric = train_metrics[eval_metric] + if saver is not None: # save proper checkpoint with eval metric - save_metric = eval_metrics[eval_metric] - best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) + best_metric, best_epoch = saver.save_checkpoint(epoch, metric=latest_metric) if lr_scheduler is not None: # step LR for next epoch - lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) + lr_scheduler.step(epoch + 1, latest_metric) results.append({ 'epoch': epoch,