Allow training w/o validation split set

This commit is contained in:
Ross Wightman 2024-01-06 10:27:08 -08:00 committed by Ross Wightman
parent be0944edae
commit c50004db79
3 changed files with 74 additions and 59 deletions

View File

@ -1,7 +1,7 @@
""" Scheduler Factory
Hacked together by / Copyright 2021 Ross Wightman
"""
from typing import List, Union
from typing import List, Optional, Union
from torch.optim import Optimizer
@ -13,12 +13,15 @@ from .step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler
def scheduler_kwargs(cfg):
def scheduler_kwargs(cfg, decreasing_metric: Optional[bool] = None):
""" cfg/argparse to kwargs helper
Convert scheduler args in argparse args or cfg (.dot) like object to keyword args.
"""
eval_metric = getattr(cfg, 'eval_metric', 'top1')
plateau_mode = 'min' if 'loss' in eval_metric else 'max'
if decreasing_metric is not None:
plateau_mode = 'min' if decreasing_metric else 'max'
else:
plateau_mode = 'min' if 'loss' in eval_metric else 'max'
kwargs = dict(
sched=cfg.sched,
num_epochs=getattr(cfg, 'epochs', 100),

View File

@ -38,7 +38,8 @@ def update_summary(
):
rowd = OrderedDict(epoch=epoch)
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
if eval_metrics:
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
if lr is not None:
rowd['lr'] = lr
if log_wandb:

121
train.py
View File

@ -614,6 +614,7 @@ def main():
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
else:
input_img_mode = args.input_img_mode
dataset_train = create_dataset(
args.dataset,
root=args.data_dir,
@ -630,19 +631,20 @@ def main():
num_samples=args.train_num_samples,
)
dataset_eval = create_dataset(
args.dataset,
root=args.data_dir,
split=args.val_split,
is_training=False,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size,
input_img_mode=input_img_mode,
input_key=args.input_key,
target_key=args.target_key,
num_samples=args.val_num_samples,
)
if args.val_split:
dataset_eval = create_dataset(
args.dataset,
root=args.data_dir,
split=args.val_split,
is_training=False,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size,
input_img_mode=input_img_mode,
input_key=args.input_key,
target_key=args.target_key,
num_samples=args.val_num_samples,
)
# setup mixup / cutmix
collate_fn = None
@ -707,25 +709,27 @@ def main():
worker_seeding=args.worker_seeding,
)
eval_workers = args.workers
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
eval_workers = min(2, args.workers)
loader_eval = create_loader(
dataset_eval,
input_size=data_config['input_size'],
batch_size=args.validation_batch_size or args.batch_size,
is_training=False,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=eval_workers,
distributed=args.distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
device=device,
use_prefetcher=args.prefetcher,
)
loader_eval = None
if args.val_split:
eval_workers = args.workers
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
eval_workers = min(2, args.workers)
loader_eval = create_loader(
dataset_eval,
input_size=data_config['input_size'],
batch_size=args.validation_batch_size or args.batch_size,
is_training=False,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=eval_workers,
distributed=args.distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
device=device,
use_prefetcher=args.prefetcher,
)
# setup loss function
if args.jsd_loss:
@ -757,7 +761,8 @@ def main():
validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric
eval_metric = args.eval_metric if loader_eval is not None else 'loss'
decreasing_metric = eval_metric == 'loss'
best_metric = None
best_epoch = None
saver = None
@ -772,7 +777,6 @@ def main():
str(data_config['input_size'][-1])
])
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = utils.CheckpointSaver(
model=model,
optimizer=optimizer,
@ -781,7 +785,7 @@ def main():
amp_scaler=loss_scaler,
checkpoint_dir=output_dir,
recovery_dir=output_dir,
decreasing=decreasing,
decreasing=decreasing_metric,
max_history=args.checkpoint_hist
)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
@ -799,7 +803,7 @@ def main():
updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
lr_scheduler, num_epochs = create_scheduler_v2(
optimizer,
**scheduler_kwargs(args),
**scheduler_kwargs(args, decreasing_metric=decreasing_metric),
updates_per_epoch=updates_per_epoch,
)
start_epoch = 0
@ -847,27 +851,30 @@ def main():
_logger.info("Distributing BatchNorm running means and vars")
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(
model,
loader_eval,
validate_loss_fn,
args,
amp_autocast=amp_autocast,
)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
model_ema.module,
if loader_eval is not None:
eval_metrics = validate(
model,
loader_eval,
validate_loss_fn,
args,
amp_autocast=amp_autocast,
log_suffix=' (EMA)',
)
eval_metrics = ema_eval_metrics
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
model_ema.module,
loader_eval,
validate_loss_fn,
args,
amp_autocast=amp_autocast,
log_suffix=' (EMA)',
)
eval_metrics = ema_eval_metrics
else:
eval_metrics = None
if output_dir is not None:
lrs = [param_group['lr'] for param_group in optimizer.param_groups]
@ -881,14 +888,18 @@ def main():
log_wandb=args.log_wandb and has_wandb,
)
if eval_metrics is not None:
latest_metric = eval_metrics[eval_metric]
else:
latest_metric = train_metrics[eval_metric]
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=latest_metric)
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
lr_scheduler.step(epoch + 1, latest_metric)
results.append({
'epoch': epoch,