Allow training w/o validation split set

This commit is contained in:
Ross Wightman 2024-01-06 10:27:08 -08:00 committed by Ross Wightman
parent be0944edae
commit c50004db79
3 changed files with 74 additions and 59 deletions

View File

@ -1,7 +1,7 @@
""" Scheduler Factory """ Scheduler Factory
Hacked together by / Copyright 2021 Ross Wightman Hacked together by / Copyright 2021 Ross Wightman
""" """
from typing import List, Union from typing import List, Optional, Union
from torch.optim import Optimizer from torch.optim import Optimizer
@ -13,11 +13,14 @@ from .step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler from .tanh_lr import TanhLRScheduler
def scheduler_kwargs(cfg): def scheduler_kwargs(cfg, decreasing_metric: Optional[bool] = None):
""" cfg/argparse to kwargs helper """ cfg/argparse to kwargs helper
Convert scheduler args in argparse args or cfg (.dot) like object to keyword args. Convert scheduler args in argparse args or cfg (.dot) like object to keyword args.
""" """
eval_metric = getattr(cfg, 'eval_metric', 'top1') eval_metric = getattr(cfg, 'eval_metric', 'top1')
if decreasing_metric is not None:
plateau_mode = 'min' if decreasing_metric else 'max'
else:
plateau_mode = 'min' if 'loss' in eval_metric else 'max' plateau_mode = 'min' if 'loss' in eval_metric else 'max'
kwargs = dict( kwargs = dict(
sched=cfg.sched, sched=cfg.sched,

View File

@ -38,6 +38,7 @@ def update_summary(
): ):
rowd = OrderedDict(epoch=epoch) rowd = OrderedDict(epoch=epoch)
rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
if eval_metrics:
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
if lr is not None: if lr is not None:
rowd['lr'] = lr rowd['lr'] = lr

View File

@ -614,6 +614,7 @@ def main():
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L' input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
else: else:
input_img_mode = args.input_img_mode input_img_mode = args.input_img_mode
dataset_train = create_dataset( dataset_train = create_dataset(
args.dataset, args.dataset,
root=args.data_dir, root=args.data_dir,
@ -630,6 +631,7 @@ def main():
num_samples=args.train_num_samples, num_samples=args.train_num_samples,
) )
if args.val_split:
dataset_eval = create_dataset( dataset_eval = create_dataset(
args.dataset, args.dataset,
root=args.data_dir, root=args.data_dir,
@ -707,6 +709,8 @@ def main():
worker_seeding=args.worker_seeding, worker_seeding=args.worker_seeding,
) )
loader_eval = None
if args.val_split:
eval_workers = args.workers eval_workers = args.workers
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset): if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training # FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
@ -757,7 +761,8 @@ def main():
validate_loss_fn = nn.CrossEntropyLoss().to(device=device) validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
# setup checkpoint saver and eval metric tracking # setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric eval_metric = args.eval_metric if loader_eval is not None else 'loss'
decreasing_metric = eval_metric == 'loss'
best_metric = None best_metric = None
best_epoch = None best_epoch = None
saver = None saver = None
@ -772,7 +777,6 @@ def main():
str(data_config['input_size'][-1]) str(data_config['input_size'][-1])
]) ])
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name) output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = utils.CheckpointSaver( saver = utils.CheckpointSaver(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
@ -781,7 +785,7 @@ def main():
amp_scaler=loss_scaler, amp_scaler=loss_scaler,
checkpoint_dir=output_dir, checkpoint_dir=output_dir,
recovery_dir=output_dir, recovery_dir=output_dir,
decreasing=decreasing, decreasing=decreasing_metric,
max_history=args.checkpoint_hist max_history=args.checkpoint_hist
) )
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
@ -799,7 +803,7 @@ def main():
updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
lr_scheduler, num_epochs = create_scheduler_v2( lr_scheduler, num_epochs = create_scheduler_v2(
optimizer, optimizer,
**scheduler_kwargs(args), **scheduler_kwargs(args, decreasing_metric=decreasing_metric),
updates_per_epoch=updates_per_epoch, updates_per_epoch=updates_per_epoch,
) )
start_epoch = 0 start_epoch = 0
@ -847,6 +851,7 @@ def main():
_logger.info("Distributing BatchNorm running means and vars") _logger.info("Distributing BatchNorm running means and vars")
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce') utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
if loader_eval is not None:
eval_metrics = validate( eval_metrics = validate(
model, model,
loader_eval, loader_eval,
@ -868,6 +873,8 @@ def main():
log_suffix=' (EMA)', log_suffix=' (EMA)',
) )
eval_metrics = ema_eval_metrics eval_metrics = ema_eval_metrics
else:
eval_metrics = None
if output_dir is not None: if output_dir is not None:
lrs = [param_group['lr'] for param_group in optimizer.param_groups] lrs = [param_group['lr'] for param_group in optimizer.param_groups]
@ -881,14 +888,18 @@ def main():
log_wandb=args.log_wandb and has_wandb, log_wandb=args.log_wandb and has_wandb,
) )
if eval_metrics is not None:
latest_metric = eval_metrics[eval_metric]
else:
latest_metric = train_metrics[eval_metric]
if saver is not None: if saver is not None:
# save proper checkpoint with eval metric # save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint(epoch, metric=latest_metric)
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
if lr_scheduler is not None: if lr_scheduler is not None:
# step LR for next epoch # step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) lr_scheduler.step(epoch + 1, latest_metric)
results.append({ results.append({
'epoch': epoch, 'epoch': epoch,