mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Allow training w/o validation split set
This commit is contained in:
parent
be0944edae
commit
c50004db79
@ -1,7 +1,7 @@
|
|||||||
""" Scheduler Factory
|
""" Scheduler Factory
|
||||||
Hacked together by / Copyright 2021 Ross Wightman
|
Hacked together by / Copyright 2021 Ross Wightman
|
||||||
"""
|
"""
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
@ -13,11 +13,14 @@ from .step_lr import StepLRScheduler
|
|||||||
from .tanh_lr import TanhLRScheduler
|
from .tanh_lr import TanhLRScheduler
|
||||||
|
|
||||||
|
|
||||||
def scheduler_kwargs(cfg):
|
def scheduler_kwargs(cfg, decreasing_metric: Optional[bool] = None):
|
||||||
""" cfg/argparse to kwargs helper
|
""" cfg/argparse to kwargs helper
|
||||||
Convert scheduler args in argparse args or cfg (.dot) like object to keyword args.
|
Convert scheduler args in argparse args or cfg (.dot) like object to keyword args.
|
||||||
"""
|
"""
|
||||||
eval_metric = getattr(cfg, 'eval_metric', 'top1')
|
eval_metric = getattr(cfg, 'eval_metric', 'top1')
|
||||||
|
if decreasing_metric is not None:
|
||||||
|
plateau_mode = 'min' if decreasing_metric else 'max'
|
||||||
|
else:
|
||||||
plateau_mode = 'min' if 'loss' in eval_metric else 'max'
|
plateau_mode = 'min' if 'loss' in eval_metric else 'max'
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
sched=cfg.sched,
|
sched=cfg.sched,
|
||||||
|
@ -38,6 +38,7 @@ def update_summary(
|
|||||||
):
|
):
|
||||||
rowd = OrderedDict(epoch=epoch)
|
rowd = OrderedDict(epoch=epoch)
|
||||||
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
|
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
|
||||||
|
if eval_metrics:
|
||||||
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
|
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
|
||||||
if lr is not None:
|
if lr is not None:
|
||||||
rowd['lr'] = lr
|
rowd['lr'] = lr
|
||||||
|
25
train.py
25
train.py
@ -614,6 +614,7 @@ def main():
|
|||||||
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
|
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
|
||||||
else:
|
else:
|
||||||
input_img_mode = args.input_img_mode
|
input_img_mode = args.input_img_mode
|
||||||
|
|
||||||
dataset_train = create_dataset(
|
dataset_train = create_dataset(
|
||||||
args.dataset,
|
args.dataset,
|
||||||
root=args.data_dir,
|
root=args.data_dir,
|
||||||
@ -630,6 +631,7 @@ def main():
|
|||||||
num_samples=args.train_num_samples,
|
num_samples=args.train_num_samples,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.val_split:
|
||||||
dataset_eval = create_dataset(
|
dataset_eval = create_dataset(
|
||||||
args.dataset,
|
args.dataset,
|
||||||
root=args.data_dir,
|
root=args.data_dir,
|
||||||
@ -707,6 +709,8 @@ def main():
|
|||||||
worker_seeding=args.worker_seeding,
|
worker_seeding=args.worker_seeding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
loader_eval = None
|
||||||
|
if args.val_split:
|
||||||
eval_workers = args.workers
|
eval_workers = args.workers
|
||||||
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
|
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
|
||||||
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
|
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
|
||||||
@ -757,7 +761,8 @@ def main():
|
|||||||
validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
|
validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
|
||||||
|
|
||||||
# setup checkpoint saver and eval metric tracking
|
# setup checkpoint saver and eval metric tracking
|
||||||
eval_metric = args.eval_metric
|
eval_metric = args.eval_metric if loader_eval is not None else 'loss'
|
||||||
|
decreasing_metric = eval_metric == 'loss'
|
||||||
best_metric = None
|
best_metric = None
|
||||||
best_epoch = None
|
best_epoch = None
|
||||||
saver = None
|
saver = None
|
||||||
@ -772,7 +777,6 @@ def main():
|
|||||||
str(data_config['input_size'][-1])
|
str(data_config['input_size'][-1])
|
||||||
])
|
])
|
||||||
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
|
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
|
||||||
decreasing = True if eval_metric == 'loss' else False
|
|
||||||
saver = utils.CheckpointSaver(
|
saver = utils.CheckpointSaver(
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
@ -781,7 +785,7 @@ def main():
|
|||||||
amp_scaler=loss_scaler,
|
amp_scaler=loss_scaler,
|
||||||
checkpoint_dir=output_dir,
|
checkpoint_dir=output_dir,
|
||||||
recovery_dir=output_dir,
|
recovery_dir=output_dir,
|
||||||
decreasing=decreasing,
|
decreasing=decreasing_metric,
|
||||||
max_history=args.checkpoint_hist
|
max_history=args.checkpoint_hist
|
||||||
)
|
)
|
||||||
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
||||||
@ -799,7 +803,7 @@ def main():
|
|||||||
updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
|
updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
|
||||||
lr_scheduler, num_epochs = create_scheduler_v2(
|
lr_scheduler, num_epochs = create_scheduler_v2(
|
||||||
optimizer,
|
optimizer,
|
||||||
**scheduler_kwargs(args),
|
**scheduler_kwargs(args, decreasing_metric=decreasing_metric),
|
||||||
updates_per_epoch=updates_per_epoch,
|
updates_per_epoch=updates_per_epoch,
|
||||||
)
|
)
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
@ -847,6 +851,7 @@ def main():
|
|||||||
_logger.info("Distributing BatchNorm running means and vars")
|
_logger.info("Distributing BatchNorm running means and vars")
|
||||||
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
||||||
|
|
||||||
|
if loader_eval is not None:
|
||||||
eval_metrics = validate(
|
eval_metrics = validate(
|
||||||
model,
|
model,
|
||||||
loader_eval,
|
loader_eval,
|
||||||
@ -868,6 +873,8 @@ def main():
|
|||||||
log_suffix=' (EMA)',
|
log_suffix=' (EMA)',
|
||||||
)
|
)
|
||||||
eval_metrics = ema_eval_metrics
|
eval_metrics = ema_eval_metrics
|
||||||
|
else:
|
||||||
|
eval_metrics = None
|
||||||
|
|
||||||
if output_dir is not None:
|
if output_dir is not None:
|
||||||
lrs = [param_group['lr'] for param_group in optimizer.param_groups]
|
lrs = [param_group['lr'] for param_group in optimizer.param_groups]
|
||||||
@ -881,14 +888,18 @@ def main():
|
|||||||
log_wandb=args.log_wandb and has_wandb,
|
log_wandb=args.log_wandb and has_wandb,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if eval_metrics is not None:
|
||||||
|
latest_metric = eval_metrics[eval_metric]
|
||||||
|
else:
|
||||||
|
latest_metric = train_metrics[eval_metric]
|
||||||
|
|
||||||
if saver is not None:
|
if saver is not None:
|
||||||
# save proper checkpoint with eval metric
|
# save proper checkpoint with eval metric
|
||||||
save_metric = eval_metrics[eval_metric]
|
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=latest_metric)
|
||||||
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
|
|
||||||
|
|
||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
# step LR for next epoch
|
# step LR for next epoch
|
||||||
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
|
lr_scheduler.step(epoch + 1, latest_metric)
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
'epoch': epoch,
|
'epoch': epoch,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user