mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update scripts to support torch.compile(). Make --results_file arg more consistent across benchmark/validate/inference. Fix #1570
This commit is contained in:
parent
05637a4bb0
commit
dbe7531aa3
64
benchmark.py
64
benchmark.py
@ -56,13 +56,7 @@ try:
|
||||
except ImportError as e:
|
||||
has_functorch = False
|
||||
|
||||
try:
|
||||
import torch._dynamo
|
||||
has_dynamo = True
|
||||
except ImportError:
|
||||
has_dynamo = False
|
||||
pass
|
||||
|
||||
has_compile = hasattr(torch, 'compile')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
@ -81,8 +75,10 @@ parser.add_argument('--detail', action='store_true', default=False,
|
||||
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
|
||||
parser.add_argument('--no-retry', action='store_true', default=False,
|
||||
help='Do not decay batch size and retry on error.')
|
||||
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
||||
parser.add_argument('--results-file', default='', type=str,
|
||||
help='Output csv file for validation results (summary)')
|
||||
parser.add_argument('--results-format', default='csv', type=str,
|
||||
help='Format for results file one of (csv, json) (default: csv).')
|
||||
parser.add_argument('--num-warm-iter', default=10, type=int,
|
||||
metavar='N', help='Number of warmup iterations (default: 10)')
|
||||
parser.add_argument('--num-bench-iter', default=40, type=int,
|
||||
@ -113,8 +109,6 @@ parser.add_argument('--precision', default='float32', type=str,
|
||||
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
|
||||
parser.add_argument('--fuser', default='', type=str,
|
||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||
parser.add_argument('--dynamo-backend', default=None, type=str,
|
||||
help="Select dynamo backend. Default: None")
|
||||
parser.add_argument('--fast-norm', default=False, action='store_true',
|
||||
help='enable experimental fast-norm')
|
||||
|
||||
@ -122,10 +116,11 @@ parser.add_argument('--fast-norm', default=False, action='store_true',
|
||||
scripting_group = parser.add_mutually_exclusive_group()
|
||||
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='convert model torchscript for inference')
|
||||
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
|
||||
help="Enable compilation w/ specified backend (default: inductor).")
|
||||
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
|
||||
help="Enable AOT Autograd optimization.")
|
||||
scripting_group.add_argument('--dynamo', default=False, action='store_true',
|
||||
help="Enable Dynamo optimization.")
|
||||
|
||||
|
||||
# train optimizer parameters
|
||||
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||
@ -218,9 +213,8 @@ class BenchmarkRunner:
|
||||
detail=False,
|
||||
device='cuda',
|
||||
torchscript=False,
|
||||
torchcompile=None,
|
||||
aot_autograd=False,
|
||||
dynamo=False,
|
||||
dynamo_backend=None,
|
||||
precision='float32',
|
||||
fuser='',
|
||||
num_warm_iter=10,
|
||||
@ -259,20 +253,19 @@ class BenchmarkRunner:
|
||||
self.input_size = data_config['input_size']
|
||||
self.batch_size = kwargs.pop('batch_size', 256)
|
||||
|
||||
self.scripted = False
|
||||
self.compiled = False
|
||||
if torchscript:
|
||||
self.model = torch.jit.script(self.model)
|
||||
self.scripted = True
|
||||
elif dynamo:
|
||||
assert has_dynamo, "torch._dynamo is needed for --dynamo"
|
||||
self.compiled = True
|
||||
elif torchcompile:
|
||||
assert has_compile, 'A version of torch w/ torch.compile() is required, possibly a nightly.'
|
||||
torch._dynamo.reset()
|
||||
if dynamo_backend is not None:
|
||||
self.model = torch._dynamo.optimize(dynamo_backend)(self.model)
|
||||
else:
|
||||
self.model = torch._dynamo.optimize()(self.model)
|
||||
self.model = torch.compile(self.model, backend=torchcompile)
|
||||
self.compiled = True
|
||||
elif aot_autograd:
|
||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||
self.model = memory_efficient_fusion(self.model)
|
||||
self.compiled = True
|
||||
|
||||
self.example_inputs = None
|
||||
self.num_warm_iter = num_warm_iter
|
||||
@ -344,7 +337,7 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
|
||||
param_count=round(self.param_count / 1e6, 2),
|
||||
)
|
||||
|
||||
retries = 0 if self.scripted else 2 # skip profiling if model is scripted
|
||||
retries = 0 if self.compiled else 2 # skip profiling if model is scripted
|
||||
while retries:
|
||||
retries -= 1
|
||||
try:
|
||||
@ -642,7 +635,6 @@ def main():
|
||||
model_cfgs = [(n, None) for n in model_names]
|
||||
|
||||
if len(model_cfgs):
|
||||
results_file = args.results_file or './benchmark.csv'
|
||||
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
|
||||
results = []
|
||||
try:
|
||||
@ -663,22 +655,30 @@ def main():
|
||||
sort_key = 'infer_gmacs'
|
||||
results = filter(lambda x: sort_key in x, results)
|
||||
results = sorted(results, key=lambda x: x[sort_key], reverse=True)
|
||||
if len(results):
|
||||
write_results(results_file, results)
|
||||
else:
|
||||
results = benchmark(args)
|
||||
|
||||
if args.results_file:
|
||||
write_results(args.results_file, results, format=args.results_format)
|
||||
|
||||
# output results in JSON to stdout w/ delimiter for runner script
|
||||
print(f'--result\n{json.dumps(results, indent=4)}')
|
||||
|
||||
|
||||
def write_results(results_file, results):
|
||||
def write_results(results_file, results, format='csv'):
|
||||
with open(results_file, mode='w') as cf:
|
||||
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
|
||||
dw.writeheader()
|
||||
for r in results:
|
||||
dw.writerow(r)
|
||||
cf.flush()
|
||||
if format == 'json':
|
||||
json.dump(results, cf, indent=4)
|
||||
else:
|
||||
if not isinstance(results, (list, tuple)):
|
||||
results = [results]
|
||||
if not results:
|
||||
return
|
||||
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
|
||||
dw.writeheader()
|
||||
for r in results:
|
||||
dw.writerow(r)
|
||||
cf.flush()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
45
inference.py
45
inference.py
@ -8,6 +8,7 @@ Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
@ -41,11 +42,7 @@ try:
|
||||
except ImportError as e:
|
||||
has_functorch = False
|
||||
|
||||
try:
|
||||
import torch._dynamo
|
||||
has_dynamo = True
|
||||
except ImportError:
|
||||
has_dynamo = False
|
||||
has_compile = hasattr(torch, 'compile')
|
||||
|
||||
|
||||
_FMT_EXT = {
|
||||
@ -60,14 +57,16 @@ _logger = logging.getLogger('inference')
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
|
||||
help='dataset type (default: ImageFolder/ImageTar if empty)')
|
||||
parser.add_argument('data', nargs='?', metavar='DIR', const=None,
|
||||
help='path to dataset (*deprecated*, use --data-dir)')
|
||||
parser.add_argument('--data-dir', metavar='DIR',
|
||||
help='path to dataset (root dir)')
|
||||
parser.add_argument('--dataset', metavar='NAME', default='',
|
||||
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
|
||||
parser.add_argument('--split', metavar='NAME', default='validation',
|
||||
help='dataset split (default: validation)')
|
||||
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92',
|
||||
help='model architecture (default: dpn92)')
|
||||
parser.add_argument('--model', '-m', metavar='MODEL', default='resnet50',
|
||||
help='model architecture (default: resnet50)')
|
||||
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
@ -112,16 +111,14 @@ parser.add_argument('--amp-dtype', default='float16', type=str,
|
||||
help='lower precision AMP dtype (default: float16)')
|
||||
parser.add_argument('--fuser', default='', type=str,
|
||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||
parser.add_argument('--dynamo-backend', default=None, type=str,
|
||||
help="Select dynamo backend. Default: None")
|
||||
|
||||
scripting_group = parser.add_mutually_exclusive_group()
|
||||
scripting_group.add_argument('--torchscript', default=False, action='store_true',
|
||||
help='torch.jit.script the full model')
|
||||
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
|
||||
help="Enable compilation w/ specified backend (default: inductor).")
|
||||
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
|
||||
help="Enable AOT Autograd support.")
|
||||
scripting_group.add_argument('--dynamo', default=False, action='store_true',
|
||||
help="Enable Dynamo optimization.")
|
||||
|
||||
parser.add_argument('--results-dir',type=str, default=None,
|
||||
help='folder for output results')
|
||||
@ -160,7 +157,6 @@ def main():
|
||||
device = torch.device(args.device)
|
||||
|
||||
# resolve AMP arguments based on PyTorch / Apex availability
|
||||
use_amp = None
|
||||
amp_autocast = suppress
|
||||
if args.amp:
|
||||
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
|
||||
@ -201,22 +197,20 @@ def main():
|
||||
|
||||
if args.torchscript:
|
||||
model = torch.jit.script(model)
|
||||
elif args.torchcompile:
|
||||
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
|
||||
torch._dynamo.reset()
|
||||
model = torch.compile(model, backend=args.torchcompile)
|
||||
elif args.aot_autograd:
|
||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||
model = memory_efficient_fusion(model)
|
||||
elif args.dynamo:
|
||||
assert has_dynamo, "torch._dynamo is needed for --dynamo"
|
||||
torch._dynamo.reset()
|
||||
if args.dynamo_backend is not None:
|
||||
model = torch._dynamo.optimize(args.dynamo_backend)(model)
|
||||
else:
|
||||
model = torch._dynamo.optimize()(model)
|
||||
|
||||
if args.num_gpu > 1:
|
||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
|
||||
|
||||
root_dir = args.data or args.data_dir
|
||||
dataset = create_dataset(
|
||||
root=args.data,
|
||||
root=root_dir,
|
||||
name=args.dataset,
|
||||
split=args.split,
|
||||
class_map=args.class_map,
|
||||
@ -304,6 +298,9 @@ def main():
|
||||
for fmt in args.results_format:
|
||||
save_results(df, results_filename, fmt)
|
||||
|
||||
print(f'--result')
|
||||
print(json.dumps(dict(filename=results_filename)))
|
||||
|
||||
|
||||
def save_results(df, results_filename, results_format='csv', filename_col='filename'):
|
||||
results_filename += _FMT_EXT[results_format]
|
||||
|
39
train.py
39
train.py
@ -66,12 +66,7 @@ try:
|
||||
except ImportError as e:
|
||||
has_functorch = False
|
||||
|
||||
try:
|
||||
import torch._dynamo
|
||||
has_dynamo = True
|
||||
except ImportError:
|
||||
has_dynamo = False
|
||||
pass
|
||||
has_compile = hasattr(torch, 'compile')
|
||||
|
||||
|
||||
_logger = logging.getLogger('train')
|
||||
@ -88,10 +83,12 @@ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
||||
# Dataset parameters
|
||||
group = parser.add_argument_group('Dataset parameters')
|
||||
# Keep this argument outside of the dataset group because it is positional.
|
||||
parser.add_argument('data_dir', metavar='DIR',
|
||||
help='path to dataset')
|
||||
group.add_argument('--dataset', '-d', metavar='NAME', default='',
|
||||
help='dataset type (default: ImageFolder/ImageTar if empty)')
|
||||
parser.add_argument('data', nargs='?', metavar='DIR', const=None,
|
||||
help='path to dataset (positional is *deprecated*, use --data-dir)')
|
||||
parser.add_argument('--data-dir', metavar='DIR',
|
||||
help='path to dataset (root dir)')
|
||||
parser.add_argument('--dataset', metavar='NAME', default='',
|
||||
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
|
||||
group.add_argument('--train-split', metavar='NAME', default='train',
|
||||
help='dataset train split (default: train)')
|
||||
group.add_argument('--val-split', metavar='NAME', default='validation',
|
||||
@ -143,16 +140,14 @@ group.add_argument('--grad-checkpointing', action='store_true', default=False,
|
||||
help='Enable gradient checkpointing through model blocks/stages')
|
||||
group.add_argument('--fast-norm', default=False, action='store_true',
|
||||
help='enable experimental fast-norm')
|
||||
parser.add_argument('--dynamo-backend', default=None, type=str,
|
||||
help="Select dynamo backend. Default: None")
|
||||
|
||||
scripting_group = group.add_mutually_exclusive_group()
|
||||
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='torch.jit.script the full model')
|
||||
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
|
||||
help="Enable compilation w/ specified backend (default: inductor).")
|
||||
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
|
||||
help="Enable AOT Autograd support.")
|
||||
scripting_group.add_argument('--dynamo', default=False, action='store_true',
|
||||
help="Enable Dynamo optimization.")
|
||||
|
||||
# Optimizer parameters
|
||||
group = parser.add_argument_group('Optimizer parameters')
|
||||
@ -377,6 +372,8 @@ def main():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
if args.data and not args.data_dir:
|
||||
args.data_dir = args.data
|
||||
args.prefetcher = not args.no_prefetcher
|
||||
device = utils.init_distributed_device(args)
|
||||
if args.distributed:
|
||||
@ -485,18 +482,16 @@ def main():
|
||||
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
|
||||
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
|
||||
model = torch.jit.script(model)
|
||||
elif args.torchcompile:
|
||||
# FIXME dynamo might need move below DDP wrapping? TBD
|
||||
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
|
||||
torch._dynamo.reset()
|
||||
model = torch.compile(model, backend=args.torchcompile)
|
||||
elif args.aot_autograd:
|
||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||
model = memory_efficient_fusion(model)
|
||||
elif args.dynamo:
|
||||
# FIXME dynamo might need move below DDP wrapping? TBD
|
||||
assert has_dynamo, "torch._dynamo is needed for --dynamo"
|
||||
if args.dynamo_backend is not None:
|
||||
model = torch._dynamo.optimize(args.dynamo_backend)(model)
|
||||
else:
|
||||
model = torch._dynamo.optimize()(model)
|
||||
|
||||
if args.lr is None:
|
||||
if not args.lr:
|
||||
global_batch_size = args.batch_size * args.world_size
|
||||
batch_ratio = global_batch_size / args.lr_base_size
|
||||
if not args.lr_base_scale:
|
||||
|
66
validate.py
66
validate.py
@ -26,12 +26,11 @@ from timm.data import create_dataset, create_loader, resolve_data_config, RealLa
|
||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\
|
||||
decay_batch_step, check_batch_size_retry
|
||||
|
||||
has_apex = False
|
||||
try:
|
||||
from apex import amp
|
||||
has_apex = True
|
||||
except ImportError:
|
||||
pass
|
||||
has_apex = False
|
||||
|
||||
has_native_amp = False
|
||||
try:
|
||||
@ -46,21 +45,18 @@ try:
|
||||
except ImportError as e:
|
||||
has_functorch = False
|
||||
|
||||
try:
|
||||
import torch._dynamo
|
||||
has_dynamo = True
|
||||
except ImportError:
|
||||
has_dynamo = False
|
||||
pass
|
||||
has_compile = hasattr(torch, 'compile')
|
||||
|
||||
_logger = logging.getLogger('validate')
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
|
||||
help='dataset type (default: ImageFolder/ImageTar if empty)')
|
||||
parser.add_argument('data', nargs='?', metavar='DIR', const=None,
|
||||
help='path to dataset (*deprecated*, use --data-dir)')
|
||||
parser.add_argument('--data-dir', metavar='DIR',
|
||||
help='path to dataset (root dir)')
|
||||
parser.add_argument('--dataset', metavar='NAME', default='',
|
||||
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
|
||||
parser.add_argument('--split', metavar='NAME', default='validation',
|
||||
help='dataset split (default: validation)')
|
||||
parser.add_argument('--dataset-download', action='store_true', default=False,
|
||||
@ -125,19 +121,19 @@ parser.add_argument('--fuser', default='', type=str,
|
||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||
parser.add_argument('--fast-norm', default=False, action='store_true',
|
||||
help='enable experimental fast-norm')
|
||||
parser.add_argument('--dynamo-backend', default=None, type=str,
|
||||
help="Select dynamo backend. Default: None")
|
||||
|
||||
scripting_group = parser.add_mutually_exclusive_group()
|
||||
scripting_group.add_argument('--torchscript', default=False, action='store_true',
|
||||
help='torch.jit.script the full model')
|
||||
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
|
||||
help="Enable compilation w/ specified backend (default: inductor).")
|
||||
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
|
||||
help="Enable AOT Autograd support.")
|
||||
scripting_group.add_argument('--dynamo', default=False, action='store_true',
|
||||
help="Enable Dynamo optimization.")
|
||||
|
||||
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
||||
help='Output csv file for validation results (summary)')
|
||||
parser.add_argument('--results-format', default='csv', type=str,
|
||||
help='Format for results file one of (csv, json) (default: csv).')
|
||||
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
|
||||
help='Real labels JSON file for imagenet evaluation')
|
||||
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
|
||||
@ -218,16 +214,13 @@ def validate(args):
|
||||
if args.torchscript:
|
||||
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
|
||||
model = torch.jit.script(model)
|
||||
elif args.torchcompile:
|
||||
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
|
||||
torch._dynamo.reset()
|
||||
model = torch.compile(model, backend=args.torchcompile)
|
||||
elif args.aot_autograd:
|
||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||
model = memory_efficient_fusion(model)
|
||||
elif args.dynamo:
|
||||
assert has_dynamo, "torch._dynamo is needed for --dynamo"
|
||||
torch._dynamo.reset()
|
||||
if args.dynamo_backend is not None:
|
||||
model = torch._dynamo.optimize(args.dynamo_backend)(model)
|
||||
else:
|
||||
model = torch._dynamo.optimize()(model)
|
||||
|
||||
if use_amp == 'apex':
|
||||
model = amp.initialize(model, opt_level='O1')
|
||||
@ -407,7 +400,6 @@ def main():
|
||||
model_cfgs = [(n, None) for n in model_names if n]
|
||||
|
||||
if len(model_cfgs):
|
||||
results_file = args.results_file or './results-all.csv'
|
||||
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
|
||||
results = []
|
||||
try:
|
||||
@ -424,24 +416,34 @@ def main():
|
||||
except KeyboardInterrupt as e:
|
||||
pass
|
||||
results = sorted(results, key=lambda x: x['top1'], reverse=True)
|
||||
if len(results):
|
||||
write_results(results_file, results)
|
||||
else:
|
||||
if args.retry:
|
||||
results = _try_run(args, args.batch_size)
|
||||
else:
|
||||
results = validate(args)
|
||||
|
||||
if args.results_file:
|
||||
write_results(args.results_file, results, format=args.results_format)
|
||||
|
||||
# output results in JSON to stdout w/ delimiter for runner script
|
||||
print(f'--result\n{json.dumps(results, indent=4)}')
|
||||
|
||||
|
||||
def write_results(results_file, results):
|
||||
def write_results(results_file, results, format='csv'):
|
||||
with open(results_file, mode='w') as cf:
|
||||
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
|
||||
dw.writeheader()
|
||||
for r in results:
|
||||
dw.writerow(r)
|
||||
cf.flush()
|
||||
if format == 'json':
|
||||
json.dump(results, cf, indent=4)
|
||||
else:
|
||||
if not isinstance(results, (list, tuple)):
|
||||
results = [results]
|
||||
if not results:
|
||||
return
|
||||
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
|
||||
dw.writeheader()
|
||||
for r in results:
|
||||
dw.writerow(r)
|
||||
cf.flush()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user