mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Allow training w/o validation split set
This commit is contained in:
parent
be0944edae
commit
c50004db79
@ -1,7 +1,7 @@
|
|||||||
""" Scheduler Factory
|
""" Scheduler Factory
|
||||||
Hacked together by / Copyright 2021 Ross Wightman
|
Hacked together by / Copyright 2021 Ross Wightman
|
||||||
"""
|
"""
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
@ -13,12 +13,15 @@ from .step_lr import StepLRScheduler
|
|||||||
from .tanh_lr import TanhLRScheduler
|
from .tanh_lr import TanhLRScheduler
|
||||||
|
|
||||||
|
|
||||||
def scheduler_kwargs(cfg):
|
def scheduler_kwargs(cfg, decreasing_metric: Optional[bool] = None):
|
||||||
""" cfg/argparse to kwargs helper
|
""" cfg/argparse to kwargs helper
|
||||||
Convert scheduler args in argparse args or cfg (.dot) like object to keyword args.
|
Convert scheduler args in argparse args or cfg (.dot) like object to keyword args.
|
||||||
"""
|
"""
|
||||||
eval_metric = getattr(cfg, 'eval_metric', 'top1')
|
eval_metric = getattr(cfg, 'eval_metric', 'top1')
|
||||||
plateau_mode = 'min' if 'loss' in eval_metric else 'max'
|
if decreasing_metric is not None:
|
||||||
|
plateau_mode = 'min' if decreasing_metric else 'max'
|
||||||
|
else:
|
||||||
|
plateau_mode = 'min' if 'loss' in eval_metric else 'max'
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
sched=cfg.sched,
|
sched=cfg.sched,
|
||||||
num_epochs=getattr(cfg, 'epochs', 100),
|
num_epochs=getattr(cfg, 'epochs', 100),
|
||||||
|
@ -38,7 +38,8 @@ def update_summary(
|
|||||||
):
|
):
|
||||||
rowd = OrderedDict(epoch=epoch)
|
rowd = OrderedDict(epoch=epoch)
|
||||||
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
|
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
|
||||||
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
|
if eval_metrics:
|
||||||
|
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
|
||||||
if lr is not None:
|
if lr is not None:
|
||||||
rowd['lr'] = lr
|
rowd['lr'] = lr
|
||||||
if log_wandb:
|
if log_wandb:
|
||||||
|
121
train.py
121
train.py
@ -614,6 +614,7 @@ def main():
|
|||||||
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
|
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
|
||||||
else:
|
else:
|
||||||
input_img_mode = args.input_img_mode
|
input_img_mode = args.input_img_mode
|
||||||
|
|
||||||
dataset_train = create_dataset(
|
dataset_train = create_dataset(
|
||||||
args.dataset,
|
args.dataset,
|
||||||
root=args.data_dir,
|
root=args.data_dir,
|
||||||
@ -630,19 +631,20 @@ def main():
|
|||||||
num_samples=args.train_num_samples,
|
num_samples=args.train_num_samples,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_eval = create_dataset(
|
if args.val_split:
|
||||||
args.dataset,
|
dataset_eval = create_dataset(
|
||||||
root=args.data_dir,
|
args.dataset,
|
||||||
split=args.val_split,
|
root=args.data_dir,
|
||||||
is_training=False,
|
split=args.val_split,
|
||||||
class_map=args.class_map,
|
is_training=False,
|
||||||
download=args.dataset_download,
|
class_map=args.class_map,
|
||||||
batch_size=args.batch_size,
|
download=args.dataset_download,
|
||||||
input_img_mode=input_img_mode,
|
batch_size=args.batch_size,
|
||||||
input_key=args.input_key,
|
input_img_mode=input_img_mode,
|
||||||
target_key=args.target_key,
|
input_key=args.input_key,
|
||||||
num_samples=args.val_num_samples,
|
target_key=args.target_key,
|
||||||
)
|
num_samples=args.val_num_samples,
|
||||||
|
)
|
||||||
|
|
||||||
# setup mixup / cutmix
|
# setup mixup / cutmix
|
||||||
collate_fn = None
|
collate_fn = None
|
||||||
@ -707,25 +709,27 @@ def main():
|
|||||||
worker_seeding=args.worker_seeding,
|
worker_seeding=args.worker_seeding,
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_workers = args.workers
|
loader_eval = None
|
||||||
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
|
if args.val_split:
|
||||||
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
|
eval_workers = args.workers
|
||||||
eval_workers = min(2, args.workers)
|
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
|
||||||
loader_eval = create_loader(
|
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
|
||||||
dataset_eval,
|
eval_workers = min(2, args.workers)
|
||||||
input_size=data_config['input_size'],
|
loader_eval = create_loader(
|
||||||
batch_size=args.validation_batch_size or args.batch_size,
|
dataset_eval,
|
||||||
is_training=False,
|
input_size=data_config['input_size'],
|
||||||
interpolation=data_config['interpolation'],
|
batch_size=args.validation_batch_size or args.batch_size,
|
||||||
mean=data_config['mean'],
|
is_training=False,
|
||||||
std=data_config['std'],
|
interpolation=data_config['interpolation'],
|
||||||
num_workers=eval_workers,
|
mean=data_config['mean'],
|
||||||
distributed=args.distributed,
|
std=data_config['std'],
|
||||||
crop_pct=data_config['crop_pct'],
|
num_workers=eval_workers,
|
||||||
pin_memory=args.pin_mem,
|
distributed=args.distributed,
|
||||||
device=device,
|
crop_pct=data_config['crop_pct'],
|
||||||
use_prefetcher=args.prefetcher,
|
pin_memory=args.pin_mem,
|
||||||
)
|
device=device,
|
||||||
|
use_prefetcher=args.prefetcher,
|
||||||
|
)
|
||||||
|
|
||||||
# setup loss function
|
# setup loss function
|
||||||
if args.jsd_loss:
|
if args.jsd_loss:
|
||||||
@ -757,7 +761,8 @@ def main():
|
|||||||
validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
|
validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
|
||||||
|
|
||||||
# setup checkpoint saver and eval metric tracking
|
# setup checkpoint saver and eval metric tracking
|
||||||
eval_metric = args.eval_metric
|
eval_metric = args.eval_metric if loader_eval is not None else 'loss'
|
||||||
|
decreasing_metric = eval_metric == 'loss'
|
||||||
best_metric = None
|
best_metric = None
|
||||||
best_epoch = None
|
best_epoch = None
|
||||||
saver = None
|
saver = None
|
||||||
@ -772,7 +777,6 @@ def main():
|
|||||||
str(data_config['input_size'][-1])
|
str(data_config['input_size'][-1])
|
||||||
])
|
])
|
||||||
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
|
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
|
||||||
decreasing = True if eval_metric == 'loss' else False
|
|
||||||
saver = utils.CheckpointSaver(
|
saver = utils.CheckpointSaver(
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
@ -781,7 +785,7 @@ def main():
|
|||||||
amp_scaler=loss_scaler,
|
amp_scaler=loss_scaler,
|
||||||
checkpoint_dir=output_dir,
|
checkpoint_dir=output_dir,
|
||||||
recovery_dir=output_dir,
|
recovery_dir=output_dir,
|
||||||
decreasing=decreasing,
|
decreasing=decreasing_metric,
|
||||||
max_history=args.checkpoint_hist
|
max_history=args.checkpoint_hist
|
||||||
)
|
)
|
||||||
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
||||||
@ -799,7 +803,7 @@ def main():
|
|||||||
updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
|
updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
|
||||||
lr_scheduler, num_epochs = create_scheduler_v2(
|
lr_scheduler, num_epochs = create_scheduler_v2(
|
||||||
optimizer,
|
optimizer,
|
||||||
**scheduler_kwargs(args),
|
**scheduler_kwargs(args, decreasing_metric=decreasing_metric),
|
||||||
updates_per_epoch=updates_per_epoch,
|
updates_per_epoch=updates_per_epoch,
|
||||||
)
|
)
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
@ -847,27 +851,30 @@ def main():
|
|||||||
_logger.info("Distributing BatchNorm running means and vars")
|
_logger.info("Distributing BatchNorm running means and vars")
|
||||||
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
||||||
|
|
||||||
eval_metrics = validate(
|
if loader_eval is not None:
|
||||||
model,
|
eval_metrics = validate(
|
||||||
loader_eval,
|
model,
|
||||||
validate_loss_fn,
|
|
||||||
args,
|
|
||||||
amp_autocast=amp_autocast,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_ema is not None and not args.model_ema_force_cpu:
|
|
||||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
|
||||||
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
|
||||||
|
|
||||||
ema_eval_metrics = validate(
|
|
||||||
model_ema.module,
|
|
||||||
loader_eval,
|
loader_eval,
|
||||||
validate_loss_fn,
|
validate_loss_fn,
|
||||||
args,
|
args,
|
||||||
amp_autocast=amp_autocast,
|
amp_autocast=amp_autocast,
|
||||||
log_suffix=' (EMA)',
|
|
||||||
)
|
)
|
||||||
eval_metrics = ema_eval_metrics
|
|
||||||
|
if model_ema is not None and not args.model_ema_force_cpu:
|
||||||
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||||
|
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
||||||
|
|
||||||
|
ema_eval_metrics = validate(
|
||||||
|
model_ema.module,
|
||||||
|
loader_eval,
|
||||||
|
validate_loss_fn,
|
||||||
|
args,
|
||||||
|
amp_autocast=amp_autocast,
|
||||||
|
log_suffix=' (EMA)',
|
||||||
|
)
|
||||||
|
eval_metrics = ema_eval_metrics
|
||||||
|
else:
|
||||||
|
eval_metrics = None
|
||||||
|
|
||||||
if output_dir is not None:
|
if output_dir is not None:
|
||||||
lrs = [param_group['lr'] for param_group in optimizer.param_groups]
|
lrs = [param_group['lr'] for param_group in optimizer.param_groups]
|
||||||
@ -881,14 +888,18 @@ def main():
|
|||||||
log_wandb=args.log_wandb and has_wandb,
|
log_wandb=args.log_wandb and has_wandb,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if eval_metrics is not None:
|
||||||
|
latest_metric = eval_metrics[eval_metric]
|
||||||
|
else:
|
||||||
|
latest_metric = train_metrics[eval_metric]
|
||||||
|
|
||||||
if saver is not None:
|
if saver is not None:
|
||||||
# save proper checkpoint with eval metric
|
# save proper checkpoint with eval metric
|
||||||
save_metric = eval_metrics[eval_metric]
|
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=latest_metric)
|
||||||
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
|
|
||||||
|
|
||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
# step LR for next epoch
|
# step LR for next epoch
|
||||||
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
|
lr_scheduler.step(epoch + 1, latest_metric)
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
'epoch': epoch,
|
'epoch': epoch,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user