From 4d4bdd64a996bf7b5919ec62f20af4a1c07d5848 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 2 Oct 2024 15:14:33 -0700 Subject: [PATCH] Add --torchcompile-mode args to train, validation, inference, benchmark scripts --- benchmark.py | 5 ++++- inference.py | 4 +++- train.py | 4 +++- validate.py | 5 +++-- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/benchmark.py b/benchmark.py index c31708f5..422b9a35 100755 --- a/benchmark.py +++ b/benchmark.py @@ -120,6 +120,8 @@ parser.add_argument('--fast-norm', default=False, action='store_true', parser.add_argument('--reparam', default=False, action='store_true', help='Reparameterize model') parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) +parser.add_argument('--torchcompile-mode', type=str, default=None, + help="torch.compile mode (default: None).") # codegen (model compilation) options scripting_group = parser.add_mutually_exclusive_group() @@ -224,6 +226,7 @@ class BenchmarkRunner: device='cuda', torchscript=False, torchcompile=None, + torchcompile_mode=None, aot_autograd=False, reparam=False, precision='float32', @@ -278,7 +281,7 @@ class BenchmarkRunner: elif torchcompile: assert has_compile, 'A version of torch w/ torch.compile() is required, possibly a nightly.' torch._dynamo.reset() - self.model = torch.compile(self.model, backend=torchcompile) + self.model = torch.compile(self.model, backend=torchcompile, mode=torchcompile_mode) self.compiled = True elif aot_autograd: assert has_functorch, "functorch is needed for --aot-autograd" diff --git a/inference.py b/inference.py index d8cb3dee..e6bd4ae1 100755 --- a/inference.py +++ b/inference.py @@ -114,6 +114,8 @@ parser.add_argument('--amp-dtype', default='float16', type=str, parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) +parser.add_argument('--torchcompile-mode', type=str, default=None, + help="torch.compile mode (default: None).") scripting_group = parser.add_mutually_exclusive_group() scripting_group.add_argument('--torchscript', default=False, action='store_true', @@ -216,7 +218,7 @@ def main(): elif args.torchcompile: assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' torch._dynamo.reset() - model = torch.compile(model, backend=args.torchcompile) + model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode) elif args.aot_autograd: assert has_functorch, "functorch is needed for --aot-autograd" model = memory_efficient_fusion(model) diff --git a/train.py b/train.py index 09968494..ebd9bc80 100755 --- a/train.py +++ b/train.py @@ -161,6 +161,8 @@ group.add_argument('--head-init-scale', default=None, type=float, help='Head initialization scale') group.add_argument('--head-init-bias', default=None, type=float, help='Head initialization bias value') +group.add_argument('--torchcompile-mode', type=str, default=None, + help="torch.compile mode (default: None).") # scripting / codegen scripting_group = group.add_mutually_exclusive_group() @@ -627,7 +629,7 @@ def main(): if args.torchcompile: # torch compile should be done after DDP assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' - model = torch.compile(model, backend=args.torchcompile) + model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode) # create the train and eval datasets if args.data and not args.data_dir: diff --git a/validate.py b/validate.py index cb71a062..6115de7a 100755 --- a/validate.py +++ b/validate.py @@ -139,7 +139,8 @@ parser.add_argument('--fast-norm', default=False, action='store_true', parser.add_argument('--reparam', default=False, action='store_true', help='Reparameterize model') parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) - +parser.add_argument('--torchcompile-mode', type=str, default=None, + help="torch.compile mode (default: None).") scripting_group = parser.add_mutually_exclusive_group() scripting_group.add_argument('--torchscript', default=False, action='store_true', @@ -246,7 +247,7 @@ def validate(args): elif args.torchcompile: assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' torch._dynamo.reset() - model = torch.compile(model, backend=args.torchcompile) + model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode) elif args.aot_autograd: assert has_functorch, "functorch is needed for --aot-autograd" model = memory_efficient_fusion(model)