mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add wandb support
This commit is contained in:
parent
779107b693
commit
8e6fb861e4
13
train.py
13
train.py
@ -23,6 +23,8 @@ from collections import OrderedDict
|
|||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
import wandb
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision.utils
|
import torchvision.utils
|
||||||
@ -293,7 +295,8 @@ def _parse_args():
|
|||||||
def main():
|
def main():
|
||||||
setup_default_logging()
|
setup_default_logging()
|
||||||
args, args_text = _parse_args()
|
args, args_text = _parse_args()
|
||||||
|
wandb.init(project='efficientnet_v2', config=args)
|
||||||
|
wandb.run.name = args.model
|
||||||
args.prefetcher = not args.no_prefetcher
|
args.prefetcher = not args.no_prefetcher
|
||||||
args.distributed = False
|
args.distributed = False
|
||||||
if 'WORLD_SIZE' in os.environ:
|
if 'WORLD_SIZE' in os.environ:
|
||||||
@ -572,14 +575,14 @@ def main():
|
|||||||
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
||||||
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
||||||
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
|
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
|
||||||
|
wandb.log(train_metrics)
|
||||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
_logger.info("Distributing BatchNorm running means and vars")
|
_logger.info("Distributing BatchNorm running means and vars")
|
||||||
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
||||||
|
|
||||||
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
|
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
|
||||||
|
wandb.log(eval_metrics)
|
||||||
if model_ema is not None and not args.model_ema_force_cpu:
|
if model_ema is not None and not args.model_ema_force_cpu:
|
||||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||||
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
||||||
@ -711,7 +714,7 @@ def train_one_epoch(
|
|||||||
if hasattr(optimizer, 'sync_lookahead'):
|
if hasattr(optimizer, 'sync_lookahead'):
|
||||||
optimizer.sync_lookahead()
|
optimizer.sync_lookahead()
|
||||||
|
|
||||||
return OrderedDict([('loss', losses_m.avg)])
|
return OrderedDict([('train_loss', losses_m.avg)])
|
||||||
|
|
||||||
|
|
||||||
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
|
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
|
||||||
@ -773,7 +776,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
|
|||||||
log_name, batch_idx, last_idx, batch_time=batch_time_m,
|
log_name, batch_idx, last_idx, batch_time=batch_time_m,
|
||||||
loss=losses_m, top1=top1_m, top5=top5_m))
|
loss=losses_m, top1=top1_m, top5=top5_m))
|
||||||
|
|
||||||
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
|
metrics = OrderedDict([('val_loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
|
||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user