mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add support for loading args from yaml file (and saving them with each experiment)
This commit is contained in:
parent
d3ba34ee7e
commit
187ecbafbe
@ -1,2 +1,3 @@
|
||||
torch>=1.1.0
|
||||
torchvision>=0.3.0
|
||||
pyyaml
|
||||
|
31
train.py
31
train.py
@ -2,6 +2,7 @@
|
||||
import argparse
|
||||
import time
|
||||
import logging
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
@ -26,6 +27,14 @@ import torchvision.utils
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
# The first arg parser parses out only the --config argument, this argument is used to
|
||||
# load a yaml file containing key-values that override the defaults for the main parser below
|
||||
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
|
||||
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
|
||||
help='YAML config file specifying default arguments')
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Training')
|
||||
# Dataset / Model parameters
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
@ -145,9 +154,27 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
|
||||
parser.add_argument("--local_rank", default=0, type=int)
|
||||
|
||||
|
||||
def _parse_args():
|
||||
# Do we have a config file to parse?
|
||||
args_config, remaining = config_parser.parse_known_args()
|
||||
if args_config.config:
|
||||
with open(args_config.config, 'r') as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
parser.set_defaults(**cfg)
|
||||
|
||||
# The main arg parser parses the rest of the args, the usual
|
||||
# defaults will have been overridden if config file specified.
|
||||
args = parser.parse_args(remaining)
|
||||
|
||||
# Cache the args as a text string to save them in the output dir later
|
||||
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
|
||||
return args, args_text
|
||||
|
||||
|
||||
def main():
|
||||
setup_default_logging()
|
||||
args = parser.parse_args()
|
||||
args, args_text = _parse_args()
|
||||
|
||||
args.prefetcher = not args.no_prefetcher
|
||||
args.distributed = False
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
@ -345,6 +372,8 @@ def main():
|
||||
output_dir = get_outdir(output_base, 'train', exp_name)
|
||||
decreasing = True if eval_metric == 'loss' else False
|
||||
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
|
||||
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
||||
f.write(args_text)
|
||||
|
||||
try:
|
||||
for epoch in range(start_epoch, num_epochs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user