Add --torchcompile-mode args to train, validation, inference, benchmark scripts

This commit is contained in:
Ross Wightman 2024-10-02 15:14:33 -07:00
parent 14d55a7cf3
commit 4d4bdd64a9
4 changed files with 13 additions and 5 deletions

View File

@ -120,6 +120,8 @@ parser.add_argument('--fast-norm', default=False, action='store_true',
parser.add_argument('--reparam', default=False, action='store_true', parser.add_argument('--reparam', default=False, action='store_true',
help='Reparameterize model') help='Reparameterize model')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
parser.add_argument('--torchcompile-mode', type=str, default=None,
help="torch.compile mode (default: None).")
# codegen (model compilation) options # codegen (model compilation) options
scripting_group = parser.add_mutually_exclusive_group() scripting_group = parser.add_mutually_exclusive_group()
@ -224,6 +226,7 @@ class BenchmarkRunner:
device='cuda', device='cuda',
torchscript=False, torchscript=False,
torchcompile=None, torchcompile=None,
torchcompile_mode=None,
aot_autograd=False, aot_autograd=False,
reparam=False, reparam=False,
precision='float32', precision='float32',
@ -278,7 +281,7 @@ class BenchmarkRunner:
elif torchcompile: elif torchcompile:
assert has_compile, 'A version of torch w/ torch.compile() is required, possibly a nightly.' assert has_compile, 'A version of torch w/ torch.compile() is required, possibly a nightly.'
torch._dynamo.reset() torch._dynamo.reset()
self.model = torch.compile(self.model, backend=torchcompile) self.model = torch.compile(self.model, backend=torchcompile, mode=torchcompile_mode)
self.compiled = True self.compiled = True
elif aot_autograd: elif aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd" assert has_functorch, "functorch is needed for --aot-autograd"

View File

@ -114,6 +114,8 @@ parser.add_argument('--amp-dtype', default='float16', type=str,
parser.add_argument('--fuser', default='', type=str, parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
parser.add_argument('--torchcompile-mode', type=str, default=None,
help="torch.compile mode (default: None).")
scripting_group = parser.add_mutually_exclusive_group() scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true', scripting_group.add_argument('--torchscript', default=False, action='store_true',
@ -216,7 +218,7 @@ def main():
elif args.torchcompile: elif args.torchcompile:
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
torch._dynamo.reset() torch._dynamo.reset()
model = torch.compile(model, backend=args.torchcompile) model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
elif args.aot_autograd: elif args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd" assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model) model = memory_efficient_fusion(model)

View File

@ -161,6 +161,8 @@ group.add_argument('--head-init-scale', default=None, type=float,
help='Head initialization scale') help='Head initialization scale')
group.add_argument('--head-init-bias', default=None, type=float, group.add_argument('--head-init-bias', default=None, type=float,
help='Head initialization bias value') help='Head initialization bias value')
group.add_argument('--torchcompile-mode', type=str, default=None,
help="torch.compile mode (default: None).")
# scripting / codegen # scripting / codegen
scripting_group = group.add_mutually_exclusive_group() scripting_group = group.add_mutually_exclusive_group()
@ -627,7 +629,7 @@ def main():
if args.torchcompile: if args.torchcompile:
# torch compile should be done after DDP # torch compile should be done after DDP
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
model = torch.compile(model, backend=args.torchcompile) model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
# create the train and eval datasets # create the train and eval datasets
if args.data and not args.data_dir: if args.data and not args.data_dir:

View File

@ -139,7 +139,8 @@ parser.add_argument('--fast-norm', default=False, action='store_true',
parser.add_argument('--reparam', default=False, action='store_true', parser.add_argument('--reparam', default=False, action='store_true',
help='Reparameterize model') help='Reparameterize model')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
parser.add_argument('--torchcompile-mode', type=str, default=None,
help="torch.compile mode (default: None).")
scripting_group = parser.add_mutually_exclusive_group() scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true', scripting_group.add_argument('--torchscript', default=False, action='store_true',
@ -246,7 +247,7 @@ def validate(args):
elif args.torchcompile: elif args.torchcompile:
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
torch._dynamo.reset() torch._dynamo.reset()
model = torch.compile(model, backend=args.torchcompile) model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
elif args.aot_autograd: elif args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd" assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model) model = memory_efficient_fusion(model)