commit
e685618f45
|
@ -5,7 +5,10 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||
import csv
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def get_outdir(path, *paths, inc=False):
|
||||
outdir = os.path.join(path, *paths)
|
||||
|
@ -23,10 +26,12 @@ def get_outdir(path, *paths, inc=False):
|
|||
return outdir
|
||||
|
||||
|
||||
def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False):
|
||||
def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False):
|
||||
rowd = OrderedDict(epoch=epoch)
|
||||
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
|
||||
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
|
||||
if log_wandb:
|
||||
wandb.log(rowd)
|
||||
with open(filename, mode='a') as cf:
|
||||
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
|
||||
if write_header: # first iteration (epoch == 1 can't be used)
|
||||
|
|
19
train.py
19
train.py
|
@ -52,6 +52,12 @@ try:
|
|||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import wandb
|
||||
has_wandb = True
|
||||
except ImportError:
|
||||
has_wandb = False
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
_logger = logging.getLogger('train')
|
||||
|
||||
|
@ -271,6 +277,8 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa
|
|||
help='use the multi-epochs-loader to save time at the beginning of every epoch')
|
||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='convert model torchscript for inference')
|
||||
parser.add_argument('--log-wandb', action='store_true', default=False,
|
||||
help='log training and validation metrics to wandb')
|
||||
|
||||
|
||||
def _parse_args():
|
||||
|
@ -293,7 +301,14 @@ def _parse_args():
|
|||
def main():
|
||||
setup_default_logging()
|
||||
args, args_text = _parse_args()
|
||||
|
||||
|
||||
if args.log_wandb:
|
||||
if has_wandb:
|
||||
wandb.init(project=args.experiment, config=args)
|
||||
else:
|
||||
_logger.warning("You've requested to log metrics to wandb but package not found. "
|
||||
"Metrics not being logged to wandb, try `pip install wandb`")
|
||||
|
||||
args.prefetcher = not args.no_prefetcher
|
||||
args.distributed = False
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
|
@ -593,7 +608,7 @@ def main():
|
|||
|
||||
update_summary(
|
||||
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
||||
write_header=best_metric is None)
|
||||
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
|
||||
|
||||
if saver is not None:
|
||||
# save proper checkpoint with eval metric
|
||||
|
|
Loading…
Reference in New Issue