mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
make wandb not required but rather optional as huggingface_hub
This commit is contained in:
parent
f13f7508a9
commit
f54897cc0b
@ -1,4 +1,3 @@
|
|||||||
torch>=1.4.0
|
torch>=1.4.0
|
||||||
torchvision>=0.5.0
|
torchvision>=0.5.0
|
||||||
pyyaml
|
pyyaml
|
||||||
wandb
|
|
@ -4,9 +4,11 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||||||
"""
|
"""
|
||||||
import csv
|
import csv
|
||||||
import os
|
import os
|
||||||
import wandb
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
def get_outdir(path, *paths, inc=False):
|
def get_outdir(path, *paths, inc=False):
|
||||||
outdir = os.path.join(path, *paths)
|
outdir = os.path.join(path, *paths)
|
||||||
@ -28,8 +30,6 @@ def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=Fa
|
|||||||
rowd = OrderedDict(epoch=epoch)
|
rowd = OrderedDict(epoch=epoch)
|
||||||
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
|
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
|
||||||
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
|
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
|
||||||
if log_wandb:
|
|
||||||
wandb.log(rowd)
|
|
||||||
with open(filename, mode='a') as cf:
|
with open(filename, mode='a') as cf:
|
||||||
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
|
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
|
||||||
if write_header: # first iteration (epoch == 1 can't be used)
|
if write_header: # first iteration (epoch == 1 can't be used)
|
||||||
|
20
train.py
20
train.py
@ -23,8 +23,6 @@ from collections import OrderedDict
|
|||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import wandb
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision.utils
|
import torchvision.utils
|
||||||
@ -54,6 +52,12 @@ try:
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
has_wandb = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
has_wandb = False
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
_logger = logging.getLogger('train')
|
_logger = logging.getLogger('train')
|
||||||
|
|
||||||
@ -274,7 +278,7 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa
|
|||||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||||
help='convert model torchscript for inference')
|
help='convert model torchscript for inference')
|
||||||
parser.add_argument('--log-wandb', action='store_true', default=False,
|
parser.add_argument('--log-wandb', action='store_true', default=False,
|
||||||
help='use wandb for training and validation logs')
|
help='log training and validation metrics to wandb')
|
||||||
|
|
||||||
|
|
||||||
def _parse_args():
|
def _parse_args():
|
||||||
@ -299,8 +303,12 @@ def main():
|
|||||||
args, args_text = _parse_args()
|
args, args_text = _parse_args()
|
||||||
|
|
||||||
if args.log_wandb:
|
if args.log_wandb:
|
||||||
wandb.init(project=args.experiment, config=args)
|
if has_wandb:
|
||||||
|
wandb.init(project=args.experiment, config=args)
|
||||||
|
else:
|
||||||
|
_logger.warning("You've requested to log metrics to wandb but package not found. "
|
||||||
|
"Metrics not being logged to wandb, try `pip install wandb`")
|
||||||
|
|
||||||
args.prefetcher = not args.no_prefetcher
|
args.prefetcher = not args.no_prefetcher
|
||||||
args.distributed = False
|
args.distributed = False
|
||||||
if 'WORLD_SIZE' in os.environ:
|
if 'WORLD_SIZE' in os.environ:
|
||||||
@ -600,7 +608,7 @@ def main():
|
|||||||
|
|
||||||
update_summary(
|
update_summary(
|
||||||
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
||||||
write_header=best_metric is None, log_wandb=args.log_wandb)
|
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
|
||||||
|
|
||||||
if saver is not None:
|
if saver is not None:
|
||||||
# save proper checkpoint with eval metric
|
# save proper checkpoint with eval metric
|
||||||
|
Loading…
x
Reference in New Issue
Block a user