mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix #566, summary.csv writing to pwd on local_rank != 0. Tweak benchmark mem handling to see if it reduces likelihood of 'bad' exceptions on OOM.
This commit is contained in:
parent
1b0c8e7b01
commit
e15e68d881
@ -374,14 +374,14 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
|
|||||||
batch_size = initial_batch_size
|
batch_size = initial_batch_size
|
||||||
results = dict()
|
results = dict()
|
||||||
while batch_size >= 1:
|
while batch_size >= 1:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
try:
|
try:
|
||||||
bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs)
|
bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs)
|
||||||
results = bench.run()
|
results = bench.run()
|
||||||
return results
|
return results
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
torch.cuda.empty_cache()
|
|
||||||
batch_size = decay_batch_exp(batch_size)
|
|
||||||
print(f'Error: {str(e)} while running benchmark. Reducing batch size to {batch_size} for retry.')
|
print(f'Error: {str(e)} while running benchmark. Reducing batch size to {batch_size} for retry.')
|
||||||
|
batch_size = decay_batch_exp(batch_size)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
11
train.py
11
train.py
@ -560,7 +560,7 @@ def main():
|
|||||||
best_metric = None
|
best_metric = None
|
||||||
best_epoch = None
|
best_epoch = None
|
||||||
saver = None
|
saver = None
|
||||||
output_dir = ''
|
output_dir = None
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
if args.experiment:
|
if args.experiment:
|
||||||
exp_name = args.experiment
|
exp_name = args.experiment
|
||||||
@ -606,9 +606,10 @@ def main():
|
|||||||
# step LR for next epoch
|
# step LR for next epoch
|
||||||
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
|
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
|
||||||
|
|
||||||
update_summary(
|
if output_dir is not None:
|
||||||
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
update_summary(
|
||||||
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
|
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
||||||
|
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
|
||||||
|
|
||||||
if saver is not None:
|
if saver is not None:
|
||||||
# save proper checkpoint with eval metric
|
# save proper checkpoint with eval metric
|
||||||
@ -623,7 +624,7 @@ def main():
|
|||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
epoch, model, loader, optimizer, loss_fn, args,
|
epoch, model, loader, optimizer, loss_fn, args,
|
||||||
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
|
lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
|
||||||
loss_scaler=None, model_ema=None, mixup_fn=None):
|
loss_scaler=None, model_ema=None, mixup_fn=None):
|
||||||
|
|
||||||
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user