mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Allow training w/o validation split set
This commit is contained in:
parent
be0944edae
commit
c50004db79
@ -1,7 +1,7 @@
|
||||
""" Scheduler Factory
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from torch.optim import Optimizer
|
||||
|
||||
@ -13,12 +13,15 @@ from .step_lr import StepLRScheduler
|
||||
from .tanh_lr import TanhLRScheduler
|
||||
|
||||
|
||||
def scheduler_kwargs(cfg):
|
||||
def scheduler_kwargs(cfg, decreasing_metric: Optional[bool] = None):
|
||||
""" cfg/argparse to kwargs helper
|
||||
Convert scheduler args in argparse args or cfg (.dot) like object to keyword args.
|
||||
"""
|
||||
eval_metric = getattr(cfg, 'eval_metric', 'top1')
|
||||
plateau_mode = 'min' if 'loss' in eval_metric else 'max'
|
||||
if decreasing_metric is not None:
|
||||
plateau_mode = 'min' if decreasing_metric else 'max'
|
||||
else:
|
||||
plateau_mode = 'min' if 'loss' in eval_metric else 'max'
|
||||
kwargs = dict(
|
||||
sched=cfg.sched,
|
||||
num_epochs=getattr(cfg, 'epochs', 100),
|
||||
|
@ -38,7 +38,8 @@ def update_summary(
|
||||
):
|
||||
rowd = OrderedDict(epoch=epoch)
|
||||
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
|
||||
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
|
||||
if eval_metrics:
|
||||
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
|
||||
if lr is not None:
|
||||
rowd['lr'] = lr
|
||||
if log_wandb:
|
||||
|
121
train.py
121
train.py
@ -614,6 +614,7 @@ def main():
|
||||
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
|
||||
else:
|
||||
input_img_mode = args.input_img_mode
|
||||
|
||||
dataset_train = create_dataset(
|
||||
args.dataset,
|
||||
root=args.data_dir,
|
||||
@ -630,19 +631,20 @@ def main():
|
||||
num_samples=args.train_num_samples,
|
||||
)
|
||||
|
||||
dataset_eval = create_dataset(
|
||||
args.dataset,
|
||||
root=args.data_dir,
|
||||
split=args.val_split,
|
||||
is_training=False,
|
||||
class_map=args.class_map,
|
||||
download=args.dataset_download,
|
||||
batch_size=args.batch_size,
|
||||
input_img_mode=input_img_mode,
|
||||
input_key=args.input_key,
|
||||
target_key=args.target_key,
|
||||
num_samples=args.val_num_samples,
|
||||
)
|
||||
if args.val_split:
|
||||
dataset_eval = create_dataset(
|
||||
args.dataset,
|
||||
root=args.data_dir,
|
||||
split=args.val_split,
|
||||
is_training=False,
|
||||
class_map=args.class_map,
|
||||
download=args.dataset_download,
|
||||
batch_size=args.batch_size,
|
||||
input_img_mode=input_img_mode,
|
||||
input_key=args.input_key,
|
||||
target_key=args.target_key,
|
||||
num_samples=args.val_num_samples,
|
||||
)
|
||||
|
||||
# setup mixup / cutmix
|
||||
collate_fn = None
|
||||
@ -707,25 +709,27 @@ def main():
|
||||
worker_seeding=args.worker_seeding,
|
||||
)
|
||||
|
||||
eval_workers = args.workers
|
||||
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
|
||||
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
|
||||
eval_workers = min(2, args.workers)
|
||||
loader_eval = create_loader(
|
||||
dataset_eval,
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=args.validation_batch_size or args.batch_size,
|
||||
is_training=False,
|
||||
interpolation=data_config['interpolation'],
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
num_workers=eval_workers,
|
||||
distributed=args.distributed,
|
||||
crop_pct=data_config['crop_pct'],
|
||||
pin_memory=args.pin_mem,
|
||||
device=device,
|
||||
use_prefetcher=args.prefetcher,
|
||||
)
|
||||
loader_eval = None
|
||||
if args.val_split:
|
||||
eval_workers = args.workers
|
||||
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
|
||||
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
|
||||
eval_workers = min(2, args.workers)
|
||||
loader_eval = create_loader(
|
||||
dataset_eval,
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=args.validation_batch_size or args.batch_size,
|
||||
is_training=False,
|
||||
interpolation=data_config['interpolation'],
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
num_workers=eval_workers,
|
||||
distributed=args.distributed,
|
||||
crop_pct=data_config['crop_pct'],
|
||||
pin_memory=args.pin_mem,
|
||||
device=device,
|
||||
use_prefetcher=args.prefetcher,
|
||||
)
|
||||
|
||||
# setup loss function
|
||||
if args.jsd_loss:
|
||||
@ -757,7 +761,8 @@ def main():
|
||||
validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
|
||||
|
||||
# setup checkpoint saver and eval metric tracking
|
||||
eval_metric = args.eval_metric
|
||||
eval_metric = args.eval_metric if loader_eval is not None else 'loss'
|
||||
decreasing_metric = eval_metric == 'loss'
|
||||
best_metric = None
|
||||
best_epoch = None
|
||||
saver = None
|
||||
@ -772,7 +777,6 @@ def main():
|
||||
str(data_config['input_size'][-1])
|
||||
])
|
||||
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
|
||||
decreasing = True if eval_metric == 'loss' else False
|
||||
saver = utils.CheckpointSaver(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
@ -781,7 +785,7 @@ def main():
|
||||
amp_scaler=loss_scaler,
|
||||
checkpoint_dir=output_dir,
|
||||
recovery_dir=output_dir,
|
||||
decreasing=decreasing,
|
||||
decreasing=decreasing_metric,
|
||||
max_history=args.checkpoint_hist
|
||||
)
|
||||
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
||||
@ -799,7 +803,7 @@ def main():
|
||||
updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
|
||||
lr_scheduler, num_epochs = create_scheduler_v2(
|
||||
optimizer,
|
||||
**scheduler_kwargs(args),
|
||||
**scheduler_kwargs(args, decreasing_metric=decreasing_metric),
|
||||
updates_per_epoch=updates_per_epoch,
|
||||
)
|
||||
start_epoch = 0
|
||||
@ -847,27 +851,30 @@ def main():
|
||||
_logger.info("Distributing BatchNorm running means and vars")
|
||||
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
||||
|
||||
eval_metrics = validate(
|
||||
model,
|
||||
loader_eval,
|
||||
validate_loss_fn,
|
||||
args,
|
||||
amp_autocast=amp_autocast,
|
||||
)
|
||||
|
||||
if model_ema is not None and not args.model_ema_force_cpu:
|
||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
||||
|
||||
ema_eval_metrics = validate(
|
||||
model_ema.module,
|
||||
if loader_eval is not None:
|
||||
eval_metrics = validate(
|
||||
model,
|
||||
loader_eval,
|
||||
validate_loss_fn,
|
||||
args,
|
||||
amp_autocast=amp_autocast,
|
||||
log_suffix=' (EMA)',
|
||||
)
|
||||
eval_metrics = ema_eval_metrics
|
||||
|
||||
if model_ema is not None and not args.model_ema_force_cpu:
|
||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
||||
|
||||
ema_eval_metrics = validate(
|
||||
model_ema.module,
|
||||
loader_eval,
|
||||
validate_loss_fn,
|
||||
args,
|
||||
amp_autocast=amp_autocast,
|
||||
log_suffix=' (EMA)',
|
||||
)
|
||||
eval_metrics = ema_eval_metrics
|
||||
else:
|
||||
eval_metrics = None
|
||||
|
||||
if output_dir is not None:
|
||||
lrs = [param_group['lr'] for param_group in optimizer.param_groups]
|
||||
@ -881,14 +888,18 @@ def main():
|
||||
log_wandb=args.log_wandb and has_wandb,
|
||||
)
|
||||
|
||||
if eval_metrics is not None:
|
||||
latest_metric = eval_metrics[eval_metric]
|
||||
else:
|
||||
latest_metric = train_metrics[eval_metric]
|
||||
|
||||
if saver is not None:
|
||||
# save proper checkpoint with eval metric
|
||||
save_metric = eval_metrics[eval_metric]
|
||||
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
|
||||
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=latest_metric)
|
||||
|
||||
if lr_scheduler is not None:
|
||||
# step LR for next epoch
|
||||
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
|
||||
lr_scheduler.step(epoch + 1, latest_metric)
|
||||
|
||||
results.append({
|
||||
'epoch': epoch,
|
||||
|
Loading…
x
Reference in New Issue
Block a user